feat(tasks): pass downcasted error from panicked task (#3319)

This commit is contained in:
Alexey Shekhirin
2023-06-22 15:21:35 +01:00
committed by GitHub
parent 68b93a88de
commit 0ebdba6c64

View File

@ -17,6 +17,8 @@ use futures_util::{
pin_mut, Future, FutureExt, TryFutureExt,
};
use std::{
any::Any,
fmt::{Display, Formatter},
pin::Pin,
task::{ready, Context, Poll},
};
@ -136,9 +138,9 @@ pub struct TaskManager {
/// See [`Handle`] docs.
handle: Handle,
/// Sender half for sending panic signals to this type
panicked_tasks_tx: UnboundedSender<&'static str>,
panicked_tasks_tx: UnboundedSender<PanickedTaskError>,
/// Listens for panicked tasks
panicked_tasks_rx: UnboundedReceiver<&'static str>,
panicked_tasks_rx: UnboundedReceiver<PanickedTaskError>,
/// The [Signal] to fire when all tasks should be shutdown.
///
/// This is fired on drop.
@ -177,14 +179,41 @@ impl Future for TaskManager {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let err = ready!(self.get_mut().panicked_tasks_rx.poll_recv(cx));
Poll::Ready(err.map(PanickedTaskError).expect("stream can not end"))
Poll::Ready(err.expect("stream can not end"))
}
}
/// Error with the name of the task that panicked.
/// Error with the name of the task that panicked and an error downcasted to string, if possible.
#[derive(Debug, thiserror::Error)]
#[error("Critical task panicked: `{0}`")]
pub struct PanickedTaskError(&'static str);
pub struct PanickedTaskError {
task_name: &'static str,
error: Option<String>,
}
impl Display for PanickedTaskError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let task_name = self.task_name;
if let Some(error) = &self.error {
write!(f, "Critical task `{task_name}` panicked: `{error}`")
} else {
write!(f, "Critical task `{task_name}` panicked")
}
}
}
impl PanickedTaskError {
fn new(task_name: &'static str, error: Box<dyn Any>) -> Self {
let error = match error.downcast::<String>() {
Ok(value) => Some(*value),
Err(error) => match error.downcast::<&str>() {
Ok(value) => Some(value.to_string()),
Err(_) => None,
},
};
Self { task_name, error }
}
}
/// A type that can spawn new tokio tasks
#[derive(Debug, Clone)]
@ -196,7 +225,7 @@ pub struct TaskExecutor {
/// Receiver of the shutdown signal.
on_shutdown: Shutdown,
/// Sender half for sending panic signals to this type
panicked_tasks_tx: UnboundedSender<&'static str>,
panicked_tasks_tx: UnboundedSender<PanickedTaskError>,
// Task Executor Metrics
metrics: TaskExecutorMetrics,
}
@ -298,9 +327,10 @@ impl TaskExecutor {
// wrap the task in catch unwind
let task = std::panic::AssertUnwindSafe(fut)
.catch_unwind()
.inspect_err(move |res| {
error!("Critical task `{name}` panicked: {res:?}");
let _ = panicked_tasks_tx.send(name);
.map_err(move |error| {
let task_error = PanickedTaskError::new(name, error);
error!("{task_error}");
let _ = panicked_tasks_tx.send(task_error);
})
.in_current_span();
@ -352,9 +382,10 @@ impl TaskExecutor {
// wrap the task in catch unwind
let task = std::panic::AssertUnwindSafe(fut)
.catch_unwind()
.inspect_err(move |res| {
error!("Critical task `{name}` panicked: {res:?}");
let _ = panicked_tasks_tx.send(name);
.map_err(move |error| {
let task_error = PanickedTaskError::new(name, error);
error!("{task_error}");
let _ = panicked_tasks_tx.send(task_error);
})
.map(|_| ())
.in_current_span();
@ -428,7 +459,8 @@ mod tests {
runtime.block_on(async move {
let err = manager.await;
assert_eq!(err.0, "this is a critical task");
assert_eq!(err.task_name, "this is a critical task");
assert_eq!(err.error, Some("intentionally panic".to_string()));
})
}