diff --git a/Cargo.lock b/Cargo.lock index a23afc6fd..44ee12ae3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4773,6 +4773,7 @@ name = "reth-tasks" version = "0.1.0" dependencies = [ "futures-util", + "thiserror", "tokio", "tracing", "tracing-futures", diff --git a/crates/net/network/src/session/mod.rs b/crates/net/network/src/session/mod.rs index 07d785706..d5201246e 100644 --- a/crates/net/network/src/session/mod.rs +++ b/crates/net/network/src/session/mod.rs @@ -170,7 +170,7 @@ impl SessionManager { F: Future + Send + 'static, { if let Some(ref executor) = self.executor { - executor.spawn(async move { f.await }) + executor.spawn(async move { f.await }); } else { tokio::task::spawn(async move { f.await }); } diff --git a/crates/tasks/Cargo.toml b/crates/tasks/Cargo.toml index 4fecc9e3b..03beb9f03 100644 --- a/crates/tasks/Cargo.toml +++ b/crates/tasks/Cargo.toml @@ -12,6 +12,7 @@ tokio = { version = "1", features = ["sync", "rt"] } tracing-futures = "0.2" tracing = { version = "0.1", default-features = false } futures-util = "0.3" +thiserror = "1.0" [dev-dependencies] tokio = { version = "1", features = ["sync", "rt", "rt-multi-thread", "time", "macros"] } \ No newline at end of file diff --git a/crates/tasks/src/lib.rs b/crates/tasks/src/lib.rs index dafcbc668..0f80b027a 100644 --- a/crates/tasks/src/lib.rs +++ b/crates/tasks/src/lib.rs @@ -8,14 +8,15 @@ //! reth task management use crate::shutdown::{signal, Shutdown, Signal}; -use futures_util::{future::select, pin_mut, Future, FutureExt, Stream}; +use futures_util::{future::select, pin_mut, Future, FutureExt}; use std::{ pin::Pin, - task::{Context, Poll}, + task::{ready, Context, Poll}, }; use tokio::{ runtime::Handle, sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, + task::JoinHandle, }; use tracing::error; use tracing_futures::Instrument; @@ -40,9 +41,9 @@ pub struct TaskManager { /// See [`Handle`] docs. handle: Handle, /// Sender half for sending panic signals to this type - panicked_tasks_tx: UnboundedSender, + panicked_tasks_tx: UnboundedSender<&'static str>, /// Listens for panicked tasks - panicked_tasks_rx: UnboundedReceiver, + panicked_tasks_rx: UnboundedReceiver<&'static str>, /// The [Signal] to fire when all tasks should be shutdown. /// /// This is fired on drop. @@ -72,17 +73,23 @@ impl TaskManager { } } -/// A stream that yields the name of panicked tasks. +/// An endless future that resolves if a critical task panicked. /// /// See [`TaskExecutor::spawn_critical`] -impl Stream for TaskManager { - type Item = String; +impl Future for TaskManager { + type Output = PanickedTaskError; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.get_mut().panicked_tasks_rx.poll_recv(cx) + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let err = ready!(self.get_mut().panicked_tasks_rx.poll_recv(cx)); + Poll::Ready(err.map(PanickedTaskError).expect("stream can not end")) } } +/// Error with the name of the task that panicked. +#[derive(Debug, thiserror::Error)] +#[error("Critical task panicked {0}")] +pub struct PanickedTaskError(&'static str); + /// A type that can spawn new tokio tasks #[derive(Debug, Clone)] pub struct TaskExecutor { @@ -93,16 +100,38 @@ pub struct TaskExecutor { /// Receiver of the shutdown signal. on_shutdown: Shutdown, /// Sender half for sending panic signals to this type - panicked_tasks_tx: UnboundedSender, + panicked_tasks_tx: UnboundedSender<&'static str>, } // === impl TaskExecutor === impl TaskExecutor { - /// Spawns the task onto the runtime. - /// - /// See also [`Handle::spawn`]. - pub fn spawn(&self, fut: F) + /// Returns the [Handle] to the tokio runtime. + pub fn handle(&self) -> &Handle { + &self.handle + } + + /// Returns the receiver of the shutdown signal. + pub fn on_shutdown_signal(&self) -> &Shutdown { + &self.on_shutdown + } + + /// Spawns a future on the tokio runtime depending on the [TaskKind] + fn spawn_on_rt(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()> + where + F: Future + Send + 'static, + { + match task_kind { + TaskKind::Default => self.handle.spawn(fut), + TaskKind::Blocking => { + let handle = self.handle.clone(); + self.handle.spawn_blocking(move || handle.block_on(fut)) + } + } + } + + /// Spawns a regular task depending on the given [TaskKind] + fn spawn_task_as(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()> where F: Future + Send + 'static, { @@ -114,13 +143,54 @@ impl TaskExecutor { } .in_current_span(); - self.handle.spawn(task); + self.spawn_on_rt(task, task_kind) } - /// This spawns a critical task onto the runtime. + /// Spawns the task onto the runtime. + /// The given future resolves as soon as the [Shutdown] signal is received. /// - /// If this task panics, the [`TaskManager`] is notified. - pub fn spawn_critical(&self, name: &'static str, fut: F) + /// See also [`Handle::spawn`]. + pub fn spawn(&self, fut: F) -> JoinHandle<()> + where + F: Future + Send + 'static, + { + self.spawn_task_as(fut, TaskKind::Default) + } + + /// Spawns a blocking task onto the runtime. + /// The given future resolves as soon as the [Shutdown] signal is received. + /// + /// See also [`Handle::spawn_blocking`]. + pub fn spawn_blocking(&self, fut: F) -> JoinHandle<()> + where + F: Future + Send + 'static, + { + self.spawn_task_as(fut, TaskKind::Blocking) + } + + /// Spawns the task onto the runtime. + /// The given future resolves as soon as the [Shutdown] signal is received. + /// + /// See also [`Handle::spawn`]. + pub fn spawn_with_signal(&self, f: impl FnOnce(Shutdown) -> F) -> JoinHandle<()> + where + F: Future + Send + 'static, + { + let on_shutdown = self.on_shutdown.clone(); + let fut = f(on_shutdown); + + let task = fut.in_current_span(); + + self.handle.spawn(task) + } + + /// Spawns a critical task depending on the given [TaskKind] + fn spawn_critical_as( + &self, + name: &'static str, + fut: F, + task_kind: TaskKind, + ) -> JoinHandle<()> where F: Future + Send + 'static, { @@ -132,28 +202,86 @@ impl TaskExecutor { .catch_unwind() .map(move |res| { error!("Critical task `{name}` panicked: {res:?}"); - let _ = panicked_tasks_tx.send(name.to_string()); + let _ = panicked_tasks_tx.send(name); }) .in_current_span(); - self.handle.spawn(async move { + let task = async move { pin_mut!(task); let _ = select(on_shutdown, task).await; - }); + }; + + self.spawn_on_rt(task, task_kind) } + + /// This spawns a critical blocking task onto the runtime. + /// The given future resolves as soon as the [Shutdown] signal is received. + /// + /// If this task panics, the [`TaskManager`] is notified. + pub fn spawn_critical_blocking(&self, name: &'static str, fut: F) -> JoinHandle<()> + where + F: Future + Send + 'static, + { + self.spawn_critical_as(name, fut, TaskKind::Blocking) + } + + /// This spawns a critical task onto the runtime. + /// The given future resolves as soon as the [Shutdown] signal is received. + /// + /// If this task panics, the [`TaskManager`] is notified. + pub fn spawn_critical(&self, name: &'static str, fut: F) -> JoinHandle<()> + where + F: Future + Send + 'static, + { + self.spawn_critical_as(name, fut, TaskKind::Default) + } + + /// This spawns a critical task onto the runtime. + /// + /// If this task panics, the [`TaskManager`] is notified. + pub fn spawn_critical_with_signal( + &self, + name: &'static str, + f: impl FnOnce(Shutdown) -> F, + ) -> JoinHandle<()> + where + F: Future + Send + 'static, + { + let panicked_tasks_tx = self.panicked_tasks_tx.clone(); + let on_shutdown = self.on_shutdown.clone(); + let fut = f(on_shutdown); + + // wrap the task in catch unwind + let task = std::panic::AssertUnwindSafe(fut) + .catch_unwind() + .map(move |res| { + error!("Critical task `{name}` panicked: {res:?}"); + let _ = panicked_tasks_tx.send(name); + }) + .in_current_span(); + + self.handle.spawn(task) + } +} + +/// Determines how a task is spawned +enum TaskKind { + /// Spawn the task to the default executor [Handle::spawn] + Default, + /// Spawn the task to the blocking executor [Handle::spawn_blocking] + Blocking, } #[cfg(test)] mod tests { use super::*; - use futures_util::StreamExt; use std::time::Duration; #[test] fn test_critical() { let runtime = tokio::runtime::Runtime::new().unwrap(); let handle = runtime.handle().clone(); - let mut manager = TaskManager::new(handle); + let manager = TaskManager::new(handle); let executor = manager.executor(); executor.spawn_critical( @@ -162,8 +290,8 @@ mod tests { ); runtime.block_on(async move { - let panicked_task = manager.next().await.unwrap(); - assert_eq!(panicked_task, "this is a critical task"); + let err = manager.await; + assert_eq!(err.0, "this is a critical task"); }) }