From 72968d758b2e3d8d012e3489a0fa3f8e992974b8 Mon Sep 17 00:00:00 2001 From: Bjerg Date: Mon, 10 Oct 2022 05:49:45 +0200 Subject: [PATCH] refactor: pipeline refactor (#23) * refactor: split stage crate into modules * refactor: clean up pipeline events * refactor: internal fn for executing single stage * refactor: pipeline error * refactor: move errors to own module * refactor: misc cleanup * test: add tests for `opt::max` and `opt::min` * chore: db errors (#26) * refactor: replace ext trait with wrapper type Co-authored-by: Roman Krasiuk --- crates/stages/src/error.rs | 41 +++++ crates/stages/src/id.rs | 50 +++++++ crates/stages/src/lib.rs | 153 +------------------ crates/stages/src/pipeline.rs | 225 ++++++++++++---------------- crates/stages/src/pipeline/event.rs | 46 ++++++ crates/stages/src/stage.rs | 80 ++++++++++ crates/stages/src/util.rs | 63 ++++++++ 7 files changed, 382 insertions(+), 276 deletions(-) create mode 100644 crates/stages/src/error.rs create mode 100644 crates/stages/src/id.rs create mode 100644 crates/stages/src/pipeline/event.rs create mode 100644 crates/stages/src/stage.rs create mode 100644 crates/stages/src/util.rs diff --git a/crates/stages/src/error.rs b/crates/stages/src/error.rs new file mode 100644 index 000000000..bd323fb93 --- /dev/null +++ b/crates/stages/src/error.rs @@ -0,0 +1,41 @@ +use crate::pipeline::PipelineEvent; +use reth_db::mdbx; +use reth_primitives::BlockNumber; +use thiserror::Error; +use tokio::sync::mpsc::error::SendError; + +/// A stage execution error. +#[derive(Error, Debug)] +pub enum StageError { + /// The stage encountered a state validation error. + /// + /// TODO: This depends on the consensus engine and should include the validation failure reason + #[error("Stage encountered a validation error in block {block}.")] + Validation { + /// The block that failed validation. + block: BlockNumber, + }, + /// The stage encountered a database error. + #[error("A database error occurred.")] + Database(#[from] mdbx::Error), + /// The stage encountered an internal error. + #[error(transparent)] + Internal(Box), +} + +/// A pipeline execution error. +#[derive(Error, Debug)] +pub enum PipelineError { + /// The pipeline encountered an irrecoverable error in one of the stages. + #[error("A stage encountered an irrecoverable error.")] + Stage(#[from] StageError), + /// The pipeline encountered a database error. + #[error("A database error occurred.")] + Database(#[from] mdbx::Error), + /// The pipeline encountered an error while trying to send an event. + #[error("The pipeline encountered an error while trying to send an event.")] + Channel(#[from] SendError), + /// The stage encountered an internal error. + #[error(transparent)] + Internal(Box), +} diff --git a/crates/stages/src/id.rs b/crates/stages/src/id.rs new file mode 100644 index 000000000..cbc981068 --- /dev/null +++ b/crates/stages/src/id.rs @@ -0,0 +1,50 @@ +use reth_db::mdbx; +use reth_primitives::BlockNumber; +use std::fmt::Display; + +/// The ID of a stage. +/// +/// Each stage ID must be unique. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct StageId(pub &'static str); + +impl Display for StageId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl StageId { + /// Get the last committed progress of this stage. + pub fn get_progress<'db, K, E>( + &self, + tx: &mdbx::Transaction<'db, K, E>, + ) -> Result, mdbx::Error> + where + K: mdbx::TransactionKind, + E: mdbx::EnvironmentKind, + { + // TODO: Clean up when we get better database abstractions + let bytes: Option> = tx.get(&tx.open_db(Some("SyncStage"))?, self.0.as_ref())?; + + Ok(bytes.map(|b| BlockNumber::from_be_bytes(b.try_into().expect("Database corrupt")))) + } + + /// Save the progress of this stage. + pub fn save_progress<'db, E>( + &self, + tx: &mdbx::Transaction<'db, mdbx::RW, E>, + block: BlockNumber, + ) -> Result<(), mdbx::Error> + where + E: mdbx::EnvironmentKind, + { + // TODO: Clean up when we get better database abstractions + tx.put( + &tx.open_db(Some("SyncStage"))?, + self.0, + block.to_be_bytes(), + mdbx::WriteFlags::UPSERT, + ) + } +} diff --git a/crates/stages/src/lib.rs b/crates/stages/src/lib.rs index ca3c413ee..7e4ce6525 100644 --- a/crates/stages/src/lib.rs +++ b/crates/stages/src/lib.rs @@ -8,150 +8,13 @@ //! //! See [Stage] and [Pipeline]. -use async_trait::async_trait; -use reth_db::mdbx; -use reth_primitives::BlockNumber; -use std::fmt::Display; -use thiserror::Error; - +mod error; +mod id; mod pipeline; +mod stage; +mod util; + +pub use error::*; +pub use id::*; pub use pipeline::*; - -/// Stage execution input, see [Stage::execute]. -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub struct ExecInput { - /// The stage that was run before the current stage and the block number it reached. - pub previous_stage: Option<(StageId, BlockNumber)>, - /// The progress of this stage the last time it was executed. - pub stage_progress: Option, -} - -/// Stage unwind input, see [Stage::unwind]. -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub struct UnwindInput { - /// The current highest block of the stage. - pub stage_progress: BlockNumber, - /// The block to unwind to. - pub unwind_to: BlockNumber, - /// The bad block that caused the unwind, if any. - pub bad_block: Option, -} - -/// The output of a stage execution. -#[derive(Debug, PartialEq, Eq, Clone)] -pub struct ExecOutput { - /// How far the stage got. - pub stage_progress: BlockNumber, - /// Whether or not the stage is done. - pub done: bool, - /// Whether or not the stage reached the tip of the chain. - pub reached_tip: bool, -} - -/// The output of a stage unwinding. -#[derive(Debug, PartialEq, Eq, Clone)] -pub struct UnwindOutput { - /// The block at which the stage has unwound to. - pub stage_progress: BlockNumber, -} - -/// A stage execution error. -#[derive(Error, Debug)] -pub enum StageError { - /// The stage encountered a state validation error. - /// - /// TODO: This depends on the consensus engine and should include the validation failure reason - #[error("Stage encountered a validation error in block {block}.")] - Validation { - /// The block that failed validation. - block: BlockNumber, - }, - /// The stage encountered an internal error. - #[error(transparent)] - Internal(Box), -} - -/// A stage is a segmented part of the syncing process of the node. -/// -/// Each stage takes care of a well-defined task, such as downloading headers or executing -/// transactions, and persist their results to a database. -/// -/// Stages must have a unique [ID][StageId] and implement a way to "roll forwards" -/// ([Stage::execute]) and a way to "roll back" ([Stage::unwind]). -/// -/// Stages are executed as part of a pipeline where they are executed serially. -#[async_trait] -pub trait Stage<'db, E>: Send + Sync -where - E: mdbx::EnvironmentKind, -{ - /// Get the ID of the stage. - /// - /// Stage IDs must be unique. - fn id(&self) -> StageId; - - /// Execute the stage. - async fn execute<'tx>( - &mut self, - tx: &mut mdbx::Transaction<'tx, mdbx::RW, E>, - input: ExecInput, - ) -> Result - where - 'db: 'tx; - - /// Unwind the stage. - async fn unwind<'tx>( - &mut self, - tx: &mut mdbx::Transaction<'tx, mdbx::RW, E>, - input: UnwindInput, - ) -> Result> - where - 'db: 'tx; -} - -/// The ID of a stage. -/// -/// Each stage ID must be unique. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub struct StageId(pub &'static str); - -impl Display for StageId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -impl StageId { - /// Get the last committed progress of this stage. - pub fn get_progress<'db, K, E>( - &self, - tx: &mdbx::Transaction<'db, K, E>, - ) -> Result, mdbx::Error> - where - K: mdbx::TransactionKind, - E: mdbx::EnvironmentKind, - { - // TODO: Clean up when we get better database abstractions - let bytes: Option> = tx.get(&tx.open_db(Some("SyncStage"))?, self.0.as_ref())?; - - Ok(bytes.map(|b| BlockNumber::from_be_bytes(b.try_into().expect("Database corrupt")))) - } - - /// Save the progress of this stage. - pub fn save_progress<'db, E>( - &self, - tx: &mdbx::Transaction<'db, mdbx::RW, E>, - block: BlockNumber, - ) -> Result<(), mdbx::Error> - where - E: mdbx::EnvironmentKind, - { - // TODO: Clean up when we get better database abstractions - tx.put( - &tx.open_db(Some("SyncStage"))?, - self.0, - block.to_be_bytes(), - mdbx::WriteFlags::UPSERT, - ) - } -} +pub use stage::*; diff --git a/crates/stages/src/pipeline.rs b/crates/stages/src/pipeline.rs index ed11f773f..f1c0c7b36 100644 --- a/crates/stages/src/pipeline.rs +++ b/crates/stages/src/pipeline.rs @@ -1,10 +1,17 @@ -use crate::{ExecInput, ExecOutput, Stage, StageError, StageId, UnwindInput, UnwindOutput}; +use crate::{ + error::*, + util::opt::{self, MaybeSender}, + ExecInput, ExecOutput, Stage, StageError, UnwindInput, UnwindOutput, +}; use reth_db::mdbx; use reth_primitives::BlockNumber; use std::fmt::{Debug, Formatter}; use tokio::sync::mpsc::Sender; use tracing::*; +mod event; +pub use event::*; + struct QueuedStage<'db, E> where E: mdbx::EnvironmentKind, @@ -42,7 +49,7 @@ where { stages: Vec>, max_block: Option, - events_sender: Option>, + events_sender: MaybeSender, } impl<'db, E> Default for Pipeline<'db, E> @@ -50,7 +57,7 @@ where E: mdbx::EnvironmentKind, { fn default() -> Self { - Self { stages: Vec::new(), max_block: None, events_sender: None } + Self { stages: Vec::new(), max_block: None, events_sender: MaybeSender::new(None) } } } @@ -113,15 +120,12 @@ where /// Set a channel the pipeline will transmit events over (see [PipelineEvent]). pub fn set_channel(mut self, sender: Sender) -> Self { - self.events_sender = Some(sender); + self.events_sender.set(Some(sender)); self } /// Run the pipeline. - pub async fn run( - &mut self, - db: &'db mdbx::Environment, - ) -> Result<(), Box> { + pub async fn run(&mut self, db: &'db mdbx::Environment) -> Result<(), PipelineError> { let mut previous_stage = None; let mut minimum_progress: Option = None; let mut maximum_progress: Option = None; @@ -129,80 +133,58 @@ where 'run: loop { let mut tx = db.begin_rw_txn()?; - for (_, QueuedStage { stage, require_tip, .. }) in self.stages.iter_mut().enumerate() { - let stage_id = stage.id(); + for (_, queued_stage) in self.stages.iter_mut().enumerate() { + let stage_id = queued_stage.stage.id(); let block_reached = loop { let prev_progress = stage_id.get_progress(&tx)?; + self.events_sender + .send(PipelineEvent::Running { stage_id, stage_progress: prev_progress }) + .await?; - if let Some(rx) = &self.events_sender { - rx.send(PipelineEvent::Running { stage_id, stage_progress: prev_progress }) - .await? - } - - let reached_virtual_tip = maximum_progress + // Whether any stage has reached the maximum block, which also counts as having + // reached the tip for stages that have reached the tip + let reached_max_block = maximum_progress .zip(self.max_block) .map_or(false, |(progress, target)| progress >= target); - // Execute stage - let output = async { - if !reached_tip_flag && *require_tip && !reached_virtual_tip { - info!("Tip not reached, skipping."); + // Whether this stage reached the max block + let stage_reached_max_block = prev_progress + .zip(self.max_block) + .map_or(false, |(prev_progress, target)| prev_progress >= target); - // Stage requires us to reach the tip of the chain first, but we have - // not. - Ok(ExecOutput { - stage_progress: prev_progress.unwrap_or_default(), - done: true, - reached_tip: false, - }) - } else if prev_progress - .zip(self.max_block) - .map_or(false, |(prev_progress, target)| prev_progress >= target) - { - info!("Stage reached maximum block, skipping."); - // We reached the maximum block, so we skip the stage - Ok(ExecOutput { - stage_progress: prev_progress.unwrap_or_default(), - done: true, - reached_tip: true, - }) - } else { - stage - .execute( - &mut tx, - ExecInput { previous_stage, stage_progress: prev_progress }, - ) - .await - } - } + // Execute stage + let output = Self::execute_stage( + &mut tx, + queued_stage, + ExecInput { previous_stage, stage_progress: prev_progress }, + reached_tip_flag || reached_max_block, + stage_reached_max_block, + ) .instrument(info_span!("Running", stage = %stage_id)) .await; + if output.is_err() { + self.events_sender + .send(PipelineEvent::Ran { stage_id, result: None }) + .await?; + } + match output { Ok(out @ ExecOutput { stage_progress, done, reached_tip }) => { debug!(stage = %stage_id, %stage_progress, %done, "Stage made progress"); stage_id.save_progress(&tx, stage_progress)?; - if let Some(rx) = &self.events_sender { - rx.send(PipelineEvent::Ran { stage_id, result: Some(out.clone()) }) - .await? - } + self.events_sender + .send(PipelineEvent::Ran { stage_id, result: Some(out.clone()) }) + .await?; // TODO: Make the commit interval configurable tx.commit()?; tx = db.begin_rw_txn()?; - // TODO: Clean up - if let Some(min) = &mut minimum_progress { - *min = std::cmp::min(*min, stage_progress); - } else { - minimum_progress = Some(stage_progress); - } - if let Some(max) = &mut maximum_progress { - *max = std::cmp::max(*max, stage_progress); - } else { - maximum_progress = Some(stage_progress); - } + // Update our minimum and maximum stage progress + minimum_progress = opt::min(minimum_progress, stage_progress); + maximum_progress = opt::max(maximum_progress, stage_progress); if done { reached_tip_flag = reached_tip; @@ -212,10 +194,6 @@ where Err(StageError::Validation { block }) => { debug!(stage = %stage_id, bad_block = %block, "Stage encountered a validation error."); - if let Some(rx) = &self.events_sender { - rx.send(PipelineEvent::Ran { stage_id, result: None }).await? - } - // We unwind because of a validation error. If the unwind itself fails, // we bail entirely, otherwise we restart the execution loop from the // beginning. @@ -234,13 +212,7 @@ where Err(e) => return Err(e), } } - Err(StageError::Internal(e)) => { - if let Some(rx) = &self.events_sender { - rx.send(PipelineEvent::Ran { stage_id, result: None }).await? - } - - return Err(e) - } + Err(e) => return Err(PipelineError::Stage(e)), } }; @@ -267,7 +239,7 @@ where db: &'db mdbx::Environment, to: BlockNumber, bad_block: Option, - ) -> Result<(), Box> { + ) -> Result<(), PipelineError> { // Sort stages by unwind priority let mut unwind_pipeline = { let mut stages: Vec<_> = self.stages.iter_mut().enumerate().collect(); @@ -288,26 +260,22 @@ where let stage_id = stage.id(); let mut stage_progress = stage_id.get_progress(&tx)?.unwrap_or_default(); - let unwind: Result<(), Box> = async { + let unwind: Result<(), PipelineError> = async { if stage_progress < to { debug!(from = %stage_progress, %to, "Unwind point too far for stage"); - if let Some(rx) = &self.events_sender { - rx.send(PipelineEvent::Unwound { + self.events_sender + .send(PipelineEvent::Unwound { stage_id, result: Some(UnwindOutput { stage_progress }), }) - .await? - } - + .await?; return Ok(()) } debug!(from = %stage_progress, %to, ?bad_block, "Starting unwind"); while stage_progress > to { let input = UnwindInput { stage_progress, unwind_to: to, bad_block }; - if let Some(rx) = &self.events_sender { - rx.send(PipelineEvent::Unwinding { stage_id, input }).await? - } + self.events_sender.send(PipelineEvent::Unwinding { stage_id, input }).await?; let output = stage.unwind(&mut tx, input).await; match output { @@ -315,20 +283,18 @@ where stage_progress = unwind_output.stage_progress; stage_id.save_progress(&tx, stage_progress)?; - if let Some(rx) = &self.events_sender { - rx.send(PipelineEvent::Unwound { + self.events_sender + .send(PipelineEvent::Unwound { stage_id, result: Some(unwind_output), }) - .await? - } + .await?; } Err(err) => { - if let Some(rx) = &self.events_sender { - rx.send(PipelineEvent::Unwound { stage_id, result: None }).await? - } - - return Err(err) + self.events_sender + .send(PipelineEvent::Unwound { stage_id, result: None }) + .await?; + return Err(PipelineError::Stage(StageError::Internal(err))) } } } @@ -345,45 +311,42 @@ where } } -/// An event emitted by a [Pipeline]. -#[derive(Debug, PartialEq, Eq, Clone)] -pub enum PipelineEvent { - /// Emitted when a stage is about to be run. - Running { - /// The stage that is about to be run. - stage_id: StageId, - /// The previous checkpoint of the stage. - stage_progress: Option, - }, - /// Emitted when a stage has run a single time. - /// - /// It is possible for multiple of these events to be emitted over the duration of a pipeline's - /// execution: - /// - If the pipeline loops, the stage will be run again at some point - /// - If the stage exits early but has acknowledged that it is not entirely done - Ran { - /// The stage that was run. - stage_id: StageId, - /// The result of executing the stage. If it is None then an error was encountered. - result: Option, - }, - /// Emitted when a stage is about to be unwound. - Unwinding { - /// The stage that is about to be unwound. - stage_id: StageId, - /// The unwind parameters. - input: UnwindInput, - }, - /// Emitted when a stage has been unwound. - /// - /// It is possible for multiple of these events to be emitted over the duration of a pipeline's - /// execution, since other stages may ask the pipeline to unwind. - Unwound { - /// The stage that was unwound. - stage_id: StageId, - /// The result of unwinding the stage. If it is None then an error was encountered. - result: Option, - }, +impl<'db, E> Pipeline<'db, E> +where + E: mdbx::EnvironmentKind, +{ + async fn execute_stage<'tx>( + tx: &mut mdbx::Transaction<'tx, mdbx::RW, E>, + QueuedStage { stage, require_tip, .. }: &mut QueuedStage<'db, E>, + input: ExecInput, + reached_tip: bool, + stage_reached_max_block: bool, + ) -> Result + where + 'db: 'tx, + { + if !reached_tip && *require_tip { + info!("Tip not reached, skipping."); + + // Stage requires us to reach the tip of the chain first, but we have + // not. + Ok(ExecOutput { + stage_progress: input.stage_progress.unwrap_or_default(), + done: true, + reached_tip: false, + }) + } else if stage_reached_max_block { + info!("Stage reached maximum block, skipping."); + // We reached the maximum block, so we skip the stage + Ok(ExecOutput { + stage_progress: input.stage_progress.unwrap_or_default(), + done: true, + reached_tip: true, + }) + } else { + stage.execute(tx, input).await + } + } } #[cfg(test)] diff --git a/crates/stages/src/pipeline/event.rs b/crates/stages/src/pipeline/event.rs new file mode 100644 index 000000000..2aaf26ddf --- /dev/null +++ b/crates/stages/src/pipeline/event.rs @@ -0,0 +1,46 @@ +use crate::{ + id::StageId, + stage::{ExecOutput, UnwindInput, UnwindOutput}, +}; +use reth_primitives::BlockNumber; + +/// An event emitted by a [Pipeline]. +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum PipelineEvent { + /// Emitted when a stage is about to be run. + Running { + /// The stage that is about to be run. + stage_id: StageId, + /// The previous checkpoint of the stage. + stage_progress: Option, + }, + /// Emitted when a stage has run a single time. + /// + /// It is possible for multiple of these events to be emitted over the duration of a pipeline's + /// execution: + /// - If the pipeline loops, the stage will be run again at some point + /// - If the stage exits early but has acknowledged that it is not entirely done + Ran { + /// The stage that was run. + stage_id: StageId, + /// The result of executing the stage. If it is None then an error was encountered. + result: Option, + }, + /// Emitted when a stage is about to be unwound. + Unwinding { + /// The stage that is about to be unwound. + stage_id: StageId, + /// The unwind parameters. + input: UnwindInput, + }, + /// Emitted when a stage has been unwound. + /// + /// It is possible for multiple of these events to be emitted over the duration of a pipeline's + /// execution, since other stages may ask the pipeline to unwind. + Unwound { + /// The stage that was unwound. + stage_id: StageId, + /// The result of unwinding the stage. If it is None then an error was encountered. + result: Option, + }, +} diff --git a/crates/stages/src/stage.rs b/crates/stages/src/stage.rs new file mode 100644 index 000000000..4d2b75962 --- /dev/null +++ b/crates/stages/src/stage.rs @@ -0,0 +1,80 @@ +use crate::{error::StageError, id::StageId}; +use async_trait::async_trait; +use reth_db::mdbx; +use reth_primitives::BlockNumber; + +/// Stage execution input, see [Stage::execute]. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct ExecInput { + /// The stage that was run before the current stage and the block number it reached. + pub previous_stage: Option<(StageId, BlockNumber)>, + /// The progress of this stage the last time it was executed. + pub stage_progress: Option, +} + +/// Stage unwind input, see [Stage::unwind]. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct UnwindInput { + /// The current highest block of the stage. + pub stage_progress: BlockNumber, + /// The block to unwind to. + pub unwind_to: BlockNumber, + /// The bad block that caused the unwind, if any. + pub bad_block: Option, +} + +/// The output of a stage execution. +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct ExecOutput { + /// How far the stage got. + pub stage_progress: BlockNumber, + /// Whether or not the stage is done. + pub done: bool, + /// Whether or not the stage reached the tip of the chain. + pub reached_tip: bool, +} + +/// The output of a stage unwinding. +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct UnwindOutput { + /// The block at which the stage has unwound to. + pub stage_progress: BlockNumber, +} + +/// A stage is a segmented part of the syncing process of the node. +/// +/// Each stage takes care of a well-defined task, such as downloading headers or executing +/// transactions, and persist their results to a database. +/// +/// Stages must have a unique [ID][StageId] and implement a way to "roll forwards" +/// ([Stage::execute]) and a way to "roll back" ([Stage::unwind]). +/// +/// Stages are executed as part of a pipeline where they are executed serially. +#[async_trait] +pub trait Stage<'db, E>: Send + Sync +where + E: mdbx::EnvironmentKind, +{ + /// Get the ID of the stage. + /// + /// Stage IDs must be unique. + fn id(&self) -> StageId; + + /// Execute the stage. + async fn execute<'tx>( + &mut self, + tx: &mut mdbx::Transaction<'tx, mdbx::RW, E>, + input: ExecInput, + ) -> Result + where + 'db: 'tx; + + /// Unwind the stage. + async fn unwind<'tx>( + &mut self, + tx: &mut mdbx::Transaction<'tx, mdbx::RW, E>, + input: UnwindInput, + ) -> Result> + where + 'db: 'tx; +} diff --git a/crates/stages/src/util.rs b/crates/stages/src/util.rs new file mode 100644 index 000000000..9cdfd2010 --- /dev/null +++ b/crates/stages/src/util.rs @@ -0,0 +1,63 @@ +pub(crate) mod opt { + use tokio::sync::mpsc::{error::SendError, Sender}; + + /// Get an [Option] with the maximum value, compared between the passed in value and the inner + /// value of the [Option]. If the [Option] is `None`, then an option containing the passed in + /// value will be returned. + pub(crate) fn max(a: Option, b: T) -> Option { + a.map_or(Some(b), |v| Some(std::cmp::max(v, b))) + } + + /// Get an [Option] with the minimum value, compared between the passed in value and the inner + /// value of the [Option]. If the [Option] is `None`, then an option containing the passed in + /// value will be returned. + pub(crate) fn min(a: Option, b: T) -> Option { + a.map_or(Some(b), |v| Some(std::cmp::min(v, b))) + } + + /// The producing side of a [tokio::mpsc] channel that may or may not be set. + #[derive(Default)] + pub(crate) struct MaybeSender { + inner: Option>, + } + + impl MaybeSender { + /// Create a new [MaybeSender] + pub(crate) fn new(sender: Option>) -> Self { + Self { inner: sender } + } + + /// Send a value over the channel if an internal sender has been set. + pub(crate) async fn send(&self, value: T) -> Result<(), SendError> { + if let Some(rx) = &self.inner { + rx.send(value).await + } else { + Ok(()) + } + } + + /// Set or unset the internal sender. + pub(crate) fn set(&mut self, sender: Option>) { + self.inner = sender; + } + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn opt_max() { + assert_eq!(max(None, 5), Some(5)); + assert_eq!(max(Some(1), 5), Some(5)); + assert_eq!(max(Some(10), 5), Some(10)); + } + + #[test] + fn opt_min() { + assert_eq!(min(None, 5), Some(5)); + assert_eq!(min(Some(1), 5), Some(1)); + assert_eq!(min(Some(10), 5), Some(5)); + } + } +}