diff --git a/crates/stages/src/stages/headers.rs b/crates/stages/src/stages/headers.rs index 3ffa711c4..d30cfa7b8 100644 --- a/crates/stages/src/stages/headers.rs +++ b/crates/stages/src/stages/headers.rs @@ -1,12 +1,13 @@ use crate::{ + util::unwind::{unwind_table_by_num, unwind_table_by_num_hash}, DatabaseIntegrityError, ExecInput, ExecOutput, Stage, StageError, StageId, UnwindInput, UnwindOutput, }; use reth_interfaces::{ consensus::{Consensus, ForkchoiceState}, db::{ - self, models::blocks::BlockNumHash, tables, DBContainer, Database, DatabaseGAT, DbCursorRO, - DbCursorRW, DbTx, DbTxMut, Table, + models::blocks::BlockNumHash, tables, DBContainer, Database, DatabaseGAT, DbCursorRO, + DbCursorRW, DbTx, DbTxMut, }, p2p::headers::{ client::HeadersClient, @@ -17,7 +18,7 @@ use reth_primitives::{rpc::BigEndianHash, BlockNumber, HeaderLocked, H256, U256} use std::{fmt::Debug, sync::Arc}; use tracing::*; -const HEADERS: StageId = StageId("HEADERS"); +const HEADERS: StageId = StageId("Headers"); /// The headers stage implementation for staged sync #[derive(Debug)] @@ -47,20 +48,17 @@ impl Stage input: ExecInput, ) -> Result { let tx = db.get_mut(); - let last_block_num = - input.previous_stage.as_ref().map(|(_, block)| *block).unwrap_or_default(); + let last_block_num = input.stage_progress.unwrap_or_default(); self.update_head::(tx, last_block_num).await?; // download the headers // TODO: handle input.max_block - let last_hash = - tx.get::(last_block_num)?.ok_or_else(|| -> StageError { - DatabaseIntegrityError::CannonicalHash { number: last_block_num }.into() - })?; - let last_header = tx - .get::((last_block_num, last_hash).into())? - .ok_or_else(|| -> StageError { - DatabaseIntegrityError::Header { number: last_block_num, hash: last_hash }.into() + let last_hash = tx + .get::(last_block_num)? + .ok_or(DatabaseIntegrityError::CannonicalHash { number: last_block_num })?; + let last_header = + tx.get::((last_block_num, last_hash).into())?.ok_or({ + DatabaseIntegrityError::Header { number: last_block_num, hash: last_hash } })?; let head = HeaderLocked::new(last_header, last_hash); @@ -105,11 +103,11 @@ impl Stage input: UnwindInput, ) -> Result> { // TODO: handle bad block - let tx = &mut db.get_mut(); - self.unwind_table::(tx, input.unwind_to, |num| num)?; - self.unwind_table::(tx, input.unwind_to, |key| key.0 .0)?; - self.unwind_table::(tx, input.unwind_to, |key| key.0 .0)?; - self.unwind_table::(tx, input.unwind_to, |key| key.0 .0)?; + let tx = db.get_mut(); + unwind_table_by_num::(tx, input.unwind_to)?; + unwind_table_by_num_hash::(tx, input.unwind_to)?; + unwind_table_by_num_hash::(tx, input.unwind_to)?; + unwind_table_by_num_hash::(tx, input.unwind_to)?; Ok(UnwindOutput { stage_progress: input.unwind_to }) } } @@ -120,9 +118,9 @@ impl HeaderStage { tx: &mut >::TXMut, height: BlockNumber, ) -> Result<(), StageError> { - let hash = tx.get::(height)?.ok_or_else(|| -> StageError { - DatabaseIntegrityError::CannonicalHeader { number: height }.into() - })?; + let hash = tx + .get::(height)? + .ok_or(DatabaseIntegrityError::CannonicalHeader { number: height })?; let td: Vec = tx.get::((height, hash).into())?.unwrap(); // TODO: self.client.update_status(height, hash, H256::from_slice(&td)).await; Ok(()) @@ -168,63 +166,30 @@ impl HeaderStage { // TODO: investigate default write flags cursor_header_number.append(key, header.number)?; cursor_header.append(key, header)?; - cursor_canonical.append(key.0 .0, key.0 .1)?; + cursor_canonical.append(key.number(), key.hash())?; cursor_td.append(key, H256::from_uint(&td).as_bytes().to_vec())?; } Ok(latest) } - - /// Unwind the table to a provided block - fn unwind_table( - &self, - tx: &mut >::TXMut, - block: BlockNumber, - mut selector: F, - ) -> Result<(), db::Error> - where - DB: Database, - T: Table, - F: FnMut(T::Key) -> BlockNumber, - { - let mut cursor = tx.cursor_mut::()?; - let mut entry = cursor.last()?; - while let Some((key, _)) = entry { - if selector(key) <= block { - break - } - cursor.delete_current()?; - entry = cursor.prev()?; - } - Ok(()) - } } #[cfg(test)] -pub mod tests { +mod tests { use super::*; + use crate::util::test_utils::StageTestRunner; use assert_matches::assert_matches; - use reth_db::{kv::Env, mdbx::WriteMap}; - use reth_headers_downloaders::linear::LinearDownloadBuilder; - use reth_interfaces::{ - db::DBContainer, - test_utils::{ - gen_random_header, gen_random_header_range, TestConsensus, TestHeadersClient, - }, - }; - use std::{borrow::Borrow, sync::Arc}; - use test_utils::HeadersDB; - use tokio::sync::oneshot; + use reth_interfaces::test_utils::{gen_random_header, gen_random_header_range}; + use test_utils::{HeadersTestRunner, TestDownloader}; - const TEST_STAGE: StageId = StageId("HEADERS"); + const TEST_STAGE: StageId = StageId("Headers"); #[tokio::test] // Check that the execution errors on empty database or // prev progress missing from the database. - async fn headers_execute_empty_db() { - let db = HeadersDB::default(); - let input = ExecInput { previous_stage: None, stage_progress: None }; - let rx = execute_stage(db.inner(), input, H256::zero(), Ok(vec![])); + async fn execute_empty_db() { + let runner = HeadersTestRunner::default(); + let rx = runner.execute(ExecInput::default()); assert_matches!( rx.await.unwrap(), Err(StageError::DatabaseIntegrity(DatabaseIntegrityError::CannonicalHeader { .. })) @@ -233,301 +198,271 @@ pub mod tests { #[tokio::test] // Check that the execution exits on downloader timeout. - async fn headers_execute_timeout() { + async fn execute_timeout() { let head = gen_random_header(0, None); - let db = HeadersDB::default(); - db.insert_header(&head).expect("failed to insert header"); - - let input = ExecInput { previous_stage: None, stage_progress: None }; - let rx = execute_stage( - db.inner(), - input, - H256::from_low_u64_be(1), - Err(DownloadError::Timeout { request_id: 0 }), - ); + let runner = + HeadersTestRunner::with_downloader(TestDownloader::new(Err(DownloadError::Timeout { + request_id: 0, + }))); + runner.insert_header(&head).expect("failed to insert header"); + let rx = runner.execute(ExecInput::default()); + runner.consensus.update_tip(H256::from_low_u64_be(1)); assert_matches!(rx.await.unwrap(), Ok(ExecOutput { done, .. }) if !done); } #[tokio::test] // Check that validation error is propagated during the execution. - async fn headers_execute_validation_error() { + async fn execute_validation_error() { let head = gen_random_header(0, None); - let db = HeadersDB::default(); - db.insert_header(&head).expect("failed to insert header"); - - let input = ExecInput { previous_stage: None, stage_progress: None }; - let rx = execute_stage( - db.inner(), - input, - H256::from_low_u64_be(1), - Err(DownloadError::HeaderValidation { hash: H256::zero(), details: "".to_owned() }), - ); + let runner = HeadersTestRunner::with_downloader(TestDownloader::new(Err( + DownloadError::HeaderValidation { hash: H256::zero(), details: "".to_owned() }, + ))); + runner.insert_header(&head).expect("failed to insert header"); + let rx = runner.execute(ExecInput::default()); + runner.consensus.update_tip(H256::from_low_u64_be(1)); assert_matches!(rx.await.unwrap(), Err(StageError::Validation { block }) if block == 0); } #[tokio::test] // Validate that all necessary tables are updated after the // header download on no previous progress. - async fn headers_execute_no_progress() { + async fn execute_no_progress() { let (start, end) = (0, 100); let head = gen_random_header(start, None); let headers = gen_random_header_range(start + 1..end, head.hash()); - let db = HeadersDB::default(); - db.insert_header(&head).expect("failed to insert header"); - let result: Vec<_> = headers.iter().rev().cloned().collect(); - let input = ExecInput { previous_stage: None, stage_progress: None }; + let result = headers.iter().rev().cloned().collect::>(); + let runner = HeadersTestRunner::with_downloader(TestDownloader::new(Ok(result))); + runner.insert_header(&head).expect("failed to insert header"); + + let rx = runner.execute(ExecInput::default()); let tip = headers.last().unwrap(); - let rx = execute_stage(db.inner(), input, tip.hash(), Ok(result)); + runner.consensus.update_tip(tip.hash()); - let result = rx.await.unwrap(); - assert_matches!(result, Ok(ExecOutput { .. })); - let result = result.unwrap(); - assert!(result.done && result.reached_tip); - assert_eq!(result.stage_progress, tip.number); - - for header in headers { - assert!(db.validate_db_header(&header).is_ok()); - } + assert_matches!( + rx.await.unwrap(), + Ok(ExecOutput { done, reached_tip, stage_progress }) + if done && reached_tip && stage_progress == tip.number + ); + assert!(headers.iter().try_for_each(|h| runner.validate_db_header(&h)).is_ok()); } #[tokio::test] // Validate that all necessary tables are updated after the // header download with some previous progress. - async fn headers_stage_prev_progress() { + async fn execute_prev_progress() { let (start, end) = (10000, 10241); let head = gen_random_header(start, None); let headers = gen_random_header_range(start + 1..end, head.hash()); - let db = HeadersDB::default(); - db.insert_header(&head).expect("failed to insert header"); - let result: Vec<_> = headers.iter().rev().cloned().collect(); - let input = ExecInput { + let result = headers.iter().rev().cloned().collect::>(); + let runner = HeadersTestRunner::with_downloader(TestDownloader::new(Ok(result))); + runner.insert_header(&head).expect("failed to insert header"); + + let rx = runner.execute(ExecInput { previous_stage: Some((TEST_STAGE, head.number)), stage_progress: Some(head.number), - }; + }); let tip = headers.last().unwrap(); - let rx = execute_stage(db.inner(), input, tip.hash(), Ok(result)); + runner.consensus.update_tip(tip.hash()); - let result = rx.await.unwrap(); - assert_matches!(result, Ok(ExecOutput { .. })); - let result = result.unwrap(); - assert!(result.done && result.reached_tip); - assert_eq!(result.stage_progress, tip.number); - - for header in headers { - assert!(db.validate_db_header(&header).is_ok()); - } + assert_matches!( + rx.await.unwrap(), + Ok(ExecOutput { done, reached_tip, stage_progress }) + if done && reached_tip && stage_progress == tip.number + ); + assert!(headers.iter().try_for_each(|h| runner.validate_db_header(&h)).is_ok()); } #[tokio::test] // Execute the stage with linear downloader - async fn headers_execute_linear() { - // TODO: set bigger range once `MDBX_EKEYMISMATCH` issue is resolved - let (start, end) = (1000, 1024); + async fn execute_with_linear_downloader() { + let (start, end) = (1000, 1200); let head = gen_random_header(start, None); let headers = gen_random_header_range(start + 1..end, head.hash()); - let db = HeadersDB::default(); - db.insert_header(&head).expect("failed to insert header"); - let input = ExecInput { + let runner = HeadersTestRunner::with_linear_downloader(); + runner.insert_header(&head).expect("failed to insert header"); + let rx = runner.execute(ExecInput { previous_stage: Some((TEST_STAGE, head.number)), stage_progress: Some(head.number), - }; + }); + let tip = headers.last().unwrap(); + runner.consensus.update_tip(tip.hash()); + let mut download_result = headers.clone(); download_result.insert(0, head); - let rx = execute_stage_linear(db.inner(), input, tip.hash(), download_result).await; + runner + .client + .on_header_request(1, |id, _| { + runner.client.send_header_response( + id, + download_result.clone().into_iter().map(|h| h.unlock()).collect(), + ) + }) + .await; - let result = rx.await.unwrap(); - assert_matches!(result, Ok(ExecOutput { .. })); - let result = result.unwrap(); - assert!(result.done && result.reached_tip); - assert_eq!(result.stage_progress, tip.number); - - for header in headers { - assert!(db.validate_db_header(&header).is_ok()); - } + assert_matches!( + rx.await.unwrap(), + Ok(ExecOutput { done, reached_tip, stage_progress }) + if done && reached_tip && stage_progress == tip.number + ); + assert!(headers.iter().try_for_each(|h| runner.validate_db_header(&h)).is_ok()); } #[tokio::test] // Check that unwind does not panic on empty database. - async fn headers_unwind_empty_db() { - let db = HeadersDB::default(); - let input = UnwindInput { bad_block: None, stage_progress: 100, unwind_to: 100 }; - let rx = unwind_stage(db.inner(), input); + async fn unwind_empty_db() { + let unwind_to = 100; + let runner = HeadersTestRunner::default(); + let rx = + runner.unwind(UnwindInput { bad_block: None, stage_progress: unwind_to, unwind_to }); assert_matches!( rx.await.unwrap(), - Ok(UnwindOutput {stage_progress} ) if stage_progress == input.unwind_to + Ok(UnwindOutput {stage_progress} ) if stage_progress == unwind_to ); } #[tokio::test] // Check that unwind can remove headers across gaps - async fn headers_unwind_db_gaps() { + async fn unwind_db_gaps() { + let runner = HeadersTestRunner::default(); let head = gen_random_header(0, None); let first_range = gen_random_header_range(1..20, head.hash()); let second_range = gen_random_header_range(50..100, H256::zero()); - let db = HeadersDB::default(); - db.insert_header(&head).expect("failed to insert header"); - for header in first_range.iter().chain(second_range.iter()) { - db.insert_header(header).expect("failed to insert header"); - } + runner.insert_header(&head).expect("failed to insert header"); + runner + .insert_headers(first_range.iter().chain(second_range.iter())) + .expect("failed to insert headers"); - let input = UnwindInput { bad_block: None, stage_progress: 15, unwind_to: 15 }; - let rx = unwind_stage(db.inner(), input); + let unwind_to = 15; + let rx = + runner.unwind(UnwindInput { bad_block: None, stage_progress: unwind_to, unwind_to }); assert_matches!( rx.await.unwrap(), - Ok(UnwindOutput {stage_progress} ) if stage_progress == input.unwind_to + Ok(UnwindOutput {stage_progress} ) if stage_progress == unwind_to ); - db.check_no_entry_above::(input.unwind_to, |key| key) + runner + .db() + .check_no_entry_above::(unwind_to, |key| key) .expect("failed to check cannonical headers"); - db.check_no_entry_above::(input.unwind_to, |key| key.0 .0) + runner + .db() + .check_no_entry_above::(unwind_to, |key| key.number()) .expect("failed to check header numbers"); - db.check_no_entry_above::(input.unwind_to, |key| key.0 .0) + runner + .db() + .check_no_entry_above::(unwind_to, |key| key.number()) .expect("failed to check headers"); - db.check_no_entry_above::(input.unwind_to, |key| key.0 .0) + runner + .db() + .check_no_entry_above::(unwind_to, |key| key.number()) .expect("failed to check td"); } - // A helper function to run [HeaderStage::execute] - // with default consensus, client & test downloader - fn execute_stage( - db: Arc>, - input: ExecInput, - tip: H256, - download_result: Result, DownloadError>, - ) -> oneshot::Receiver> { - let (tx, rx) = oneshot::channel(); - - let client = Arc::new(TestHeadersClient::default()); - let consensus = Arc::new(TestConsensus::default()); - let downloader = test_utils::TestDownloader::new(download_result); - - let mut stage = HeaderStage { consensus: consensus.clone(), client, downloader }; - tokio::spawn(async move { - let mut db = DBContainer::>::new(db.borrow()).unwrap(); - let result = stage.execute(&mut db, input).await; - db.commit().expect("failed to commit"); - tx.send(result).expect("failed to send result"); - }); - consensus.update_tip(tip); - rx - } - - // A helper function to run [HeaderStage::execute] - // with linear downloader - async fn execute_stage_linear( - db: Arc>, - input: ExecInput, - tip: H256, - headers: Vec, - ) -> oneshot::Receiver> { - let (tx, rx) = oneshot::channel(); - - let consensus = Arc::new(TestConsensus::default()); - let client = Arc::new(TestHeadersClient::default()); - let downloader = LinearDownloadBuilder::new().build(consensus.clone(), client.clone()); - - let mut stage = - HeaderStage { consensus: consensus.clone(), client: client.clone(), downloader }; - tokio::spawn(async move { - let mut db = DBContainer::>::new(db.borrow()).unwrap(); - let result = stage.execute(&mut db, input).await; - db.commit().expect("failed to commit"); - tx.send(result).expect("failed to send result"); - }); - - consensus.update_tip(tip); - client - .on_header_request(1, |id, _| { - client.send_header_response( - id, - headers.clone().into_iter().map(|h| h.unlock()).collect(), - ) - }) - .await; - rx - } - - // A helper function to run [HeaderStage::unwind] - fn unwind_stage( - db: Arc>, - input: UnwindInput, - ) -> oneshot::Receiver>> { - let (tx, rx) = oneshot::channel(); - let mut stage = HeaderStage { - client: Arc::new(TestHeadersClient::default()), - consensus: Arc::new(TestConsensus::default()), - downloader: test_utils::TestDownloader::new(Ok(vec![])), + mod test_utils { + use crate::{ + stages::headers::HeaderStage, + util::test_utils::{StageTestDB, StageTestRunner}, }; - tokio::spawn(async move { - let mut db = DBContainer::>::new(db.borrow()).unwrap(); - let result = stage.unwind(&mut db, input).await; - db.commit().expect("failed to commit"); - tx.send(result).expect("failed to send result"); - }); - rx - } - - pub(crate) mod test_utils { use async_trait::async_trait; - use reth_db::{ - kv::{test_utils::create_test_db, Env, EnvKind}, - mdbx, - mdbx::WriteMap, - }; + use reth_headers_downloaders::linear::{LinearDownloadBuilder, LinearDownloader}; use reth_interfaces::{ consensus::ForkchoiceState, - db::{ - self, models::blocks::BlockNumHash, tables, DBContainer, DbCursorRO, DbCursorRW, - DbTx, DbTxMut, Table, - }, + db::{self, models::blocks::BlockNumHash, tables, DbTx}, p2p::headers::downloader::{DownloadError, Downloader}, test_utils::{TestConsensus, TestHeadersClient}, }; - use reth_primitives::{rpc::BigEndianHash, BlockNumber, HeaderLocked, H256, U256}; - use std::{borrow::Borrow, sync::Arc, time::Duration}; + use reth_primitives::{rpc::BigEndianHash, HeaderLocked, H256, U256}; + use std::{ops::Deref, sync::Arc, time::Duration}; - pub(crate) struct HeadersDB { - db: Arc>, + pub(crate) struct HeadersTestRunner { + pub(crate) consensus: Arc, + pub(crate) client: Arc, + downloader: Arc, + db: StageTestDB, } - impl Default for HeadersDB { + impl Default for HeadersTestRunner { fn default() -> Self { - Self { db: Arc::new(create_test_db::(EnvKind::RW)) } + Self { + client: Arc::new(TestHeadersClient::default()), + consensus: Arc::new(TestConsensus::default()), + downloader: Arc::new(TestDownloader::new(Ok(Vec::default()))), + db: StageTestDB::default(), + } } } - impl HeadersDB { - pub(crate) fn inner(&self) -> Arc> { - self.db.clone() + impl StageTestRunner for HeadersTestRunner { + type S = HeaderStage, TestConsensus, TestHeadersClient>; + + fn db(&self) -> &StageTestDB { + &self.db } - fn container(&self) -> DBContainer<'_, Env> { - DBContainer::new(self.db.borrow()).expect("failed to create db container") + fn stage(&self) -> Self::S { + HeaderStage { + consensus: self.consensus.clone(), + client: self.client.clone(), + downloader: self.downloader.clone(), + } + } + } + + impl HeadersTestRunner> { + pub(crate) fn with_linear_downloader() -> Self { + let client = Arc::new(TestHeadersClient::default()); + let consensus = Arc::new(TestConsensus::default()); + let downloader = + Arc::new(LinearDownloadBuilder::new().build(consensus.clone(), client.clone())); + Self { client, consensus, downloader, db: StageTestDB::default() } + } + } + + impl HeadersTestRunner { + pub(crate) fn with_downloader(downloader: D) -> Self { + HeadersTestRunner { + client: Arc::new(TestHeadersClient::default()), + consensus: Arc::new(TestConsensus::default()), + downloader: Arc::new(downloader), + db: StageTestDB::default(), + } } /// Insert header into tables pub(crate) fn insert_header(&self, header: &HeaderLocked) -> Result<(), db::Error> { - let mut db = self.container(); - let tx = db.get_mut(); + self.insert_headers(std::iter::once(header)) + } - let key: BlockNumHash = (header.number, header.hash()).into(); - tx.put::(key, header.number)?; - tx.put::(key, header.clone().unlock())?; - tx.put::(header.number, header.hash())?; + /// Insert headers into tables + pub(crate) fn insert_headers<'a, I>(&self, headers: I) -> Result<(), db::Error> + where + I: Iterator, + { + let headers = headers.collect::>(); + self.db.map_put::(&headers, |h| { + (BlockNumHash((h.number, h.hash())), h.number) + })?; + self.db.map_put::(&headers, |h| { + (BlockNumHash((h.number, h.hash())), h.deref().clone().unlock()) + })?; + self.db.map_put::(&headers, |h| { + (h.number, h.hash()) + })?; - let mut cursor_td = tx.cursor_mut::()?; - let td = - U256::from_big_endian(&cursor_td.last()?.map(|(_, v)| v).unwrap_or_default()); - cursor_td - .append(key, H256::from_uint(&(td + header.difficulty)).as_bytes().to_vec())?; + self.db.transform_append::(&headers, |prev, h| { + let prev_td = U256::from_big_endian(&prev.clone().unwrap_or_default()); + ( + BlockNumHash((h.number, h.hash())), + H256::from_uint(&(prev_td + h.difficulty)).as_bytes().to_vec(), + ) + })?; - db.commit()?; Ok(()) } @@ -536,7 +471,7 @@ pub mod tests { &self, header: &HeaderLocked, ) -> Result<(), db::Error> { - let db = self.container(); + let db = self.db.container(); let tx = db.get(); let key: BlockNumHash = (header.number, header.hash()).into(); @@ -561,27 +496,6 @@ pub mod tests { Ok(()) } - - /// Check there there is no table entry above given block - pub(crate) fn check_no_entry_above( - &self, - block: BlockNumber, - mut selector: F, - ) -> Result<(), db::Error> - where - T: Table, - F: FnMut(T::Key) -> BlockNumber, - { - let db = self.container(); - let tx = db.get(); - - let mut cursor = tx.cursor::()?; - if let Some((key, _)) = cursor.last()? { - assert!(selector(key) <= block); - } - - Ok(()) - } } #[derive(Debug)] diff --git a/crates/stages/src/util.rs b/crates/stages/src/util.rs index 6ad50e0bd..c205aabe3 100644 --- a/crates/stages/src/util.rs +++ b/crates/stages/src/util.rs @@ -68,7 +68,21 @@ pub(crate) mod unwind { }; use reth_primitives::BlockNumber; + /// Unwind table by block number key + #[inline] + pub(crate) fn unwind_table_by_num( + tx: &mut >::TXMut, + block: BlockNumber, + ) -> Result<(), Error> + where + DB: Database, + T: Table, + { + unwind_table::(tx, block, |key| key) + } + /// Unwind table by composite block number hash key + #[inline] pub(crate) fn unwind_table_by_num_hash( tx: &mut >::TXMut, block: BlockNumber,