mirror of
https://github.com/hl-archive-node/nanoreth.git
synced 2025-12-06 10:59:55 +00:00
feat: add shutdown signal to TaskManager (#831)
This commit is contained in:
@ -14,4 +14,4 @@ tracing = { version = "0.1", default-features = false }
|
|||||||
futures-util = "0.3"
|
futures-util = "0.3"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tokio = { version = "1", features = ["sync", "rt", "rt-multi-thread"] }
|
tokio = { version = "1", features = ["sync", "rt", "rt-multi-thread", "time", "macros"] }
|
||||||
@ -7,7 +7,8 @@
|
|||||||
|
|
||||||
//! reth task management
|
//! reth task management
|
||||||
|
|
||||||
use futures_util::{Future, FutureExt, Stream};
|
use crate::shutdown::{signal, Shutdown, Signal};
|
||||||
|
use futures_util::{future::select, pin_mut, Future, FutureExt, Stream};
|
||||||
use std::{
|
use std::{
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
@ -19,6 +20,8 @@ use tokio::{
|
|||||||
use tracing::error;
|
use tracing::error;
|
||||||
use tracing_futures::Instrument;
|
use tracing_futures::Instrument;
|
||||||
|
|
||||||
|
pub mod shutdown;
|
||||||
|
|
||||||
/// Many reth components require to spawn tasks for long-running jobs. For example `discovery`
|
/// Many reth components require to spawn tasks for long-running jobs. For example `discovery`
|
||||||
/// spawns tasks to handle egress and ingress of udp traffic or `network` that spawns session tasks
|
/// spawns tasks to handle egress and ingress of udp traffic or `network` that spawns session tasks
|
||||||
/// that handle the traffic to and from a peer.
|
/// that handle the traffic to and from a peer.
|
||||||
@ -40,6 +43,12 @@ pub struct TaskManager {
|
|||||||
panicked_tasks_tx: UnboundedSender<String>,
|
panicked_tasks_tx: UnboundedSender<String>,
|
||||||
/// Listens for panicked tasks
|
/// Listens for panicked tasks
|
||||||
panicked_tasks_rx: UnboundedReceiver<String>,
|
panicked_tasks_rx: UnboundedReceiver<String>,
|
||||||
|
/// The [Signal] to fire when all tasks should be shutdown.
|
||||||
|
///
|
||||||
|
/// This is fired on drop.
|
||||||
|
_signal: Signal,
|
||||||
|
/// Receiver of the shutdown signal.
|
||||||
|
on_shutdown: Shutdown,
|
||||||
}
|
}
|
||||||
|
|
||||||
// === impl TaskManager ===
|
// === impl TaskManager ===
|
||||||
@ -48,7 +57,8 @@ impl TaskManager {
|
|||||||
/// Create a new instance connected to the given handle's tokio runtime.
|
/// Create a new instance connected to the given handle's tokio runtime.
|
||||||
pub fn new(handle: Handle) -> Self {
|
pub fn new(handle: Handle) -> Self {
|
||||||
let (panicked_tasks_tx, panicked_tasks_rx) = unbounded_channel();
|
let (panicked_tasks_tx, panicked_tasks_rx) = unbounded_channel();
|
||||||
Self { handle, panicked_tasks_tx, panicked_tasks_rx }
|
let (_signal, on_shutdown) = signal();
|
||||||
|
Self { handle, panicked_tasks_tx, panicked_tasks_rx, _signal, on_shutdown }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a new [`TaskExecutor`] that can spawn new tasks onto the tokio runtime this type is
|
/// Returns a new [`TaskExecutor`] that can spawn new tasks onto the tokio runtime this type is
|
||||||
@ -56,6 +66,7 @@ impl TaskManager {
|
|||||||
pub fn executor(&self) -> TaskExecutor {
|
pub fn executor(&self) -> TaskExecutor {
|
||||||
TaskExecutor {
|
TaskExecutor {
|
||||||
handle: self.handle.clone(),
|
handle: self.handle.clone(),
|
||||||
|
on_shutdown: self.on_shutdown.clone(),
|
||||||
panicked_tasks_tx: self.panicked_tasks_tx.clone(),
|
panicked_tasks_tx: self.panicked_tasks_tx.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -79,6 +90,8 @@ pub struct TaskExecutor {
|
|||||||
///
|
///
|
||||||
/// See [`Handle`] docs.
|
/// See [`Handle`] docs.
|
||||||
handle: Handle,
|
handle: Handle,
|
||||||
|
/// Receiver of the shutdown signal.
|
||||||
|
on_shutdown: Shutdown,
|
||||||
/// Sender half for sending panic signals to this type
|
/// Sender half for sending panic signals to this type
|
||||||
panicked_tasks_tx: UnboundedSender<String>,
|
panicked_tasks_tx: UnboundedSender<String>,
|
||||||
}
|
}
|
||||||
@ -93,7 +106,14 @@ impl TaskExecutor {
|
|||||||
where
|
where
|
||||||
F: Future<Output = ()> + Send + 'static,
|
F: Future<Output = ()> + Send + 'static,
|
||||||
{
|
{
|
||||||
let task = async move { fut.await }.in_current_span();
|
let on_shutdown = self.on_shutdown.clone();
|
||||||
|
|
||||||
|
let task = async move {
|
||||||
|
pin_mut!(fut);
|
||||||
|
let _ = select(on_shutdown, fut).await;
|
||||||
|
}
|
||||||
|
.in_current_span();
|
||||||
|
|
||||||
self.handle.spawn(task);
|
self.handle.spawn(task);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -105,6 +125,7 @@ impl TaskExecutor {
|
|||||||
F: Future<Output = ()> + Send + 'static,
|
F: Future<Output = ()> + Send + 'static,
|
||||||
{
|
{
|
||||||
let panicked_tasks_tx = self.panicked_tasks_tx.clone();
|
let panicked_tasks_tx = self.panicked_tasks_tx.clone();
|
||||||
|
let on_shutdown = self.on_shutdown.clone();
|
||||||
|
|
||||||
// wrap the task in catch unwind
|
// wrap the task in catch unwind
|
||||||
let task = std::panic::AssertUnwindSafe(fut)
|
let task = std::panic::AssertUnwindSafe(fut)
|
||||||
@ -114,7 +135,11 @@ impl TaskExecutor {
|
|||||||
let _ = panicked_tasks_tx.send(name.to_string());
|
let _ = panicked_tasks_tx.send(name.to_string());
|
||||||
})
|
})
|
||||||
.in_current_span();
|
.in_current_span();
|
||||||
self.handle.spawn(task);
|
|
||||||
|
self.handle.spawn(async move {
|
||||||
|
pin_mut!(task);
|
||||||
|
let _ = select(on_shutdown, task).await;
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,6 +147,7 @@ impl TaskExecutor {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use futures_util::StreamExt;
|
use futures_util::StreamExt;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_critical() {
|
fn test_critical() {
|
||||||
@ -140,4 +166,47 @@ mod tests {
|
|||||||
assert_eq!(panicked_task, "this is a critical task");
|
assert_eq!(panicked_task, "this is a critical task");
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tests that spawned tasks are terminated if the `TaskManager` drops
|
||||||
|
#[test]
|
||||||
|
fn test_manager_shutdown_critical() {
|
||||||
|
let runtime = tokio::runtime::Runtime::new().unwrap();
|
||||||
|
let handle = runtime.handle().clone();
|
||||||
|
let manager = TaskManager::new(handle.clone());
|
||||||
|
let executor = manager.executor();
|
||||||
|
|
||||||
|
let (signal, shutdown) = signal();
|
||||||
|
|
||||||
|
executor.spawn_critical(
|
||||||
|
"this is a critical task",
|
||||||
|
Box::pin(async move {
|
||||||
|
tokio::time::sleep(Duration::from_millis(200)).await;
|
||||||
|
drop(signal);
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
drop(manager);
|
||||||
|
|
||||||
|
handle.block_on(shutdown);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests that spawned tasks are terminated if the `TaskManager` drops
|
||||||
|
#[test]
|
||||||
|
fn test_manager_shutdown() {
|
||||||
|
let runtime = tokio::runtime::Runtime::new().unwrap();
|
||||||
|
let handle = runtime.handle().clone();
|
||||||
|
let manager = TaskManager::new(handle.clone());
|
||||||
|
let executor = manager.executor();
|
||||||
|
|
||||||
|
let (signal, shutdown) = signal();
|
||||||
|
|
||||||
|
executor.spawn(Box::pin(async move {
|
||||||
|
tokio::time::sleep(Duration::from_millis(200)).await;
|
||||||
|
drop(signal);
|
||||||
|
}));
|
||||||
|
|
||||||
|
drop(manager);
|
||||||
|
|
||||||
|
handle.block_on(shutdown);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
100
crates/tasks/src/shutdown.rs
Normal file
100
crates/tasks/src/shutdown.rs
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
//! Helper for shutdown signals
|
||||||
|
|
||||||
|
use futures_util::{
|
||||||
|
future::{FusedFuture, Shared},
|
||||||
|
FutureExt,
|
||||||
|
};
|
||||||
|
use std::{
|
||||||
|
future::Future,
|
||||||
|
pin::Pin,
|
||||||
|
task::{Context, Poll},
|
||||||
|
};
|
||||||
|
use tokio::sync::oneshot;
|
||||||
|
|
||||||
|
/// A Future that resolves when the shutdown event has been fired.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Shutdown(Shared<oneshot::Receiver<()>>);
|
||||||
|
|
||||||
|
impl Future for Shutdown {
|
||||||
|
type Output = ();
|
||||||
|
|
||||||
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
|
let pin = self.get_mut();
|
||||||
|
if pin.0.is_terminated() || pin.0.poll_unpin(cx).is_ready() {
|
||||||
|
Poll::Ready(())
|
||||||
|
} else {
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shutdown signal that fires either manually or on drop by closing the channel
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Signal(oneshot::Sender<()>);
|
||||||
|
|
||||||
|
impl Signal {
|
||||||
|
/// Fire the signal manually.
|
||||||
|
pub fn fire(self) {
|
||||||
|
let _ = self.0.send(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a channel pair that's used to propagate shutdown event
|
||||||
|
pub fn signal() -> (Signal, Shutdown) {
|
||||||
|
let (sender, receiver) = oneshot::channel();
|
||||||
|
(Signal(sender), Shutdown(receiver.shared()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use futures_util::future::join_all;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread")]
|
||||||
|
async fn test_shutdown() {
|
||||||
|
let (_signal, _shutdown) = signal();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread")]
|
||||||
|
async fn test_drop_signal() {
|
||||||
|
let (signal, shutdown) = signal();
|
||||||
|
|
||||||
|
tokio::task::spawn(async move {
|
||||||
|
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||||
|
drop(signal)
|
||||||
|
});
|
||||||
|
|
||||||
|
shutdown.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread")]
|
||||||
|
async fn test_multi_shutdowns() {
|
||||||
|
let (signal, shutdown) = signal();
|
||||||
|
|
||||||
|
let mut tasks = Vec::with_capacity(100);
|
||||||
|
for _ in 0..100 {
|
||||||
|
let shutdown = shutdown.clone();
|
||||||
|
let task = tokio::task::spawn(async move {
|
||||||
|
shutdown.await;
|
||||||
|
});
|
||||||
|
tasks.push(task);
|
||||||
|
}
|
||||||
|
|
||||||
|
drop(signal);
|
||||||
|
|
||||||
|
join_all(tasks).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread")]
|
||||||
|
async fn test_drop_signal_from_thread() {
|
||||||
|
let (signal, shutdown) = signal();
|
||||||
|
|
||||||
|
let _thread = std::thread::spawn(|| {
|
||||||
|
std::thread::sleep(Duration::from_millis(500));
|
||||||
|
drop(signal)
|
||||||
|
});
|
||||||
|
|
||||||
|
shutdown.await;
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user