diff --git a/crates/rpc/rpc-testing-util/src/trace.rs b/crates/rpc/rpc-testing-util/src/trace.rs index 5d9eeda55..575383b66 100644 --- a/crates/rpc/rpc-testing-util/src/trace.rs +++ b/crates/rpc/rpc-testing-util/src/trace.rs @@ -4,7 +4,10 @@ use jsonrpsee::core::Error as RpcError; use reth_primitives::{BlockId, Bytes, TxHash, B256}; use reth_rpc_api::clients::TraceApiClient; use reth_rpc_types::{ - trace::parity::{LocalizedTransactionTrace, TraceResults, TraceType}, + trace::{ + filter::TraceFilter, + parity::{LocalizedTransactionTrace, TraceResults, TraceType}, + }, CallRequest, Index, }; use std::{ @@ -31,6 +34,9 @@ pub type CallManyTraceResult = Result< /// index. pub type TraceGetResult = Result<(Option, B256, Vec), (RpcError, B256, Vec)>; +/// Represents a result type for the `trace_filter` stream extension. +pub type TraceFilterResult = + Result<(Vec, TraceFilter), (RpcError, TraceFilter)>; /// An extension trait for the Trace API. #[async_trait::async_trait] @@ -86,6 +92,33 @@ pub trait TraceApiExt { fn trace_get_stream(&self, hash: B256, indices: I) -> TraceGetStream<'_> where I: IntoIterator; + + /// Returns a new stream that yields traces for given filters. + fn trace_filter_stream(&self, filters: I) -> TraceFilterStream<'_> + where + I: IntoIterator; +} + +/// Represents a stream that asynchronously yields the results of the `trace_filter` method. +#[must_use = "streams do nothing unless polled"] +pub struct TraceFilterStream<'a> { + stream: Pin + 'a>>, +} + +impl<'a> Stream for TraceFilterStream<'a> { + type Item = TraceFilterResult; + + /// Attempts to pull out the next value of the stream. + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.stream.as_mut().poll_next(cx) + } +} + +impl<'a> std::fmt::Debug for TraceFilterStream<'a> { + /// Provides a debug representation of the `TraceFilterStream`. + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TraceFilterStream").finish_non_exhaustive() + } } /// A stream that asynchronously yields the results of the `trace_get` method for a given /// transaction hash and a series of indices. @@ -274,6 +307,21 @@ impl TraceApiExt for T { .buffered(10); TraceGetStream { stream: Box::pin(stream) } } + + fn trace_filter_stream(&self, filters: I) -> TraceFilterStream<'_> + where + I: IntoIterator, + { + let filter_list = filters.into_iter().collect::>(); + let stream = futures::stream::iter(filter_list.into_iter().map(move |filter| async move { + match self.trace_filter(filter.clone()).await { + Ok(result) => Ok((result, filter)), + Err(err) => Err((err, filter)), + } + })) + .buffered(10); + TraceFilterStream { stream: Box::pin(stream) } + } } /// A stream that yields the traces for the requested blocks.