feat: add shutdown signal to TaskManager (#831)

This commit is contained in:
Matthias Seitz
2023-01-13 10:34:22 +01:00
committed by GitHub
parent 7767b216bc
commit e6ca4c56c6
3 changed files with 174 additions and 5 deletions

View File

@ -7,7 +7,8 @@
//! reth task management
use futures_util::{Future, FutureExt, Stream};
use crate::shutdown::{signal, Shutdown, Signal};
use futures_util::{future::select, pin_mut, Future, FutureExt, Stream};
use std::{
pin::Pin,
task::{Context, Poll},
@ -19,6 +20,8 @@ use tokio::{
use tracing::error;
use tracing_futures::Instrument;
pub mod shutdown;
/// Many reth components require to spawn tasks for long-running jobs. For example `discovery`
/// spawns tasks to handle egress and ingress of udp traffic or `network` that spawns session tasks
/// that handle the traffic to and from a peer.
@ -40,6 +43,12 @@ pub struct TaskManager {
panicked_tasks_tx: UnboundedSender<String>,
/// Listens for panicked tasks
panicked_tasks_rx: UnboundedReceiver<String>,
/// The [Signal] to fire when all tasks should be shutdown.
///
/// This is fired on drop.
_signal: Signal,
/// Receiver of the shutdown signal.
on_shutdown: Shutdown,
}
// === impl TaskManager ===
@ -48,7 +57,8 @@ impl TaskManager {
/// Create a new instance connected to the given handle's tokio runtime.
pub fn new(handle: Handle) -> Self {
let (panicked_tasks_tx, panicked_tasks_rx) = unbounded_channel();
Self { handle, panicked_tasks_tx, panicked_tasks_rx }
let (_signal, on_shutdown) = signal();
Self { handle, panicked_tasks_tx, panicked_tasks_rx, _signal, on_shutdown }
}
/// Returns a new [`TaskExecutor`] that can spawn new tasks onto the tokio runtime this type is
@ -56,6 +66,7 @@ impl TaskManager {
pub fn executor(&self) -> TaskExecutor {
TaskExecutor {
handle: self.handle.clone(),
on_shutdown: self.on_shutdown.clone(),
panicked_tasks_tx: self.panicked_tasks_tx.clone(),
}
}
@ -79,6 +90,8 @@ pub struct TaskExecutor {
///
/// See [`Handle`] docs.
handle: Handle,
/// Receiver of the shutdown signal.
on_shutdown: Shutdown,
/// Sender half for sending panic signals to this type
panicked_tasks_tx: UnboundedSender<String>,
}
@ -93,7 +106,14 @@ impl TaskExecutor {
where
F: Future<Output = ()> + Send + 'static,
{
let task = async move { fut.await }.in_current_span();
let on_shutdown = self.on_shutdown.clone();
let task = async move {
pin_mut!(fut);
let _ = select(on_shutdown, fut).await;
}
.in_current_span();
self.handle.spawn(task);
}
@ -105,6 +125,7 @@ impl TaskExecutor {
F: Future<Output = ()> + Send + 'static,
{
let panicked_tasks_tx = self.panicked_tasks_tx.clone();
let on_shutdown = self.on_shutdown.clone();
// wrap the task in catch unwind
let task = std::panic::AssertUnwindSafe(fut)
@ -114,7 +135,11 @@ impl TaskExecutor {
let _ = panicked_tasks_tx.send(name.to_string());
})
.in_current_span();
self.handle.spawn(task);
self.handle.spawn(async move {
pin_mut!(task);
let _ = select(on_shutdown, task).await;
});
}
}
@ -122,6 +147,7 @@ impl TaskExecutor {
mod tests {
use super::*;
use futures_util::StreamExt;
use std::time::Duration;
#[test]
fn test_critical() {
@ -140,4 +166,47 @@ mod tests {
assert_eq!(panicked_task, "this is a critical task");
})
}
// Tests that spawned tasks are terminated if the `TaskManager` drops
#[test]
fn test_manager_shutdown_critical() {
let runtime = tokio::runtime::Runtime::new().unwrap();
let handle = runtime.handle().clone();
let manager = TaskManager::new(handle.clone());
let executor = manager.executor();
let (signal, shutdown) = signal();
executor.spawn_critical(
"this is a critical task",
Box::pin(async move {
tokio::time::sleep(Duration::from_millis(200)).await;
drop(signal);
}),
);
drop(manager);
handle.block_on(shutdown);
}
// Tests that spawned tasks are terminated if the `TaskManager` drops
#[test]
fn test_manager_shutdown() {
let runtime = tokio::runtime::Runtime::new().unwrap();
let handle = runtime.handle().clone();
let manager = TaskManager::new(handle.clone());
let executor = manager.executor();
let (signal, shutdown) = signal();
executor.spawn(Box::pin(async move {
tokio::time::sleep(Duration::from_millis(200)).await;
drop(signal);
}));
drop(manager);
handle.block_on(shutdown);
}
}

View File

@ -0,0 +1,100 @@
//! Helper for shutdown signals
use futures_util::{
future::{FusedFuture, Shared},
FutureExt,
};
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tokio::sync::oneshot;
/// A Future that resolves when the shutdown event has been fired.
#[derive(Debug, Clone)]
pub struct Shutdown(Shared<oneshot::Receiver<()>>);
impl Future for Shutdown {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let pin = self.get_mut();
if pin.0.is_terminated() || pin.0.poll_unpin(cx).is_ready() {
Poll::Ready(())
} else {
Poll::Pending
}
}
}
/// Shutdown signal that fires either manually or on drop by closing the channel
#[derive(Debug)]
pub struct Signal(oneshot::Sender<()>);
impl Signal {
/// Fire the signal manually.
pub fn fire(self) {
let _ = self.0.send(());
}
}
/// Create a channel pair that's used to propagate shutdown event
pub fn signal() -> (Signal, Shutdown) {
let (sender, receiver) = oneshot::channel();
(Signal(sender), Shutdown(receiver.shared()))
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::future::join_all;
use std::time::Duration;
#[tokio::test(flavor = "multi_thread")]
async fn test_shutdown() {
let (_signal, _shutdown) = signal();
}
#[tokio::test(flavor = "multi_thread")]
async fn test_drop_signal() {
let (signal, shutdown) = signal();
tokio::task::spawn(async move {
tokio::time::sleep(Duration::from_millis(500)).await;
drop(signal)
});
shutdown.await;
}
#[tokio::test(flavor = "multi_thread")]
async fn test_multi_shutdowns() {
let (signal, shutdown) = signal();
let mut tasks = Vec::with_capacity(100);
for _ in 0..100 {
let shutdown = shutdown.clone();
let task = tokio::task::spawn(async move {
shutdown.await;
});
tasks.push(task);
}
drop(signal);
join_all(tasks).await;
}
#[tokio::test(flavor = "multi_thread")]
async fn test_drop_signal_from_thread() {
let (signal, shutdown) = signal();
let _thread = std::thread::spawn(|| {
std::thread::sleep(Duration::from_millis(500));
drop(signal)
});
shutdown.await;
}
}