From e6ca4c56c6a314e5553656e5b62be3bca6801b85 Mon Sep 17 00:00:00 2001 From: Matthias Seitz Date: Fri, 13 Jan 2023 10:34:22 +0100 Subject: [PATCH] feat: add shutdown signal to TaskManager (#831) --- crates/tasks/Cargo.toml | 2 +- crates/tasks/src/lib.rs | 77 +++++++++++++++++++++++++-- crates/tasks/src/shutdown.rs | 100 +++++++++++++++++++++++++++++++++++ 3 files changed, 174 insertions(+), 5 deletions(-) create mode 100644 crates/tasks/src/shutdown.rs diff --git a/crates/tasks/Cargo.toml b/crates/tasks/Cargo.toml index 6019d7a21..4fecc9e3b 100644 --- a/crates/tasks/Cargo.toml +++ b/crates/tasks/Cargo.toml @@ -14,4 +14,4 @@ tracing = { version = "0.1", default-features = false } futures-util = "0.3" [dev-dependencies] -tokio = { version = "1", features = ["sync", "rt", "rt-multi-thread"] } \ No newline at end of file +tokio = { version = "1", features = ["sync", "rt", "rt-multi-thread", "time", "macros"] } \ No newline at end of file diff --git a/crates/tasks/src/lib.rs b/crates/tasks/src/lib.rs index 8531f090d..dafcbc668 100644 --- a/crates/tasks/src/lib.rs +++ b/crates/tasks/src/lib.rs @@ -7,7 +7,8 @@ //! reth task management -use futures_util::{Future, FutureExt, Stream}; +use crate::shutdown::{signal, Shutdown, Signal}; +use futures_util::{future::select, pin_mut, Future, FutureExt, Stream}; use std::{ pin::Pin, task::{Context, Poll}, @@ -19,6 +20,8 @@ use tokio::{ use tracing::error; use tracing_futures::Instrument; +pub mod shutdown; + /// Many reth components require to spawn tasks for long-running jobs. For example `discovery` /// spawns tasks to handle egress and ingress of udp traffic or `network` that spawns session tasks /// that handle the traffic to and from a peer. @@ -40,6 +43,12 @@ pub struct TaskManager { panicked_tasks_tx: UnboundedSender, /// Listens for panicked tasks panicked_tasks_rx: UnboundedReceiver, + /// The [Signal] to fire when all tasks should be shutdown. + /// + /// This is fired on drop. + _signal: Signal, + /// Receiver of the shutdown signal. + on_shutdown: Shutdown, } // === impl TaskManager === @@ -48,7 +57,8 @@ impl TaskManager { /// Create a new instance connected to the given handle's tokio runtime. pub fn new(handle: Handle) -> Self { let (panicked_tasks_tx, panicked_tasks_rx) = unbounded_channel(); - Self { handle, panicked_tasks_tx, panicked_tasks_rx } + let (_signal, on_shutdown) = signal(); + Self { handle, panicked_tasks_tx, panicked_tasks_rx, _signal, on_shutdown } } /// Returns a new [`TaskExecutor`] that can spawn new tasks onto the tokio runtime this type is @@ -56,6 +66,7 @@ impl TaskManager { pub fn executor(&self) -> TaskExecutor { TaskExecutor { handle: self.handle.clone(), + on_shutdown: self.on_shutdown.clone(), panicked_tasks_tx: self.panicked_tasks_tx.clone(), } } @@ -79,6 +90,8 @@ pub struct TaskExecutor { /// /// See [`Handle`] docs. handle: Handle, + /// Receiver of the shutdown signal. + on_shutdown: Shutdown, /// Sender half for sending panic signals to this type panicked_tasks_tx: UnboundedSender, } @@ -93,7 +106,14 @@ impl TaskExecutor { where F: Future + Send + 'static, { - let task = async move { fut.await }.in_current_span(); + let on_shutdown = self.on_shutdown.clone(); + + let task = async move { + pin_mut!(fut); + let _ = select(on_shutdown, fut).await; + } + .in_current_span(); + self.handle.spawn(task); } @@ -105,6 +125,7 @@ impl TaskExecutor { F: Future + Send + 'static, { let panicked_tasks_tx = self.panicked_tasks_tx.clone(); + let on_shutdown = self.on_shutdown.clone(); // wrap the task in catch unwind let task = std::panic::AssertUnwindSafe(fut) @@ -114,7 +135,11 @@ impl TaskExecutor { let _ = panicked_tasks_tx.send(name.to_string()); }) .in_current_span(); - self.handle.spawn(task); + + self.handle.spawn(async move { + pin_mut!(task); + let _ = select(on_shutdown, task).await; + }); } } @@ -122,6 +147,7 @@ impl TaskExecutor { mod tests { use super::*; use futures_util::StreamExt; + use std::time::Duration; #[test] fn test_critical() { @@ -140,4 +166,47 @@ mod tests { assert_eq!(panicked_task, "this is a critical task"); }) } + + // Tests that spawned tasks are terminated if the `TaskManager` drops + #[test] + fn test_manager_shutdown_critical() { + let runtime = tokio::runtime::Runtime::new().unwrap(); + let handle = runtime.handle().clone(); + let manager = TaskManager::new(handle.clone()); + let executor = manager.executor(); + + let (signal, shutdown) = signal(); + + executor.spawn_critical( + "this is a critical task", + Box::pin(async move { + tokio::time::sleep(Duration::from_millis(200)).await; + drop(signal); + }), + ); + + drop(manager); + + handle.block_on(shutdown); + } + + // Tests that spawned tasks are terminated if the `TaskManager` drops + #[test] + fn test_manager_shutdown() { + let runtime = tokio::runtime::Runtime::new().unwrap(); + let handle = runtime.handle().clone(); + let manager = TaskManager::new(handle.clone()); + let executor = manager.executor(); + + let (signal, shutdown) = signal(); + + executor.spawn(Box::pin(async move { + tokio::time::sleep(Duration::from_millis(200)).await; + drop(signal); + })); + + drop(manager); + + handle.block_on(shutdown); + } } diff --git a/crates/tasks/src/shutdown.rs b/crates/tasks/src/shutdown.rs new file mode 100644 index 000000000..6264841ae --- /dev/null +++ b/crates/tasks/src/shutdown.rs @@ -0,0 +1,100 @@ +//! Helper for shutdown signals + +use futures_util::{ + future::{FusedFuture, Shared}, + FutureExt, +}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::sync::oneshot; + +/// A Future that resolves when the shutdown event has been fired. +#[derive(Debug, Clone)] +pub struct Shutdown(Shared>); + +impl Future for Shutdown { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let pin = self.get_mut(); + if pin.0.is_terminated() || pin.0.poll_unpin(cx).is_ready() { + Poll::Ready(()) + } else { + Poll::Pending + } + } +} + +/// Shutdown signal that fires either manually or on drop by closing the channel +#[derive(Debug)] +pub struct Signal(oneshot::Sender<()>); + +impl Signal { + /// Fire the signal manually. + pub fn fire(self) { + let _ = self.0.send(()); + } +} + +/// Create a channel pair that's used to propagate shutdown event +pub fn signal() -> (Signal, Shutdown) { + let (sender, receiver) = oneshot::channel(); + (Signal(sender), Shutdown(receiver.shared())) +} + +#[cfg(test)] +mod tests { + use super::*; + use futures_util::future::join_all; + use std::time::Duration; + + #[tokio::test(flavor = "multi_thread")] + async fn test_shutdown() { + let (_signal, _shutdown) = signal(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_drop_signal() { + let (signal, shutdown) = signal(); + + tokio::task::spawn(async move { + tokio::time::sleep(Duration::from_millis(500)).await; + drop(signal) + }); + + shutdown.await; + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_multi_shutdowns() { + let (signal, shutdown) = signal(); + + let mut tasks = Vec::with_capacity(100); + for _ in 0..100 { + let shutdown = shutdown.clone(); + let task = tokio::task::spawn(async move { + shutdown.await; + }); + tasks.push(task); + } + + drop(signal); + + join_all(tasks).await; + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_drop_signal_from_thread() { + let (signal, shutdown) = signal(); + + let _thread = std::thread::spawn(|| { + std::thread::sleep(Duration::from_millis(500)); + drop(signal) + }); + + shutdown.await; + } +}