feat: extend task executor (#1302)

This commit is contained in:
Matthias Seitz
2023-02-13 10:42:48 +01:00
committed by GitHub
parent 73ffc425a3
commit 37351df585
4 changed files with 156 additions and 26 deletions

1
Cargo.lock generated
View File

@ -4773,6 +4773,7 @@ name = "reth-tasks"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"futures-util", "futures-util",
"thiserror",
"tokio", "tokio",
"tracing", "tracing",
"tracing-futures", "tracing-futures",

View File

@ -170,7 +170,7 @@ impl SessionManager {
F: Future<Output = ()> + Send + 'static, F: Future<Output = ()> + Send + 'static,
{ {
if let Some(ref executor) = self.executor { if let Some(ref executor) = self.executor {
executor.spawn(async move { f.await }) executor.spawn(async move { f.await });
} else { } else {
tokio::task::spawn(async move { f.await }); tokio::task::spawn(async move { f.await });
} }

View File

@ -12,6 +12,7 @@ tokio = { version = "1", features = ["sync", "rt"] }
tracing-futures = "0.2" tracing-futures = "0.2"
tracing = { version = "0.1", default-features = false } tracing = { version = "0.1", default-features = false }
futures-util = "0.3" futures-util = "0.3"
thiserror = "1.0"
[dev-dependencies] [dev-dependencies]
tokio = { version = "1", features = ["sync", "rt", "rt-multi-thread", "time", "macros"] } tokio = { version = "1", features = ["sync", "rt", "rt-multi-thread", "time", "macros"] }

View File

@ -8,14 +8,15 @@
//! reth task management //! reth task management
use crate::shutdown::{signal, Shutdown, Signal}; use crate::shutdown::{signal, Shutdown, Signal};
use futures_util::{future::select, pin_mut, Future, FutureExt, Stream}; use futures_util::{future::select, pin_mut, Future, FutureExt};
use std::{ use std::{
pin::Pin, pin::Pin,
task::{Context, Poll}, task::{ready, Context, Poll},
}; };
use tokio::{ use tokio::{
runtime::Handle, runtime::Handle,
sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
task::JoinHandle,
}; };
use tracing::error; use tracing::error;
use tracing_futures::Instrument; use tracing_futures::Instrument;
@ -40,9 +41,9 @@ pub struct TaskManager {
/// See [`Handle`] docs. /// See [`Handle`] docs.
handle: Handle, handle: Handle,
/// Sender half for sending panic signals to this type /// Sender half for sending panic signals to this type
panicked_tasks_tx: UnboundedSender<String>, panicked_tasks_tx: UnboundedSender<&'static str>,
/// Listens for panicked tasks /// Listens for panicked tasks
panicked_tasks_rx: UnboundedReceiver<String>, panicked_tasks_rx: UnboundedReceiver<&'static str>,
/// The [Signal] to fire when all tasks should be shutdown. /// The [Signal] to fire when all tasks should be shutdown.
/// ///
/// This is fired on drop. /// This is fired on drop.
@ -72,17 +73,23 @@ impl TaskManager {
} }
} }
/// A stream that yields the name of panicked tasks. /// An endless future that resolves if a critical task panicked.
/// ///
/// See [`TaskExecutor::spawn_critical`] /// See [`TaskExecutor::spawn_critical`]
impl Stream for TaskManager { impl Future for TaskManager {
type Item = String; type Output = PanickedTaskError;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.get_mut().panicked_tasks_rx.poll_recv(cx) let err = ready!(self.get_mut().panicked_tasks_rx.poll_recv(cx));
Poll::Ready(err.map(PanickedTaskError).expect("stream can not end"))
} }
} }
/// Error with the name of the task that panicked.
#[derive(Debug, thiserror::Error)]
#[error("Critical task panicked {0}")]
pub struct PanickedTaskError(&'static str);
/// A type that can spawn new tokio tasks /// A type that can spawn new tokio tasks
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct TaskExecutor { pub struct TaskExecutor {
@ -93,16 +100,38 @@ pub struct TaskExecutor {
/// Receiver of the shutdown signal. /// Receiver of the shutdown signal.
on_shutdown: Shutdown, on_shutdown: Shutdown,
/// Sender half for sending panic signals to this type /// Sender half for sending panic signals to this type
panicked_tasks_tx: UnboundedSender<String>, panicked_tasks_tx: UnboundedSender<&'static str>,
} }
// === impl TaskExecutor === // === impl TaskExecutor ===
impl TaskExecutor { impl TaskExecutor {
/// Spawns the task onto the runtime. /// Returns the [Handle] to the tokio runtime.
/// pub fn handle(&self) -> &Handle {
/// See also [`Handle::spawn`]. &self.handle
pub fn spawn<F>(&self, fut: F) }
/// Returns the receiver of the shutdown signal.
pub fn on_shutdown_signal(&self) -> &Shutdown {
&self.on_shutdown
}
/// Spawns a future on the tokio runtime depending on the [TaskKind]
fn spawn_on_rt<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
where
F: Future<Output = ()> + Send + 'static,
{
match task_kind {
TaskKind::Default => self.handle.spawn(fut),
TaskKind::Blocking => {
let handle = self.handle.clone();
self.handle.spawn_blocking(move || handle.block_on(fut))
}
}
}
/// Spawns a regular task depending on the given [TaskKind]
fn spawn_task_as<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
where where
F: Future<Output = ()> + Send + 'static, F: Future<Output = ()> + Send + 'static,
{ {
@ -114,13 +143,54 @@ impl TaskExecutor {
} }
.in_current_span(); .in_current_span();
self.handle.spawn(task); self.spawn_on_rt(task, task_kind)
} }
/// This spawns a critical task onto the runtime. /// Spawns the task onto the runtime.
/// The given future resolves as soon as the [Shutdown] signal is received.
/// ///
/// If this task panics, the [`TaskManager`] is notified. /// See also [`Handle::spawn`].
pub fn spawn_critical<F>(&self, name: &'static str, fut: F) pub fn spawn<F>(&self, fut: F) -> JoinHandle<()>
where
F: Future<Output = ()> + Send + 'static,
{
self.spawn_task_as(fut, TaskKind::Default)
}
/// Spawns a blocking task onto the runtime.
/// The given future resolves as soon as the [Shutdown] signal is received.
///
/// See also [`Handle::spawn_blocking`].
pub fn spawn_blocking<F>(&self, fut: F) -> JoinHandle<()>
where
F: Future<Output = ()> + Send + 'static,
{
self.spawn_task_as(fut, TaskKind::Blocking)
}
/// Spawns the task onto the runtime.
/// The given future resolves as soon as the [Shutdown] signal is received.
///
/// See also [`Handle::spawn`].
pub fn spawn_with_signal<F>(&self, f: impl FnOnce(Shutdown) -> F) -> JoinHandle<()>
where
F: Future<Output = ()> + Send + 'static,
{
let on_shutdown = self.on_shutdown.clone();
let fut = f(on_shutdown);
let task = fut.in_current_span();
self.handle.spawn(task)
}
/// Spawns a critical task depending on the given [TaskKind]
fn spawn_critical_as<F>(
&self,
name: &'static str,
fut: F,
task_kind: TaskKind,
) -> JoinHandle<()>
where where
F: Future<Output = ()> + Send + 'static, F: Future<Output = ()> + Send + 'static,
{ {
@ -132,28 +202,86 @@ impl TaskExecutor {
.catch_unwind() .catch_unwind()
.map(move |res| { .map(move |res| {
error!("Critical task `{name}` panicked: {res:?}"); error!("Critical task `{name}` panicked: {res:?}");
let _ = panicked_tasks_tx.send(name.to_string()); let _ = panicked_tasks_tx.send(name);
}) })
.in_current_span(); .in_current_span();
self.handle.spawn(async move { let task = async move {
pin_mut!(task); pin_mut!(task);
let _ = select(on_shutdown, task).await; let _ = select(on_shutdown, task).await;
}); };
self.spawn_on_rt(task, task_kind)
} }
/// This spawns a critical blocking task onto the runtime.
/// The given future resolves as soon as the [Shutdown] signal is received.
///
/// If this task panics, the [`TaskManager`] is notified.
pub fn spawn_critical_blocking<F>(&self, name: &'static str, fut: F) -> JoinHandle<()>
where
F: Future<Output = ()> + Send + 'static,
{
self.spawn_critical_as(name, fut, TaskKind::Blocking)
}
/// This spawns a critical task onto the runtime.
/// The given future resolves as soon as the [Shutdown] signal is received.
///
/// If this task panics, the [`TaskManager`] is notified.
pub fn spawn_critical<F>(&self, name: &'static str, fut: F) -> JoinHandle<()>
where
F: Future<Output = ()> + Send + 'static,
{
self.spawn_critical_as(name, fut, TaskKind::Default)
}
/// This spawns a critical task onto the runtime.
///
/// If this task panics, the [`TaskManager`] is notified.
pub fn spawn_critical_with_signal<F>(
&self,
name: &'static str,
f: impl FnOnce(Shutdown) -> F,
) -> JoinHandle<()>
where
F: Future<Output = ()> + Send + 'static,
{
let panicked_tasks_tx = self.panicked_tasks_tx.clone();
let on_shutdown = self.on_shutdown.clone();
let fut = f(on_shutdown);
// wrap the task in catch unwind
let task = std::panic::AssertUnwindSafe(fut)
.catch_unwind()
.map(move |res| {
error!("Critical task `{name}` panicked: {res:?}");
let _ = panicked_tasks_tx.send(name);
})
.in_current_span();
self.handle.spawn(task)
}
}
/// Determines how a task is spawned
enum TaskKind {
/// Spawn the task to the default executor [Handle::spawn]
Default,
/// Spawn the task to the blocking executor [Handle::spawn_blocking]
Blocking,
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use futures_util::StreamExt;
use std::time::Duration; use std::time::Duration;
#[test] #[test]
fn test_critical() { fn test_critical() {
let runtime = tokio::runtime::Runtime::new().unwrap(); let runtime = tokio::runtime::Runtime::new().unwrap();
let handle = runtime.handle().clone(); let handle = runtime.handle().clone();
let mut manager = TaskManager::new(handle); let manager = TaskManager::new(handle);
let executor = manager.executor(); let executor = manager.executor();
executor.spawn_critical( executor.spawn_critical(
@ -162,8 +290,8 @@ mod tests {
); );
runtime.block_on(async move { runtime.block_on(async move {
let panicked_task = manager.next().await.unwrap(); let err = manager.await;
assert_eq!(panicked_task, "this is a critical task"); assert_eq!(err.0, "this is a critical task");
}) })
} }