From 9628d03871fce44deb8e6d8c4f59feb2e091b899 Mon Sep 17 00:00:00 2001 From: Roman Krasiuk Date: Tue, 6 Dec 2022 08:28:11 +0200 Subject: [PATCH] feat(sync): headers commit threshold (#296) * headers stream init * fix tests * return header if available regardless of control flow * proper stream termination & docs * upd headers stage to consume stream * adjust response validation for stream * use cursor.insert for headers * wrap poll_next in a loop to bypass poking waker * fix typo * fix last td lookup * Apply suggestions from code review Co-authored-by: Georgios Konstantopoulos * misc * remove waker ref * dedup response handling logic * clippy * add docs to poll Co-authored-by: Georgios Konstantopoulos --- .../interfaces/src/p2p/headers/downloader.rs | 8 + crates/interfaces/src/test_utils/headers.rs | 123 ++++--- crates/net/headers-downloaders/src/linear.rs | 312 ++++++++++++------ crates/primitives/src/header.rs | 5 + crates/stages/src/stages/headers.rs | 124 ++++--- 5 files changed, 385 insertions(+), 187 deletions(-) diff --git a/crates/interfaces/src/p2p/headers/downloader.rs b/crates/interfaces/src/p2p/headers/downloader.rs index cb6c03991..c64d707b8 100644 --- a/crates/interfaces/src/p2p/headers/downloader.rs +++ b/crates/interfaces/src/p2p/headers/downloader.rs @@ -4,6 +4,7 @@ use crate::{ p2p::{headers::error::DownloadError, traits::BatchDownload}, }; +use futures::Stream; use reth_primitives::SealedHeader; use reth_rpc_types::engine::ForkchoiceState; use std::{pin::Pin, time::Duration}; @@ -20,6 +21,10 @@ pub type HeaderBatchDownload<'a> = Pin< >, >; +/// A stream for downloading headers. +pub type HeaderDownloadStream = + Pin> + Send>>; + /// A downloader capable of fetching block headers. /// /// A downloader represents a distinct strategy for submitting requests to download block headers, @@ -45,6 +50,9 @@ pub trait HeaderDownloader: Sync + Send + Unpin { /// Download the headers fn download(&self, head: SealedHeader, forkchoice: ForkchoiceState) -> HeaderBatchDownload<'_>; + /// Stream the headers + fn stream(&self, head: SealedHeader, forkchoice: ForkchoiceState) -> HeaderDownloadStream; + /// Validate whether the header is valid in relation to it's parent /// /// Returns Ok(false) if the diff --git a/crates/interfaces/src/test_utils/headers.rs b/crates/interfaces/src/test_utils/headers.rs index 284317f57..35452e5d3 100644 --- a/crates/interfaces/src/test_utils/headers.rs +++ b/crates/interfaces/src/test_utils/headers.rs @@ -5,7 +5,7 @@ use crate::{ error::{RequestError, RequestResult}, headers::{ client::{HeadersClient, HeadersRequest}, - downloader::{HeaderBatchDownload, HeaderDownloader}, + downloader::{HeaderBatchDownload, HeaderDownloadStream, HeaderDownloader}, error::DownloadError, }, traits::BatchDownload, @@ -39,6 +39,17 @@ impl TestHeaderDownloader { pub fn new(client: Arc, consensus: Arc, limit: u64) -> Self { Self { client, consensus, limit } } + + fn create_download(&self) -> TestDownload { + TestDownload { + client: Arc::clone(&self.client), + consensus: Arc::clone(&self.consensus), + limit: self.limit, + fut: None, + buffer: vec![], + done: false, + } + } } #[async_trait::async_trait] @@ -63,39 +74,55 @@ impl HeaderDownloader for TestHeaderDownloader { _head: SealedHeader, _forkchoice: ForkchoiceState, ) -> HeaderBatchDownload<'_> { - Box::pin(TestDownload { - client: Arc::clone(&self.client), - consensus: Arc::clone(&self.consensus), - limit: self.limit, - }) + Box::pin(self.create_download()) + } + + fn stream(&self, _head: SealedHeader, _forkchoice: ForkchoiceState) -> HeaderDownloadStream { + Box::pin(self.create_download()) } } +type TestHeadersFut = Pin> + Send>>; + struct TestDownload { client: Arc, consensus: Arc, limit: u64, + fut: Option, + buffer: Vec, + done: bool, +} + +impl TestDownload { + fn get_or_init_fut(&mut self) -> &mut TestHeadersFut { + if self.fut.is_none() { + let request = HeadersRequest { + limit: self.limit, + direction: HeadersDirection::Rising, + start: reth_primitives::BlockHashOrNumber::Number(0), // ignored + }; + let client = Arc::clone(&self.client); + self.fut = Some(Box::pin(async move { client.get_headers(request).await })); + } + self.fut.as_mut().unwrap() + } } impl Future for TestDownload { type Output = Result, DownloadError>; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let empty = SealedHeader::default(); if let Err(error) = self.consensus.validate_header(&empty, &empty) { return Poll::Ready(Err(DownloadError::HeaderValidation { hash: empty.hash(), error })) } - let request = HeadersRequest { - limit: self.limit, - direction: HeadersDirection::Rising, - start: reth_primitives::BlockHashOrNumber::Number(0), // ignored - }; - match ready!(self.client.get_headers(request).poll_unpin(cx)) { + match ready!(self.get_or_init_fut().poll_unpin(cx)) { Ok(resp) => { + // Skip head and seal headers let mut headers = resp.0.into_iter().skip(1).map(|h| h.seal()).collect::>(); headers.sort_unstable_by_key(|h| h.number); - Poll::Ready(Ok(headers)) + Poll::Ready(Ok(headers.into_iter().rev().collect())) } Err(err) => Poll::Ready(Err(match err { RequestError::Timeout => DownloadError::Timeout, @@ -108,8 +135,44 @@ impl Future for TestDownload { impl Stream for TestDownload { type Item = Result; - fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - todo!() + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + loop { + if let Some(header) = this.buffer.pop() { + return Poll::Ready(Some(Ok(header))) + } else if this.done { + return Poll::Ready(None) + } + + let empty = SealedHeader::default(); + if let Err(error) = this.consensus.validate_header(&empty, &empty) { + this.done = true; + return Poll::Ready(Some(Err(DownloadError::HeaderValidation { + hash: empty.hash(), + error, + }))) + } + + match ready!(this.get_or_init_fut().poll_unpin(cx)) { + Ok(resp) => { + // Skip head and seal headers + let mut headers = + resp.0.into_iter().skip(1).map(|h| h.seal()).collect::>(); + headers.sort_unstable_by_key(|h| h.number); + headers.into_iter().for_each(|h| this.buffer.push(h)); + this.done = true; + continue + } + Err(err) => { + this.done = true; + return Poll::Ready(Some(Err(match err { + RequestError::Timeout => DownloadError::Timeout, + _ => DownloadError::RequestError(err), + }))) + } + } + } } } @@ -120,34 +183,6 @@ impl BatchDownload for TestDownload { fn into_stream_unordered(self) -> Box>> { Box::new(self) } - - // async fn download( - // &self, - // _: &SealedHeader, - // _: &ForkchoiceState, - // ) -> Result, DownloadError> { - // // call consensus stub first. fails if the flag is set - // let empty = SealedHeader::default(); - // self.consensus - // .validate_header(&empty, &empty) - // .map_err(|error| DownloadError::HeaderValidation { hash: empty.hash(), error })?; - // - // let stream = self.client.stream_headers().await; - // let stream = stream.timeout(Duration::from_secs(1)); - // - // match Box::pin(stream).try_next().await { - // Ok(Some(res)) => { - // let mut headers = res.headers.iter().map(|h| - // h.clone().seal()).collect::>(); if !headers.is_empty() { - // headers.sort_unstable_by_key(|h| h.number); - // headers.remove(0); // remove head from response - // headers.reverse(); - // } - // Ok(headers) - // } - // _ => Err(DownloadError::Timeout { request_id: 0 }), - // } - // } } /// A test client for fetching headers diff --git a/crates/net/headers-downloaders/src/linear.rs b/crates/net/headers-downloaders/src/linear.rs index 991b88dbf..28c6656d2 100644 --- a/crates/net/headers-downloaders/src/linear.rs +++ b/crates/net/headers-downloaders/src/linear.rs @@ -5,7 +5,10 @@ use reth_interfaces::{ error::{RequestError, RequestResult}, headers::{ client::{BlockHeaders, HeadersClient, HeadersRequest}, - downloader::{validate_header_download, HeaderBatchDownload, HeaderDownloader}, + downloader::{ + validate_header_download, HeaderBatchDownload, HeaderDownloadStream, + HeaderDownloader, + }, error::DownloadError, }, traits::BatchDownload, @@ -15,6 +18,7 @@ use reth_primitives::{HeadersDirection, SealedHeader, H256}; use reth_rpc_types::engine::ForkchoiceState; use std::{ borrow::Borrow, + collections::VecDeque, future::Future, pin::Pin, sync::Arc, @@ -59,16 +63,11 @@ where } fn download(&self, head: SealedHeader, forkchoice: ForkchoiceState) -> HeaderBatchDownload<'_> { - Box::pin(HeadersDownload { - head, - forkchoice, - buffered: vec![], - request: Default::default(), - consensus: Arc::clone(&self.consensus), - request_retries: self.request_retries, - batch_size: self.batch_size, - client: Arc::clone(&self.client), - }) + Box::pin(self.new_download(head, forkchoice)) + } + + fn stream(&self, head: SealedHeader, forkchoice: ForkchoiceState) -> HeaderDownloadStream { + Box::pin(self.new_download(head, forkchoice)) } } @@ -84,6 +83,26 @@ impl Clone for LinearDownloader { } } +impl LinearDownloader { + fn new_download( + &self, + head: SealedHeader, + forkchoice: ForkchoiceState, + ) -> HeadersDownload { + HeadersDownload { + head, + forkchoice, + buffered: VecDeque::default(), + request: Default::default(), + consensus: Arc::clone(&self.consensus), + request_retries: self.request_retries, + batch_size: self.batch_size, + client: Arc::clone(&self.client), + done: false, + } + } +} + type HeadersFut = Pin> + Send>>; /// A retryable future that returns a list of [`BlockHeaders`] on success. @@ -121,7 +140,7 @@ pub struct HeadersDownload { head: SealedHeader, forkchoice: ForkchoiceState, /// Buffered results - buffered: Vec, + buffered: VecDeque, /// Contains the request that's currently in progress. /// /// TODO(mattsse): this could be converted into a `FuturesOrdered` where batching is done via @@ -135,6 +154,9 @@ pub struct HeadersDownload { batch_size: u64, /// The number of retries for downloading request_retries: usize, + /// The flag indicating whether the downloader has finished + /// or the retries have been exhausted + done: bool, } impl HeadersDownload @@ -142,11 +164,17 @@ where C: Consensus + 'static, H: HeadersClient + 'static, { - /// Returns the start hash for a new request. - fn request_start(&self) -> H256 { - self.buffered.last().map_or(self.forkchoice.head_block_hash, |h| h.parent_hash) + /// Returns the first header from the vector of buffered headers + fn earliest_header(&self) -> Option<&SealedHeader> { + self.buffered.back() } + /// Returns the start hash for a new request. + fn request_start(&self) -> H256 { + self.earliest_header().map_or(self.forkchoice.head_block_hash, |h| h.parent_hash) + } + + /// Get the headers request to dispatch fn headers_request(&self) -> HeadersRequest { HeadersRequest { start: self.request_start().into(), @@ -155,7 +183,30 @@ where } } - /// Tries to fuse the future with a new request + /// Insert the header into buffer + fn push_header_into_buffer(&mut self, header: SealedHeader) { + self.buffered.push_back(header); + } + + /// Get a current future or instantiate a new one + fn get_or_init_fut(&mut self) -> Option { + match self.request.take() { + None if !self.done => { + // queue in the first request + let client = Arc::clone(&self.client); + let req = self.headers_request(); + Some(HeadersRequestFuture { + request: req.clone(), + fut: Box::pin(async move { client.get_headers(req).await }), + retries: 0, + max_retries: self.request_retries, + }) + } + fut => fut, + } + } + + /// Tries to fuse the future with a new request. /// /// Returns an `Err` if the request exhausted all retries fn try_fuse_request_fut(&self, fut: &mut HeadersRequestFuture) -> Result<(), ()> { @@ -176,6 +227,48 @@ where validate_header_download(&self.consensus, header, parent)?; Ok(()) } + + fn process_header_response( + &mut self, + response: Result, + ) -> Result<(), DownloadError> { + match response { + Ok(res) => { + let mut headers = res.0; + headers.sort_unstable_by_key(|h| h.number); + + if headers.is_empty() { + return Err(RequestError::BadResponse.into()) + } + + // Iterate headers in reverse + for parent in headers.into_iter().rev() { + let parent = parent.seal(); + + if self.head.hash() == parent.hash() { + // We've reached the target, stop buffering headers + self.done = true; + break + } + + if let Some(header) = self.earliest_header() { + // Proceed to insert. If there is a validation error re-queue + // the future. + self.validate(header, &parent)?; + } else if parent.hash() != self.forkchoice.head_block_hash { + // The buffer is empty and the first header does not match the + // tip, requeue the future + return Err(RequestError::BadResponse.into()) + } + + // Record new parent + self.push_header_into_buffer(parent); + } + Ok(()) + } + Err(err) => Err(err.into()), + } + } } impl Future for HeadersDownload @@ -185,84 +278,30 @@ where { type Output = Result, DownloadError>; + /// Linear header download implemented as a [Future]. The downloader + /// aggregates all of the header responses in a local buffer until the + /// previous head is reached. + /// + /// Upon encountering an error, the downloader will try to resend the request. + /// Returns the error if all of the request retries have been exhausted. fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); - 'outer: loop { - let mut fut = match this.request.take() { - Some(fut) => fut, - None => { - // queue in the first request - let client = Arc::clone(&this.client); - let req = this.headers_request(); - HeadersRequestFuture { - request: req.clone(), - fut: Box::pin(async move { client.get_headers(req).await }), - retries: 0, - max_retries: this.request_retries, - } + // Safe to unwrap, because the future is `done` + // only upon returning the result + let mut fut = this.get_or_init_fut().expect("fut exists; qed"); + let response = ready!(fut.poll_unpin(cx)); + if let Err(err) = this.process_header_response(response) { + if this.try_fuse_request_fut(&mut fut).is_err() { + this.done = true; + return Poll::Ready(Err(err)) } - }; + this.request = Some(fut); + continue 'outer + } - match ready!(fut.poll_unpin(cx)) { - Ok(resp) => { - let mut headers = resp.0; - headers.sort_unstable_by_key(|h| h.number); - - if headers.is_empty() { - if this.try_fuse_request_fut(&mut fut).is_err() { - return Poll::Ready(Err(RequestError::BadResponse.into())) - } else { - this.request = Some(fut); - continue - } - } - - // Iterate headers in reverse - for parent in headers.into_iter().rev() { - let parent = parent.seal(); - - if this.head.hash() == parent.hash() { - // We've reached the target - let headers = - std::mem::take(&mut this.buffered).into_iter().rev().collect(); - return Poll::Ready(Ok(headers)) - } - - if let Some(header) = this.buffered.last() { - match this.validate(header, &parent) { - Ok(_) => { - // record new parent - this.buffered.push(parent); - } - Err(err) => { - if this.try_fuse_request_fut(&mut fut).is_err() { - return Poll::Ready(Err(err)) - } - this.request = Some(fut); - continue 'outer - } - } - } else { - // The buffer is empty and the first header does not match the tip, - // discard - if parent.hash() != this.forkchoice.head_block_hash { - if this.try_fuse_request_fut(&mut fut).is_err() { - return Poll::Ready(Err(RequestError::BadResponse.into())) - } - this.request = Some(fut); - continue 'outer - } - this.buffered.push(parent); - } - } - } - Err(err) => { - if this.try_fuse_request_fut(&mut fut).is_err() { - return Poll::Ready(Err(DownloadError::RequestError(err))) - } - this.request = Some(fut); - } + if this.done { + return Poll::Ready(Ok(std::mem::take(&mut this.buffered).into())) } } } @@ -275,8 +314,50 @@ where { type Item = Result; - fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - todo!() + /// Linear header downloader implemented as a [Stream]. The downloader sends header + /// requests until the head is reached and buffers the responses. If the request future + /// is still pending, the downloader will return a buffered header if any is available. + /// + /// Internally, the stream is terminated if the `done` flag has been set and there are no + /// more headers available in the buffer. + /// + /// Upon encountering an error, the downloader will attempt to retry the failed request. + /// If the number of retries is exhausted, the downloader will stream an error, set the `done` + /// flag to true and clear the buffered headers, thus resulting in stream termination. + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + 'outer: loop { + if let Some(mut fut) = this.get_or_init_fut() { + if let Poll::Ready(result) = fut.poll_unpin(cx) { + if let Err(err) = this.process_header_response(result) { + if this.try_fuse_request_fut(&mut fut).is_err() { + // We exhausted all of the retries. Stream must terminate + this.done = true; + this.buffered.clear(); + return Poll::Ready(Some(Err(err))) + } + this.request = Some(fut); + continue 'outer + } + } + } + + if !this.done && this.buffered.len() > 1 { + if let Some(header) = this.buffered.pop_front() { + // Stream buffered header + return Poll::Ready(Some(Ok(header))) + } + } else if this.done { + if let Some(header) = this.buffered.pop_front() { + // Stream buffered header + return Poll::Ready(Some(Ok(header))) + } else { + // Polling finished, we've reached the target + return Poll::Ready(None) + } + } + } } } @@ -350,6 +431,7 @@ impl LinearDownloadBuilder { #[cfg(test)] mod tests { use super::*; + use futures::TryStreamExt; use once_cell::sync::Lazy; use reth_interfaces::test_utils::{TestConsensus, TestHeadersClient}; use reth_primitives::SealedHeader; @@ -428,8 +510,52 @@ mod tests { let result = downloader.download(p3, fork).await; let headers = result.unwrap(); assert_eq!(headers.len(), 3); - assert_eq!(headers[0], p2); + assert_eq!(headers[0], p0); assert_eq!(headers[1], p1); - assert_eq!(headers[2], p0); + assert_eq!(headers[2], p2); + } + + #[tokio::test] + async fn download_empty_stream() { + let client = Arc::new(TestHeadersClient::default()); + let downloader = + LinearDownloadBuilder::default().build(CONSENSUS.clone(), Arc::clone(&client)); + + let result = downloader + .stream(SealedHeader::default(), ForkchoiceState::default()) + .try_collect::>() + .await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn download_stream() { + let client = Arc::new(TestHeadersClient::default()); + let downloader = LinearDownloadBuilder::default() + .batch_size(3) + .build(CONSENSUS.clone(), Arc::clone(&client)); + + let p3 = SealedHeader::default(); + let p2 = child_header(&p3); + let p1 = child_header(&p2); + let p0 = child_header(&p1); + + client + .extend(vec![ + p0.as_ref().clone(), + p1.as_ref().clone(), + p2.as_ref().clone(), + p3.as_ref().clone(), + ]) + .await; + + let fork = ForkchoiceState { head_block_hash: p0.hash_slow(), ..Default::default() }; + + let result = downloader.stream(p3, fork).try_collect::>().await; + let headers = result.unwrap(); + assert_eq!(headers.len(), 3); + assert_eq!(headers[0], p0); + assert_eq!(headers[1], p1); + assert_eq!(headers[2], p2); } } diff --git a/crates/primitives/src/header.rs b/crates/primitives/src/header.rs index c9acdf7f6..7d7362320 100644 --- a/crates/primitives/src/header.rs +++ b/crates/primitives/src/header.rs @@ -266,6 +266,11 @@ impl SealedHeader { pub fn hash(&self) -> BlockHash { self.hash } + + /// Return the number hash tuple. + pub fn num_hash(&self) -> (BlockNumber, BlockHash) { + (self.number, self.hash) + } } /// Represents the direction for a headers request depending on the `reverse` field of the request. diff --git a/crates/stages/src/stages/headers.rs b/crates/stages/src/stages/headers.rs index 3df83486a..e44a6dc31 100644 --- a/crates/stages/src/stages/headers.rs +++ b/crates/stages/src/stages/headers.rs @@ -2,6 +2,7 @@ use crate::{ db::StageDB, DatabaseIntegrityError, ExecInput, ExecOutput, Stage, StageError, StageId, UnwindInput, UnwindOutput, }; +use futures_util::StreamExt; use reth_interfaces::{ consensus::{Consensus, ForkchoiceState}, db::{models::blocks::BlockNumHash, tables, Database, DbCursorRO, DbCursorRW, DbTx, DbTxMut}, @@ -36,6 +37,8 @@ pub struct HeaderStage { pub consensus: Arc, /// Downloader client implementation pub client: Arc, + /// The minimum number of block headers to commit at once + pub commit_threshold: usize, } #[async_trait::async_trait] @@ -74,33 +77,43 @@ impl Stage { - // Perform basic response validation - self.validate_header_response(&res, head, forkchoice)?; - res + while let Some(headers) = stream.next().await { + match headers.into_iter().collect::, _>>() { + Ok(res) => { + // Perform basic response validation + self.validate_header_response(&res)?; + let write_progress = + self.write_headers::(db, res).await?.unwrap_or_default(); + current_progress = current_progress.max(write_progress); + } + Err(e) => match e { + DownloadError::Timeout => { + warn!("No response for header request"); + return Ok(ExecOutput { stage_progress, reached_tip: false, done: false }) + } + DownloadError::HeaderValidation { hash, error } => { + warn!("Validation error for header {hash}: {error}"); + return Err(StageError::Validation { block: stage_progress, error }) + } + error => { + warn!("Unexpected error occurred: {error}"); + return Err(StageError::Download(error.to_string())) + } + }, } - Err(e) => match e { - DownloadError::Timeout => { - warn!("No response for header request"); - return Ok(ExecOutput { stage_progress, reached_tip: false, done: false }) - } - DownloadError::HeaderValidation { hash, error } => { - warn!("Validation error for header {hash}: {error}"); - return Err(StageError::Validation { block: stage_progress, error }) - } - error => { - warn!("Unexpected error occurred: {error}"); - return Err(StageError::Download(error.to_string())) - } - }, - }; - let stage_progress = self.write_headers::(db, headers).await?.unwrap_or(stage_progress); - Ok(ExecOutput { stage_progress, reached_tip: true, done: true }) + } + + // Write total difficulty values after all headers have been inserted + self.write_td::(db, &head)?; + + Ok(ExecOutput { stage_progress: current_progress, reached_tip: true, done: true }) } /// Unwind the stage. @@ -146,25 +159,13 @@ impl HeaderStage { } /// Perform basic header response validation - fn validate_header_response( - &self, - headers: &[SealedHeader], - head: SealedHeader, - forkchoice: ForkchoiceState, - ) -> Result<(), StageError> { - // The response must include at least head and tip - if headers.len() < 2 { - return Err(StageError::Download("Not enough headers".to_owned())) - } - - let mut headers_iter = headers.iter().rev().peekable(); - if headers_iter.peek().unwrap().hash() != forkchoice.head_block_hash { - return Err(StageError::Download("Response must end with tip".to_owned())) - } - + fn validate_header_response(&self, headers: &[SealedHeader]) -> Result<(), StageError> { + let mut headers_iter = headers.iter().peekable(); while let Some(header) = headers_iter.next() { - ensure_parent(header, headers_iter.peek().unwrap_or(&&head)) - .map_err(|err| StageError::Download(err.to_string()))?; + if let Some(parent) = headers_iter.peek() { + ensure_parent(header, parent) + .map_err(|err| StageError::Download(err.to_string()))?; + } } Ok(()) @@ -178,13 +179,11 @@ impl HeaderStage { ) -> Result, StageError> { let mut cursor_header = db.cursor_mut::()?; let mut cursor_canonical = db.cursor_mut::()?; - let mut cursor_td = db.cursor_mut::()?; - let mut td: U256 = cursor_td.last()?.map(|(_, v)| v).unwrap().into(); let mut latest = None; // Since the headers were returned in descending order, // iterate them in the reverse order - for header in headers.into_iter() { + for header in headers.into_iter().rev() { if header.number == 0 { continue } @@ -194,17 +193,41 @@ impl HeaderStage { let header = header.unseal(); latest = Some(header.number); - td += header.difficulty; - // NOTE: HeaderNumbers are not sorted and can't be inserted with cursor. db.put::(block_hash, header.number)?; - cursor_header.append(key, header)?; - cursor_canonical.append(key.number(), key.hash())?; - cursor_td.append(key, td.into())?; + cursor_header.insert(key, header)?; + cursor_canonical.insert(key.number(), key.hash())?; } Ok(latest) } + + /// Iterate over inserted headers and write td entries + fn write_td( + &self, + db: &StageDB<'_, DB>, + head: &SealedHeader, + ) -> Result<(), StageError> { + // Acquire cursor over total difficulty table + let mut cursor_td = db.cursor_mut::()?; + + // Get latest total difficulty + let last_entry = cursor_td + .seek_exact(head.num_hash().into())? + .ok_or(DatabaseIntegrityError::TotalDifficulty { number: head.number })?; + let mut td: U256 = last_entry.1.into(); + + // Start at first inserted block during this iteration + let start_key = db.get_block_numhash(head.number + 1)?; + + // Walk over newly inserted headers, update & insert td + for entry in db.cursor::()?.walk(start_key)? { + let (key, header) = entry?; + td += header.difficulty; + cursor_td.append(key, td.into())?; + } + Ok(()) + } } #[cfg(test)] @@ -262,7 +285,7 @@ mod tests { /// Check that unexpected download errors are caught #[tokio::test] - async fn executed_download_error() { + async fn execute_download_error() { let mut runner = HeadersTestRunner::default(); let (stage_progress, previous_stage) = (1000, 1200); let input = ExecInput { @@ -363,6 +386,7 @@ mod tests { consensus: self.consensus.clone(), client: self.client.clone(), downloader: self.downloader.clone(), + commit_threshold: 100, } } }