diff --git a/Cargo.lock b/Cargo.lock index 26f1f74ba..38d1f8f61 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3275,8 +3275,10 @@ name = "reth-stages" version = "0.1.0" dependencies = [ "aquamarine", + "assert_matches", "async-trait", "metrics", + "once_cell", "reth-db", "reth-interfaces", "reth-primitives", diff --git a/crates/interfaces/src/consensus.rs b/crates/interfaces/src/consensus.rs index 10aef4926..8de389b5f 100644 --- a/crates/interfaces/src/consensus.rs +++ b/crates/interfaces/src/consensus.rs @@ -1,9 +1,11 @@ use async_trait::async_trait; use reth_primitives::Header; -use reth_rpc_types::engine::ForkchoiceState; use thiserror::Error; use tokio::sync::watch::Receiver; +/// Re-export forkchoice state +pub use reth_rpc_types::engine::ForkchoiceState; + /// Consensus is a protocol that chooses canonical chain. /// We are checking validity of block header here. #[async_trait] diff --git a/crates/interfaces/src/db/models/blocks.rs b/crates/interfaces/src/db/models/blocks.rs index 534a18f3a..9aabcf3a7 100644 --- a/crates/interfaces/src/db/models/blocks.rs +++ b/crates/interfaces/src/db/models/blocks.rs @@ -25,8 +25,8 @@ pub type HeaderHash = H256; /// element as BlockNumber, helps out with querying/sorting. /// /// Since it's used as a key, the `BlockNumber` is not compressed when encoding it. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] -pub struct BlockNumHash((BlockNumber, BlockHash)); +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +pub struct BlockNumHash(pub (BlockNumber, BlockHash)); impl BlockNumHash { /// Consumes `Self` and returns [`BlockNumber`], [`BlockHash`] diff --git a/crates/interfaces/src/p2p/headers/downloader.rs b/crates/interfaces/src/p2p/headers/downloader.rs index 680867b3c..ccbc2630a 100644 --- a/crates/interfaces/src/p2p/headers/downloader.rs +++ b/crates/interfaces/src/p2p/headers/downloader.rs @@ -22,12 +22,6 @@ pub enum DownloadError { /// The details of validation failure details: String, }, - /// No headers reponse received - #[error("Failed to get headers for request {request_id}.")] - NoHeaderResponse { - /// The last request ID - request_id: u64, - }, /// Timed out while waiting for request id response. #[error("Timed out while getting headers for request {request_id}.")] Timeout { @@ -53,7 +47,7 @@ impl DownloadError { /// Returns bool indicating whether this error is retryable or fatal, in the cases /// where the peer responds with no headers, or times out. pub fn is_retryable(&self) -> bool { - matches!(self, DownloadError::NoHeaderResponse { .. } | DownloadError::Timeout { .. }) + matches!(self, DownloadError::Timeout { .. }) } } @@ -106,7 +100,7 @@ pub trait Downloader: Sync + Send { // Pop the first item. match Box::pin(stream).try_next().await { Ok(Some(item)) => Ok(item.headers), - _ => return Err(DownloadError::NoHeaderResponse { request_id }), + _ => return Err(DownloadError::Timeout { request_id }), } } diff --git a/crates/interfaces/src/test_utils.rs b/crates/interfaces/src/test_utils.rs index 2703f878b..dcc351812 100644 --- a/crates/interfaces/src/test_utils.rs +++ b/crates/interfaces/src/test_utils.rs @@ -7,7 +7,7 @@ use crate::{ }; use std::{collections::HashSet, sync::Arc, time::Duration}; -use reth_primitives::{Header, HeaderLocked, H256, H512}; +use reth_primitives::{Header, HeaderLocked, H256, H512, U256}; use reth_rpc_types::engine::ForkchoiceState; use tokio::sync::{broadcast, mpsc, watch}; @@ -134,7 +134,7 @@ impl Default for TestConsensus { impl TestConsensus { /// Update the forkchoice state - pub fn update_tip(&mut self, tip: H256) { + pub fn update_tip(&self, tip: H256) { let state = ForkchoiceState { head_block_hash: tip, finalized_block_hash: H256::zero(), @@ -163,3 +163,28 @@ impl Consensus for TestConsensus { } } } + +/// Generate a range of random header. The parent hash of the first header +/// in the result will be equal to head +pub fn gen_random_header_range(rng: std::ops::Range, head: H256) -> Vec { + let mut headers = Vec::with_capacity(rng.end.saturating_sub(rng.start) as usize); + for idx in rng { + headers.push(gen_random_header( + idx, + Some(headers.last().map(|h: &HeaderLocked| h.hash()).unwrap_or(head)), + )); + } + headers +} + +/// Generate a random header +pub fn gen_random_header(number: u64, parent: Option) -> HeaderLocked { + let header = reth_primitives::Header { + number, + nonce: rand::random(), + difficulty: U256::from(rand::random::()), + parent_hash: parent.unwrap_or_default(), + ..Default::default() + }; + header.lock() +} diff --git a/crates/net/headers-downloaders/src/linear.rs b/crates/net/headers-downloaders/src/linear.rs index 589ad0d35..5f61b4c69 100644 --- a/crates/net/headers-downloaders/src/linear.rs +++ b/crates/net/headers-downloaders/src/linear.rs @@ -206,9 +206,11 @@ mod tests { use super::*; use reth_interfaces::{ p2p::headers::client::HeadersRequest, - test_utils::{TestConsensus, TestHeadersClient}, + test_utils::{ + gen_random_header, gen_random_header_range, TestConsensus, TestHeadersClient, + }, }; - use reth_primitives::{rpc::BlockId, HeaderLocked, H256}; + use reth_primitives::{rpc::BlockId, HeaderLocked}; use assert_matches::assert_matches; use once_cell::sync::Lazy; @@ -246,7 +248,7 @@ mod tests { }) .await; assert_eq!(requests.len(), retries); - assert_matches!(rx.await, Ok(Err(DownloadError::NoHeaderResponse { .. }))); + assert_matches!(rx.await, Ok(Err(DownloadError::Timeout { .. }))); } #[tokio::test] @@ -277,7 +279,7 @@ mod tests { assert_eq!(num_of_reqs, retries); assert_matches!( rx.await, - Ok(Err(DownloadError::NoHeaderResponse { request_id })) if request_id == last_req_id.unwrap() + Ok(Err(DownloadError::Timeout { request_id })) if request_id == last_req_id.unwrap() ); } @@ -358,7 +360,7 @@ mod tests { async fn download_returns_headers_desc() { let (start, end) = (100, 200); let head = gen_random_header(start, None); - let mut headers = gen_block_range(start + 1..end, head.hash()); + let mut headers = gen_random_header_range(start + 1..end, head.hash()); headers.reverse(); let tip_hash = headers.first().unwrap().hash(); @@ -395,25 +397,4 @@ mod tests { assert_eq!(result.len(), headers.len()); assert_eq!(result, headers); } - - pub(crate) fn gen_block_range(rng: std::ops::Range, head: H256) -> Vec { - let mut headers = Vec::with_capacity(rng.end.saturating_sub(rng.start) as usize); - for idx in rng { - headers.push(gen_random_header( - idx, - Some(headers.last().map(|h: &HeaderLocked| h.hash()).unwrap_or(head)), - )); - } - headers - } - - pub(crate) fn gen_random_header(number: u64, parent: Option) -> HeaderLocked { - let header = reth_primitives::Header { - number, - nonce: rand::random(), - parent_hash: parent.unwrap_or_default(), - ..Default::default() - }; - header.lock() - } } diff --git a/crates/stages/Cargo.toml b/crates/stages/Cargo.toml index f0bca4da0..1ac219b20 100644 --- a/crates/stages/Cargo.toml +++ b/crates/stages/Cargo.toml @@ -20,7 +20,10 @@ aquamarine = "0.1.12" metrics = "0.20.1" [dev-dependencies] +reth-db = { path = "../db", features = ["test-utils"] } +reth-interfaces = { path = "../interfaces", features = ["test-utils"] } tokio = { version = "*", features = ["rt", "sync", "macros"] } tokio-stream = "0.1.10" +once_cell = "1.15.0" tempfile = "3.3.0" -reth-db = { path = "../db", features = ["test-utils"] } \ No newline at end of file +assert_matches = "1.5.0" diff --git a/crates/stages/src/error.rs b/crates/stages/src/error.rs index 3370d8fdd..73d6f3bbd 100644 --- a/crates/stages/src/error.rs +++ b/crates/stages/src/error.rs @@ -1,4 +1,4 @@ -use crate::pipeline::PipelineEvent; +use crate::{pipeline::PipelineEvent, stages::headers::HeaderStageError}; use reth_interfaces::db::Error as DbError; use reth_primitives::BlockNumber; use thiserror::Error; @@ -18,6 +18,9 @@ pub enum StageError { /// The stage encountered a database error. #[error("A database error occurred.")] Database(#[from] DbError), + /// The headers stage encountered an error. + #[error("Headers stage error.")] + HeadersStage(#[from] HeaderStageError), /// The stage encountered an internal error. #[error(transparent)] Internal(Box), diff --git a/crates/stages/src/lib.rs b/crates/stages/src/lib.rs index 5735240c4..697e2f525 100644 --- a/crates/stages/src/lib.rs +++ b/crates/stages/src/lib.rs @@ -20,6 +20,9 @@ mod pipeline; mod stage; mod util; +/// Implementations of stages. +pub mod stages; + pub use error::*; pub use id::*; pub use pipeline::*; diff --git a/crates/stages/src/stages/headers.rs b/crates/stages/src/stages/headers.rs new file mode 100644 index 000000000..587dd7a0b --- /dev/null +++ b/crates/stages/src/stages/headers.rs @@ -0,0 +1,580 @@ +use crate::{ExecInput, ExecOutput, Stage, StageError, StageId, UnwindInput, UnwindOutput}; +use reth_interfaces::{ + consensus::{Consensus, ForkchoiceState}, + db::{ + self, models::blocks::BlockNumHash, tables, DBContainer, Database, DatabaseGAT, DbCursorRO, + DbCursorRW, DbTx, DbTxMut, Table, + }, + p2p::headers::{ + client::HeadersClient, + downloader::{DownloadError, Downloader}, + }, +}; +use reth_primitives::{rpc::BigEndianHash, BlockNumber, HeaderLocked, H256, U256}; +use std::fmt::Debug; +use thiserror::Error; +use tracing::*; + +const HEADERS: StageId = StageId("HEADERS"); + +/// The headers stage implementation for staged sync +#[derive(Debug)] +pub struct HeaderStage { + /// Strategy for downloading the headers + pub downloader: D, + /// Consensus client implementation + pub consensus: C, + /// Downloader client implementation + pub client: H, +} + +/// The header stage error +#[derive(Error, Debug)] +pub enum HeaderStageError { + /// Cannonical hash is missing from db + #[error("no cannonical hash for block #{number}")] + NoCannonicalHash { + /// The block number key + number: BlockNumber, + }, + /// Cannonical header is missing from db + #[error("no cannonical hash for block #{number}")] + NoCannonicalHeader { + /// The block number key + number: BlockNumber, + }, + /// Header is missing from db + #[error("no header for block #{number} ({hash})")] + NoHeader { + /// The block number key + number: BlockNumber, + /// The block hash key + hash: H256, + }, +} + +#[async_trait::async_trait] +impl Stage + for HeaderStage +{ + /// Return the id of the stage + fn id(&self) -> StageId { + HEADERS + } + + /// Download the headers in reverse order + /// starting from the tip + async fn execute( + &mut self, + db: &mut DBContainer<'_, DB>, + input: ExecInput, + ) -> Result { + let tx = db.get_mut(); + let last_block_num = + input.previous_stage.as_ref().map(|(_, block)| *block).unwrap_or_default(); + self.update_head::(tx, last_block_num).await?; + + // download the headers + // TODO: handle input.max_block + let last_hash = + tx.get::(last_block_num)?.ok_or_else(|| -> StageError { + HeaderStageError::NoCannonicalHash { number: last_block_num }.into() + })?; + let last_header = tx + .get::((last_block_num, last_hash).into())? + .ok_or_else(|| -> StageError { + HeaderStageError::NoHeader { number: last_block_num, hash: last_hash }.into() + })?; + let head = HeaderLocked::new(last_header, last_hash); + + let forkchoice = self.next_fork_choice_state(&head.hash()).await; + // The stage relies on the downloader to return the headers + // in descending order starting from the tip down to + // the local head (latest block in db) + let headers = match self.downloader.download(&head, &forkchoice).await { + Ok(res) => { + // TODO: validate the result order? + // at least check if it attaches (first == tip && last == last_hash) + res + } + Err(e) => match e { + DownloadError::Timeout { request_id } => { + warn!("no response for header request {request_id}"); + return Ok(ExecOutput { + stage_progress: last_block_num, + reached_tip: false, + done: false, + }) + } + DownloadError::HeaderValidation { hash, details } => { + warn!("validation error for header {hash}: {details}"); + return Err(StageError::Validation { block: last_block_num }) + } + // TODO: this error is never propagated, clean up + DownloadError::MismatchedHeaders { .. } => { + return Err(StageError::Validation { block: last_block_num }) + } + }, + }; + + let stage_progress = self.write_headers::(tx, headers).await?.unwrap_or(last_block_num); + Ok(ExecOutput { stage_progress, reached_tip: true, done: true }) + } + + /// Unwind the stage. + async fn unwind( + &mut self, + db: &mut DBContainer<'_, DB>, + input: UnwindInput, + ) -> Result> { + // TODO: handle bad block + let tx = &mut db.get_mut(); + self.unwind_table::(tx, input.unwind_to, |num| num)?; + self.unwind_table::(tx, input.unwind_to, |key| key.0 .0)?; + self.unwind_table::(tx, input.unwind_to, |key| key.0 .0)?; + self.unwind_table::(tx, input.unwind_to, |key| key.0 .0)?; + Ok(UnwindOutput { stage_progress: input.unwind_to }) + } +} + +impl HeaderStage { + async fn update_head( + &self, + tx: &mut >::TXMut, + height: BlockNumber, + ) -> Result<(), StageError> { + let hash = tx.get::(height)?.ok_or_else(|| -> StageError { + HeaderStageError::NoCannonicalHeader { number: height }.into() + })?; + let td: Vec = tx.get::((height, hash).into())?.unwrap(); // TODO: + self.client.update_status(height, hash, H256::from_slice(&td)).await; + Ok(()) + } + + async fn next_fork_choice_state(&self, head: &H256) -> ForkchoiceState { + let mut state_rcv = self.consensus.fork_choice_state(); + loop { + let _ = state_rcv.changed().await; + let forkchoice = state_rcv.borrow(); + if !forkchoice.head_block_hash.is_zero() && forkchoice.head_block_hash != *head { + return forkchoice.clone() + } + } + } + + /// Write downloaded headers to the database + async fn write_headers( + &self, + tx: &mut >::TXMut, + headers: Vec, + ) -> Result, StageError> { + let mut cursor_header_number = tx.cursor_mut::()?; + let mut cursor_header = tx.cursor_mut::()?; + let mut cursor_canonical = tx.cursor_mut::()?; + let mut cursor_td = tx.cursor_mut::()?; + let mut td = U256::from_big_endian(&cursor_td.last()?.map(|(_, v)| v).unwrap()); + + let mut latest = None; + // Since the headers were returned in descending order, + // iterate them in the reverse order + for header in headers.into_iter().rev() { + if header.number == 0 { + continue + } + + let key: BlockNumHash = (header.number, header.hash()).into(); + let header = header.unlock(); + latest = Some(header.number); + + td += header.difficulty; + + // TODO: investigate default write flags + cursor_header_number.append(key, header.number)?; + cursor_header.append(key, header)?; + cursor_canonical.append(key.0 .0, key.0 .1)?; + cursor_td.append(key, H256::from_uint(&td).as_bytes().to_vec())?; + } + + Ok(latest) + } + + /// Unwind the table to a provided block + fn unwind_table( + &self, + tx: &mut >::TXMut, + block: BlockNumber, + mut selector: F, + ) -> Result<(), db::Error> + where + DB: Database, + T: Table, + F: FnMut(T::Key) -> BlockNumber, + { + let mut cursor = tx.cursor_mut::()?; + let mut entry = cursor.last()?; + while let Some((key, _)) = entry { + if selector(key) <= block { + break + } + cursor.delete_current()?; + entry = cursor.prev()?; + } + Ok(()) + } +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + use assert_matches::assert_matches; + use once_cell::sync::Lazy; + use reth_db::{kv::Env, mdbx::WriteMap}; + use reth_interfaces::{ + db::DBContainer, + test_utils::{ + gen_random_header, gen_random_header_range, TestConsensus, TestHeadersClient, + }, + }; + use std::{borrow::Borrow, sync::Arc}; + use test_utils::HeadersDB; + use tokio::sync::oneshot; + + const TEST_STAGE: StageId = StageId("HEADERS"); + static CONSENSUS: Lazy = Lazy::new(|| TestConsensus::default()); + static CLIENT: Lazy = Lazy::new(|| TestHeadersClient::default()); + + #[tokio::test] + // Check that the execution errors on empty database or + // prev progress missing from the database. + async fn headers_execute_empty_db() { + let db = HeadersDB::default(); + let input = ExecInput { previous_stage: None, stage_progress: None }; + let rx = execute_stage(db.inner(), input, Ok(vec![])); + assert_matches!( + rx.await.unwrap(), + Err(StageError::HeadersStage(HeaderStageError::NoCannonicalHeader { .. })) + ); + } + + #[tokio::test] + // Check that the execution exits on downloader timeout. + async fn headers_execute_timeout() { + let head = gen_random_header(0, None); + let db = HeadersDB::default(); + db.insert_header(&head).expect("failed to insert header"); + + let input = ExecInput { previous_stage: None, stage_progress: None }; + let rx = execute_stage(db.inner(), input, Err(DownloadError::Timeout { request_id: 0 })); + CONSENSUS.update_tip(H256::from_low_u64_be(1)); + assert_matches!(rx.await.unwrap(), Ok(ExecOutput { done, .. }) if !done); + } + + #[tokio::test] + // Check that validation error is propagated during the execution. + async fn headers_execute_validation_error() { + let head = gen_random_header(0, None); + let db = HeadersDB::default(); + db.insert_header(&head).expect("failed to insert header"); + + let input = ExecInput { previous_stage: None, stage_progress: None }; + let rx = execute_stage( + db.inner(), + input, + Err(DownloadError::HeaderValidation { hash: H256::zero(), details: "".to_owned() }), + ); + CONSENSUS.update_tip(H256::from_low_u64_be(1)); + + assert_matches!(rx.await.unwrap(), Err(StageError::Validation { block }) if block == 0); + } + + #[tokio::test] + // Validate that all necessary tables are updated after the + // header download on no previous progress. + async fn headers_execute_no_progress() { + let (start, end) = (0, 100); + let head = gen_random_header(start, None); + let headers = gen_random_header_range(start + 1..end, head.hash()); + let db = HeadersDB::default(); + db.insert_header(&head).expect("failed to insert header"); + + let result: Vec<_> = headers.iter().rev().cloned().collect(); + let input = ExecInput { previous_stage: None, stage_progress: None }; + let rx = execute_stage(db.inner(), input, Ok(result)); + let tip = headers.last().unwrap(); + CONSENSUS.update_tip(tip.hash()); + + let result = rx.await.unwrap(); + assert_matches!(result, Ok(ExecOutput { .. })); + let result = result.unwrap(); + assert!(result.done && result.reached_tip); + assert_eq!(result.stage_progress, tip.number); + + for header in headers { + assert!(db.validate_db_header(&header).is_ok()); + } + } + + #[tokio::test] + // Validate that all necessary tables are updated after the + // header download with some previous progress. + async fn headers_stage_prev_progress() { + // TODO: set bigger range once `MDBX_EKEYMISMATCH` issue is resolved + let (start, end) = (10000, 10240); + let head = gen_random_header(start, None); + let headers = gen_random_header_range(start + 1..end, head.hash()); + let db = HeadersDB::default(); + db.insert_header(&head).expect("failed to insert header"); + + let result: Vec<_> = headers.iter().rev().cloned().collect(); + let input = ExecInput { + previous_stage: Some((TEST_STAGE, head.number)), + stage_progress: Some(head.number), + }; + let rx = execute_stage(db.inner(), input, Ok(result)); + let tip = headers.last().unwrap(); + CONSENSUS.update_tip(tip.hash()); + + let result = rx.await.unwrap(); + assert_matches!(result, Ok(ExecOutput { .. })); + let result = result.unwrap(); + assert!(result.done && result.reached_tip); + assert_eq!(result.stage_progress, tip.number); + + for header in headers { + assert!(db.validate_db_header(&header).is_ok()); + } + } + + #[tokio::test] + // Check that unwind does not panic on empty database. + async fn headers_unwind_empty_db() { + let db = HeadersDB::default(); + let input = UnwindInput { bad_block: None, stage_progress: 100, unwind_to: 100 }; + let rx = unwind_stage(db.inner(), input); + assert_matches!( + rx.await.unwrap(), + Ok(UnwindOutput {stage_progress} ) if stage_progress == input.unwind_to + ); + } + + #[tokio::test] + // Check that unwind can remove headers across gaps + async fn headers_unwind_db_gaps() { + let head = gen_random_header(0, None); + let first_range = gen_random_header_range(1..20, head.hash()); + let second_range = gen_random_header_range(50..100, H256::zero()); + let db = HeadersDB::default(); + db.insert_header(&head).expect("failed to insert header"); + for header in first_range.iter().chain(second_range.iter()) { + db.insert_header(&header).expect("failed to insert header"); + } + + let input = UnwindInput { bad_block: None, stage_progress: 15, unwind_to: 15 }; + let rx = unwind_stage(db.inner(), input); + assert_matches!( + rx.await.unwrap(), + Ok(UnwindOutput {stage_progress} ) if stage_progress == input.unwind_to + ); + + db.check_no_entry_above::(input.unwind_to, |key| key) + .expect("failed to check cannonical headers"); + db.check_no_entry_above::(input.unwind_to, |key| key.0 .0) + .expect("failed to check header numbers"); + db.check_no_entry_above::(input.unwind_to, |key| key.0 .0) + .expect("failed to check headers"); + db.check_no_entry_above::(input.unwind_to, |key| key.0 .0) + .expect("failed to check td"); + } + + // A helper function to run [HeaderStage::execute] + fn execute_stage( + db: Arc>, + input: ExecInput, + download_result: Result, DownloadError>, + ) -> oneshot::Receiver> { + let (tx, rx) = oneshot::channel(); + tokio::spawn(async move { + let db = db.clone(); + let mut db = DBContainer::>::new(db.borrow()).unwrap(); + let mut stage = HeaderStage { + client: &*CLIENT, + consensus: &*CONSENSUS, + downloader: test_utils::TestDownloader::new(download_result), + }; + let result = stage.execute(&mut db, input).await; + db.commit().expect("failed to commit"); + tx.send(result).expect("failed to send result"); + }); + rx + } + + // A helper function to run [HeaderStage::unwind] + fn unwind_stage( + db: Arc>, + input: UnwindInput, + ) -> oneshot::Receiver>> { + let (tx, rx) = oneshot::channel(); + tokio::spawn(async move { + let db = db.clone(); + let mut db = DBContainer::>::new(db.borrow()).unwrap(); + let mut stage = HeaderStage { + client: &*CLIENT, + consensus: &*CONSENSUS, + downloader: test_utils::TestDownloader::new(Ok(vec![])), + }; + let result = stage.unwind(&mut db, input).await; + db.commit().expect("failed to commit"); + tx.send(result).expect("failed to send result"); + }); + rx + } + + pub(crate) mod test_utils { + use async_trait::async_trait; + use reth_db::{ + kv::{test_utils::create_test_db, Env, EnvKind}, + mdbx, + mdbx::WriteMap, + }; + use reth_interfaces::{ + consensus::ForkchoiceState, + db::{ + self, models::blocks::BlockNumHash, tables, DBContainer, DbCursorRO, DbCursorRW, + DbTx, DbTxMut, Table, + }, + p2p::headers::downloader::{DownloadError, Downloader}, + test_utils::{TestConsensus, TestHeadersClient}, + }; + use reth_primitives::{rpc::BigEndianHash, BlockNumber, HeaderLocked, H256, U256}; + use std::{borrow::Borrow, sync::Arc, time::Duration}; + + pub(crate) struct HeadersDB { + db: Arc>, + } + + impl Default for HeadersDB { + fn default() -> Self { + Self { db: Arc::new(create_test_db::(EnvKind::RW)) } + } + } + + impl HeadersDB { + pub(crate) fn inner(&self) -> Arc> { + self.db.clone() + } + + fn container(&self) -> DBContainer<'_, Env> { + DBContainer::new(self.db.borrow()).expect("failed to create db container") + } + + /// Insert header into tables + pub(crate) fn insert_header(&self, header: &HeaderLocked) -> Result<(), db::Error> { + let mut db = self.container(); + let tx = db.get_mut(); + + let key: BlockNumHash = (header.number, header.hash()).into(); + tx.put::(key, header.number)?; + tx.put::(key, header.clone().unlock())?; + tx.put::(header.number, header.hash())?; + + let mut cursor_td = tx.cursor_mut::()?; + let td = + U256::from_big_endian(&cursor_td.last()?.map(|(_, v)| v).unwrap_or(vec![])); + cursor_td + .append(key, H256::from_uint(&(td + header.difficulty)).as_bytes().to_vec())?; + + db.commit()?; + Ok(()) + } + + /// Validate stored header against provided + pub(crate) fn validate_db_header( + &self, + header: &HeaderLocked, + ) -> Result<(), db::Error> { + let db = self.container(); + let tx = db.get(); + let key: BlockNumHash = (header.number, header.hash()).into(); + + let db_number = tx.get::(key)?; + assert_eq!(db_number, Some(header.number)); + + let db_header = tx.get::(key)?; + assert_eq!(db_header, Some(header.clone().unlock())); + + let db_canonical_header = tx.get::(header.number)?; + assert_eq!(db_canonical_header, Some(header.hash())); + + if header.number != 0 { + let parent_key: BlockNumHash = (header.number - 1, header.parent_hash).into(); + let parent_td = tx.get::(parent_key)?; + let td = U256::from_big_endian(&tx.get::(key)?.unwrap()); + assert_eq!( + parent_td.map(|td| U256::from_big_endian(&td) + header.difficulty), + Some(td) + ); + } + + Ok(()) + } + + /// Check there there is no table entry above given block + pub(crate) fn check_no_entry_above( + &self, + block: BlockNumber, + mut selector: F, + ) -> Result<(), db::Error> + where + T: Table, + F: FnMut(T::Key) -> BlockNumber, + { + let db = self.container(); + let tx = db.get(); + + let mut cursor = tx.cursor::()?; + if let Some((key, _)) = cursor.last()? { + assert!(selector(key) <= block); + } + + Ok(()) + } + } + + #[derive(Debug)] + pub(crate) struct TestDownloader { + result: Result, DownloadError>, + } + + impl TestDownloader { + pub(crate) fn new(result: Result, DownloadError>) -> Self { + Self { result } + } + } + + #[async_trait] + impl Downloader for TestDownloader { + type Consensus = TestConsensus; + type Client = TestHeadersClient; + + fn timeout(&self) -> Duration { + Duration::from_secs(1) + } + + fn consensus(&self) -> &Self::Consensus { + unimplemented!() + } + + fn client(&self) -> &Self::Client { + unimplemented!() + } + + async fn download( + &self, + _: &HeaderLocked, + _: &ForkchoiceState, + ) -> Result, DownloadError> { + self.result.clone() + } + } + } +} diff --git a/crates/stages/src/stages/mod.rs b/crates/stages/src/stages/mod.rs new file mode 100644 index 000000000..9afee3d00 --- /dev/null +++ b/crates/stages/src/stages/mod.rs @@ -0,0 +1,2 @@ +/// The headers stage. +pub mod headers;