diff --git a/bin/reth/src/config.rs b/bin/reth/src/config.rs index f0fcd5458..0030f58f3 100644 --- a/bin/reth/src/config.rs +++ b/bin/reth/src/config.rs @@ -47,6 +47,8 @@ impl Config { pub struct StageConfig { /// Header stage configuration. pub headers: HeadersConfig, + /// Total difficulty stage configuration + pub total_difficulty: TotalDifficultyConfig, /// Body stage configuration. pub bodies: BodiesConfig, /// Sender recovery stage configuration. @@ -70,6 +72,20 @@ impl Default for HeadersConfig { } } +/// Total difficulty stage configuration +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct TotalDifficultyConfig { + /// The maximum number of total difficulty entries to sum up before committing progress to the + /// database. + pub commit_threshold: u64, +} + +impl Default for TotalDifficultyConfig { + fn default() -> Self { + Self { commit_threshold: 100_000 } + } +} + /// Body stage configuration. #[derive(Debug, Clone, Deserialize, Serialize)] pub struct BodiesConfig { diff --git a/bin/reth/src/node/mod.rs b/bin/reth/src/node/mod.rs index fe774946b..7041105ed 100644 --- a/bin/reth/src/node/mod.rs +++ b/bin/reth/src/node/mod.rs @@ -25,7 +25,7 @@ use reth_stages::{ metrics::HeaderMetrics, stages::{ bodies::BodyStage, execution::ExecutionStage, headers::HeaderStage, - sender_recovery::SenderRecoveryStage, + sender_recovery::SenderRecoveryStage, total_difficulty::TotalDifficultyStage, }, }; use std::{net::SocketAddr, path::Path, sync::Arc}; @@ -130,6 +130,9 @@ impl Command { commit_threshold: config.stages.headers.commit_threshold, metrics: HeaderMetrics::default(), }) + .push(TotalDifficultyStage { + commit_threshold: config.stages.total_difficulty.commit_threshold, + }) .push(BodyStage { downloader: Arc::new( bodies::concurrent::ConcurrentDownloader::new( diff --git a/crates/stages/src/stages/headers.rs b/crates/stages/src/stages/headers.rs index 918e3948b..b092a7ac8 100644 --- a/crates/stages/src/stages/headers.rs +++ b/crates/stages/src/stages/headers.rs @@ -36,11 +36,9 @@ const HEADERS: StageId = StageId("Headers"); /// - [`HeaderNumbers`][reth_interfaces::db::tables::HeaderNumbers] /// - [`Headers`][reth_interfaces::db::tables::Headers] /// - [`CanonicalHeaders`][reth_interfaces::db::tables::CanonicalHeaders] -/// - [`HeaderTD`][reth_interfaces::db::tables::HeaderTD] /// -/// NOTE: This stage commits the header changes to the database (everything except the changes to -/// [`HeaderTD`][reth_interfaces::db::tables::HeaderTD] table). The stage does not return the -/// control flow to the pipeline in order to preserve the context of the chain tip. +/// NOTE: This stage downloads headers in reverse. Upon returning the control flow to the pipeline, +/// the stage progress is not updated unless this stage is done. #[derive(Debug)] pub struct HeaderStage { /// Strategy for downloading the headers @@ -101,9 +99,6 @@ impl(tx, res).await?.unwrap_or_default(); if self.is_stage_done(tx, current_progress).await? { - // Update total difficulty values after we have reached fork choice - debug!(target: "sync::stages::headers", head = ?head.hash(), "Writing total difficulty"); - self.write_td::(tx, &head)?; let stage_progress = current_progress.max( tx.cursor::()? .last()? @@ -147,7 +142,6 @@ impl(input.unwind_to)?; tx.unwind_table_by_num_hash::(input.unwind_to)?; - tx.unwind_table_by_num_hash::(input.unwind_to)?; Ok(UnwindOutput { stage_progress: input.unwind_to }) } } @@ -280,33 +274,6 @@ impl } Ok(latest) } - - /// Iterate over inserted headers and write td entries - fn write_td( - &self, - tx: &Transaction<'_, DB>, - head: &SealedHeader, - ) -> Result<(), StageError> { - // Acquire cursor over total difficulty table - let mut cursor_td = tx.cursor_mut::()?; - - // Get latest total difficulty - let last_entry = cursor_td - .seek_exact(head.num_hash().into())? - .ok_or(DatabaseIntegrityError::TotalDifficulty { number: head.number })?; - let mut td: U256 = last_entry.1.into(); - - // Start at first inserted block during this iteration - let start_key = tx.get_block_numhash(head.number + 1)?; - - // Walk over newly inserted headers, update & insert td - for entry in tx.cursor::()?.walk(start_key)? { - let (key, header) = entry?; - td += header.difficulty; - cursor_td.append(key, td.into())?; - } - Ok(()) - } } #[cfg(test)] @@ -472,7 +439,11 @@ mod tests { }, ExecInput, ExecOutput, UnwindInput, }; - use reth_db::{models::blocks::BlockNumHash, tables, transaction::DbTx}; + use reth_db::{ + models::blocks::BlockNumHash, + tables, + transaction::{DbTx, DbTxMut}, + }; use reth_downloaders::headers::linear::{LinearDownloadBuilder, LinearDownloader}; use reth_interfaces::{ p2p::headers::downloader::HeaderDownloader, @@ -533,6 +504,10 @@ mod tests { let start = input.stage_progress.unwrap_or_default(); let head = random_header(start, None); self.tx.insert_headers(std::iter::once(&head))?; + // patch td table for `update_head` call + self.tx.commit(|tx| { + tx.put::(head.num_hash().into(), U256::zero().into()) + })?; // use previous progress as seed size let end = input.previous_stage.map(|(_, num)| num).unwrap_or_default() + 1; @@ -571,18 +546,6 @@ mod tests { assert!(header.is_some()); let header = header.unwrap().seal(); assert_eq!(header.hash(), hash); - - // validate td consistency in the database - if header.number > initial_stage_progress { - let parent_td = tx.get::( - (header.number - 1, header.parent_hash).into(), - )?; - let td: U256 = *tx.get::(key)?.unwrap(); - assert_eq!( - parent_td.map(|td| *td + header.difficulty), - Some(td) - ); - } } Ok(()) })?; @@ -639,7 +602,6 @@ mod tests { .check_no_entry_above_by_value::(block, |val| val)?; self.tx.check_no_entry_above::(block, |key| key)?; self.tx.check_no_entry_above::(block, |key| key.number())?; - self.tx.check_no_entry_above::(block, |key| key.number())?; Ok(()) } } diff --git a/crates/stages/src/stages/mod.rs b/crates/stages/src/stages/mod.rs index 80b07365b..4bc1860a8 100644 --- a/crates/stages/src/stages/mod.rs +++ b/crates/stages/src/stages/mod.rs @@ -6,3 +6,5 @@ pub mod execution; pub mod headers; /// The sender recovery stage. pub mod sender_recovery; +/// The total difficulty stage +pub mod total_difficulty; diff --git a/crates/stages/src/stages/total_difficulty.rs b/crates/stages/src/stages/total_difficulty.rs new file mode 100644 index 000000000..59ea5fed1 --- /dev/null +++ b/crates/stages/src/stages/total_difficulty.rs @@ -0,0 +1,205 @@ +use crate::{ + db::Transaction, DatabaseIntegrityError, ExecInput, ExecOutput, Stage, StageError, StageId, + UnwindInput, UnwindOutput, +}; +use reth_db::{ + cursor::{DbCursorRO, DbCursorRW}, + database::Database, + tables, + transaction::DbTxMut, +}; +use reth_primitives::U256; +use tracing::*; + +const TOTAL_DIFFICULTY: StageId = StageId("TotalDifficulty"); + +/// The total difficulty stage. +/// +/// This stage walks over inserted headers and computes total difficulty +/// at each block. The entries are inserted into [`HeaderTD`][reth_interfaces::db::tables::HeaderTD] +/// table. +#[derive(Debug)] +pub struct TotalDifficultyStage { + /// The number of table entries to commit at once + pub commit_threshold: u64, +} + +#[async_trait::async_trait] +impl Stage for TotalDifficultyStage { + /// Return the id of the stage + fn id(&self) -> StageId { + TOTAL_DIFFICULTY + } + + /// Write total difficulty entries + async fn execute( + &mut self, + tx: &mut Transaction<'_, DB>, + input: ExecInput, + ) -> Result { + let stage_progress = input.stage_progress.unwrap_or_default(); + let previous_stage_progress = input.previous_stage_progress(); + + let start_block = stage_progress + 1; + let end_block = previous_stage_progress.min(start_block + self.commit_threshold); + + if start_block > end_block { + info!(target: "sync::stages::total_difficulty", stage_progress, "Target block already reached"); + return Ok(ExecOutput { stage_progress, done: true }) + } + + debug!(target: "sync::stages::total_difficulty", start_block, end_block, "Commencing sync"); + + // Acquire cursor over total difficulty and headers tables + let mut cursor_td = tx.cursor_mut::()?; + let mut cursor_headers = tx.cursor_mut::()?; + + // Get latest total difficulty + let last_header_key = tx.get_block_numhash(stage_progress)?; + let last_entry = cursor_td + .seek_exact(last_header_key)? + .ok_or(DatabaseIntegrityError::TotalDifficulty { number: last_header_key.number() })?; + + let mut td: U256 = last_entry.1.into(); + debug!(target: "sync::stages::total_difficulty", ?td, block_number = last_header_key.number(), "Last total difficulty entry"); + + let start_key = tx.get_block_numhash(start_block)?; + let walker = cursor_headers + .walk(start_key)? + .take_while(|e| e.as_ref().map(|(_, h)| h.number <= end_block).unwrap_or_default()); + // Walk over newly inserted headers, update & insert td + for entry in walker { + let (key, header) = entry?; + td += header.difficulty; + cursor_td.append(key, td.into())?; + } + + let done = end_block >= previous_stage_progress; + info!(target: "sync::stages::total_difficulty", stage_progress = end_block, done, "Sync iteration finished"); + Ok(ExecOutput { done, stage_progress: end_block }) + } + + /// Unwind the stage. + async fn unwind( + &mut self, + tx: &mut Transaction<'_, DB>, + input: UnwindInput, + ) -> Result> { + tx.unwind_table_by_num_hash::(input.unwind_to)?; + Ok(UnwindOutput { stage_progress: input.unwind_to }) + } +} + +#[cfg(test)] +mod tests { + use reth_db::transaction::DbTx; + use reth_interfaces::test_utils::generators::{random_header, random_header_range}; + use reth_primitives::{BlockNumber, SealedHeader}; + + use super::*; + use crate::test_utils::{ + stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, TestRunnerError, + TestTransaction, UnwindStageTestRunner, + }; + + stage_test_suite_ext!(TotalDifficultyTestRunner); + + #[derive(Default)] + struct TotalDifficultyTestRunner { + tx: TestTransaction, + } + + impl StageTestRunner for TotalDifficultyTestRunner { + type S = TotalDifficultyStage; + + fn tx(&self) -> &TestTransaction { + &self.tx + } + + fn stage(&self) -> Self::S { + TotalDifficultyStage { commit_threshold: 500 } + } + } + + #[async_trait::async_trait] + impl ExecuteStageTestRunner for TotalDifficultyTestRunner { + type Seed = Vec; + + fn seed_execution(&mut self, input: ExecInput) -> Result { + let start = input.stage_progress.unwrap_or_default(); + let head = random_header(start, None); + self.tx.insert_headers(std::iter::once(&head))?; + self.tx.commit(|tx| { + let td: U256 = tx + .cursor::()? + .last()? + .map(|(_, v)| v) + .unwrap_or_default() + .into(); + tx.put::(head.num_hash().into(), (td + head.difficulty).into()) + })?; + + // use previous progress as seed size + let end = input.previous_stage.map(|(_, num)| num).unwrap_or_default() + 1; + + if start + 1 >= end { + return Ok(Vec::default()) + } + + let mut headers = random_header_range(start + 1..end, head.hash()); + self.tx.insert_headers(headers.iter())?; + headers.insert(0, head); + Ok(headers) + } + + /// Validate stored headers + fn validate_execution( + &self, + input: ExecInput, + output: Option, + ) -> Result<(), TestRunnerError> { + let initial_stage_progress = input.stage_progress.unwrap_or_default(); + match output { + Some(output) if output.stage_progress > initial_stage_progress => { + self.tx.query(|tx| { + let start_hash = tx + .get::(initial_stage_progress)? + .expect("no initial header hash"); + let start_key = (initial_stage_progress, start_hash).into(); + let mut header_cursor = tx.cursor::()?; + let (_, mut current_header) = + header_cursor.seek_exact(start_key)?.expect("no initial header"); + let mut td: U256 = + tx.get::(start_key)?.expect("no initial td").into(); + + while let Some((next_key, next_header)) = header_cursor.next()? { + assert_eq!(current_header.number + 1, next_header.number); + td += next_header.difficulty; + assert_eq!( + tx.get::(next_key)?.map(Into::into), + Some(td) + ); + current_header = next_header; + } + Ok(()) + })?; + } + _ => self.check_no_td_above(initial_stage_progress)?, + }; + Ok(()) + } + } + + impl UnwindStageTestRunner for TotalDifficultyTestRunner { + fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> { + self.check_no_td_above(input.unwind_to) + } + } + + impl TotalDifficultyTestRunner { + fn check_no_td_above(&self, block: BlockNumber) -> Result<(), TestRunnerError> { + self.tx.check_no_entry_above::(block, |key| key.number())?; + Ok(()) + } + } +} diff --git a/crates/stages/src/test_utils/test_db.rs b/crates/stages/src/test_utils/test_db.rs index 349da4068..1db5cf33a 100644 --- a/crates/stages/src/test_utils/test_db.rs +++ b/crates/stages/src/test_utils/test_db.rs @@ -7,7 +7,7 @@ use reth_db::{ transaction::{DbTx, DbTxMut}, Error as DbError, }; -use reth_primitives::{BlockNumber, SealedHeader, U256}; +use reth_primitives::{BlockNumber, SealedHeader}; use std::{borrow::Borrow, sync::Arc}; use crate::db::Transaction; @@ -173,18 +173,12 @@ impl TestTransaction { self.commit(|tx| { let headers = headers.collect::>(); - let mut td: U256 = - tx.cursor::()?.last()?.map(|(_, v)| v).unwrap_or_default().into(); - for header in headers { let key: BlockNumHash = (header.number, header.hash()).into(); tx.put::(header.number, header.hash())?; tx.put::(header.hash(), header.number)?; tx.put::(key, header.clone().unseal())?; - - td += header.difficulty; - tx.put::(key, td.into())?; } Ok(())