From db5d01e328b2cabc37837c508665b454bafb04fc Mon Sep 17 00:00:00 2001 From: Bjerg Date: Fri, 17 Nov 2023 22:12:12 +0100 Subject: [PATCH] refactor: split async/sync work in stages (#4636) Co-authored-by: Roman Krasiuk --- bin/reth/src/chain/import.rs | 17 +- bin/reth/src/debug_cmd/execution.rs | 8 +- bin/reth/src/debug_cmd/merkle.rs | 69 ++-- bin/reth/src/node/mod.rs | 9 +- bin/reth/src/stage/dump/execution.rs | 33 +- bin/reth/src/stage/dump/hashing_account.rs | 44 ++- bin/reth/src/stage/dump/hashing_storage.rs | 44 ++- bin/reth/src/stage/dump/merkle.rs | 57 ++-- bin/reth/src/stage/run.rs | 27 +- .../consensus/beacon/src/engine/test_utils.rs | 9 +- crates/interfaces/src/provider.rs | 3 + crates/stages/benches/criterion.rs | 7 +- crates/stages/benches/setup/mod.rs | 23 +- crates/stages/src/error.rs | 4 + crates/stages/src/lib.rs | 10 +- crates/stages/src/pipeline/mod.rs | 203 ++++++------ crates/stages/src/sets.rs | 50 +-- crates/stages/src/stage.rs | 41 ++- crates/stages/src/stages/bodies.rs | 88 ++--- crates/stages/src/stages/execution.rs | 14 +- crates/stages/src/stages/finish.rs | 5 +- crates/stages/src/stages/hashing_account.rs | 13 +- crates/stages/src/stages/hashing_storage.rs | 5 +- crates/stages/src/stages/headers.rs | 302 +++++------------- .../src/stages/index_account_history.rs | 41 ++- .../src/stages/index_storage_history.rs | 41 ++- crates/stages/src/stages/merkle.rs | 5 +- crates/stages/src/stages/mod.rs | 10 +- crates/stages/src/stages/sender_recovery.rs | 14 +- crates/stages/src/stages/total_difficulty.rs | 5 +- crates/stages/src/stages/tx_lookup.rs | 5 +- crates/stages/src/test_utils/runner.rs | 13 +- crates/stages/src/test_utils/stage.rs | 5 +- crates/storage/provider/src/lib.rs | 10 +- .../provider/src/providers/database/mod.rs | 101 +++++- .../src/providers/database/provider.rs | 65 +++- .../provider/src/traits/header_sync_gap.rs | 50 +++ crates/storage/provider/src/traits/mod.rs | 3 + testing/ef-tests/src/cases/blockchain_test.rs | 3 +- 39 files changed, 775 insertions(+), 681 deletions(-) create mode 100644 crates/storage/provider/src/traits/header_sync_gap.rs diff --git a/bin/reth/src/chain/import.rs b/bin/reth/src/chain/import.rs index 984a34f8c..572f8c0ee 100644 --- a/bin/reth/src/chain/import.rs +++ b/bin/reth/src/chain/import.rs @@ -1,4 +1,8 @@ use crate::{ + args::{ + utils::{chain_help, genesis_value_parser, SUPPORTED_CHAINS}, + DatabaseArgs, + }, dirs::{DataDirPath, MaybePlatformPath}, init::init_genesis, node::events::{handle_events, NodeEvent}, @@ -8,12 +12,6 @@ use clap::Parser; use eyre::Context; use futures::{Stream, StreamExt}; use reth_beacon_consensus::BeaconConsensus; -use reth_provider::{ProviderFactory, StageCheckpointReader}; - -use crate::args::{ - utils::{chain_help, genesis_value_parser, SUPPORTED_CHAINS}, - DatabaseArgs, -}; use reth_config::Config; use reth_db::{database::Database, init_db}; use reth_downloaders::{ @@ -22,12 +20,10 @@ use reth_downloaders::{ }; use reth_interfaces::consensus::Consensus; use reth_primitives::{stage::StageId, ChainSpec, B256}; +use reth_provider::{HeaderSyncMode, ProviderFactory, StageCheckpointReader}; use reth_stages::{ prelude::*, - stages::{ - ExecutionStage, ExecutionStageThresholds, HeaderSyncMode, SenderRecoveryStage, - TotalDifficultyStage, - }, + stages::{ExecutionStage, ExecutionStageThresholds, SenderRecoveryStage, TotalDifficultyStage}, }; use std::{path::PathBuf, sync::Arc}; use tokio::sync::watch; @@ -164,6 +160,7 @@ impl ImportCommand { .with_max_block(max_block) .add_stages( DefaultStages::new( + ProviderFactory::new(db.clone(), self.chain.clone()), HeaderSyncMode::Tip(tip_rx), consensus.clone(), header_downloader, diff --git a/bin/reth/src/debug_cmd/execution.rs b/bin/reth/src/debug_cmd/execution.rs index fee6390d2..83c554945 100644 --- a/bin/reth/src/debug_cmd/execution.rs +++ b/bin/reth/src/debug_cmd/execution.rs @@ -27,13 +27,10 @@ use reth_interfaces::{ use reth_network::{NetworkEvents, NetworkHandle}; use reth_network_api::NetworkInfo; use reth_primitives::{fs, stage::StageId, BlockHashOrNumber, BlockNumber, ChainSpec, B256}; -use reth_provider::{BlockExecutionWriter, ProviderFactory, StageCheckpointReader}; +use reth_provider::{BlockExecutionWriter, HeaderSyncMode, ProviderFactory, StageCheckpointReader}; use reth_stages::{ sets::DefaultStages, - stages::{ - ExecutionStage, ExecutionStageThresholds, HeaderSyncMode, SenderRecoveryStage, - TotalDifficultyStage, - }, + stages::{ExecutionStage, ExecutionStageThresholds, SenderRecoveryStage, TotalDifficultyStage}, Pipeline, StageSet, }; use reth_tasks::TaskExecutor; @@ -118,6 +115,7 @@ impl Command { .with_tip_sender(tip_tx) .add_stages( DefaultStages::new( + ProviderFactory::new(db.clone(), self.chain.clone()), header_mode, Arc::clone(&consensus), header_downloader, diff --git a/bin/reth/src/debug_cmd/merkle.rs b/bin/reth/src/debug_cmd/merkle.rs index dc5f98e59..765d1f866 100644 --- a/bin/reth/src/debug_cmd/merkle.rs +++ b/bin/reth/src/debug_cmd/merkle.rs @@ -222,53 +222,42 @@ impl Command { None }; - execution_stage - .execute( - &provider_rw, - ExecInput { - target: Some(block), - checkpoint: block.checked_sub(1).map(StageCheckpoint::new), - }, - ) - .await?; + execution_stage.execute( + &provider_rw, + ExecInput { + target: Some(block), + checkpoint: block.checked_sub(1).map(StageCheckpoint::new), + }, + )?; let mut account_hashing_done = false; while !account_hashing_done { - let output = account_hashing_stage - .execute( - &provider_rw, - ExecInput { - target: Some(block), - checkpoint: progress.map(StageCheckpoint::new), - }, - ) - .await?; - account_hashing_done = output.done; - } - - let mut storage_hashing_done = false; - while !storage_hashing_done { - let output = storage_hashing_stage - .execute( - &provider_rw, - ExecInput { - target: Some(block), - checkpoint: progress.map(StageCheckpoint::new), - }, - ) - .await?; - storage_hashing_done = output.done; - } - - let incremental_result = merkle_stage - .execute( + let output = account_hashing_stage.execute( &provider_rw, ExecInput { target: Some(block), checkpoint: progress.map(StageCheckpoint::new), }, - ) - .await; + )?; + account_hashing_done = output.done; + } + + let mut storage_hashing_done = false; + while !storage_hashing_done { + let output = storage_hashing_stage.execute( + &provider_rw, + ExecInput { + target: Some(block), + checkpoint: progress.map(StageCheckpoint::new), + }, + )?; + storage_hashing_done = output.done; + } + + let incremental_result = merkle_stage.execute( + &provider_rw, + ExecInput { target: Some(block), checkpoint: progress.map(StageCheckpoint::new) }, + ); if incremental_result.is_err() { tracing::warn!(target: "reth::cli", block, "Incremental calculation failed, retrying from scratch"); @@ -285,7 +274,7 @@ impl Command { let clean_input = ExecInput { target: Some(block), checkpoint: None }; loop { - let clean_result = merkle_stage.execute(&provider_rw, clean_input).await; + let clean_result = merkle_stage.execute(&provider_rw, clean_input); assert!(clean_result.is_ok(), "Clean state root calculation failed"); if clean_result.unwrap().done { break diff --git a/bin/reth/src/node/mod.rs b/bin/reth/src/node/mod.rs index 3db510564..a144a3bce 100644 --- a/bin/reth/src/node/mod.rs +++ b/bin/reth/src/node/mod.rs @@ -61,7 +61,7 @@ use reth_primitives::{ }; use reth_provider::{ providers::BlockchainProvider, BlockHashReader, BlockReader, CanonStateSubscriptions, - HeaderProvider, ProviderFactory, StageCheckpointReader, + HeaderProvider, HeaderSyncMode, ProviderFactory, StageCheckpointReader, }; use reth_prune::{segments::SegmentSet, Pruner}; use reth_revm::Factory; @@ -71,9 +71,9 @@ use reth_snapshot::HighestSnapshotsTracker; use reth_stages::{ prelude::*, stages::{ - AccountHashingStage, ExecutionStage, ExecutionStageThresholds, HeaderSyncMode, - IndexAccountHistoryStage, IndexStorageHistoryStage, MerkleStage, SenderRecoveryStage, - StorageHashingStage, TotalDifficultyStage, TransactionLookupStage, + AccountHashingStage, ExecutionStage, ExecutionStageThresholds, IndexAccountHistoryStage, + IndexStorageHistoryStage, MerkleStage, SenderRecoveryStage, StorageHashingStage, + TotalDifficultyStage, TransactionLookupStage, }, }; use reth_tasks::TaskExecutor; @@ -896,6 +896,7 @@ impl NodeCommand { .with_metrics_tx(metrics_tx.clone()) .add_stages( DefaultStages::new( + ProviderFactory::new(db.clone(), self.chain.clone()), header_mode, Arc::clone(&consensus), header_downloader, diff --git a/bin/reth/src/stage/dump/execution.rs b/bin/reth/src/stage/dump/execution.rs index 67eda8033..5bc301bf8 100644 --- a/bin/reth/src/stage/dump/execution.rs +++ b/bin/reth/src/stage/dump/execution.rs @@ -100,16 +100,14 @@ async fn unwind_and_copy( let mut exec_stage = ExecutionStage::new_with_factory(Factory::new(db_tool.chain.clone())); - exec_stage - .unwind( - &provider, - UnwindInput { - unwind_to: from, - checkpoint: StageCheckpoint::new(tip_block_number), - bad_block: None, - }, - ) - .await?; + exec_stage.unwind( + &provider, + UnwindInput { + unwind_to: from, + checkpoint: StageCheckpoint::new(tip_block_number), + bad_block: None, + }, + )?; let unwind_inner_tx = provider.into_tx(); @@ -131,20 +129,13 @@ async fn dry_run( info!(target: "reth::cli", "Executing stage. [dry-run]"); let factory = ProviderFactory::new(&output_db, chain.clone()); - let provider = factory.provider_rw()?; let mut exec_stage = ExecutionStage::new_with_factory(Factory::new(chain.clone())); - exec_stage - .execute( - &provider, - reth_stages::ExecInput { - target: Some(to), - checkpoint: Some(StageCheckpoint::new(from)), - }, - ) - .await?; + let input = + reth_stages::ExecInput { target: Some(to), checkpoint: Some(StageCheckpoint::new(from)) }; + exec_stage.execute(&factory.provider_rw()?, input)?; - info!(target: "reth::cli", "Success."); + info!(target: "reth::cli", "Success"); Ok(()) } diff --git a/bin/reth/src/stage/dump/hashing_account.rs b/bin/reth/src/stage/dump/hashing_account.rs index 2a947d013..7fe723257 100644 --- a/bin/reth/src/stage/dump/hashing_account.rs +++ b/bin/reth/src/stage/dump/hashing_account.rs @@ -22,7 +22,7 @@ pub(crate) async fn dump_hashing_account_stage( tx.import_table_with_range::(&db_tool.db.tx()?, Some(from), to) })??; - unwind_and_copy(db_tool, from, tip_block_number, &output_db).await?; + unwind_and_copy(db_tool, from, tip_block_number, &output_db)?; if should_run { dry_run(db_tool.chain.clone(), output_db, to, from).await?; @@ -32,7 +32,7 @@ pub(crate) async fn dump_hashing_account_stage( } /// Dry-run an unwind to FROM block and copy the necessary table data to the new database. -async fn unwind_and_copy( +fn unwind_and_copy( db_tool: &DbTool<'_, DB>, from: u64, tip_block_number: u64, @@ -42,16 +42,14 @@ async fn unwind_and_copy( let provider = factory.provider_rw()?; let mut exec_stage = AccountHashingStage::default(); - exec_stage - .unwind( - &provider, - UnwindInput { - unwind_to: from, - checkpoint: StageCheckpoint::new(tip_block_number), - bad_block: None, - }, - ) - .await?; + exec_stage.unwind( + &provider, + UnwindInput { + unwind_to: from, + checkpoint: StageCheckpoint::new(tip_block_number), + bad_block: None, + }, + )?; let unwind_inner_tx = provider.into_tx(); output_db.update(|tx| tx.import_table::(&unwind_inner_tx))??; @@ -70,23 +68,19 @@ async fn dry_run( let factory = ProviderFactory::new(&output_db, chain); let provider = factory.provider_rw()?; - let mut exec_stage = AccountHashingStage { + let mut stage = AccountHashingStage { clean_threshold: 1, // Forces hashing from scratch ..Default::default() }; - let mut exec_output = false; - while !exec_output { - exec_output = exec_stage - .execute( - &provider, - reth_stages::ExecInput { - target: Some(to), - checkpoint: Some(StageCheckpoint::new(from)), - }, - ) - .await? - .done; + loop { + let input = reth_stages::ExecInput { + target: Some(to), + checkpoint: Some(StageCheckpoint::new(from)), + }; + if stage.execute(&provider, input)?.done { + break + } } info!(target: "reth::cli", "Success."); diff --git a/bin/reth/src/stage/dump/hashing_storage.rs b/bin/reth/src/stage/dump/hashing_storage.rs index 0a8df0a6e..373818072 100644 --- a/bin/reth/src/stage/dump/hashing_storage.rs +++ b/bin/reth/src/stage/dump/hashing_storage.rs @@ -17,7 +17,7 @@ pub(crate) async fn dump_hashing_storage_stage( ) -> Result<()> { let (output_db, tip_block_number) = setup(from, to, output_db, db_tool)?; - unwind_and_copy(db_tool, from, tip_block_number, &output_db).await?; + unwind_and_copy(db_tool, from, tip_block_number, &output_db)?; if should_run { dry_run(db_tool.chain.clone(), output_db, to, from).await?; @@ -27,7 +27,7 @@ pub(crate) async fn dump_hashing_storage_stage( } /// Dry-run an unwind to FROM block and copy the necessary table data to the new database. -async fn unwind_and_copy( +fn unwind_and_copy( db_tool: &DbTool<'_, DB>, from: u64, tip_block_number: u64, @@ -38,16 +38,14 @@ async fn unwind_and_copy( let mut exec_stage = StorageHashingStage::default(); - exec_stage - .unwind( - &provider, - UnwindInput { - unwind_to: from, - checkpoint: StageCheckpoint::new(tip_block_number), - bad_block: None, - }, - ) - .await?; + exec_stage.unwind( + &provider, + UnwindInput { + unwind_to: from, + checkpoint: StageCheckpoint::new(tip_block_number), + bad_block: None, + }, + )?; let unwind_inner_tx = provider.into_tx(); // TODO optimize we can actually just get the entries we need for both these tables @@ -69,23 +67,19 @@ async fn dry_run( let factory = ProviderFactory::new(&output_db, chain); let provider = factory.provider_rw()?; - let mut exec_stage = StorageHashingStage { + let mut stage = StorageHashingStage { clean_threshold: 1, // Forces hashing from scratch ..Default::default() }; - let mut exec_output = false; - while !exec_output { - exec_output = exec_stage - .execute( - &provider, - reth_stages::ExecInput { - target: Some(to), - checkpoint: Some(StageCheckpoint::new(from)), - }, - ) - .await? - .done; + loop { + let input = reth_stages::ExecInput { + target: Some(to), + checkpoint: Some(StageCheckpoint::new(from)), + }; + if stage.execute(&provider, input)?.done { + break + } } info!(target: "reth::cli", "Success."); diff --git a/bin/reth/src/stage/dump/merkle.rs b/bin/reth/src/stage/dump/merkle.rs index 55eef819f..4615b884c 100644 --- a/bin/reth/src/stage/dump/merkle.rs +++ b/bin/reth/src/stage/dump/merkle.rs @@ -61,10 +61,10 @@ async fn unwind_and_copy( // Unwind hashes all the way to FROM - StorageHashingStage::default().unwind(&provider, unwind).await.unwrap(); - AccountHashingStage::default().unwind(&provider, unwind).await.unwrap(); + StorageHashingStage::default().unwind(&provider, unwind).unwrap(); + AccountHashingStage::default().unwind(&provider, unwind).unwrap(); - MerkleStage::default_unwind().unwind(&provider, unwind).await?; + MerkleStage::default_unwind().unwind(&provider, unwind)?; // Bring Plainstate to TO (hashing stage execution requires it) let mut exec_stage = ExecutionStage::new( @@ -78,26 +78,21 @@ async fn unwind_and_copy( PruneModes::all(), ); - exec_stage - .unwind( - &provider, - UnwindInput { - unwind_to: to, - checkpoint: StageCheckpoint::new(tip_block_number), - bad_block: None, - }, - ) - .await?; + exec_stage.unwind( + &provider, + UnwindInput { + unwind_to: to, + checkpoint: StageCheckpoint::new(tip_block_number), + bad_block: None, + }, + )?; // Bring hashes to TO - AccountHashingStage { clean_threshold: u64::MAX, commit_threshold: u64::MAX } .execute(&provider, execute_input) - .await .unwrap(); StorageHashingStage { clean_threshold: u64::MAX, commit_threshold: u64::MAX } .execute(&provider, execute_input) - .await .unwrap(); let unwind_inner_tx = provider.into_tx(); @@ -123,25 +118,23 @@ async fn dry_run( info!(target: "reth::cli", "Executing stage."); let factory = ProviderFactory::new(&output_db, chain); let provider = factory.provider_rw()?; - let mut exec_output = false; - while !exec_output { - exec_output = MerkleStage::Execution { - clean_threshold: u64::MAX, /* Forces updating the root instead of calculating - * from - * scratch */ + + let mut stage = MerkleStage::Execution { + // Forces updating the root instead of calculating from scratch + clean_threshold: u64::MAX, + }; + + loop { + let input = reth_stages::ExecInput { + target: Some(to), + checkpoint: Some(StageCheckpoint::new(from)), + }; + if stage.execute(&provider, input)?.done { + break } - .execute( - &provider, - reth_stages::ExecInput { - target: Some(to), - checkpoint: Some(StageCheckpoint::new(from)), - }, - ) - .await? - .done; } - info!(target: "reth::cli", "Success."); + info!(target: "reth::cli", "Success"); Ok(()) } diff --git a/bin/reth/src/stage/run.rs b/bin/reth/src/stage/run.rs index c66792668..5eaeaf361 100644 --- a/bin/reth/src/stage/run.rs +++ b/bin/reth/src/stage/run.rs @@ -12,6 +12,7 @@ use crate::{ version::SHORT_VERSION, }; use clap::Parser; +use futures::future::poll_fn; use reth_beacon_consensus::BeaconConsensus; use reth_config::Config; use reth_db::init_db; @@ -24,7 +25,7 @@ use reth_stages::{ IndexAccountHistoryStage, IndexStorageHistoryStage, MerkleStage, SenderRecoveryStage, StorageHashingStage, TransactionLookupStage, }, - ExecInput, ExecOutput, Stage, UnwindInput, + ExecInput, Stage, UnwindInput, }; use std::{any::Any, net::SocketAddr, path::PathBuf, sync::Arc}; use tracing::*; @@ -175,8 +176,8 @@ impl Command { .await?; let fetch_client = Arc::new(network.fetch_client().await?); - let stage = BodyStage { - downloader: BodiesDownloaderBuilder::default() + let stage = BodyStage::new( + BodiesDownloaderBuilder::default() .with_stream_batch_size(batch_size as usize) .with_request_limit(config.stages.bodies.downloader_request_limit) .with_max_buffered_blocks_size_bytes( @@ -187,8 +188,7 @@ impl Command { config.stages.bodies.downloader_max_concurrent_requests, ) .build(fetch_client, consensus.clone(), db.clone()), - consensus: consensus.clone(), - }; + ); (Box::new(stage), None) } @@ -242,7 +242,7 @@ impl Command { if !self.skip_unwind { while unwind.checkpoint.block_number > self.from { - let unwind_output = unwind_stage.unwind(&provider_rw, unwind).await?; + let unwind_output = unwind_stage.unwind(&provider_rw, unwind)?; unwind.checkpoint = unwind_output.checkpoint; if self.commit { @@ -257,19 +257,20 @@ impl Command { checkpoint: Some(checkpoint.with_block_number(self.from)), }; - while let ExecOutput { checkpoint: stage_progress, done: false } = - exec_stage.execute(&provider_rw, input).await? - { - input.checkpoint = Some(stage_progress); + loop { + poll_fn(|cx| exec_stage.poll_execute_ready(cx, input)).await?; + let output = exec_stage.execute(&provider_rw, input)?; + + input.checkpoint = Some(output.checkpoint); if self.commit { provider_rw.commit()?; provider_rw = factory.provider_rw()?; } - } - if self.commit { - provider_rw.commit()?; + if output.done { + break + } } Ok(()) diff --git a/crates/consensus/beacon/src/engine/test_utils.rs b/crates/consensus/beacon/src/engine/test_utils.rs index 092ce9f5e..f58ebf013 100644 --- a/crates/consensus/beacon/src/engine/test_utils.rs +++ b/crates/consensus/beacon/src/engine/test_utils.rs @@ -26,17 +26,15 @@ use reth_payload_builder::test_utils::spawn_test_payload_service; use reth_primitives::{BlockNumber, ChainSpec, PruneModes, Receipt, B256, U256}; use reth_provider::{ providers::BlockchainProvider, test_utils::TestExecutorFactory, BlockExecutor, - BundleStateWithReceipts, ExecutorFactory, ProviderFactory, PrunableBlockExecutor, + BundleStateWithReceipts, ExecutorFactory, HeaderSyncMode, ProviderFactory, + PrunableBlockExecutor, }; use reth_prune::Pruner; use reth_revm::Factory; use reth_rpc_types::engine::{ CancunPayloadFields, ExecutionPayload, ForkchoiceState, ForkchoiceUpdated, PayloadStatus, }; -use reth_stages::{ - sets::DefaultStages, stages::HeaderSyncMode, test_utils::TestStages, ExecOutput, Pipeline, - StageError, -}; +use reth_stages::{sets::DefaultStages, test_utils::TestStages, ExecOutput, Pipeline, StageError}; use reth_tasks::TokioTaskExecutor; use std::{collections::VecDeque, sync::Arc}; use tokio::sync::{oneshot, watch}; @@ -502,6 +500,7 @@ where .into_task(); Pipeline::builder().add_stages(DefaultStages::new( + ProviderFactory::new(db.clone(), self.base_config.chain_spec.clone()), HeaderSyncMode::Tip(tip_rx.clone()), Arc::clone(&consensus), header_downloader, diff --git a/crates/interfaces/src/provider.rs b/crates/interfaces/src/provider.rs index f5f0a7fcc..c2137b4b7 100644 --- a/crates/interfaces/src/provider.rs +++ b/crates/interfaces/src/provider.rs @@ -20,6 +20,9 @@ pub enum ProviderError { /// Error when recovering the sender for a transaction #[error("failed to recover sender for transaction")] SenderRecoveryError, + /// Inconsistent header gap. + #[error("inconsistent header gap in the database")] + InconsistentHeaderGap, /// The header number was not found for the given block hash. #[error("block hash {0} does not exist in Headers table")] BlockHashNotFound(BlockHash), diff --git a/crates/stages/benches/criterion.rs b/crates/stages/benches/criterion.rs index 9e55781b7..ad210165c 100644 --- a/crates/stages/benches/criterion.rs +++ b/crates/stages/benches/criterion.rs @@ -12,7 +12,7 @@ use reth_stages::{ test_utils::TestTransaction, ExecInput, Stage, UnwindInput, }; -use std::{path::PathBuf, sync::Arc}; +use std::{future::poll_fn, path::PathBuf, sync::Arc}; mod setup; use setup::StageRange; @@ -138,7 +138,10 @@ fn measure_stage_with_path( let mut stage = stage.clone(); let factory = ProviderFactory::new(tx.tx.db(), MAINNET.clone()); let provider = factory.provider_rw().unwrap(); - stage.execute(&provider, input).await.unwrap(); + poll_fn(|cx| stage.poll_execute_ready(cx, input)) + .await + .and_then(|_| stage.execute(&provider, input)) + .unwrap(); provider.commit().unwrap(); }, ) diff --git a/crates/stages/benches/setup/mod.rs b/crates/stages/benches/setup/mod.rs index f5c45be9b..806f2d78f 100644 --- a/crates/stages/benches/setup/mod.rs +++ b/crates/stages/benches/setup/mod.rs @@ -47,7 +47,6 @@ pub(crate) fn stage_unwind>( // Clear previous run stage .unwind(&provider, unwind) - .await .map_err(|e| { format!( "{e}\nMake sure your test database at `{}` isn't too old and incompatible with newer stage changes.", @@ -67,22 +66,20 @@ pub(crate) fn unwind_hashes>( ) { let (input, unwind) = range; - tokio::runtime::Runtime::new().unwrap().block_on(async { - let mut stage = stage.clone(); - let factory = ProviderFactory::new(tx.tx.db(), MAINNET.clone()); - let provider = factory.provider_rw().unwrap(); + let mut stage = stage.clone(); + let factory = ProviderFactory::new(tx.tx.db(), MAINNET.clone()); + let provider = factory.provider_rw().unwrap(); - StorageHashingStage::default().unwind(&provider, unwind).await.unwrap(); - AccountHashingStage::default().unwind(&provider, unwind).await.unwrap(); + StorageHashingStage::default().unwind(&provider, unwind).unwrap(); + AccountHashingStage::default().unwind(&provider, unwind).unwrap(); - // Clear previous run - stage.unwind(&provider, unwind).await.unwrap(); + // Clear previous run + stage.unwind(&provider, unwind).unwrap(); - AccountHashingStage::default().execute(&provider, input).await.unwrap(); - StorageHashingStage::default().execute(&provider, input).await.unwrap(); + AccountHashingStage::default().execute(&provider, input).unwrap(); + StorageHashingStage::default().execute(&provider, input).unwrap(); - provider.commit().unwrap(); - }); + provider.commit().unwrap(); } // Helper for generating testdata for the benchmarks. diff --git a/crates/stages/src/error.rs b/crates/stages/src/error.rs index 180a8ca5a..8795868d0 100644 --- a/crates/stages/src/error.rs +++ b/crates/stages/src/error.rs @@ -50,6 +50,9 @@ pub enum StageError { #[source] error: Box, }, + /// The headers stage is missing sync gap. + #[error("missing sync gap")] + MissingSyncGap, /// The stage encountered a database error. #[error("internal database error occurred: {0}")] Database(#[from] DbError), @@ -94,6 +97,7 @@ impl StageError { StageError::Download(_) | StageError::DatabaseIntegrity(_) | StageError::StageCheckpoint(_) | + StageError::MissingSyncGap | StageError::ChannelClosed | StageError::Fatal(_) ) diff --git a/crates/stages/src/lib.rs b/crates/stages/src/lib.rs index f30471182..bf9ba9e8d 100644 --- a/crates/stages/src/lib.rs +++ b/crates/stages/src/lib.rs @@ -22,8 +22,11 @@ //! # use reth_primitives::{PeerId, MAINNET, B256}; //! # use reth_stages::Pipeline; //! # use reth_stages::sets::DefaultStages; -//! # use reth_stages::stages::HeaderSyncMode; //! # use tokio::sync::watch; +//! # use reth_provider::ProviderFactory; +//! # use reth_provider::HeaderSyncMode; +//! # +//! # let chain_spec = MAINNET.clone(); //! # let consensus: Arc = Arc::new(TestConsensus::default()); //! # let headers_downloader = ReverseHeadersDownloaderBuilder::default().build( //! # Arc::new(TestHeadersClient::default()), @@ -36,19 +39,20 @@ //! # db.clone() //! # ); //! # let (tip_tx, tip_rx) = watch::channel(B256::default()); -//! # let factory = Factory::new(MAINNET.clone()); +//! # let factory = Factory::new(chain_spec.clone()); //! // Create a pipeline that can fully sync //! # let pipeline = //! Pipeline::builder() //! .with_tip_sender(tip_tx) //! .add_stages(DefaultStages::new( +//! ProviderFactory::new(db.clone(), chain_spec.clone()), //! HeaderSyncMode::Tip(tip_rx), //! consensus, //! headers_downloader, //! bodies_downloader, //! factory, //! )) -//! .build(db, MAINNET.clone()); +//! .build(db, chain_spec.clone()); //! ``` //! //! ## Feature Flags diff --git a/crates/stages/src/pipeline/mod.rs b/crates/stages/src/pipeline/mod.rs index f5955a5df..718809abc 100644 --- a/crates/stages/src/pipeline/mod.rs +++ b/crates/stages/src/pipeline/mod.rs @@ -5,11 +5,13 @@ use crate::{ use futures_util::Future; use reth_db::database::Database; use reth_primitives::{ - constants::BEACON_CONSENSUS_REORG_UNWIND_DEPTH, stage::StageId, BlockNumber, ChainSpec, B256, + constants::BEACON_CONSENSUS_REORG_UNWIND_DEPTH, + stage::{StageCheckpoint, StageId}, + BlockNumber, ChainSpec, B256, }; use reth_provider::{ProviderFactory, StageCheckpointReader, StageCheckpointWriter}; use reth_tokio_util::EventListeners; -use std::{pin::Pin, sync::Arc}; +use std::{future::poll_fn, pin::Pin, sync::Arc}; use tokio::sync::watch; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::*; @@ -217,10 +219,7 @@ where let stage_id = stage.id(); trace!(target: "sync::pipeline", stage = %stage_id, "Executing stage"); - let next = self - .execute_stage_to_completion(previous_stage, stage_index) - .instrument(info_span!("execute", stage = %stage_id)) - .await?; + let next = self.execute_stage_to_completion(previous_stage, stage_index).await?; trace!(target: "sync::pipeline", stage = %stage_id, ?next, "Completed stage"); @@ -232,7 +231,7 @@ where } ControlFlow::Continue { block_number } => self.progress.update(block_number), ControlFlow::Unwind { target, bad_block } => { - self.unwind(target, Some(bad_block.number)).await?; + self.unwind(target, Some(bad_block.number))?; return Ok(ControlFlow::Unwind { target, bad_block }) } } @@ -254,7 +253,7 @@ where /// Unwind the stages to the target block. /// /// If the unwind is due to a bad block the number of that block should be specified. - pub async fn unwind( + pub fn unwind( &mut self, to: BlockNumber, bad_block: Option, @@ -293,7 +292,7 @@ where let input = UnwindInput { checkpoint, unwind_to: to, bad_block }; self.listeners.notify(PipelineEvent::Unwinding { stage_id, input }); - let output = stage.unwind(&provider_rw, input).await; + let output = stage.unwind(&provider_rw, input); match output { Ok(unwind_output) => { checkpoint = unwind_output.checkpoint; @@ -346,10 +345,9 @@ where let target = self.max_block.or(previous_stage); let factory = ProviderFactory::new(&self.db, self.chain_spec.clone()); - let mut provider_rw = factory.provider_rw()?; loop { - let prev_checkpoint = provider_rw.get_stage_checkpoint(stage_id)?; + let prev_checkpoint = factory.get_stage_checkpoint(stage_id)?; let stage_reached_max_block = prev_checkpoint .zip(self.max_block) @@ -370,6 +368,16 @@ where }) } + let exec_input = ExecInput { target, checkpoint: prev_checkpoint }; + + if let Err(err) = poll_fn(|cx| stage.poll_execute_ready(cx, exec_input)).await { + self.listeners.notify(PipelineEvent::Error { stage_id }); + match on_stage_error(&factory, stage_id, prev_checkpoint, err)? { + Some(ctrl) => return Ok(ctrl), + None => continue, + }; + } + self.listeners.notify(PipelineEvent::Running { pipeline_stages_progress: event::PipelineStagesProgress { current: stage_index + 1, @@ -379,10 +387,8 @@ where checkpoint: prev_checkpoint, }); - match stage - .execute(&provider_rw, ExecInput { target, checkpoint: prev_checkpoint }) - .await - { + let provider_rw = factory.provider_rw()?; + match stage.execute(&provider_rw, exec_input) { Ok(out @ ExecOutput { checkpoint, done }) => { made_progress |= checkpoint.block_number != prev_checkpoint.unwrap_or_default().block_number; @@ -425,9 +431,7 @@ where result: out.clone(), }); - // TODO: Make the commit interval configurable provider_rw.commit()?; - provider_rw = factory.provider_rw()?; if done { let block_number = checkpoint.block_number; @@ -439,94 +443,93 @@ where } } Err(err) => { + drop(provider_rw); self.listeners.notify(PipelineEvent::Error { stage_id }); - - let out = if let StageError::DetachedHead { local_head, header, error } = err { - warn!(target: "sync::pipeline", stage = %stage_id, ?local_head, ?header, ?error, "Stage encountered detached head"); - - // We unwind because of a detached head. - let unwind_to = local_head - .number - .saturating_sub(BEACON_CONSENSUS_REORG_UNWIND_DEPTH) - .max(1); - Ok(ControlFlow::Unwind { target: unwind_to, bad_block: local_head }) - } else if let StageError::Block { block, error } = err { - match error { - BlockErrorKind::Validation(validation_error) => { - error!( - target: "sync::pipeline", - stage = %stage_id, - bad_block = %block.number, - "Stage encountered a validation error: {validation_error}" - ); - - // FIXME: When handling errors, we do not commit the database - // transaction. This leads to the Merkle - // stage not clearing its checkpoint, and - // restarting from an invalid place. - drop(provider_rw); - provider_rw = factory.provider_rw()?; - provider_rw.save_stage_checkpoint_progress( - StageId::MerkleExecute, - vec![], - )?; - provider_rw.save_stage_checkpoint( - StageId::MerkleExecute, - prev_checkpoint.unwrap_or_default(), - )?; - provider_rw.commit()?; - - // We unwind because of a validation error. If the unwind itself - // fails, we bail entirely, - // otherwise we restart the execution loop from the - // beginning. - Ok(ControlFlow::Unwind { - target: prev_checkpoint.unwrap_or_default().block_number, - bad_block: block, - }) - } - BlockErrorKind::Execution(execution_error) => { - error!( - target: "sync::pipeline", - stage = %stage_id, - bad_block = %block.number, - "Stage encountered an execution error: {execution_error}" - ); - - // We unwind because of an execution error. If the unwind itself - // fails, we bail entirely, - // otherwise we restart - // the execution loop from the beginning. - Ok(ControlFlow::Unwind { - target: prev_checkpoint.unwrap_or_default().block_number, - bad_block: block, - }) - } - } - } else if err.is_fatal() { - error!( - target: "sync::pipeline", - stage = %stage_id, - "Stage encountered a fatal error: {err}." - ); - Err(err.into()) - } else { - // On other errors we assume they are recoverable if we discard the - // transaction and run the stage again. - warn!( - target: "sync::pipeline", - stage = %stage_id, - "Stage encountered a non-fatal error: {err}. Retrying..." - ); - continue - }; - return out + if let Some(ctrl) = on_stage_error(&factory, stage_id, prev_checkpoint, err)? { + return Ok(ctrl) + } } } } } } +fn on_stage_error( + factory: &ProviderFactory, + stage_id: StageId, + prev_checkpoint: Option, + err: StageError, +) -> Result, PipelineError> { + if let StageError::DetachedHead { local_head, header, error } = err { + warn!(target: "sync::pipeline", stage = %stage_id, ?local_head, ?header, ?error, "Stage encountered detached head"); + + // We unwind because of a detached head. + let unwind_to = + local_head.number.saturating_sub(BEACON_CONSENSUS_REORG_UNWIND_DEPTH).max(1); + Ok(Some(ControlFlow::Unwind { target: unwind_to, bad_block: local_head })) + } else if let StageError::Block { block, error } = err { + match error { + BlockErrorKind::Validation(validation_error) => { + error!( + target: "sync::pipeline", + stage = %stage_id, + bad_block = %block.number, + "Stage encountered a validation error: {validation_error}" + ); + + // FIXME: When handling errors, we do not commit the database transaction. This + // leads to the Merkle stage not clearing its checkpoint, and restarting from an + // invalid place. + let provider_rw = factory.provider_rw()?; + provider_rw.save_stage_checkpoint_progress(StageId::MerkleExecute, vec![])?; + provider_rw.save_stage_checkpoint( + StageId::MerkleExecute, + prev_checkpoint.unwrap_or_default(), + )?; + provider_rw.commit()?; + + // We unwind because of a validation error. If the unwind itself + // fails, we bail entirely, + // otherwise we restart the execution loop from the + // beginning. + Ok(Some(ControlFlow::Unwind { + target: prev_checkpoint.unwrap_or_default().block_number, + bad_block: block, + })) + } + BlockErrorKind::Execution(execution_error) => { + error!( + target: "sync::pipeline", + stage = %stage_id, + bad_block = %block.number, + "Stage encountered an execution error: {execution_error}" + ); + + // We unwind because of an execution error. If the unwind itself + // fails, we bail entirely, + // otherwise we restart + // the execution loop from the beginning. + Ok(Some(ControlFlow::Unwind { + target: prev_checkpoint.unwrap_or_default().block_number, + bad_block: block, + })) + } + } + } else if err.is_fatal() { + error!(target: "sync::pipeline", stage = %stage_id, "Stage encountered a fatal error: {err}"); + Err(err.into()) + } else { + // On other errors we assume they are recoverable if we discard the + // transaction and run the stage again. + warn!( + target: "sync::pipeline", + stage = %stage_id, + "Stage encountered a non-fatal error: {err}. Retrying..." + ); + Ok(None) + } +} + impl std::fmt::Debug for Pipeline { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Pipeline") @@ -660,7 +663,7 @@ mod tests { pipeline.run().await.expect("Could not run pipeline"); // Unwind - pipeline.unwind(1, None).await.expect("Could not unwind pipeline"); + pipeline.unwind(1, None).expect("Could not unwind pipeline"); }); // Check that the stages were unwound in reverse order @@ -764,7 +767,7 @@ mod tests { pipeline.run().await.expect("Could not run pipeline"); // Unwind - pipeline.unwind(50, None).await.expect("Could not unwind pipeline"); + pipeline.unwind(50, None).expect("Could not unwind pipeline"); }); // Check that the stages were unwound in reverse order diff --git a/crates/stages/src/sets.rs b/crates/stages/src/sets.rs index f49714e01..5a9ac7942 100644 --- a/crates/stages/src/sets.rs +++ b/crates/stages/src/sets.rs @@ -38,7 +38,7 @@ //! ``` use crate::{ stages::{ - AccountHashingStage, BodyStage, ExecutionStage, FinishStage, HeaderStage, HeaderSyncMode, + AccountHashingStage, BodyStage, ExecutionStage, FinishStage, HeaderStage, IndexAccountHistoryStage, IndexStorageHistoryStage, MerkleStage, SenderRecoveryStage, StorageHashingStage, TotalDifficultyStage, TransactionLookupStage, }, @@ -49,7 +49,7 @@ use reth_interfaces::{ consensus::Consensus, p2p::{bodies::downloader::BodyDownloader, headers::downloader::HeaderDownloader}, }; -use reth_provider::ExecutorFactory; +use reth_provider::{ExecutorFactory, HeaderSyncGapProvider, HeaderSyncMode}; use std::sync::Arc; /// A set containing all stages to run a fully syncing instance of reth. @@ -75,16 +75,17 @@ use std::sync::Arc; /// - [`IndexAccountHistoryStage`] /// - [`FinishStage`] #[derive(Debug)] -pub struct DefaultStages { +pub struct DefaultStages { /// Configuration for the online stages - online: OnlineStages, + online: OnlineStages, /// Executor factory needs for execution stage executor_factory: EF, } -impl DefaultStages { +impl DefaultStages { /// Create a new set of default stages with default values. pub fn new( + provider: Provider, header_mode: HeaderSyncMode, consensus: Arc, header_downloader: H, @@ -95,13 +96,19 @@ impl DefaultStages { EF: ExecutorFactory, { Self { - online: OnlineStages::new(header_mode, consensus, header_downloader, body_downloader), + online: OnlineStages::new( + provider, + header_mode, + consensus, + header_downloader, + body_downloader, + ), executor_factory, } } } -impl DefaultStages +impl DefaultStages where EF: ExecutorFactory, { @@ -114,9 +121,10 @@ where } } -impl StageSet for DefaultStages +impl StageSet for DefaultStages where DB: Database, + Provider: HeaderSyncGapProvider + 'static, H: HeaderDownloader + 'static, B: BodyDownloader + 'static, EF: ExecutorFactory, @@ -131,7 +139,9 @@ where /// These stages *can* be run without network access if the specified downloaders are /// themselves offline. #[derive(Debug)] -pub struct OnlineStages { +pub struct OnlineStages { + /// Sync gap provider for the headers stage. + provider: Provider, /// The sync mode for the headers stage. header_mode: HeaderSyncMode, /// The consensus engine used to validate incoming data. @@ -142,60 +152,64 @@ pub struct OnlineStages { body_downloader: B, } -impl OnlineStages { +impl OnlineStages { /// Create a new set of online stages with default values. pub fn new( + provider: Provider, header_mode: HeaderSyncMode, consensus: Arc, header_downloader: H, body_downloader: B, ) -> Self { - Self { header_mode, consensus, header_downloader, body_downloader } + Self { provider, header_mode, consensus, header_downloader, body_downloader } } } -impl OnlineStages +impl OnlineStages where + Provider: HeaderSyncGapProvider + 'static, H: HeaderDownloader + 'static, B: BodyDownloader + 'static, { /// Create a new builder using the given headers stage. pub fn builder_with_headers( - headers: HeaderStage, + headers: HeaderStage, body_downloader: B, consensus: Arc, ) -> StageSetBuilder { StageSetBuilder::default() .add_stage(headers) .add_stage(TotalDifficultyStage::new(consensus.clone())) - .add_stage(BodyStage { downloader: body_downloader, consensus }) + .add_stage(BodyStage::new(body_downloader)) } /// Create a new builder using the given bodies stage. pub fn builder_with_bodies( bodies: BodyStage, + provider: Provider, mode: HeaderSyncMode, header_downloader: H, consensus: Arc, ) -> StageSetBuilder { StageSetBuilder::default() - .add_stage(HeaderStage::new(header_downloader, mode)) + .add_stage(HeaderStage::new(provider, header_downloader, mode)) .add_stage(TotalDifficultyStage::new(consensus.clone())) .add_stage(bodies) } } -impl StageSet for OnlineStages +impl StageSet for OnlineStages where DB: Database, + Provider: HeaderSyncGapProvider + 'static, H: HeaderDownloader + 'static, B: BodyDownloader + 'static, { fn builder(self) -> StageSetBuilder { StageSetBuilder::default() - .add_stage(HeaderStage::new(self.header_downloader, self.header_mode)) + .add_stage(HeaderStage::new(self.provider, self.header_downloader, self.header_mode)) .add_stage(TotalDifficultyStage::new(self.consensus.clone())) - .add_stage(BodyStage { downloader: self.body_downloader, consensus: self.consensus }) + .add_stage(BodyStage::new(self.body_downloader)) } } diff --git a/crates/stages/src/stage.rs b/crates/stages/src/stage.rs index 95e397cbe..55a491a83 100644 --- a/crates/stages/src/stage.rs +++ b/crates/stages/src/stage.rs @@ -1,5 +1,4 @@ use crate::error::StageError; -use async_trait::async_trait; use reth_db::database::Database; use reth_primitives::{ stage::{StageCheckpoint, StageId}, @@ -9,6 +8,7 @@ use reth_provider::{BlockReader, DatabaseProviderRW, ProviderError, Transactions use std::{ cmp::{max, min}, ops::{Range, RangeInclusive}, + task::{Context, Poll}, }; /// Stage execution input, see [Stage::execute]. @@ -189,22 +189,55 @@ pub struct UnwindOutput { /// Stages are executed as part of a pipeline where they are executed serially. /// /// Stages receive [`DatabaseProviderRW`]. -#[async_trait] pub trait Stage: Send + Sync { /// Get the ID of the stage. /// /// Stage IDs must be unique. fn id(&self) -> StageId; + /// Returns `Poll::Ready(Ok(()))` when the stage is ready to execute the given range. + /// + /// This method is heavily inspired by [tower](https://crates.io/crates/tower)'s `Service` trait. + /// Any asynchronous tasks or communication should be handled in `poll_ready`, e.g. moving + /// downloaded items from downloaders to an internal buffer in the stage. + /// + /// If the stage has any pending external state, then `Poll::Pending` is returned. + /// + /// If `Poll::Ready(Err(_))` is returned, the stage may not be able to execute anymore + /// depending on the specific error. In that case, an unwind must be issued instead. + /// + /// Once `Poll::Ready(Ok(()))` is returned, the stage may be executed once using `execute`. + /// Until the stage has been executed, repeated calls to `poll_ready` must return either + /// `Poll::Ready(Ok(()))` or `Poll::Ready(Err(_))`. + /// + /// Note that `poll_ready` may reserve shared resources that are consumed in a subsequent call + /// of `execute`, e.g. internal buffers. It is crucial for implementations to not assume that + /// `execute` will always be invoked and to ensure that those resources are appropriately + /// released if the stage is dropped before `execute` is called. + /// + /// For the same reason, it is also important that any shared resources do not exhibit + /// unbounded growth on repeated calls to `poll_ready`. + /// + /// Unwinds may happen without consulting `poll_ready` first. + fn poll_execute_ready( + &mut self, + _cx: &mut Context<'_>, + _input: ExecInput, + ) -> Poll> { + Poll::Ready(Ok(())) + } + /// Execute the stage. - async fn execute( + /// It is expected that the stage will write all necessary data to the database + /// upon invoking this method. + fn execute( &mut self, provider: &DatabaseProviderRW<'_, &DB>, input: ExecInput, ) -> Result; /// Unwind the stage. - async fn unwind( + fn unwind( &mut self, provider: &DatabaseProviderRW<'_, &DB>, input: UnwindInput, diff --git a/crates/stages/src/stages/bodies.rs b/crates/stages/src/stages/bodies.rs index 8da7e6511..cb908ebf9 100644 --- a/crates/stages/src/stages/bodies.rs +++ b/crates/stages/src/stages/bodies.rs @@ -8,13 +8,10 @@ use reth_db::{ transaction::{DbTx, DbTxMut}, DatabaseError, }; -use reth_interfaces::{ - consensus::Consensus, - p2p::bodies::{downloader::BodyDownloader, response::BlockResponse}, -}; +use reth_interfaces::p2p::bodies::{downloader::BodyDownloader, response::BlockResponse}; use reth_primitives::stage::{EntitiesCheckpoint, StageCheckpoint, StageId}; use reth_provider::DatabaseProviderRW; -use std::sync::Arc; +use std::task::{ready, Context, Poll}; use tracing::*; // TODO(onbjerg): Metrics and events (gradual status for e.g. CLI) @@ -51,21 +48,55 @@ use tracing::*; #[derive(Debug)] pub struct BodyStage { /// The body downloader. - pub downloader: D, - /// The consensus engine. - pub consensus: Arc, + downloader: D, + /// Block response buffer. + buffer: Vec, +} + +impl BodyStage { + /// Create new bodies stage from downloader. + pub fn new(downloader: D) -> Self { + Self { downloader, buffer: Vec::new() } + } } -#[async_trait::async_trait] impl Stage for BodyStage { /// Return the id of the stage fn id(&self) -> StageId { StageId::Bodies } + fn poll_execute_ready( + &mut self, + cx: &mut Context<'_>, + input: ExecInput, + ) -> Poll> { + if input.target_reached() || !self.buffer.is_empty() { + return Poll::Ready(Ok(())) + } + + // Update the header range on the downloader + self.downloader.set_download_range(input.next_block_range())?; + + // Poll next downloader item. + let maybe_next_result = ready!(self.downloader.try_poll_next_unpin(cx)); + + // Task downloader can return `None` only if the response relaying channel was closed. This + // is a fatal error to prevent the pipeline from running forever. + let response = match maybe_next_result { + Some(Ok(downloaded)) => { + self.buffer.extend(downloaded); + Ok(()) + } + Some(Err(err)) => Err(err.into()), + None => Err(StageError::ChannelClosed), + }; + Poll::Ready(response) + } + /// Download block bodies from the last checkpoint for this stage up until the latest synced /// header, limited by the stage's batch size. - async fn execute( + fn execute( &mut self, provider: &DatabaseProviderRW<'_, &DB>, input: ExecInput, @@ -73,11 +104,7 @@ impl Stage for BodyStage { if input.target_reached() { return Ok(ExecOutput::done(input.checkpoint())) } - - let range = input.next_block_range(); - // Update the header range on the downloader - self.downloader.set_download_range(range.clone())?; - let (from_block, to_block) = range.into_inner(); + let (from_block, to_block) = input.next_block_range().into_inner(); // Cursors used to write bodies, ommers and transactions let tx = provider.tx_ref(); @@ -91,16 +118,9 @@ impl Stage for BodyStage { let mut next_tx_num = tx_cursor.last()?.map(|(id, _)| id + 1).unwrap_or_default(); debug!(target: "sync::stages::bodies", stage_progress = from_block, target = to_block, start_tx_id = next_tx_num, "Commencing sync"); - - // Task downloader can return `None` only if the response relaying channel was closed. This - // is a fatal error to prevent the pipeline from running forever. - let downloaded_bodies = - self.downloader.try_next().await?.ok_or(StageError::ChannelClosed)?; - - trace!(target: "sync::stages::bodies", bodies_len = downloaded_bodies.len(), "Writing blocks"); - + trace!(target: "sync::stages::bodies", bodies_len = self.buffer.len(), "Writing blocks"); let mut highest_block = from_block; - for response in downloaded_bodies { + for response in self.buffer.drain(..) { // Write block let block_number = response.block_number(); @@ -161,11 +181,13 @@ impl Stage for BodyStage { } /// Unwind the stage. - async fn unwind( + fn unwind( &mut self, provider: &DatabaseProviderRW<'_, &DB>, input: UnwindInput, ) -> Result { + self.buffer.clear(); + let tx = provider.tx_ref(); // Cursors to unwind bodies, ommers let mut body_cursor = tx.cursor_write::()?; @@ -476,7 +498,6 @@ mod tests { test_utils::{ generators, generators::{random_block_range, random_signed_tx}, - TestConsensus, }, }; use reth_primitives::{BlockBody, BlockNumber, SealedBlock, SealedHeader, TxNumber, B256}; @@ -505,7 +526,6 @@ mod tests { /// A helper struct for running the [BodyStage]. pub(crate) struct BodyTestRunner { - pub(crate) consensus: Arc, responses: HashMap, tx: TestTransaction, batch_size: u64, @@ -514,7 +534,6 @@ mod tests { impl Default for BodyTestRunner { fn default() -> Self { Self { - consensus: Arc::new(TestConsensus::default()), responses: HashMap::default(), tx: TestTransaction::default(), batch_size: 1000, @@ -540,14 +559,11 @@ mod tests { } fn stage(&self) -> Self::S { - BodyStage { - downloader: TestBodyDownloader::new( - self.tx.inner_raw(), - self.responses.clone(), - self.batch_size, - ), - consensus: self.consensus.clone(), - } + BodyStage::new(TestBodyDownloader::new( + self.tx.inner_raw(), + self.responses.clone(), + self.batch_size, + )) } } diff --git a/crates/stages/src/stages/execution.rs b/crates/stages/src/stages/execution.rs index a53bef070..d6ffc67df 100644 --- a/crates/stages/src/stages/execution.rs +++ b/crates/stages/src/stages/execution.rs @@ -331,7 +331,6 @@ fn calculate_gas_used_from_headers( Ok(gas_total) } -#[async_trait::async_trait] impl Stage for ExecutionStage { /// Return the id of the stage fn id(&self) -> StageId { @@ -339,7 +338,7 @@ impl Stage for ExecutionStage { } /// Execute the stage - async fn execute( + fn execute( &mut self, provider: &DatabaseProviderRW<'_, &DB>, input: ExecInput, @@ -348,7 +347,7 @@ impl Stage for ExecutionStage { } /// Unwind the stage. - async fn unwind( + fn unwind( &mut self, provider: &DatabaseProviderRW<'_, &DB>, input: UnwindInput, @@ -685,8 +684,8 @@ mod tests { provider.commit().unwrap(); let provider = factory.provider_rw().unwrap(); - let mut execution_stage = stage(); - let output = execution_stage.execute(&provider, input).await.unwrap(); + let mut execution_stage: ExecutionStage = stage(); + let output = execution_stage.execute(&provider, input).unwrap(); provider.commit().unwrap(); assert_matches!(output, ExecOutput { checkpoint: StageCheckpoint { @@ -787,7 +786,7 @@ mod tests { // execute let provider = factory.provider_rw().unwrap(); let mut execution_stage = stage(); - let result = execution_stage.execute(&provider, input).await.unwrap(); + let result = execution_stage.execute(&provider, input).unwrap(); provider.commit().unwrap(); let provider = factory.provider_rw().unwrap(); @@ -797,7 +796,6 @@ mod tests { &provider, UnwindInput { checkpoint: result.checkpoint, unwind_to: 0, bad_block: None }, ) - .await .unwrap(); assert_matches!(result, UnwindOutput { @@ -886,7 +884,7 @@ mod tests { // execute let provider = factory.provider_rw().unwrap(); let mut execution_stage = stage(); - let _ = execution_stage.execute(&provider, input).await.unwrap(); + let _ = execution_stage.execute(&provider, input).unwrap(); provider.commit().unwrap(); // assert unwind stage diff --git a/crates/stages/src/stages/finish.rs b/crates/stages/src/stages/finish.rs index 751c4e37b..d0a63e890 100644 --- a/crates/stages/src/stages/finish.rs +++ b/crates/stages/src/stages/finish.rs @@ -11,13 +11,12 @@ use reth_provider::DatabaseProviderRW; #[non_exhaustive] pub struct FinishStage; -#[async_trait::async_trait] impl Stage for FinishStage { fn id(&self) -> StageId { StageId::Finish } - async fn execute( + fn execute( &mut self, _provider: &DatabaseProviderRW<'_, &DB>, input: ExecInput, @@ -25,7 +24,7 @@ impl Stage for FinishStage { Ok(ExecOutput { checkpoint: StageCheckpoint::new(input.target()), done: true }) } - async fn unwind( + fn unwind( &mut self, _provider: &DatabaseProviderRW<'_, &DB>, input: UnwindInput, diff --git a/crates/stages/src/stages/hashing_account.rs b/crates/stages/src/stages/hashing_account.rs index 896bfc976..4eab05e09 100644 --- a/crates/stages/src/stages/hashing_account.rs +++ b/crates/stages/src/stages/hashing_account.rs @@ -21,8 +21,8 @@ use std::{ cmp::max, fmt::Debug, ops::{Range, RangeInclusive}, + sync::mpsc, }; -use tokio::sync::mpsc; use tracing::*; /// Account hashing stage hashes plain account. @@ -125,7 +125,6 @@ impl AccountHashingStage { } } -#[async_trait::async_trait] impl Stage for AccountHashingStage { /// Return the id of the stage fn id(&self) -> StageId { @@ -133,7 +132,7 @@ impl Stage for AccountHashingStage { } /// Execute the stage. - async fn execute( + fn execute( &mut self, provider: &DatabaseProviderRW<'_, &DB>, input: ExecInput, @@ -190,7 +189,7 @@ impl Stage for AccountHashingStage { ) { // An _unordered_ channel to receive results from a rayon job - let (tx, rx) = mpsc::unbounded_channel(); + let (tx, rx) = mpsc::channel(); channels.push(rx); let chunk = chunk.collect::, _>>()?; @@ -205,8 +204,8 @@ impl Stage for AccountHashingStage { let mut hashed_batch = Vec::with_capacity(self.commit_threshold as usize); // Iterate over channels and append the hashed accounts. - for mut channel in channels { - while let Some(hashed) = channel.recv().await { + for channel in channels { + while let Ok(hashed) = channel.recv() { hashed_batch.push(hashed); } } @@ -265,7 +264,7 @@ impl Stage for AccountHashingStage { } /// Unwind the stage. - async fn unwind( + fn unwind( &mut self, provider: &DatabaseProviderRW<'_, &DB>, input: UnwindInput, diff --git a/crates/stages/src/stages/hashing_storage.rs b/crates/stages/src/stages/hashing_storage.rs index 2580b58c9..da2fd38ac 100644 --- a/crates/stages/src/stages/hashing_storage.rs +++ b/crates/stages/src/stages/hashing_storage.rs @@ -44,7 +44,6 @@ impl Default for StorageHashingStage { } } -#[async_trait::async_trait] impl Stage for StorageHashingStage { /// Return the id of the stage fn id(&self) -> StageId { @@ -52,7 +51,7 @@ impl Stage for StorageHashingStage { } /// Execute the stage. - async fn execute( + fn execute( &mut self, provider: &DatabaseProviderRW<'_, &DB>, input: ExecInput, @@ -191,7 +190,7 @@ impl Stage for StorageHashingStage { } /// Unwind the stage. - async fn unwind( + fn unwind( &mut self, provider: &DatabaseProviderRW<'_, &DB>, input: UnwindInput, diff --git a/crates/stages/src/stages/headers.rs b/crates/stages/src/stages/headers.rs index e57b736d6..9ad06a198 100644 --- a/crates/stages/src/stages/headers.rs +++ b/crates/stages/src/stages/headers.rs @@ -2,38 +2,24 @@ use crate::{ExecInput, ExecOutput, Stage, StageError, UnwindInput, UnwindOutput} use futures_util::StreamExt; use reth_db::{ cursor::{DbCursorRO, DbCursorRW}, - database::Database, + database::{Database, DatabaseGAT}, tables, transaction::{DbTx, DbTxMut}, }; use reth_interfaces::{ - p2p::headers::{ - downloader::{HeaderDownloader, SyncTarget}, - error::HeadersDownloaderError, - }, + p2p::headers::{downloader::HeaderDownloader, error::HeadersDownloaderError}, provider::ProviderError, }; use reth_primitives::{ stage::{ CheckpointBlockRange, EntitiesCheckpoint, HeadersCheckpoint, StageCheckpoint, StageId, }, - BlockHashOrNumber, BlockNumber, SealedHeader, B256, + BlockHashOrNumber, BlockNumber, SealedHeader, }; -use reth_provider::DatabaseProviderRW; -use tokio::sync::watch; +use reth_provider::{DatabaseProviderRW, HeaderSyncGap, HeaderSyncGapProvider, HeaderSyncMode}; +use std::task::{ready, Context, Poll}; use tracing::*; -/// The header sync mode. -#[derive(Debug)] -pub enum HeaderSyncMode { - /// A sync mode in which the stage continuously requests the downloader for - /// next blocks. - Continuous, - /// A sync mode in which the stage polls the receiver for the next tip - /// to download from. - Tip(watch::Receiver), -} - /// The headers stage. /// /// The headers stage downloads all block headers from the highest block in the local database to @@ -48,27 +34,33 @@ pub enum HeaderSyncMode { /// NOTE: This stage downloads headers in reverse. Upon returning the control flow to the pipeline, /// the stage checkpoint is not updated until this stage is done. #[derive(Debug)] -pub struct HeaderStage { +pub struct HeaderStage { + /// Database handle. + provider: Provider, /// Strategy for downloading the headers - downloader: D, + downloader: Downloader, /// The sync mode for the stage. mode: HeaderSyncMode, + /// Current sync gap. + sync_gap: Option, + /// Header buffer. + buffer: Vec, } // === impl HeaderStage === -impl HeaderStage +impl HeaderStage where - D: HeaderDownloader, + Downloader: HeaderDownloader, { /// Create a new header stage - pub fn new(downloader: D, mode: HeaderSyncMode) -> Self { - Self { downloader, mode } + pub fn new(database: Provider, downloader: Downloader, mode: HeaderSyncMode) -> Self { + Self { provider: database, downloader, mode, sync_gap: None, buffer: Vec::new() } } fn is_stage_done( &self, - tx: &>::TXMut, + tx: &>::TXMut, checkpoint: u64, ) -> Result { let mut header_cursor = tx.cursor_read::()?; @@ -79,75 +71,12 @@ where Ok(header_cursor.next()?.map(|(next_num, _)| head_num + 1 == next_num).unwrap_or_default()) } - /// Get the head and tip of the range we need to sync - /// - /// See also [SyncTarget] - async fn get_sync_gap( - &mut self, - provider: &DatabaseProviderRW<'_, &DB>, - checkpoint: u64, - ) -> Result { - // Create a cursor over canonical header hashes - let mut cursor = provider.tx_ref().cursor_read::()?; - let mut header_cursor = provider.tx_ref().cursor_read::()?; - - // Get head hash and reposition the cursor - let (head_num, head_hash) = cursor - .seek_exact(checkpoint)? - .ok_or_else(|| ProviderError::HeaderNotFound(checkpoint.into()))?; - - // Construct head - let (_, head) = header_cursor - .seek_exact(head_num)? - .ok_or_else(|| ProviderError::HeaderNotFound(head_num.into()))?; - let local_head = head.seal(head_hash); - - // Look up the next header - let next_header = cursor - .next()? - .map(|(next_num, next_hash)| -> Result { - let (_, next) = header_cursor - .seek_exact(next_num)? - .ok_or_else(|| ProviderError::HeaderNotFound(next_num.into()))?; - Ok(next.seal(next_hash)) - }) - .transpose()?; - - // Decide the tip or error out on invalid input. - // If the next element found in the cursor is not the "expected" next block per our current - // checkpoint, then there is a gap in the database and we should start downloading in - // reverse from there. Else, it should use whatever the forkchoice state reports. - let target = match next_header { - Some(header) if checkpoint + 1 != header.number => SyncTarget::Gap(header), - None => self - .next_sync_target(head_num) - .await - .ok_or(StageError::StageCheckpoint(checkpoint))?, - _ => return Err(StageError::StageCheckpoint(checkpoint)), - }; - - Ok(SyncGap { local_head, target }) - } - - async fn next_sync_target(&mut self, head: BlockNumber) -> Option { - match self.mode { - HeaderSyncMode::Tip(ref mut rx) => { - let tip = rx.wait_for(|tip| !tip.is_zero()).await.ok()?; - Some(SyncTarget::Tip(*tip)) - } - HeaderSyncMode::Continuous => { - trace!(target: "sync::stages::headers", head, "No next header found, using continuous sync strategy"); - Some(SyncTarget::TipNum(head + 1)) - } - } - } - /// Write downloaded headers to the given transaction /// /// Note: this writes the headers with rising block numbers. fn write_headers( &self, - tx: &>::TXMut, + tx: &>::TXMut, headers: Vec, ) -> Result, StageError> { trace!(target: "sync::stages::headers", len = headers.len(), "writing headers"); @@ -178,10 +107,10 @@ where } } -#[async_trait::async_trait] -impl Stage for HeaderStage +impl Stage for HeaderStage where DB: Database, + Provider: HeaderSyncGapProvider, D: HeaderDownloader, { /// Return the id of the stage @@ -189,20 +118,27 @@ where StageId::Headers } - /// Download the headers in reverse order (falling block numbers) - /// starting from the tip of the chain - async fn execute( + fn poll_execute_ready( &mut self, - provider: &DatabaseProviderRW<'_, &DB>, + cx: &mut Context<'_>, input: ExecInput, - ) -> Result { - let tx = provider.tx_ref(); + ) -> Poll> { let current_checkpoint = input.checkpoint(); + // Return if buffer already has some items. + if !self.buffer.is_empty() { + trace!( + target: "sync::stages::headers", + checkpoint = %current_checkpoint.block_number, + "Buffer is not empty" + ); + return Poll::Ready(Ok(())) + } + // Lookup the head and tip of the sync range - let gap = self.get_sync_gap(provider, current_checkpoint.block_number).await?; - let local_head = gap.local_head.number; + let gap = self.provider.sync_gap(self.mode.clone(), current_checkpoint.block_number)?; let tip = gap.target.tip(); + self.sync_gap = Some(gap.clone()); // Nothing to sync if gap.is_closed() { @@ -212,7 +148,7 @@ where target = ?tip, "Target block already reached" ); - return Ok(ExecOutput::done(current_checkpoint)) + return Poll::Ready(Ok(())) } debug!(target: "sync::stages::headers", ?tip, head = ?gap.local_head.hash(), "Commencing sync"); @@ -220,31 +156,44 @@ where // let the downloader know what to sync self.downloader.update_sync_gap(gap.local_head, gap.target); - // The downloader returns the headers in descending order starting from the tip - // down to the local head (latest block in db). - // Task downloader can return `None` only if the response relaying channel was closed. This - // is a fatal error to prevent the pipeline from running forever. - let downloaded_headers = match self.downloader.next().await { - Some(Ok(headers)) => headers, + let result = match ready!(self.downloader.poll_next_unpin(cx)) { + Some(Ok(headers)) => { + info!(target: "sync::stages::headers", len = headers.len(), "Received headers"); + self.buffer.extend(headers); + Ok(()) + } Some(Err(HeadersDownloaderError::DetachedHead { local_head, header, error })) => { error!(target: "sync::stages::headers", ?error, "Cannot attach header to head"); - return Err(StageError::DetachedHead { local_head, header, error }) + Err(StageError::DetachedHead { local_head, header, error }) } - None => return Err(StageError::ChannelClosed), + None => Err(StageError::ChannelClosed), }; + Poll::Ready(result) + } - info!(target: "sync::stages::headers", len = downloaded_headers.len(), "Received headers"); + /// Download the headers in reverse order (falling block numbers) + /// starting from the tip of the chain + fn execute( + &mut self, + provider: &DatabaseProviderRW<'_, &DB>, + input: ExecInput, + ) -> Result { + let current_checkpoint = input.checkpoint(); + if self.buffer.is_empty() { + return Ok(ExecOutput::done(current_checkpoint)) + } + let gap = self.sync_gap.clone().ok_or(StageError::MissingSyncGap)?; + let local_head = gap.local_head.number; + let tip = gap.target.tip(); + + let downloaded_headers = std::mem::take(&mut self.buffer); let tip_block_number = match tip { // If tip is hash and it equals to the first downloaded header's hash, we can use // the block number of this header as tip. - BlockHashOrNumber::Hash(hash) => downloaded_headers.first().and_then(|header| { - if header.hash == hash { - Some(header.number) - } else { - None - } - }), + BlockHashOrNumber::Hash(hash) => downloaded_headers + .first() + .and_then(|header| (header.hash == hash).then_some(header.number)), // If tip is number, we can just grab it and not resolve using downloaded headers. BlockHashOrNumber::Number(number) => Some(number), }; @@ -254,13 +203,14 @@ where // syncing towards, we need to take into account already synced headers from the database. // It is `None`, if tip didn't change and we're still downloading headers for previously // calculated gap. + let tx = provider.tx_ref(); let target_block_number = if let Some(tip_block_number) = tip_block_number { let local_max_block_number = tx .cursor_read::()? .last()? .map(|(canonical_block, _)| canonical_block); - Some(tip_block_number.max(local_max_block_number.unwrap_or(tip_block_number))) + Some(tip_block_number.max(local_max_block_number.unwrap_or_default())) } else { None }; @@ -278,18 +228,17 @@ where // `target_block_number` is guaranteed to be `Some`, because on the first iteration // we download the header for missing tip and use its block number. _ => { + let target = target_block_number.expect("No downloaded header for tip found"); HeadersCheckpoint { block_range: CheckpointBlockRange { from: input.checkpoint().block_number, - to: target_block_number.expect("No downloaded header for tip found"), + to: target, }, progress: EntitiesCheckpoint { // Set processed to the local head block number + number // of block already filled in the gap. - processed: local_head + - (target_block_number.unwrap_or_default() - - tip_block_number.unwrap_or_default()), - total: target_block_number.expect("No downloaded header for tip found"), + processed: local_head + (target - tip_block_number.unwrap_or_default()), + total: target, }, } } @@ -326,12 +275,14 @@ where } /// Unwind the stage. - async fn unwind( + fn unwind( &mut self, provider: &DatabaseProviderRW<'_, &DB>, input: UnwindInput, ) -> Result { - // TODO: handle bad block + self.buffer.clear(); + self.sync_gap.take(); + provider.unwind_table_by_walker::( input.unwind_to + 1, )?; @@ -359,46 +310,22 @@ where } } -/// Represents a gap to sync: from `local_head` to `target` -#[derive(Debug)] -pub struct SyncGap { - /// The local head block. Represents lower bound of sync range. - pub local_head: SealedHeader, - - /// The sync target. Represents upper bound of sync range. - pub target: SyncTarget, -} - -// === impl SyncGap === - -impl SyncGap { - /// Returns `true` if the gap from the head to the target was closed - #[inline] - pub fn is_closed(&self) -> bool { - match self.target.tip() { - BlockHashOrNumber::Hash(hash) => self.local_head.hash() == hash, - BlockHashOrNumber::Number(num) => self.local_head.number == num, - } - } -} - #[cfg(test)] mod tests { - use super::*; use crate::test_utils::{ stage_test_suite, ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner, }; use assert_matches::assert_matches; - use rand::Rng; - use reth_interfaces::test_utils::{generators, generators::random_header}; - use reth_primitives::{stage::StageUnitCheckpoint, B256, MAINNET}; + use reth_interfaces::test_utils::generators::random_header; + use reth_primitives::{stage::StageUnitCheckpoint, B256}; use reth_provider::ProviderFactory; use test_runner::HeadersTestRunner; mod test_runner { use super::*; use crate::test_utils::{TestRunnerError, TestTransaction}; + use reth_db::{test_utils::TempDatabase, DatabaseEnv}; use reth_downloaders::headers::reverse_headers::{ ReverseHeadersDownloader, ReverseHeadersDownloaderBuilder, }; @@ -409,6 +336,7 @@ mod tests { use reth_primitives::U256; use reth_provider::{BlockHashReader, BlockNumReader, HeaderProvider}; use std::sync::Arc; + use tokio::sync::watch; pub(crate) struct HeadersTestRunner { pub(crate) client: TestHeadersClient, @@ -437,17 +365,18 @@ mod tests { } impl StageTestRunner for HeadersTestRunner { - type S = HeaderStage; + type S = HeaderStage>>, D>; fn tx(&self) -> &TestTransaction { &self.tx } fn stage(&self) -> Self::S { - HeaderStage { - mode: HeaderSyncMode::Tip(self.channel.1.clone()), - downloader: (*self.downloader_factory)(), - } + HeaderStage::new( + self.tx.factory.clone(), + (*self.downloader_factory)(), + HeaderSyncMode::Tip(self.channel.1.clone()), + ) } } @@ -599,65 +528,6 @@ mod tests { assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed"); } - /// Test the head and tip range lookup - #[tokio::test] - async fn head_and_tip_lookup() { - let runner = HeadersTestRunner::default(); - let factory = ProviderFactory::new(runner.tx().tx.as_ref(), MAINNET.clone()); - let provider = factory.provider_rw().unwrap(); - let tx = provider.tx_ref(); - let mut stage = runner.stage(); - - let mut rng = generators::rng(); - - let consensus_tip = rng.gen(); - runner.send_tip(consensus_tip); - - // Genesis - let checkpoint = 0; - let head = random_header(&mut rng, 0, None); - let gap_fill = random_header(&mut rng, 1, Some(head.hash())); - let gap_tip = random_header(&mut rng, 2, Some(gap_fill.hash())); - - // Empty database - assert_matches!( - stage.get_sync_gap(&provider, checkpoint).await, - Err(StageError::DatabaseIntegrity(ProviderError::HeaderNotFound(block_number))) - if block_number.as_number().unwrap() == checkpoint - ); - - // Checkpoint and no gap - tx.put::(head.number, head.hash()) - .expect("failed to write canonical"); - tx.put::(head.number, head.clone().unseal()) - .expect("failed to write header"); - - let gap = stage.get_sync_gap(&provider, checkpoint).await.unwrap(); - assert_eq!(gap.local_head, head); - assert_eq!(gap.target.tip(), consensus_tip.into()); - - // Checkpoint and gap - tx.put::(gap_tip.number, gap_tip.hash()) - .expect("failed to write canonical"); - tx.put::(gap_tip.number, gap_tip.clone().unseal()) - .expect("failed to write header"); - - let gap = stage.get_sync_gap(&provider, checkpoint).await.unwrap(); - assert_eq!(gap.local_head, head); - assert_eq!(gap.target.tip(), gap_tip.parent_hash.into()); - - // Checkpoint and gap closed - tx.put::(gap_fill.number, gap_fill.hash()) - .expect("failed to write canonical"); - tx.put::(gap_fill.number, gap_fill.clone().unseal()) - .expect("failed to write header"); - - assert_matches!( - stage.get_sync_gap(&provider, checkpoint).await, - Err(StageError::StageCheckpoint(_checkpoint)) if _checkpoint == checkpoint - ); - } - /// Execute the stage in two steps #[tokio::test] async fn execute_from_previous_checkpoint() { diff --git a/crates/stages/src/stages/index_account_history.rs b/crates/stages/src/stages/index_account_history.rs index 0945538c3..b1e7721dc 100644 --- a/crates/stages/src/stages/index_account_history.rs +++ b/crates/stages/src/stages/index_account_history.rs @@ -35,7 +35,6 @@ impl Default for IndexAccountHistoryStage { } } -#[async_trait::async_trait] impl Stage for IndexAccountHistoryStage { /// Return the id of the stage fn id(&self) -> StageId { @@ -43,7 +42,7 @@ impl Stage for IndexAccountHistoryStage { } /// Execute the stage. - async fn execute( + fn execute( &mut self, provider: &DatabaseProviderRW<'_, &DB>, mut input: ExecInput, @@ -86,7 +85,7 @@ impl Stage for IndexAccountHistoryStage { } /// Unwind the stage. - async fn unwind( + fn unwind( &mut self, provider: &DatabaseProviderRW<'_, &DB>, input: UnwindInput, @@ -178,17 +177,17 @@ mod tests { .unwrap() } - async fn run(tx: &TestTransaction, run_to: u64) { + fn run(tx: &TestTransaction, run_to: u64) { let input = ExecInput { target: Some(run_to), ..Default::default() }; let mut stage = IndexAccountHistoryStage::default(); let factory = ProviderFactory::new(tx.tx.as_ref(), MAINNET.clone()); let provider = factory.provider_rw().unwrap(); - let out = stage.execute(&provider, input).await.unwrap(); + let out = stage.execute(&provider, input).unwrap(); assert_eq!(out, ExecOutput { checkpoint: StageCheckpoint::new(5), done: true }); provider.commit().unwrap(); } - async fn unwind(tx: &TestTransaction, unwind_from: u64, unwind_to: u64) { + fn unwind(tx: &TestTransaction, unwind_from: u64, unwind_to: u64) { let input = UnwindInput { checkpoint: StageCheckpoint::new(unwind_from), unwind_to, @@ -197,7 +196,7 @@ mod tests { let mut stage = IndexAccountHistoryStage::default(); let factory = ProviderFactory::new(tx.tx.as_ref(), MAINNET.clone()); let provider = factory.provider_rw().unwrap(); - let out = stage.unwind(&provider, input).await.unwrap(); + let out = stage.unwind(&provider, input).unwrap(); assert_eq!(out, UnwindOutput { checkpoint: StageCheckpoint::new(unwind_to) }); provider.commit().unwrap(); } @@ -211,14 +210,14 @@ mod tests { partial_setup(&tx); // run - run(&tx, 5).await; + run(&tx, 5); // verify let table = cast(tx.table::().unwrap()); assert_eq!(table, BTreeMap::from([(shard(u64::MAX), vec![4, 5])])); // unwind - unwind(&tx, 5, 0).await; + unwind(&tx, 5, 0); // verify initial state let table = tx.table::().unwrap(); @@ -239,14 +238,14 @@ mod tests { .unwrap(); // run - run(&tx, 5).await; + run(&tx, 5); // verify let table = cast(tx.table::().unwrap()); assert_eq!(table, BTreeMap::from([(shard(u64::MAX), vec![1, 2, 3, 4, 5]),])); // unwind - unwind(&tx, 5, 0).await; + unwind(&tx, 5, 0); // verify initial state let table = cast(tx.table::().unwrap()); @@ -268,7 +267,7 @@ mod tests { .unwrap(); // run - run(&tx, 5).await; + run(&tx, 5); // verify let table = cast(tx.table::().unwrap()); @@ -278,7 +277,7 @@ mod tests { ); // unwind - unwind(&tx, 5, 0).await; + unwind(&tx, 5, 0); // verify initial state let table = cast(tx.table::().unwrap()); @@ -300,7 +299,7 @@ mod tests { .unwrap(); // run - run(&tx, 5).await; + run(&tx, 5); // verify close_full_list.push(4); @@ -309,7 +308,7 @@ mod tests { assert_eq!(table, BTreeMap::from([(shard(u64::MAX), close_full_list.clone()),])); // unwind - unwind(&tx, 5, 0).await; + unwind(&tx, 5, 0); // verify initial state close_full_list.pop(); @@ -335,7 +334,7 @@ mod tests { .unwrap(); // run - run(&tx, 5).await; + run(&tx, 5); // verify close_full_list.push(4); @@ -346,7 +345,7 @@ mod tests { ); // unwind - unwind(&tx, 5, 0).await; + unwind(&tx, 5, 0); // verify initial state close_full_list.pop(); @@ -370,7 +369,7 @@ mod tests { }) .unwrap(); - run(&tx, 5).await; + run(&tx, 5); // verify let table = cast(tx.table::().unwrap()); @@ -384,7 +383,7 @@ mod tests { ); // unwind - unwind(&tx, 5, 0).await; + unwind(&tx, 5, 0); // verify initial state let table = cast(tx.table::().unwrap()); @@ -434,7 +433,7 @@ mod tests { }; let factory = ProviderFactory::new(tx.tx.as_ref(), MAINNET.clone()); let provider = factory.provider_rw().unwrap(); - let out = stage.execute(&provider, input).await.unwrap(); + let out = stage.execute(&provider, input).unwrap(); assert_eq!(out, ExecOutput { checkpoint: StageCheckpoint::new(20000), done: true }); provider.commit().unwrap(); @@ -443,7 +442,7 @@ mod tests { assert_eq!(table, BTreeMap::from([(shard(u64::MAX), vec![36, 100])])); // unwind - unwind(&tx, 20000, 0).await; + unwind(&tx, 20000, 0); // verify initial state let table = tx.table::().unwrap(); diff --git a/crates/stages/src/stages/index_storage_history.rs b/crates/stages/src/stages/index_storage_history.rs index b1e27aed1..f9896fb4f 100644 --- a/crates/stages/src/stages/index_storage_history.rs +++ b/crates/stages/src/stages/index_storage_history.rs @@ -34,7 +34,6 @@ impl Default for IndexStorageHistoryStage { } } -#[async_trait::async_trait] impl Stage for IndexStorageHistoryStage { /// Return the id of the stage fn id(&self) -> StageId { @@ -42,7 +41,7 @@ impl Stage for IndexStorageHistoryStage { } /// Execute the stage. - async fn execute( + fn execute( &mut self, provider: &DatabaseProviderRW<'_, &DB>, mut input: ExecInput, @@ -84,7 +83,7 @@ impl Stage for IndexStorageHistoryStage { } /// Unwind the stage. - async fn unwind( + fn unwind( &mut self, provider: &DatabaseProviderRW<'_, &DB>, input: UnwindInput, @@ -188,17 +187,17 @@ mod tests { .unwrap() } - async fn run(tx: &TestTransaction, run_to: u64) { + fn run(tx: &TestTransaction, run_to: u64) { let input = ExecInput { target: Some(run_to), ..Default::default() }; let mut stage = IndexStorageHistoryStage::default(); let factory = ProviderFactory::new(tx.tx.as_ref(), MAINNET.clone()); let provider = factory.provider_rw().unwrap(); - let out = stage.execute(&provider, input).await.unwrap(); + let out = stage.execute(&provider, input).unwrap(); assert_eq!(out, ExecOutput { checkpoint: StageCheckpoint::new(5), done: true }); provider.commit().unwrap(); } - async fn unwind(tx: &TestTransaction, unwind_from: u64, unwind_to: u64) { + fn unwind(tx: &TestTransaction, unwind_from: u64, unwind_to: u64) { let input = UnwindInput { checkpoint: StageCheckpoint::new(unwind_from), unwind_to, @@ -207,7 +206,7 @@ mod tests { let mut stage = IndexStorageHistoryStage::default(); let factory = ProviderFactory::new(tx.tx.as_ref(), MAINNET.clone()); let provider = factory.provider_rw().unwrap(); - let out = stage.unwind(&provider, input).await.unwrap(); + let out = stage.unwind(&provider, input).unwrap(); assert_eq!(out, UnwindOutput { checkpoint: StageCheckpoint::new(unwind_to) }); provider.commit().unwrap(); } @@ -221,14 +220,14 @@ mod tests { partial_setup(&tx); // run - run(&tx, 5).await; + run(&tx, 5); // verify let table = cast(tx.table::().unwrap()); assert_eq!(table, BTreeMap::from([(shard(u64::MAX), vec![4, 5]),])); // unwind - unwind(&tx, 5, 0).await; + unwind(&tx, 5, 0); // verify initial state let table = tx.table::().unwrap(); @@ -249,14 +248,14 @@ mod tests { .unwrap(); // run - run(&tx, 5).await; + run(&tx, 5); // verify let table = cast(tx.table::().unwrap()); assert_eq!(table, BTreeMap::from([(shard(u64::MAX), vec![1, 2, 3, 4, 5]),])); // unwind - unwind(&tx, 5, 0).await; + unwind(&tx, 5, 0); // verify initial state let table = cast(tx.table::().unwrap()); @@ -281,7 +280,7 @@ mod tests { .unwrap(); // run - run(&tx, 5).await; + run(&tx, 5); // verify let table = cast(tx.table::().unwrap()); @@ -291,7 +290,7 @@ mod tests { ); // unwind - unwind(&tx, 5, 0).await; + unwind(&tx, 5, 0); // verify initial state let table = cast(tx.table::().unwrap()); @@ -313,7 +312,7 @@ mod tests { .unwrap(); // run - run(&tx, 5).await; + run(&tx, 5); // verify close_full_list.push(4); @@ -322,7 +321,7 @@ mod tests { assert_eq!(table, BTreeMap::from([(shard(u64::MAX), close_full_list.clone()),])); // unwind - unwind(&tx, 5, 0).await; + unwind(&tx, 5, 0); // verify initial state close_full_list.pop(); @@ -348,7 +347,7 @@ mod tests { .unwrap(); // run - run(&tx, 5).await; + run(&tx, 5); // verify close_full_list.push(4); @@ -359,7 +358,7 @@ mod tests { ); // unwind - unwind(&tx, 5, 0).await; + unwind(&tx, 5, 0); // verify initial state close_full_list.pop(); @@ -383,7 +382,7 @@ mod tests { }) .unwrap(); - run(&tx, 5).await; + run(&tx, 5); // verify let table = cast(tx.table::().unwrap()); @@ -397,7 +396,7 @@ mod tests { ); // unwind - unwind(&tx, 5, 0).await; + unwind(&tx, 5, 0); // verify initial state let table = cast(tx.table::().unwrap()); @@ -447,7 +446,7 @@ mod tests { }; let factory = ProviderFactory::new(tx.tx.as_ref(), MAINNET.clone()); let provider = factory.provider_rw().unwrap(); - let out = stage.execute(&provider, input).await.unwrap(); + let out = stage.execute(&provider, input).unwrap(); assert_eq!(out, ExecOutput { checkpoint: StageCheckpoint::new(20000), done: true }); provider.commit().unwrap(); @@ -456,7 +455,7 @@ mod tests { assert_eq!(table, BTreeMap::from([(shard(u64::MAX), vec![36, 100]),])); // unwind - unwind(&tx, 20000, 0).await; + unwind(&tx, 20000, 0); // verify initial state let table = tx.table::().unwrap(); diff --git a/crates/stages/src/stages/merkle.rs b/crates/stages/src/stages/merkle.rs index cd02696ce..4354b5628 100644 --- a/crates/stages/src/stages/merkle.rs +++ b/crates/stages/src/stages/merkle.rs @@ -113,7 +113,6 @@ impl MerkleStage { } } -#[async_trait::async_trait] impl Stage for MerkleStage { /// Return the id of the stage fn id(&self) -> StageId { @@ -126,7 +125,7 @@ impl Stage for MerkleStage { } /// Execute the stage. - async fn execute( + fn execute( &mut self, provider: &DatabaseProviderRW<'_, &DB>, input: ExecInput, @@ -260,7 +259,7 @@ impl Stage for MerkleStage { } /// Unwind the stage. - async fn unwind( + fn unwind( &mut self, provider: &DatabaseProviderRW<'_, &DB>, input: UnwindInput, diff --git a/crates/stages/src/stages/mod.rs b/crates/stages/src/stages/mod.rs index d4eeaf2d3..771de3586 100644 --- a/crates/stages/src/stages/mod.rs +++ b/crates/stages/src/stages/mod.rs @@ -139,7 +139,7 @@ mod tests { prune_modes.clone(), ); - execution_stage.execute(&provider, input).await.unwrap(); + execution_stage.execute(&provider, input).unwrap(); assert_eq!( provider.receipts_by_block(1.into()).unwrap().unwrap().len(), expect_num_receipts @@ -163,9 +163,9 @@ mod tests { if let Some(PruneMode::Full) = prune_modes.account_history { // Full is not supported - assert!(acc_indexing_stage.execute(&provider, input).await.is_err()); + assert!(acc_indexing_stage.execute(&provider, input).is_err()); } else { - acc_indexing_stage.execute(&provider, input).await.unwrap(); + acc_indexing_stage.execute(&provider, input).unwrap(); let mut account_history: Cursor = provider.tx_ref().cursor_read::().unwrap(); assert_eq!(account_history.walk(None).unwrap().count(), expect_num_acc_changesets); @@ -179,9 +179,9 @@ mod tests { if let Some(PruneMode::Full) = prune_modes.storage_history { // Full is not supported - assert!(acc_indexing_stage.execute(&provider, input).await.is_err()); + assert!(acc_indexing_stage.execute(&provider, input).is_err()); } else { - storage_indexing_stage.execute(&provider, input).await.unwrap(); + storage_indexing_stage.execute(&provider, input).unwrap(); let mut storage_history = provider.tx_ref().cursor_read::().unwrap(); diff --git a/crates/stages/src/stages/sender_recovery.rs b/crates/stages/src/stages/sender_recovery.rs index 80ffb040a..cdafd9e62 100644 --- a/crates/stages/src/stages/sender_recovery.rs +++ b/crates/stages/src/stages/sender_recovery.rs @@ -16,9 +16,8 @@ use reth_primitives::{ use reth_provider::{ BlockReader, DatabaseProviderRW, HeaderProvider, ProviderError, PruneCheckpointReader, }; -use std::fmt::Debug; +use std::{fmt::Debug, sync::mpsc}; use thiserror::Error; -use tokio::sync::mpsc; use tracing::*; /// The sender recovery stage iterates over existing transactions, @@ -44,7 +43,6 @@ impl Default for SenderRecoveryStage { } } -#[async_trait::async_trait] impl Stage for SenderRecoveryStage { /// Return the id of the stage fn id(&self) -> StageId { @@ -56,7 +54,7 @@ impl Stage for SenderRecoveryStage { /// collect transactions within that range, /// recover signer for each transaction and store entries in /// the [`TxSenders`][reth_db::tables::TxSenders] table. - async fn execute( + fn execute( &mut self, provider: &DatabaseProviderRW<'_, &DB>, input: ExecInput, @@ -110,7 +108,7 @@ impl Stage for SenderRecoveryStage { for chunk in &tx_walker.chunks(chunk_size) { // An _unordered_ channel to receive results from a rayon job - let (recovered_senders_tx, recovered_senders_rx) = mpsc::unbounded_channel(); + let (recovered_senders_tx, recovered_senders_rx) = mpsc::channel(); channels.push(recovered_senders_rx); // Note: Unfortunate side-effect of how chunk is designed in itertools (it is not Send) let chunk: Vec<_> = chunk.collect(); @@ -128,8 +126,8 @@ impl Stage for SenderRecoveryStage { } // Iterate over channels and append the sender in the order that they are received. - for mut channel in channels { - while let Some(recovered) = channel.recv().await { + for channel in channels { + while let Ok(recovered) = channel.recv() { let (tx_id, sender) = match recovered { Ok(result) => result, Err(error) => { @@ -168,7 +166,7 @@ impl Stage for SenderRecoveryStage { } /// Unwind the stage. - async fn unwind( + fn unwind( &mut self, provider: &DatabaseProviderRW<'_, &DB>, input: UnwindInput, diff --git a/crates/stages/src/stages/total_difficulty.rs b/crates/stages/src/stages/total_difficulty.rs index ea1e20630..1cdaa971c 100644 --- a/crates/stages/src/stages/total_difficulty.rs +++ b/crates/stages/src/stages/total_difficulty.rs @@ -41,7 +41,6 @@ impl TotalDifficultyStage { } } -#[async_trait::async_trait] impl Stage for TotalDifficultyStage { /// Return the id of the stage fn id(&self) -> StageId { @@ -49,7 +48,7 @@ impl Stage for TotalDifficultyStage { } /// Write total difficulty entries - async fn execute( + fn execute( &mut self, provider: &DatabaseProviderRW<'_, &DB>, input: ExecInput, @@ -99,7 +98,7 @@ impl Stage for TotalDifficultyStage { } /// Unwind the stage. - async fn unwind( + fn unwind( &mut self, provider: &DatabaseProviderRW<'_, &DB>, input: UnwindInput, diff --git a/crates/stages/src/stages/tx_lookup.rs b/crates/stages/src/stages/tx_lookup.rs index 758fa4033..0de9ce74b 100644 --- a/crates/stages/src/stages/tx_lookup.rs +++ b/crates/stages/src/stages/tx_lookup.rs @@ -42,7 +42,6 @@ impl TransactionLookupStage { } } -#[async_trait::async_trait] impl Stage for TransactionLookupStage { /// Return the id of the stage fn id(&self) -> StageId { @@ -50,7 +49,7 @@ impl Stage for TransactionLookupStage { } /// Write transaction hash -> id entries - async fn execute( + fn execute( &mut self, provider: &DatabaseProviderRW<'_, &DB>, mut input: ExecInput, @@ -128,7 +127,7 @@ impl Stage for TransactionLookupStage { } /// Unwind the stage. - async fn unwind( + fn unwind( &mut self, provider: &DatabaseProviderRW<'_, &DB>, input: UnwindInput, diff --git a/crates/stages/src/test_utils/runner.rs b/crates/stages/src/test_utils/runner.rs index 9bc08638d..96c44cacb 100644 --- a/crates/stages/src/test_utils/runner.rs +++ b/crates/stages/src/test_utils/runner.rs @@ -4,7 +4,7 @@ use reth_db::DatabaseEnv; use reth_interfaces::db::DatabaseError; use reth_primitives::MAINNET; use reth_provider::{ProviderError, ProviderFactory}; -use std::{borrow::Borrow, sync::Arc}; +use std::{borrow::Borrow, future::poll_fn, sync::Arc}; use tokio::sync::oneshot; #[derive(thiserror::Error, Debug)] @@ -48,10 +48,13 @@ pub(crate) trait ExecuteStageTestRunner: StageTestRunner { let (db, mut stage) = (self.tx().inner_raw(), self.stage()); tokio::spawn(async move { let factory = ProviderFactory::new(db.db(), MAINNET.clone()); - let provider = factory.provider_rw().unwrap(); - let result = stage.execute(&provider, input).await; - provider.commit().expect("failed to commit"); + let result = poll_fn(|cx| stage.poll_execute_ready(cx, input)).await.and_then(|_| { + let provider_rw = factory.provider_rw().unwrap(); + let result = stage.execute(&provider_rw, input); + provider_rw.commit().expect("failed to commit"); + result + }); tx.send(result).expect("failed to send message") }); rx @@ -76,7 +79,7 @@ pub(crate) trait UnwindStageTestRunner: StageTestRunner { let factory = ProviderFactory::new(db.db(), MAINNET.clone()); let provider = factory.provider_rw().unwrap(); - let result = stage.unwind(&provider, input).await; + let result = stage.unwind(&provider, input); provider.commit().expect("failed to commit"); tx.send(result).expect("failed to send result"); }); diff --git a/crates/stages/src/test_utils/stage.rs b/crates/stages/src/test_utils/stage.rs index 65ea51362..85e88841b 100644 --- a/crates/stages/src/test_utils/stage.rs +++ b/crates/stages/src/test_utils/stage.rs @@ -40,13 +40,12 @@ impl TestStage { } } -#[async_trait::async_trait] impl Stage for TestStage { fn id(&self) -> StageId { self.id } - async fn execute( + fn execute( &mut self, _: &DatabaseProviderRW<'_, &DB>, _input: ExecInput, @@ -56,7 +55,7 @@ impl Stage for TestStage { .unwrap_or_else(|| panic!("Test stage {} executed too many times.", self.id)) } - async fn unwind( + fn unwind( &mut self, _: &DatabaseProviderRW<'_, &DB>, _input: UnwindInput, diff --git a/crates/storage/provider/src/lib.rs b/crates/storage/provider/src/lib.rs index 87118a635..194c60d50 100644 --- a/crates/storage/provider/src/lib.rs +++ b/crates/storage/provider/src/lib.rs @@ -21,11 +21,11 @@ pub use traits::{ BlockWriter, BlockchainTreePendingStateProvider, BundleStateDataProvider, CanonChainTracker, CanonStateNotification, CanonStateNotificationSender, CanonStateNotifications, CanonStateSubscriptions, ChainSpecProvider, ChangeSetReader, EvmEnvProvider, ExecutorFactory, - HashingWriter, HeaderProvider, HistoryWriter, PrunableBlockExecutor, PruneCheckpointReader, - PruneCheckpointWriter, ReceiptProvider, ReceiptProviderIdExt, StageCheckpointReader, - StageCheckpointWriter, StateProvider, StateProviderBox, StateProviderFactory, - StateRootProvider, StorageReader, TransactionVariant, TransactionsProvider, - TransactionsProviderExt, WithdrawalsProvider, + HashingWriter, HeaderProvider, HeaderSyncGap, HeaderSyncGapProvider, HeaderSyncMode, + HistoryWriter, PrunableBlockExecutor, PruneCheckpointReader, PruneCheckpointWriter, + ReceiptProvider, ReceiptProviderIdExt, StageCheckpointReader, StageCheckpointWriter, + StateProvider, StateProviderBox, StateProviderFactory, StateRootProvider, StorageReader, + TransactionVariant, TransactionsProvider, TransactionsProviderExt, WithdrawalsProvider, }; /// Provider trait implementations. diff --git a/crates/storage/provider/src/providers/database/mod.rs b/crates/storage/provider/src/providers/database/mod.rs index 38b4be901..c21cbdd68 100644 --- a/crates/storage/provider/src/providers/database/mod.rs +++ b/crates/storage/provider/src/providers/database/mod.rs @@ -5,8 +5,9 @@ use crate::{ }, traits::{BlockSource, ReceiptProvider}, BlockHashReader, BlockNumReader, BlockReader, ChainSpecProvider, EvmEnvProvider, - HeaderProvider, ProviderError, PruneCheckpointReader, StageCheckpointReader, StateProviderBox, - TransactionVariant, TransactionsProvider, WithdrawalsProvider, + HeaderProvider, HeaderSyncGap, HeaderSyncGapProvider, HeaderSyncMode, ProviderError, + PruneCheckpointReader, StageCheckpointReader, StateProviderBox, TransactionVariant, + TransactionsProvider, WithdrawalsProvider, }; use reth_db::{database::Database, init_db, models::StoredBlockBodyIndices, DatabaseEnv}; use reth_interfaces::{db::LogLevel, provider::ProviderResult, RethError, RethResult}; @@ -196,6 +197,16 @@ impl ProviderFactory { } } +impl HeaderSyncGapProvider for ProviderFactory { + fn sync_gap( + &self, + mode: HeaderSyncMode, + highest_uninterrupted_block: BlockNumber, + ) -> RethResult { + self.provider()?.sync_gap(mode, highest_uninterrupted_block) + } +} + impl HeaderProvider for ProviderFactory { fn header(&self, block_hash: &BlockHash) -> ProviderResult> { self.provider()?.header(block_hash) @@ -477,19 +488,32 @@ impl PruneCheckpointReader for ProviderFactory { #[cfg(test)] mod tests { use super::ProviderFactory; - use crate::{BlockHashReader, BlockNumReader, BlockWriter, TransactionsProvider}; + use crate::{ + BlockHashReader, BlockNumReader, BlockWriter, HeaderSyncGapProvider, HeaderSyncMode, + TransactionsProvider, + }; use alloy_rlp::Decodable; use assert_matches::assert_matches; + use rand::Rng; use reth_db::{ tables, test_utils::{create_test_rw_db, ERROR_TEMPDIR}, + transaction::DbTxMut, DatabaseEnv, }; - use reth_interfaces::test_utils::{generators, generators::random_block}; + use reth_interfaces::{ + provider::ProviderError, + test_utils::{ + generators, + generators::{random_block, random_header}, + }, + RethError, + }; use reth_primitives::{ hex_literal::hex, ChainSpecBuilder, PruneMode, PruneModes, SealedBlock, TxNumber, B256, }; use std::{ops::RangeInclusive, sync::Arc}; + use tokio::sync::watch; #[test] fn common_history_provider() { @@ -618,4 +642,73 @@ mod tests { ) } } + + #[test] + fn header_sync_gap_lookup() { + let mut rng = generators::rng(); + let chain_spec = ChainSpecBuilder::mainnet().build(); + let db = create_test_rw_db(); + let factory = ProviderFactory::new(db, Arc::new(chain_spec)); + let provider = factory.provider_rw().unwrap(); + + let consensus_tip = rng.gen(); + let (_tip_tx, tip_rx) = watch::channel(consensus_tip); + let mode = HeaderSyncMode::Tip(tip_rx); + + // Genesis + let checkpoint = 0; + let head = random_header(&mut rng, 0, None); + let gap_fill = random_header(&mut rng, 1, Some(head.hash())); + let gap_tip = random_header(&mut rng, 2, Some(gap_fill.hash())); + + // Empty database + assert_matches!( + provider.sync_gap(mode.clone(), checkpoint), + Err(RethError::Provider(ProviderError::HeaderNotFound(block_number))) + if block_number.as_number().unwrap() == checkpoint + ); + + // Checkpoint and no gap + provider + .tx_ref() + .put::(head.number, head.hash()) + .expect("failed to write canonical"); + provider + .tx_ref() + .put::(head.number, head.clone().unseal()) + .expect("failed to write header"); + + let gap = provider.sync_gap(mode.clone(), checkpoint).unwrap(); + assert_eq!(gap.local_head, head); + assert_eq!(gap.target.tip(), consensus_tip.into()); + + // Checkpoint and gap + provider + .tx_ref() + .put::(gap_tip.number, gap_tip.hash()) + .expect("failed to write canonical"); + provider + .tx_ref() + .put::(gap_tip.number, gap_tip.clone().unseal()) + .expect("failed to write header"); + + let gap = provider.sync_gap(mode.clone(), checkpoint).unwrap(); + assert_eq!(gap.local_head, head); + assert_eq!(gap.target.tip(), gap_tip.parent_hash.into()); + + // Checkpoint and gap closed + provider + .tx_ref() + .put::(gap_fill.number, gap_fill.hash()) + .expect("failed to write canonical"); + provider + .tx_ref() + .put::(gap_fill.number, gap_fill.clone().unseal()) + .expect("failed to write header"); + + assert_matches!( + provider.sync_gap(mode, checkpoint), + Err(RethError::Provider(ProviderError::InconsistentHeaderGap)) + ); + } } diff --git a/crates/storage/provider/src/providers/database/provider.rs b/crates/storage/provider/src/providers/database/provider.rs index 198aeb553..ad289f198 100644 --- a/crates/storage/provider/src/providers/database/provider.rs +++ b/crates/storage/provider/src/providers/database/provider.rs @@ -5,10 +5,10 @@ use crate::{ AccountExtReader, BlockSource, ChangeSetReader, ReceiptProvider, StageCheckpointWriter, }, AccountReader, BlockExecutionWriter, BlockHashReader, BlockNumReader, BlockReader, BlockWriter, - Chain, EvmEnvProvider, HashingWriter, HeaderProvider, HistoryWriter, OriginalValuesKnown, - ProviderError, PruneCheckpointReader, PruneCheckpointWriter, StageCheckpointReader, - StorageReader, TransactionVariant, TransactionsProvider, TransactionsProviderExt, - WithdrawalsProvider, + Chain, EvmEnvProvider, HashingWriter, HeaderProvider, HeaderSyncGap, HeaderSyncGapProvider, + HeaderSyncMode, HistoryWriter, OriginalValuesKnown, ProviderError, PruneCheckpointReader, + PruneCheckpointWriter, StageCheckpointReader, StorageReader, TransactionVariant, + TransactionsProvider, TransactionsProviderExt, WithdrawalsProvider, }; use itertools::{izip, Itertools}; use reth_db::{ @@ -24,7 +24,11 @@ use reth_db::{ transaction::{DbTx, DbTxMut}, BlockNumberList, DatabaseError, }; -use reth_interfaces::provider::{ProviderResult, RootMismatch}; +use reth_interfaces::{ + p2p::headers::downloader::SyncTarget, + provider::{ProviderResult, RootMismatch}, + RethError, RethResult, +}; use reth_primitives::{ keccak256, revm::{ @@ -868,6 +872,57 @@ impl ChangeSetReader for DatabaseProvider { } } +impl HeaderSyncGapProvider for DatabaseProvider { + fn sync_gap( + &self, + mode: HeaderSyncMode, + highest_uninterrupted_block: BlockNumber, + ) -> RethResult { + // Create a cursor over canonical header hashes + let mut cursor = self.tx.cursor_read::()?; + let mut header_cursor = self.tx.cursor_read::()?; + + // Get head hash and reposition the cursor + let (head_num, head_hash) = cursor + .seek_exact(highest_uninterrupted_block)? + .ok_or_else(|| ProviderError::HeaderNotFound(highest_uninterrupted_block.into()))?; + + // Construct head + let (_, head) = header_cursor + .seek_exact(head_num)? + .ok_or_else(|| ProviderError::HeaderNotFound(head_num.into()))?; + let local_head = head.seal(head_hash); + + // Look up the next header + let next_header = cursor + .next()? + .map(|(next_num, next_hash)| -> Result { + let (_, next) = header_cursor + .seek_exact(next_num)? + .ok_or_else(|| ProviderError::HeaderNotFound(next_num.into()))?; + Ok(next.seal(next_hash)) + }) + .transpose()?; + + // Decide the tip or error out on invalid input. + // If the next element found in the cursor is not the "expected" next block per our current + // checkpoint, then there is a gap in the database and we should start downloading in + // reverse from there. Else, it should use whatever the forkchoice state reports. + let target = match next_header { + Some(header) if highest_uninterrupted_block + 1 != header.number => { + SyncTarget::Gap(header) + } + None => match mode { + HeaderSyncMode::Tip(rx) => SyncTarget::Tip(*rx.borrow()), + HeaderSyncMode::Continuous => SyncTarget::TipNum(head_num + 1), + }, + _ => return Err(ProviderError::InconsistentHeaderGap.into()), + }; + + Ok(HeaderSyncGap { local_head, target }) + } +} + impl HeaderProvider for DatabaseProvider { fn header(&self, block_hash: &BlockHash) -> ProviderResult> { if let Some(num) = self.block_number(*block_hash)? { diff --git a/crates/storage/provider/src/traits/header_sync_gap.rs b/crates/storage/provider/src/traits/header_sync_gap.rs new file mode 100644 index 000000000..576a26a9e --- /dev/null +++ b/crates/storage/provider/src/traits/header_sync_gap.rs @@ -0,0 +1,50 @@ +use auto_impl::auto_impl; +use reth_interfaces::{p2p::headers::downloader::SyncTarget, RethResult}; +use reth_primitives::{BlockHashOrNumber, BlockNumber, SealedHeader, B256}; +use tokio::sync::watch; + +/// The header sync mode. +#[derive(Clone, Debug)] +pub enum HeaderSyncMode { + /// A sync mode in which the stage continuously requests the downloader for + /// next blocks. + Continuous, + /// A sync mode in which the stage polls the receiver for the next tip + /// to download from. + Tip(watch::Receiver), +} + +/// Represents a gap to sync: from `local_head` to `target` +#[derive(Clone, Debug)] +pub struct HeaderSyncGap { + /// The local head block. Represents lower bound of sync range. + pub local_head: SealedHeader, + + /// The sync target. Represents upper bound of sync range. + pub target: SyncTarget, +} + +impl HeaderSyncGap { + /// Returns `true` if the gap from the head to the target was closed + #[inline] + pub fn is_closed(&self) -> bool { + match self.target.tip() { + BlockHashOrNumber::Hash(hash) => self.local_head.hash() == hash, + BlockHashOrNumber::Number(num) => self.local_head.number == num, + } + } +} + +/// Client trait for determining the current headers sync gap. +#[auto_impl(&, Arc)] +pub trait HeaderSyncGapProvider: Send + Sync { + /// Find a current sync gap for the headers depending on the [HeaderSyncMode] and the last + /// uninterrupted block number. Last uninterrupted block represents the block number before + /// which there are no gaps. It's up to the caller to ensure that last uninterrupted block is + /// determined correctly. + fn sync_gap( + &self, + mode: HeaderSyncMode, + highest_uninterrupted_block: BlockNumber, + ) -> RethResult; +} diff --git a/crates/storage/provider/src/traits/mod.rs b/crates/storage/provider/src/traits/mod.rs index 8134a1961..64f806f5f 100644 --- a/crates/storage/provider/src/traits/mod.rs +++ b/crates/storage/provider/src/traits/mod.rs @@ -27,6 +27,9 @@ pub use chain_info::CanonChainTracker; mod header; pub use header::HeaderProvider; +mod header_sync_gap; +pub use header_sync_gap::{HeaderSyncGap, HeaderSyncGapProvider, HeaderSyncMode}; + mod receipts; pub use receipts::{ReceiptProvider, ReceiptProviderIdExt}; diff --git a/testing/ef-tests/src/cases/blockchain_test.rs b/testing/ef-tests/src/cases/blockchain_test.rs index 5d9a4bf86..d77555029 100644 --- a/testing/ef-tests/src/cases/blockchain_test.rs +++ b/testing/ef-tests/src/cases/blockchain_test.rs @@ -111,8 +111,7 @@ impl Case for BlockchainTestCase { .expect("Could not build tokio RT") .block_on(async { // ignore error - let _ = - stage.execute(&provider, ExecInput { target, checkpoint: None }).await; + let _ = stage.execute(&provider, ExecInput { target, checkpoint: None }); }); }