diff --git a/Cargo.lock b/Cargo.lock index 01933c1dc..257a1f8c0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3280,6 +3280,7 @@ version = "0.1.0" dependencies = [ "assert_matches", "async-trait", + "futures", "once_cell", "rand 0.8.5", "reth-interfaces", diff --git a/crates/db/src/kv/mod.rs b/crates/db/src/kv/mod.rs index d97876ae0..f23d2410f 100644 --- a/crates/db/src/kv/mod.rs +++ b/crates/db/src/kv/mod.rs @@ -354,7 +354,7 @@ mod gat_tests { tokio::spawn(async move { let mut container = DBContainer::new(&db).unwrap(); let mut stage = MyStage(&db); - let _ = stage.run(&mut container); + stage.run(&mut container).await; }); } } diff --git a/crates/interfaces/src/db/container.rs b/crates/interfaces/src/db/container.rs index f3a55b34a..0e250d00e 100644 --- a/crates/interfaces/src/db/container.rs +++ b/crates/interfaces/src/db/container.rs @@ -94,7 +94,7 @@ mod tests { tokio::spawn(async move { let mut container = DBContainer::new(&db).unwrap(); let mut stage = MyStage(&db); - let _ = stage.run(&mut container); + stage.run(&mut container).await; }); } } diff --git a/crates/interfaces/src/p2p/headers/client.rs b/crates/interfaces/src/p2p/headers/client.rs index 5ff872887..3d10b0146 100644 --- a/crates/interfaces/src/p2p/headers/client.rs +++ b/crates/interfaces/src/p2p/headers/client.rs @@ -1,31 +1,8 @@ -use crate::p2p::MessageStream; - -use reth_primitives::{Header, H256, H512}; - +use crate::p2p::error::RequestResult; use async_trait::async_trait; -use reth_primitives::BlockHashOrNumber; -use std::{collections::HashSet, fmt::Debug}; - -/// Each peer returns a list of headers and the request id corresponding -/// to these headers. This allows clients to make multiple requests in parallel -/// and multiplex the responses accordingly. -pub type HeadersStream = MessageStream; - -/// The item contained in each [`MessageStream`] when used to fetch [`Header`]s via -/// [`HeadersClient`]. -#[derive(Clone, Debug)] -pub struct HeadersResponse { - /// The request id associated with this response. - pub id: u64, - /// The headers the peer replied with. - pub headers: Vec
, -} - -impl From<(u64, Vec
)> for HeadersResponse { - fn from((id, headers): (u64, Vec
)) -> Self { - HeadersResponse { id, headers } - } -} +pub use reth_eth_wire::BlockHeaders; +use reth_primitives::{BlockHashOrNumber, H256}; +use std::fmt::Debug; /// The header request struct to be sent to connected peers, which /// will proceed to ask them to stream the requested headers to us. @@ -47,12 +24,9 @@ pub trait HeadersClient: Send + Sync + Debug { /// Update the node's Status message. /// /// The updated Status message will be used during any new eth/65 handshakes. - async fn update_status(&self, height: u64, hash: H256, td: H256); + fn update_status(&self, height: u64, hash: H256, td: H256); - /// Sends the header request to the p2p network. - // TODO: What does this return? - async fn send_header_request(&self, id: u64, request: HeadersRequest) -> HashSet; - - /// Stream the header response messages - async fn stream_headers(&self) -> HeadersStream; + /// Sends the header request to the p2p network and returns the header response received from a + /// peer. + async fn get_headers(&self, request: HeadersRequest) -> RequestResult; } diff --git a/crates/interfaces/src/p2p/headers/downloader.rs b/crates/interfaces/src/p2p/headers/downloader.rs index 97fbe9c5b..cb6c03991 100644 --- a/crates/interfaces/src/p2p/headers/downloader.rs +++ b/crates/interfaces/src/p2p/headers/downloader.rs @@ -1,18 +1,31 @@ -use super::client::{HeadersClient, HeadersRequest, HeadersStream}; -use crate::{consensus::Consensus, p2p::headers::error::DownloadError}; -use async_trait::async_trait; -use reth_primitives::{BlockHashOrNumber, Header, SealedHeader}; +use super::client::HeadersClient; +use crate::{ + consensus::Consensus, + p2p::{headers::error::DownloadError, traits::BatchDownload}, +}; + +use reth_primitives::SealedHeader; use reth_rpc_types::engine::ForkchoiceState; -use std::time::Duration; -use tokio_stream::StreamExt; +use std::{pin::Pin, time::Duration}; + +/// A Future for downloading a batch of headers. +pub type HeaderBatchDownload<'a> = Pin< + Box< + dyn BatchDownload< + Ok = SealedHeader, + Error = DownloadError, + Output = Result, DownloadError>, + > + Send + + 'a, + >, +>; /// A downloader capable of fetching block headers. /// /// A downloader represents a distinct strategy for submitting requests to download block headers, /// while a [HeadersClient] represents a client capable of fulfilling these requests. -#[async_trait] #[auto_impl::auto_impl(&, Arc, Box)] -pub trait HeaderDownloader: Sync + Send { +pub trait HeaderDownloader: Sync + Send + Unpin { /// The Consensus used to verify block validity when /// downloading type Consensus: Consensus; @@ -30,55 +43,41 @@ pub trait HeaderDownloader: Sync + Send { fn client(&self) -> &Self::Client; /// Download the headers - async fn download( - &self, - head: &SealedHeader, - forkchoice: &ForkchoiceState, - ) -> Result, DownloadError>; - - /// Perform a header request and returns the headers. - // TODO: Isn't this effectively blocking per request per downloader? - // Might be fine, given we can spawn multiple downloaders? - // TODO: Rethink this function, I don't really like the `stream: &mut HeadersStream` - // in the signature. Why can we not call `self.client.stream_headers()`? Gives lifetime error. - async fn download_headers( - &self, - stream: &mut HeadersStream, - start: BlockHashOrNumber, - limit: u64, - ) -> Result, DownloadError> { - let request_id = rand::random(); - let request = HeadersRequest { start, limit, reverse: true }; - let _ = self.client().send_header_request(request_id, request).await; - - // Filter stream by request id and non empty headers content - let stream = stream - .filter(|resp| request_id == resp.id && !resp.headers.is_empty()) - .timeout(self.timeout()); - - // Pop the first item. - match Box::pin(stream).try_next().await { - Ok(Some(item)) => Ok(item.headers), - _ => return Err(DownloadError::Timeout { request_id }), - } - } + fn download(&self, head: SealedHeader, forkchoice: ForkchoiceState) -> HeaderBatchDownload<'_>; /// Validate whether the header is valid in relation to it's parent /// /// Returns Ok(false) if the fn validate(&self, header: &SealedHeader, parent: &SealedHeader) -> Result<(), DownloadError> { - if !(parent.hash() == header.parent_hash && parent.number + 1 == header.number) { - return Err(DownloadError::MismatchedHeaders { - header_number: header.number.into(), - parent_number: parent.number.into(), - header_hash: header.hash(), - parent_hash: parent.hash(), - }) - } - - self.consensus() - .validate_header(header, parent) - .map_err(|error| DownloadError::HeaderValidation { hash: parent.hash(), error })?; + validate_header_download(self.consensus(), header, parent)?; Ok(()) } } + +/// Validate whether the header is valid in relation to it's parent +/// +/// Returns Ok(false) if the +pub fn validate_header_download( + consensus: &C, + header: &SealedHeader, + parent: &SealedHeader, +) -> Result<(), DownloadError> { + ensure_parent(header, parent)?; + consensus + .validate_header(header, parent) + .map_err(|error| DownloadError::HeaderValidation { hash: parent.hash(), error })?; + Ok(()) +} + +/// Ensures that the given `parent` header is the actual parent of the `header` +pub fn ensure_parent(header: &SealedHeader, parent: &SealedHeader) -> Result<(), DownloadError> { + if !(parent.hash() == header.parent_hash && parent.number + 1 == header.number) { + return Err(DownloadError::MismatchedHeaders { + header_number: header.number.into(), + parent_number: parent.number.into(), + header_hash: header.hash(), + parent_hash: parent.hash(), + }) + } + Ok(()) +} diff --git a/crates/interfaces/src/p2p/headers/error.rs b/crates/interfaces/src/p2p/headers/error.rs index 035727f1a..2ceaa21e5 100644 --- a/crates/interfaces/src/p2p/headers/error.rs +++ b/crates/interfaces/src/p2p/headers/error.rs @@ -1,4 +1,4 @@ -use crate::consensus; +use crate::{consensus, p2p::error::RequestError}; use reth_primitives::{rpc::BlockNumber, H256}; use thiserror::Error; @@ -15,11 +15,8 @@ pub enum DownloadError { error: consensus::Error, }, /// Timed out while waiting for request id response. - #[error("Timed out while getting headers for request {request_id}.")] - Timeout { - /// The request id that timed out - request_id: u64, - }, + #[error("Timed out while getting headers for request.")] + Timeout, /// Error when checking that the current [`Header`] has the parent's hash as the parent_hash /// field, and that they have sequential block numbers. #[error("Headers did not match, current number: {header_number} / current hash: {header_hash}, parent number: {parent_number} / parent_hash: {parent_hash}")] @@ -33,6 +30,9 @@ pub enum DownloadError { /// The parent hash being evaluated parent_hash: H256, }, + /// Error while executing the request. + #[error(transparent)] + RequestError(#[from] RequestError), } impl DownloadError { diff --git a/crates/interfaces/src/p2p/mod.rs b/crates/interfaces/src/p2p/mod.rs index ecdad5970..5676c9975 100644 --- a/crates/interfaces/src/p2p/mod.rs +++ b/crates/interfaces/src/p2p/mod.rs @@ -13,8 +13,5 @@ pub mod headers; /// interacting with the network implementation pub mod error; -use futures::Stream; -use std::pin::Pin; - -/// The stream of responses from the connected peers, generic over the response type. -pub type MessageStream = Pin + Send>>; +/// Commonly used traits when implementing clients. +pub mod traits; diff --git a/crates/interfaces/src/p2p/traits.rs b/crates/interfaces/src/p2p/traits.rs new file mode 100644 index 000000000..ae66cf2cb --- /dev/null +++ b/crates/interfaces/src/p2p/traits.rs @@ -0,0 +1,21 @@ +use futures::Stream; +use std::future::Future; + +/// Abstraction for downloading several items at once. +/// +/// A [`BatchDownload`] is a [`Future`] that represents a collection of download futures and +/// resolves once all of them finished. +/// +/// This is similar to the [`futures::future::join_all`] function, but it's open to implementers how +/// this Future behaves exactly. +/// +/// It is expected that the underlying futures return a [`Result`]. +pub trait BatchDownload: Future, Self::Error>> { + /// The `Ok` variant of the futures output in this batch. + type Ok; + /// The `Err` variant of the futures output in this batch. + type Error; + + /// Consumes the batch future and returns a [`Stream`] that yields results as they become ready. + fn into_stream_unordered(self) -> Box>>; +} diff --git a/crates/interfaces/src/test_utils/headers.rs b/crates/interfaces/src/test_utils/headers.rs index 087594c09..6af929f81 100644 --- a/crates/interfaces/src/test_utils/headers.rs +++ b/crates/interfaces/src/test_utils/headers.rs @@ -1,36 +1,43 @@ //! Testing support for headers related interfaces. use crate::{ consensus::{self, Consensus}, - p2p::headers::{ - client::{HeadersClient, HeadersRequest, HeadersResponse, HeadersStream}, - downloader::HeaderDownloader, - error::DownloadError, + p2p::{ + error::{RequestError, RequestResult}, + headers::{ + client::{HeadersClient, HeadersRequest}, + downloader::{HeaderBatchDownload, HeaderDownloader}, + error::DownloadError, + }, + traits::BatchDownload, }, }; -use reth_primitives::{BlockLocked, Header, SealedHeader, H256, H512}; +use futures::{Future, FutureExt, Stream}; +use reth_eth_wire::BlockHeaders; +use reth_primitives::{BlockLocked, Header, SealedHeader, H256}; use reth_rpc_types::engine::ForkchoiceState; use std::{ - collections::HashSet, + pin::Pin, sync::{ atomic::{AtomicBool, Ordering}, Arc, }, + task::{ready, Context, Poll}, time::Duration, }; -use tokio::sync::{broadcast, mpsc, watch}; -use tokio_stream::{wrappers::BroadcastStream, StreamExt}; +use tokio::sync::{watch, Mutex}; /// A test downloader which just returns the values that have been pushed to it. #[derive(Debug)] pub struct TestHeaderDownloader { client: Arc, consensus: Arc, + limit: u64, } impl TestHeaderDownloader { /// Instantiates the downloader with the mock responses - pub fn new(client: Arc, consensus: Arc) -> Self { - Self { client, consensus } + pub fn new(client: Arc, consensus: Arc, limit: u64) -> Self { + Self { client, consensus, limit } } } @@ -51,99 +58,104 @@ impl HeaderDownloader for TestHeaderDownloader { &self.client } - async fn download( + fn download( &self, - _: &SealedHeader, - _: &ForkchoiceState, - ) -> Result, DownloadError> { - // call consensus stub first. fails if the flag is set + _head: SealedHeader, + _forkchoice: ForkchoiceState, + ) -> HeaderBatchDownload<'_> { + Box::pin(TestDownload { + client: Arc::clone(&self.client), + consensus: Arc::clone(&self.consensus), + limit: self.limit, + }) + } +} + +struct TestDownload { + client: Arc, + consensus: Arc, + limit: u64, +} + +impl Future for TestDownload { + type Output = Result, DownloadError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let empty = SealedHeader::default(); - self.consensus - .validate_header(&empty, &empty) - .map_err(|error| DownloadError::HeaderValidation { hash: empty.hash(), error })?; - - let stream = self.client.stream_headers().await; - let stream = stream.timeout(Duration::from_secs(1)); - - match Box::pin(stream).try_next().await { - Ok(Some(res)) => { - let mut headers = res.headers.iter().map(|h| h.clone().seal()).collect::>(); - if !headers.is_empty() { - headers.sort_unstable_by_key(|h| h.number); - headers.remove(0); // remove head from response - headers.reverse(); - } - Ok(headers) - } - _ => Err(DownloadError::Timeout { request_id: 0 }), + if let Err(error) = self.consensus.validate_header(&empty, &empty) { + return Poll::Ready(Err(DownloadError::HeaderValidation { hash: empty.hash(), error })) } + + let request = HeadersRequest { + limit: self.limit, + reverse: true, + start: reth_primitives::BlockHashOrNumber::Number(0), // ignored + }; + match ready!(self.client.get_headers(request).poll_unpin(cx)) { + Ok(resp) => { + let mut headers = resp.0.into_iter().skip(1).map(|h| h.seal()).collect::>(); + headers.sort_unstable_by_key(|h| h.number); + Poll::Ready(Ok(headers)) + } + Err(err) => Poll::Ready(Err(match err { + RequestError::Timeout => DownloadError::Timeout, + _ => DownloadError::RequestError(err), + })), + } + } +} + +impl Stream for TestDownload { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + todo!() + } +} + +impl BatchDownload for TestDownload { + type Ok = SealedHeader; + type Error = DownloadError; + + fn into_stream_unordered(self) -> Box>> { + Box::new(self) } } /// A test client for fetching headers -#[derive(Debug)] +#[derive(Debug, Default)] pub struct TestHeadersClient { - req_tx: mpsc::Sender<(u64, HeadersRequest)>, - req_rx: Arc>>, - res_tx: broadcast::Sender, - res_rx: broadcast::Receiver, -} - -impl Default for TestHeadersClient { - /// Construct a new test header downloader. - fn default() -> Self { - let (req_tx, req_rx) = mpsc::channel(1); - let (res_tx, res_rx) = broadcast::channel(1); - Self { req_tx, req_rx: Arc::new(tokio::sync::Mutex::new(req_rx)), res_tx, res_rx } - } + responses: Arc>>, + error: Arc>>, } impl TestHeadersClient { - /// Helper for interacting with the environment on each request, allowing the client - /// to also reply to messages. - pub async fn on_header_request(&self, mut count: usize, mut f: F) -> Vec - where - F: FnMut(u64, HeadersRequest) -> T, - { - let mut rx = self.req_rx.lock().await; - let mut results = vec![]; - while let Some((id, req)) = rx.recv().await { - results.push(f(id, req)); - count -= 1; - if count == 0 { - break - } - } - results + /// Adds headers to the set. + pub async fn extend(&self, headers: impl IntoIterator) { + let mut lock = self.responses.lock().await; + lock.extend(headers); } - /// Helper for pushing responses to the client - pub fn send_header_response(&self, id: u64, headers: Vec
) { - self.res_tx.send((id, headers).into()).expect("failed to send header response"); - } - - /// Helper for pushing responses to the client - pub async fn send_header_response_delayed(&self, id: u64, headers: Vec
, secs: u64) { - tokio::time::sleep(Duration::from_secs(secs)).await; - self.send_header_response(id, headers); + /// Set repsonse error + pub async fn set_error(&self, err: RequestError) { + let mut lock = self.error.lock().await; + lock.replace(err); } } #[async_trait::async_trait] impl HeadersClient for TestHeadersClient { - // noop - async fn update_status(&self, _height: u64, _hash: H256, _td: H256) {} + fn update_status(&self, _height: u64, _hash: H256, _td: H256) {} - async fn send_header_request(&self, id: u64, request: HeadersRequest) -> HashSet { - self.req_tx.send((id, request)).await.expect("failed to send request"); - HashSet::default() - } - - async fn stream_headers(&self) -> HeadersStream { - if !self.res_rx.is_empty() { - println!("WARNING: broadcast receiver already contains messages.") + async fn get_headers(&self, request: HeadersRequest) -> RequestResult { + if let Some(err) = &mut *self.error.lock().await { + return Err(err.clone()) } - Box::pin(BroadcastStream::new(self.res_rx.resubscribe()).filter_map(|e| e.ok())) + + let mut lock = self.responses.lock().await; + let len = lock.len().min(request.limit as usize); + let resp = lock.drain(..len).collect(); + return Ok(BlockHeaders(resp)) } } diff --git a/crates/net/headers-downloaders/Cargo.toml b/crates/net/headers-downloaders/Cargo.toml index b781442e7..505bca6ca 100644 --- a/crates/net/headers-downloaders/Cargo.toml +++ b/crates/net/headers-downloaders/Cargo.toml @@ -8,11 +8,15 @@ readme = "README.md" description = "Implementations of various header downloader" [dependencies] -async-trait = "0.1.58" +# reth reth-interfaces = { path = "../../interfaces" } reth-primitives = { path = "../../primitives" } reth-rpc-types = { path = "../rpc-types" } +# async +async-trait = "0.1.58" +futures = "0.3" + [dev-dependencies] assert_matches = "1.5.0" once_cell = "1.15.0" diff --git a/crates/net/headers-downloaders/src/linear.rs b/crates/net/headers-downloaders/src/linear.rs index 1dc2dae6d..61c95944d 100644 --- a/crates/net/headers-downloaders/src/linear.rs +++ b/crates/net/headers-downloaders/src/linear.rs @@ -1,16 +1,26 @@ -use std::{borrow::Borrow, sync::Arc, time::Duration}; - -use async_trait::async_trait; +use futures::{stream::Stream, FutureExt}; use reth_interfaces::{ consensus::Consensus, - p2p::headers::{ - client::{HeadersClient, HeadersStream}, - downloader::HeaderDownloader, - error::DownloadError, + p2p::{ + error::{RequestError, RequestResult}, + headers::{ + client::{BlockHeaders, HeadersClient, HeadersRequest}, + downloader::{validate_header_download, HeaderBatchDownload, HeaderDownloader}, + error::DownloadError, + }, + traits::BatchDownload, }, }; -use reth_primitives::SealedHeader; +use reth_primitives::{SealedHeader, H256}; use reth_rpc_types::engine::ForkchoiceState; +use std::{ + borrow::Borrow, + future::Future, + pin::Pin, + sync::Arc, + task::{ready, Context, Poll}, + time::Duration, +}; /// Download headers in batches #[derive(Debug)] @@ -27,11 +37,19 @@ pub struct LinearDownloader { pub request_retries: usize, } -#[async_trait] -impl HeaderDownloader for LinearDownloader { +impl HeaderDownloader for LinearDownloader +where + C: Consensus + 'static, + H: HeadersClient + 'static, +{ type Consensus = C; type Client = H; + /// The request timeout + fn timeout(&self) -> Duration { + self.request_timeout + } + fn consensus(&self) -> &Self::Consensus { self.consensus.borrow() } @@ -40,105 +58,234 @@ impl HeaderDownloader for LinearDownloader self.client.borrow() } - /// The request timeout - fn timeout(&self) -> Duration { - self.request_timeout + fn download(&self, head: SealedHeader, forkchoice: ForkchoiceState) -> HeaderBatchDownload<'_> { + Box::pin(HeadersDownload { + head, + forkchoice, + buffered: vec![], + request: Default::default(), + consensus: Arc::clone(&self.consensus), + request_retries: self.request_retries, + batch_size: self.batch_size, + client: Arc::clone(&self.client), + }) } +} - /// Download headers in batches with retries. - /// Returns the header collection in sorted descending - /// order from chain tip to local head - async fn download( - &self, - head: &SealedHeader, - forkchoice: &ForkchoiceState, - ) -> Result, DownloadError> { - let mut stream = self.client().stream_headers().await; - let mut retries = self.request_retries; - - // Header order will be preserved during inserts - let mut out = vec![]; - loop { - let result = self.download_batch(&mut stream, forkchoice, head, out.last()).await; - match result { - Ok(result) => match result { - LinearDownloadResult::Batch(mut headers) => { - out.append(&mut headers); - } - LinearDownloadResult::Finished(mut headers) => { - out.append(&mut headers); - return Ok(out) - } - LinearDownloadResult::Ignore => (), - }, - Err(e) if e.is_retryable() && retries > 1 => { - retries -= 1; - } - Err(e) => return Err(e), - } +impl Clone for LinearDownloader { + fn clone(&self) -> Self { + Self { + consensus: Arc::clone(&self.consensus), + client: Arc::clone(&self.client), + batch_size: self.batch_size, + request_timeout: self.request_timeout, + request_retries: self.request_retries, } } } -/// The intermediate download result -#[derive(Debug)] -pub enum LinearDownloadResult { - /// Downloaded last batch up to tip - Finished(Vec), - /// Downloaded batch - Batch(Vec), - /// Ignore this batch - Ignore, +type HeadersFut = Pin> + Send>>; + +/// A retryable future that returns a list of [`BlockHeaders`] on success. +struct HeadersRequestFuture { + request: HeadersRequest, + fut: HeadersFut, + retries: usize, + max_retries: usize, } -impl LinearDownloader { - async fn download_batch( - &self, - stream: &mut HeadersStream, - forkchoice: &ForkchoiceState, - head: &SealedHeader, - earliest: Option<&SealedHeader>, - ) -> Result { - // Request headers starting from tip or earliest cached - let start = earliest.map_or(forkchoice.head_block_hash, |h| h.parent_hash); - let mut headers = self.download_headers(stream, start.into(), self.batch_size).await?; - headers.sort_unstable_by_key(|h| h.number); +impl HeadersRequestFuture { + /// Returns true if the request can be retried. + fn is_retryable(&self) -> bool { + self.retries < self.max_retries + } - let mut out = Vec::with_capacity(headers.len()); - // Iterate headers in reverse - for parent in headers.into_iter().rev() { - let parent = parent.seal(); + /// Increments the retry counter and returns whether the request can still be retried. + fn inc_err(&mut self) -> bool { + self.retries += 1; + self.is_retryable() + } +} - if head.hash() == parent.hash() { - // We've reached the target - return Ok(LinearDownloadResult::Finished(out)) - } +impl Future for HeadersRequestFuture { + type Output = RequestResult; - match out.last().or(earliest) { - Some(header) => { - match self.validate(header, &parent) { - // ignore mismatched headers - Err(DownloadError::MismatchedHeaders { .. }) => { - return Ok(LinearDownloadResult::Ignore) - } - // propagate any other error if any - Err(e) => return Err(e), - // proceed to insert if validation is successful - _ => (), - }; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.get_mut().fut.poll_unpin(cx) + } +} + +/// An in progress headers download. +pub struct HeadersDownload { + /// The local head of the chain. + head: SealedHeader, + forkchoice: ForkchoiceState, + /// Buffered results + buffered: Vec, + /// Contains the request that's currently in progress. + /// + /// TODO(mattsse): this could be converted into a `FuturesOrdered` where batching is done via + /// `skip` so we don't actually need to know the start hash + request: Option, + /// Downloader used to issue new requests. + consensus: Arc, + /// Downloader used to issue new requests. + client: Arc, + /// The number of headers to request in one call + batch_size: u64, + /// The number of retries for downloading + request_retries: usize, +} + +impl HeadersDownload +where + C: Consensus + 'static, + H: HeadersClient + 'static, +{ + /// Returns the start hash for a new request. + fn request_start(&self) -> H256 { + self.buffered.last().map_or(self.forkchoice.head_block_hash, |h| h.parent_hash) + } + + fn headers_request(&self) -> HeadersRequest { + HeadersRequest { start: self.request_start().into(), limit: self.batch_size, reverse: true } + } + + /// Tries to fuse the future with a new request + /// + /// Returns an `Err` if the request exhausted all retries + fn try_fuse_request_fut(&self, fut: &mut HeadersRequestFuture) -> Result<(), ()> { + if !fut.inc_err() { + return Err(()) + } + let req = self.headers_request(); + fut.request = req.clone(); + let client = Arc::clone(&self.client); + fut.fut = Box::pin(async move { client.get_headers(req).await }); + Ok(()) + } + + /// Validate whether the header is valid in relation to it's parent + /// + /// Returns Ok(false) if the + fn validate(&self, header: &SealedHeader, parent: &SealedHeader) -> Result<(), DownloadError> { + validate_header_download(&self.consensus, header, parent)?; + Ok(()) + } +} + +impl Future for HeadersDownload +where + C: Consensus + 'static, + H: HeadersClient + 'static, +{ + type Output = Result, DownloadError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + 'outer: loop { + let mut fut = match this.request.take() { + Some(fut) => fut, + None => { + // queue in the first request + let client = Arc::clone(&this.client); + let req = this.headers_request(); + HeadersRequestFuture { + request: req.clone(), + fut: Box::pin(async move { client.get_headers(req).await }), + retries: 0, + max_retries: this.request_retries, + } } - // The buffer is empty and the first header does not match the tip, discard - // TODO: penalize the peer? - None if parent.hash() != forkchoice.head_block_hash => { - return Ok(LinearDownloadResult::Ignore) - } - _ => (), }; - out.push(parent); - } + match ready!(fut.poll_unpin(cx)) { + Ok(resp) => { + let mut headers = resp.0; + headers.sort_unstable_by_key(|h| h.number); - Ok(LinearDownloadResult::Batch(out)) + if headers.is_empty() { + if this.try_fuse_request_fut(&mut fut).is_err() { + return Poll::Ready(Err(RequestError::BadResponse.into())) + } else { + this.request = Some(fut); + continue + } + } + + // Iterate headers in reverse + for parent in headers.into_iter().rev() { + let parent = parent.seal(); + + if this.head.hash() == parent.hash() { + // We've reached the target + let headers = + std::mem::take(&mut this.buffered).into_iter().rev().collect(); + return Poll::Ready(Ok(headers)) + } + + if let Some(header) = this.buffered.last() { + match this.validate(header, &parent) { + Ok(_) => { + // record new parent + this.buffered.push(parent); + } + Err(err) => { + if this.try_fuse_request_fut(&mut fut).is_err() { + return Poll::Ready(Err(err)) + } + this.request = Some(fut); + continue 'outer + } + } + } else { + // The buffer is empty and the first header does not match the tip, + // discard + if parent.hash() != this.forkchoice.head_block_hash { + if this.try_fuse_request_fut(&mut fut).is_err() { + return Poll::Ready(Err(RequestError::BadResponse.into())) + } + this.request = Some(fut); + continue 'outer + } + this.buffered.push(parent); + } + } + } + Err(err) => { + if this.try_fuse_request_fut(&mut fut).is_err() { + return Poll::Ready(Err(DownloadError::RequestError(err))) + } + this.request = Some(fut); + } + } + } + } +} + +impl Stream for HeadersDownload +where + C: Consensus + 'static, + H: HeadersClient + 'static, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + todo!() + } +} + +impl BatchDownload for HeadersDownload +where + C: Consensus + 'static, + H: HeadersClient + 'static, +{ + type Ok = SealedHeader; + type Error = DownloadError; + + fn into_stream_unordered(self) -> Box>> { + Box::new(self) } } @@ -199,200 +346,86 @@ impl LinearDownloadBuilder { #[cfg(test)] mod tests { use super::*; - use reth_interfaces::{ - p2p::headers::client::HeadersRequest, - test_utils::{ - generators::{random_header, random_header_range}, - TestConsensus, TestHeadersClient, - }, - }; - use reth_primitives::{BlockHashOrNumber, SealedHeader}; - - use assert_matches::assert_matches; use once_cell::sync::Lazy; - use serial_test::serial; - use tokio::sync::oneshot::{self, error::TryRecvError}; + use reth_interfaces::test_utils::{TestConsensus, TestHeadersClient}; + use reth_primitives::SealedHeader; static CONSENSUS: Lazy> = Lazy::new(|| Arc::new(TestConsensus::default())); - static CONSENSUS_FAIL: Lazy> = Lazy::new(|| { - let consensus = TestConsensus::default(); - consensus.set_fail_validation(true); - Arc::new(consensus) - }); - static CLIENT: Lazy> = - Lazy::new(|| Arc::new(TestHeadersClient::default())); - - #[tokio::test] - #[serial] - async fn download_timeout() { - let retries = 5; - let (tx, rx) = oneshot::channel(); - tokio::spawn(async move { - let downloader = LinearDownloadBuilder::default() - .retries(retries) - .build(CONSENSUS.clone(), CLIENT.clone()); - let result = - downloader.download(&SealedHeader::default(), &ForkchoiceState::default()).await; - tx.send(result).expect("failed to forward download response"); - }); - - let mut requests = vec![]; - CLIENT - .on_header_request(retries, |_id, req| { - requests.push(req); - }) - .await; - assert_eq!(requests.len(), retries); - assert_matches!(rx.await, Ok(Err(DownloadError::Timeout { .. }))); + fn child_header(parent: &SealedHeader) -> SealedHeader { + let mut child = parent.as_ref().clone(); + child.number += 1; + child.parent_hash = parent.hash_slow(); + let hash = child.hash_slow(); + SealedHeader::new(child, hash) } #[tokio::test] - #[serial] - async fn download_timeout_on_invalid_messages() { - let retries = 5; - let (tx, rx) = oneshot::channel(); - tokio::spawn(async move { - let downloader = LinearDownloadBuilder::default() - .retries(retries) - .build(CONSENSUS.clone(), CLIENT.clone()); - let result = - downloader.download(&SealedHeader::default(), &ForkchoiceState::default()).await; - tx.send(result).expect("failed to forward download response"); - }); + async fn download_empty() { + let client = Arc::new(TestHeadersClient::default()); + let downloader = + LinearDownloadBuilder::default().build(CONSENSUS.clone(), Arc::clone(&client)); - let mut num_of_reqs = 0; - let mut last_req_id: Option = None; - - CLIENT - .on_header_request(retries, |id, _req| { - num_of_reqs += 1; - last_req_id = Some(id); - CLIENT.send_header_response(id.saturating_add(id % 2), vec![]); - }) - .await; - - assert_eq!(num_of_reqs, retries); - assert_matches!( - rx.await, - Ok(Err(DownloadError::Timeout { request_id })) if request_id == last_req_id.unwrap() - ); + let result = downloader.download(SealedHeader::default(), ForkchoiceState::default()).await; + assert!(result.is_err()); } #[tokio::test] - #[serial] - async fn download_propagates_consensus_validation_error() { - let tip_parent = random_header(1, None); - let tip = random_header(2, Some(tip_parent.hash())); - let tip_hash = tip.hash(); + async fn download_at_fork_head() { + let client = Arc::new(TestHeadersClient::default()); + let downloader = LinearDownloadBuilder::default() + .batch_size(3) + .build(CONSENSUS.clone(), Arc::clone(&client)); - let (tx, rx) = oneshot::channel(); - tokio::spawn(async move { - let downloader = - LinearDownloadBuilder::default().build(CONSENSUS_FAIL.clone(), CLIENT.clone()); - let forkchoice = ForkchoiceState { head_block_hash: tip_hash, ..Default::default() }; - let result = downloader.download(&SealedHeader::default(), &forkchoice).await; - tx.send(result).expect("failed to forward download response"); - }); + let p3 = SealedHeader::default(); + let p2 = child_header(&p3); + let p1 = child_header(&p2); + let p0 = child_header(&p1); - let requests = CLIENT.on_header_request(1, |id, req| (id, req)).await; - let request = requests.last(); - assert_matches!( - request, - Some((_, HeadersRequest { start, .. })) - if matches!(start, BlockHashOrNumber::Hash(hash) if *hash == tip_hash) - ); + client + .extend(vec![ + p0.as_ref().clone(), + p1.as_ref().clone(), + p2.as_ref().clone(), + p3.as_ref().clone(), + ]) + .await; - let request = request.unwrap(); - CLIENT.send_header_response( - request.0, - vec![tip_parent.clone().unseal(), tip.clone().unseal()], - ); + let fork = ForkchoiceState { head_block_hash: p0.hash_slow(), ..Default::default() }; - assert_matches!( - rx.await, - Ok(Err(DownloadError::HeaderValidation { hash, .. })) if hash == tip_parent.hash() - ); + let result = downloader.download(p0, fork).await; + let headers = result.unwrap(); + assert!(headers.is_empty()); } #[tokio::test] - #[serial] - async fn download_starts_with_chain_tip() { - let head = random_header(1, None); - let tip = random_header(2, Some(head.hash())); + async fn download_exact() { + let client = Arc::new(TestHeadersClient::default()); + let downloader = LinearDownloadBuilder::default() + .batch_size(3) + .build(CONSENSUS.clone(), Arc::clone(&client)); - let tip_hash = tip.hash(); - let chain_head = head.clone(); - let (tx, mut rx) = oneshot::channel(); - tokio::spawn(async move { - let downloader = - LinearDownloadBuilder::default().build(CONSENSUS.clone(), CLIENT.clone()); - let forkchoice = ForkchoiceState { head_block_hash: tip_hash, ..Default::default() }; - let result = downloader.download(&chain_head, &forkchoice).await; - tx.send(result).expect("failed to forward download response"); - }); + let p3 = SealedHeader::default(); + let p2 = child_header(&p3); + let p1 = child_header(&p2); + let p0 = child_header(&p1); - CLIENT - .on_header_request(1, |id, _req| { - let mut corrupted_tip = tip.clone().unseal(); - corrupted_tip.nonce = rand::random(); - CLIENT.send_header_response(id, vec![corrupted_tip, head.clone().unseal()]) - }) - .await; - assert_matches!(rx.try_recv(), Err(TryRecvError::Empty)); - - CLIENT - .on_header_request(1, |id, _req| { - CLIENT.send_header_response(id, vec![tip.clone().unseal(), head.clone().unseal()]) - }) + client + .extend(vec![ + p0.as_ref().clone(), + p1.as_ref().clone(), + p2.as_ref().clone(), + p3.as_ref().clone(), + ]) .await; - let result = rx.await; - assert_matches!(result, Ok(Ok(ref val)) if val.len() == 1); - assert_eq!(*result.unwrap().unwrap().first().unwrap(), tip); - } + let fork = ForkchoiceState { head_block_hash: p0.hash_slow(), ..Default::default() }; - #[tokio::test] - #[serial] - async fn download_returns_headers_desc() { - let (start, end) = (100, 200); - let head = random_header(start, None); - let mut headers = random_header_range(start + 1..end, head.hash()); - headers.reverse(); - - let tip_hash = headers.first().unwrap().hash(); - let chain_head = head.clone(); - let (tx, rx) = oneshot::channel(); - tokio::spawn(async move { - let downloader = - LinearDownloadBuilder::default().build(CONSENSUS.clone(), CLIENT.clone()); - let forkchoice = ForkchoiceState { head_block_hash: tip_hash, ..Default::default() }; - let result = downloader.download(&chain_head, &forkchoice).await; - tx.send(result).expect("failed to forward download response"); - }); - - let mut idx = 0; - let chunk_size = 10; - // `usize::div_ceil` is unstable. ref: https://github.com/rust-lang/rust/issues/88581 - let count = (headers.len() + chunk_size - 1) / chunk_size; - CLIENT - .on_header_request(count + 1, |id, _req| { - let mut chunk = - headers.iter().skip(chunk_size * idx).take(chunk_size).cloned().peekable(); - idx += 1; - if chunk.peek().is_some() { - let headers: Vec<_> = chunk.map(|h| h.unseal()).collect(); - CLIENT.send_header_response(id, headers); - } else { - CLIENT.send_header_response(id, vec![head.clone().unseal()]) - } - }) - .await; - - let result = rx.await; - assert_matches!(result, Ok(Ok(_))); - let result = result.unwrap().unwrap(); - assert_eq!(result.len(), headers.len()); - assert_eq!(result, headers); + let result = downloader.download(p3, fork).await; + let headers = result.unwrap(); + assert_eq!(headers.len(), 3); + assert_eq!(headers[0], p2); + assert_eq!(headers[1], p1); + assert_eq!(headers[2], p0); } } diff --git a/crates/primitives/src/header.rs b/crates/primitives/src/header.rs index ab4465edd..4f329f63e 100644 --- a/crates/primitives/src/header.rs +++ b/crates/primitives/src/header.rs @@ -203,7 +203,7 @@ impl Decodable for Header { /// A [`Header`] that is sealed at a precalculated hash, use [`SealedHeader::unseal()`] if you want /// to modify header. -#[derive(Debug, Clone, PartialEq, Eq, Default, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct SealedHeader { /// Locked Header fields. header: Header, @@ -211,6 +211,14 @@ pub struct SealedHeader { hash: BlockHash, } +impl Default for SealedHeader { + fn default() -> Self { + let header = Header::default(); + let hash = header.hash_slow(); + Self { header, hash } + } +} + impl Encodable for SealedHeader { fn encode(&self, out: &mut dyn BufMut) { self.header.encode(out); diff --git a/crates/stages/src/stages/headers.rs b/crates/stages/src/stages/headers.rs index 52abdf488..39382e625 100644 --- a/crates/stages/src/stages/headers.rs +++ b/crates/stages/src/stages/headers.rs @@ -73,15 +73,15 @@ impl Stage { // TODO: validate the result order? // at least check if it attaches (first == tip && last == last_hash) res } Err(e) => match e { - DownloadError::Timeout { request_id } => { - warn!("no response for header request {request_id}"); + DownloadError::Timeout => { + warn!("no response for header request"); return Ok(ExecOutput { stage_progress: last_block_num, reached_tip: false, @@ -92,10 +92,7 @@ impl Stage { - // return Err(StageError::Validation { block: last_block_num }) - // } + // TODO: handle unreachable _ => unreachable!(), }, }; @@ -134,7 +131,7 @@ impl HeaderStage { .get::(height)? .ok_or(DatabaseIntegrityError::CanonicalHeader { number: height })?; let td: Vec = tx.get::((height, hash).into())?.unwrap(); // TODO: - self.client.update_status(height, hash, H256::from_slice(&td)).await; + self.client.update_status(height, hash, H256::from_slice(&td)); Ok(()) } @@ -163,7 +160,7 @@ impl HeaderStage { let mut latest = None; // Since the headers were returned in descending order, // iterate them in the reverse order - for header in headers.into_iter().rev() { + for header in headers.into_iter() { if header.number == 0 { continue } @@ -194,6 +191,7 @@ mod tests { stage_test_suite, ExecuteStageTestRunner, UnwindStageTestRunner, PREV_STAGE_ID, }; use assert_matches::assert_matches; + use reth_interfaces::p2p::error::RequestError; use test_runner::HeadersTestRunner; stage_test_suite!(HeadersTestRunner); @@ -203,15 +201,20 @@ mod tests { #[tokio::test] // Validate that the execution does not fail on timeout async fn execute_timeout() { + let (previous_stage, stage_progress) = (500, 100); let mut runner = HeadersTestRunner::default(); - let input = ExecInput::default(); + let input = ExecInput { + previous_stage: Some((PREV_STAGE_ID, previous_stage)), + stage_progress: Some(stage_progress), + }; runner.seed_execution(input).expect("failed to seed execution"); + runner.client.set_error(RequestError::Timeout).await; let rx = runner.execute(input); runner.consensus.update_tip(H256::from_low_u64_be(1)); let result = rx.await.unwrap(); assert_matches!( result, - Ok(ExecOutput { done: false, reached_tip: false, stage_progress: 0 }) + Ok(ExecOutput { done: false, reached_tip: false, stage_progress: 100 }) ); assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed"); } @@ -242,23 +245,17 @@ mod tests { let headers = runner.seed_execution(input).expect("failed to seed execution"); let rx = runner.execute(input); + runner.client.extend(headers.iter().rev().map(|h| h.clone().unseal())).await; + // skip `after_execution` hook for linear downloader let tip = headers.last().unwrap(); runner.consensus.update_tip(tip.hash()); - let download_result = headers.clone(); - runner - .client - .on_header_request(1, |id, _| { - let response = download_result.iter().map(|h| h.clone().unseal()).collect(); - runner.client.send_header_response(id, response) - }) - .await; - let result = rx.await.unwrap(); assert_matches!( result, - Ok(ExecOutput { done: true, reached_tip: true, stage_progress }) if stage_progress == tip.number + Ok(ExecOutput { done: true, reached_tip: true, stage_progress }) + if stage_progress == tip.number ); assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed"); } @@ -298,7 +295,7 @@ mod tests { Self { client: client.clone(), consensus: consensus.clone(), - downloader: Arc::new(TestHeaderDownloader::new(client, consensus)), + downloader: Arc::new(TestHeaderDownloader::new(client, consensus, 1000)), db: StageTestDB::default(), } } @@ -341,23 +338,6 @@ mod tests { Ok(headers) } - async fn after_execution(&self, headers: Self::Seed) -> Result<(), TestRunnerError> { - let tip = if !headers.is_empty() { - headers.last().unwrap().hash() - } else { - H256::from_low_u64_be(rand::random()) - }; - self.consensus.update_tip(tip); - self.client - .send_header_response_delayed( - 0, - headers.into_iter().map(|h| h.unseal()).collect(), - 1, - ) - .await; - Ok(()) - } - /// Validate stored headers fn validate_execution( &self, @@ -405,6 +385,17 @@ mod tests { }; Ok(()) } + + async fn after_execution(&self, headers: Self::Seed) -> Result<(), TestRunnerError> { + self.client.extend(headers.iter().map(|h| h.clone().unseal())).await; + let tip = if !headers.is_empty() { + headers.last().unwrap().hash() + } else { + H256::from_low_u64_be(rand::random()) + }; + self.consensus.update_tip(tip); + Ok(()) + } } impl UnwindStageTestRunner for HeadersTestRunner { @@ -414,6 +405,7 @@ mod tests { } impl HeadersTestRunner> { + #[allow(unused)] pub(crate) fn with_linear_downloader() -> Self { let client = Arc::new(TestHeadersClient::default()); let consensus = Arc::new(TestConsensus::default()); diff --git a/crates/stages/src/stages/senders.rs b/crates/stages/src/stages/senders.rs index 3350735ac..f258f5989 100644 --- a/crates/stages/src/stages/senders.rs +++ b/crates/stages/src/stages/senders.rs @@ -218,9 +218,9 @@ mod tests { let start_hash = tx.get::(start_block)?.unwrap(); let mut body_cursor = tx.cursor::()?; - let mut body_walker = body_cursor.walk((start_block, start_hash).into())?; + let body_walker = body_cursor.walk((start_block, start_hash).into())?; - while let Some(entry) = body_walker.next() { + for entry in body_walker { let (_, body) = entry?; for tx_id in body.base_tx_id..body.base_tx_id + body.tx_amount { let transaction = tx diff --git a/crates/stages/src/stages/tx_index.rs b/crates/stages/src/stages/tx_index.rs index d66272775..cf81f042d 100644 --- a/crates/stages/src/stages/tx_index.rs +++ b/crates/stages/src/stages/tx_index.rs @@ -170,9 +170,9 @@ mod tests { let mut tx_count_walker = tx_count_cursor.walk((start, start_hash).into())?; let mut count = tx_count_walker.next().unwrap()?.1; let mut last_num = start; - while let Some(entry) = tx_count_walker.next() { + for entry in tx_count_walker { let (key, db_count) = entry?; - count += tx.get::(key)?.unwrap().tx_amount as u64; + count += tx.get::(key)?.unwrap().tx_amount; assert_eq!(db_count, count); last_num = key.number(); }