mirror of
https://github.com/hl-archive-node/nanoreth.git
synced 2025-12-06 10:59:55 +00:00
feat: extend task executor (#1302)
This commit is contained in:
@ -8,14 +8,15 @@
|
||||
//! reth task management
|
||||
|
||||
use crate::shutdown::{signal, Shutdown, Signal};
|
||||
use futures_util::{future::select, pin_mut, Future, FutureExt, Stream};
|
||||
use futures_util::{future::select, pin_mut, Future, FutureExt};
|
||||
use std::{
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
task::{ready, Context, Poll},
|
||||
};
|
||||
use tokio::{
|
||||
runtime::Handle,
|
||||
sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
|
||||
task::JoinHandle,
|
||||
};
|
||||
use tracing::error;
|
||||
use tracing_futures::Instrument;
|
||||
@ -40,9 +41,9 @@ pub struct TaskManager {
|
||||
/// See [`Handle`] docs.
|
||||
handle: Handle,
|
||||
/// Sender half for sending panic signals to this type
|
||||
panicked_tasks_tx: UnboundedSender<String>,
|
||||
panicked_tasks_tx: UnboundedSender<&'static str>,
|
||||
/// Listens for panicked tasks
|
||||
panicked_tasks_rx: UnboundedReceiver<String>,
|
||||
panicked_tasks_rx: UnboundedReceiver<&'static str>,
|
||||
/// The [Signal] to fire when all tasks should be shutdown.
|
||||
///
|
||||
/// This is fired on drop.
|
||||
@ -72,17 +73,23 @@ impl TaskManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// A stream that yields the name of panicked tasks.
|
||||
/// An endless future that resolves if a critical task panicked.
|
||||
///
|
||||
/// See [`TaskExecutor::spawn_critical`]
|
||||
impl Stream for TaskManager {
|
||||
type Item = String;
|
||||
impl Future for TaskManager {
|
||||
type Output = PanickedTaskError;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
self.get_mut().panicked_tasks_rx.poll_recv(cx)
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let err = ready!(self.get_mut().panicked_tasks_rx.poll_recv(cx));
|
||||
Poll::Ready(err.map(PanickedTaskError).expect("stream can not end"))
|
||||
}
|
||||
}
|
||||
|
||||
/// Error with the name of the task that panicked.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("Critical task panicked {0}")]
|
||||
pub struct PanickedTaskError(&'static str);
|
||||
|
||||
/// A type that can spawn new tokio tasks
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TaskExecutor {
|
||||
@ -93,16 +100,38 @@ pub struct TaskExecutor {
|
||||
/// Receiver of the shutdown signal.
|
||||
on_shutdown: Shutdown,
|
||||
/// Sender half for sending panic signals to this type
|
||||
panicked_tasks_tx: UnboundedSender<String>,
|
||||
panicked_tasks_tx: UnboundedSender<&'static str>,
|
||||
}
|
||||
|
||||
// === impl TaskExecutor ===
|
||||
|
||||
impl TaskExecutor {
|
||||
/// Spawns the task onto the runtime.
|
||||
///
|
||||
/// See also [`Handle::spawn`].
|
||||
pub fn spawn<F>(&self, fut: F)
|
||||
/// Returns the [Handle] to the tokio runtime.
|
||||
pub fn handle(&self) -> &Handle {
|
||||
&self.handle
|
||||
}
|
||||
|
||||
/// Returns the receiver of the shutdown signal.
|
||||
pub fn on_shutdown_signal(&self) -> &Shutdown {
|
||||
&self.on_shutdown
|
||||
}
|
||||
|
||||
/// Spawns a future on the tokio runtime depending on the [TaskKind]
|
||||
fn spawn_on_rt<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
|
||||
where
|
||||
F: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
match task_kind {
|
||||
TaskKind::Default => self.handle.spawn(fut),
|
||||
TaskKind::Blocking => {
|
||||
let handle = self.handle.clone();
|
||||
self.handle.spawn_blocking(move || handle.block_on(fut))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Spawns a regular task depending on the given [TaskKind]
|
||||
fn spawn_task_as<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
|
||||
where
|
||||
F: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
@ -114,13 +143,54 @@ impl TaskExecutor {
|
||||
}
|
||||
.in_current_span();
|
||||
|
||||
self.handle.spawn(task);
|
||||
self.spawn_on_rt(task, task_kind)
|
||||
}
|
||||
|
||||
/// This spawns a critical task onto the runtime.
|
||||
/// Spawns the task onto the runtime.
|
||||
/// The given future resolves as soon as the [Shutdown] signal is received.
|
||||
///
|
||||
/// If this task panics, the [`TaskManager`] is notified.
|
||||
pub fn spawn_critical<F>(&self, name: &'static str, fut: F)
|
||||
/// See also [`Handle::spawn`].
|
||||
pub fn spawn<F>(&self, fut: F) -> JoinHandle<()>
|
||||
where
|
||||
F: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
self.spawn_task_as(fut, TaskKind::Default)
|
||||
}
|
||||
|
||||
/// Spawns a blocking task onto the runtime.
|
||||
/// The given future resolves as soon as the [Shutdown] signal is received.
|
||||
///
|
||||
/// See also [`Handle::spawn_blocking`].
|
||||
pub fn spawn_blocking<F>(&self, fut: F) -> JoinHandle<()>
|
||||
where
|
||||
F: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
self.spawn_task_as(fut, TaskKind::Blocking)
|
||||
}
|
||||
|
||||
/// Spawns the task onto the runtime.
|
||||
/// The given future resolves as soon as the [Shutdown] signal is received.
|
||||
///
|
||||
/// See also [`Handle::spawn`].
|
||||
pub fn spawn_with_signal<F>(&self, f: impl FnOnce(Shutdown) -> F) -> JoinHandle<()>
|
||||
where
|
||||
F: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
let on_shutdown = self.on_shutdown.clone();
|
||||
let fut = f(on_shutdown);
|
||||
|
||||
let task = fut.in_current_span();
|
||||
|
||||
self.handle.spawn(task)
|
||||
}
|
||||
|
||||
/// Spawns a critical task depending on the given [TaskKind]
|
||||
fn spawn_critical_as<F>(
|
||||
&self,
|
||||
name: &'static str,
|
||||
fut: F,
|
||||
task_kind: TaskKind,
|
||||
) -> JoinHandle<()>
|
||||
where
|
||||
F: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
@ -132,28 +202,86 @@ impl TaskExecutor {
|
||||
.catch_unwind()
|
||||
.map(move |res| {
|
||||
error!("Critical task `{name}` panicked: {res:?}");
|
||||
let _ = panicked_tasks_tx.send(name.to_string());
|
||||
let _ = panicked_tasks_tx.send(name);
|
||||
})
|
||||
.in_current_span();
|
||||
|
||||
self.handle.spawn(async move {
|
||||
let task = async move {
|
||||
pin_mut!(task);
|
||||
let _ = select(on_shutdown, task).await;
|
||||
});
|
||||
};
|
||||
|
||||
self.spawn_on_rt(task, task_kind)
|
||||
}
|
||||
|
||||
/// This spawns a critical blocking task onto the runtime.
|
||||
/// The given future resolves as soon as the [Shutdown] signal is received.
|
||||
///
|
||||
/// If this task panics, the [`TaskManager`] is notified.
|
||||
pub fn spawn_critical_blocking<F>(&self, name: &'static str, fut: F) -> JoinHandle<()>
|
||||
where
|
||||
F: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
self.spawn_critical_as(name, fut, TaskKind::Blocking)
|
||||
}
|
||||
|
||||
/// This spawns a critical task onto the runtime.
|
||||
/// The given future resolves as soon as the [Shutdown] signal is received.
|
||||
///
|
||||
/// If this task panics, the [`TaskManager`] is notified.
|
||||
pub fn spawn_critical<F>(&self, name: &'static str, fut: F) -> JoinHandle<()>
|
||||
where
|
||||
F: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
self.spawn_critical_as(name, fut, TaskKind::Default)
|
||||
}
|
||||
|
||||
/// This spawns a critical task onto the runtime.
|
||||
///
|
||||
/// If this task panics, the [`TaskManager`] is notified.
|
||||
pub fn spawn_critical_with_signal<F>(
|
||||
&self,
|
||||
name: &'static str,
|
||||
f: impl FnOnce(Shutdown) -> F,
|
||||
) -> JoinHandle<()>
|
||||
where
|
||||
F: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
let panicked_tasks_tx = self.panicked_tasks_tx.clone();
|
||||
let on_shutdown = self.on_shutdown.clone();
|
||||
let fut = f(on_shutdown);
|
||||
|
||||
// wrap the task in catch unwind
|
||||
let task = std::panic::AssertUnwindSafe(fut)
|
||||
.catch_unwind()
|
||||
.map(move |res| {
|
||||
error!("Critical task `{name}` panicked: {res:?}");
|
||||
let _ = panicked_tasks_tx.send(name);
|
||||
})
|
||||
.in_current_span();
|
||||
|
||||
self.handle.spawn(task)
|
||||
}
|
||||
}
|
||||
|
||||
/// Determines how a task is spawned
|
||||
enum TaskKind {
|
||||
/// Spawn the task to the default executor [Handle::spawn]
|
||||
Default,
|
||||
/// Spawn the task to the blocking executor [Handle::spawn_blocking]
|
||||
Blocking,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use futures_util::StreamExt;
|
||||
use std::time::Duration;
|
||||
|
||||
#[test]
|
||||
fn test_critical() {
|
||||
let runtime = tokio::runtime::Runtime::new().unwrap();
|
||||
let handle = runtime.handle().clone();
|
||||
let mut manager = TaskManager::new(handle);
|
||||
let manager = TaskManager::new(handle);
|
||||
let executor = manager.executor();
|
||||
|
||||
executor.spawn_critical(
|
||||
@ -162,8 +290,8 @@ mod tests {
|
||||
);
|
||||
|
||||
runtime.block_on(async move {
|
||||
let panicked_task = manager.next().await.unwrap();
|
||||
assert_eq!(panicked_task, "this is a critical task");
|
||||
let err = manager.await;
|
||||
assert_eq!(err.0, "this is a critical task");
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user