chore(sync): migrate pipeline to ProviderFactory (#5532)

This commit is contained in:
Roman Krasiuk
2023-11-22 08:40:56 -08:00
committed by GitHub
parent 5e378b13ca
commit 5ae4fd1c65
41 changed files with 700 additions and 774 deletions

View File

@ -86,6 +86,7 @@ impl ImportCommand {
info!(target: "reth::cli", path = ?db_path, "Opening database");
let db = Arc::new(init_db(db_path, self.db.log_level)?);
info!(target: "reth::cli", "Database opened");
let provider_factory = ProviderFactory::new(db.clone(), self.chain.clone());
debug!(target: "reth::cli", chain=%self.chain.chain, genesis=?self.chain.genesis_hash(), "Initializing genesis");
@ -102,15 +103,15 @@ impl ImportCommand {
let tip = file_client.tip().expect("file client has no tip");
info!(target: "reth::cli", "Chain file imported");
let (mut pipeline, events) =
self.build_import_pipeline(config, Arc::clone(&db), &consensus, file_client).await?;
let (mut pipeline, events) = self
.build_import_pipeline(config, provider_factory.clone(), &consensus, file_client)
.await?;
// override the tip
pipeline.set_tip(tip);
debug!(target: "reth::cli", ?tip, "Tip manually set");
let factory = ProviderFactory::new(db.clone(), self.chain.clone());
let provider = factory.provider()?;
let provider = provider_factory.provider()?;
let latest_block_number =
provider.get_stage_checkpoint(StageId::Finish)?.map(|ch| ch.block_number);
@ -130,7 +131,7 @@ impl ImportCommand {
async fn build_import_pipeline<DB, C>(
&self,
config: Config,
db: DB,
provider_factory: ProviderFactory<DB>,
consensus: &Arc<C>,
file_client: Arc<FileClient>,
) -> eyre::Result<(Pipeline<DB>, impl Stream<Item = NodeEvent>)>
@ -147,11 +148,7 @@ impl ImportCommand {
.into_task();
let body_downloader = BodiesDownloaderBuilder::from(config.stages.bodies)
.build(
file_client.clone(),
consensus.clone(),
ProviderFactory::new(db.clone(), self.chain.clone()),
)
.build(file_client.clone(), consensus.clone(), provider_factory.clone())
.into_task();
let (tip_tx, tip_rx) = watch::channel(B256::ZERO);
@ -164,7 +161,7 @@ impl ImportCommand {
.with_max_block(max_block)
.add_stages(
DefaultStages::new(
ProviderFactory::new(db.clone(), self.chain.clone()),
provider_factory.clone(),
HeaderSyncMode::Tip(tip_rx),
consensus.clone(),
header_downloader,
@ -194,7 +191,7 @@ impl ImportCommand {
config.prune.map(|prune| prune.segments).unwrap_or_default(),
)),
)
.build(db, self.chain.clone());
.build(provider_factory);
let events = pipeline.events().map(Into::into);

View File

@ -89,7 +89,7 @@ impl Command {
config: &Config,
client: Client,
consensus: Arc<dyn Consensus>,
db: DB,
provider_factory: ProviderFactory<DB>,
task_executor: &TaskExecutor,
) -> eyre::Result<Pipeline<DB>>
where
@ -102,11 +102,7 @@ impl Command {
.into_task_with(task_executor);
let body_downloader = BodiesDownloaderBuilder::from(config.stages.bodies)
.build(
client,
Arc::clone(&consensus),
ProviderFactory::new(db.clone(), self.chain.clone()),
)
.build(client, Arc::clone(&consensus), provider_factory.clone())
.into_task_with(task_executor);
let stage_conf = &config.stages;
@ -119,7 +115,7 @@ impl Command {
.with_tip_sender(tip_tx)
.add_stages(
DefaultStages::new(
ProviderFactory::new(db.clone(), self.chain.clone()),
provider_factory.clone(),
header_mode,
Arc::clone(&consensus),
header_downloader,
@ -148,7 +144,7 @@ impl Command {
config.prune.as_ref().map(|prune| prune.segments.clone()).unwrap_or_default(),
)),
)
.build(db, self.chain.clone());
.build(provider_factory);
Ok(pipeline)
}
@ -206,6 +202,7 @@ impl Command {
let db_path = data_dir.db_path();
fs::create_dir_all(&db_path)?;
let db = Arc::new(init_db(db_path, self.db.log_level)?);
let provider_factory = ProviderFactory::new(db.clone(), self.chain.clone());
debug!(target: "reth::cli", chain=%self.chain.chain, genesis=?self.chain.genesis_hash(), "Initializing genesis");
init_genesis(db.clone(), self.chain.clone())?;
@ -231,12 +228,11 @@ impl Command {
&config,
fetch_client.clone(),
Arc::clone(&consensus),
db.clone(),
provider_factory.clone(),
&ctx.task_executor,
)?;
let factory = ProviderFactory::new(db.clone(), self.chain.clone());
let provider = factory.provider()?;
let provider = provider_factory.provider()?;
let latest_block_number =
provider.get_stage_checkpoint(StageId::Finish)?.map(|ch| ch.block_number);
@ -270,7 +266,7 @@ impl Command {
// Unwind the pipeline without committing.
{
factory
provider_factory
.provider_rw()?
.take_block_and_execution_range(&self.chain, next_block..=target_block)?;
}

View File

@ -259,14 +259,16 @@ impl<Ext: RethCliExt> NodeCommand<Ext> {
let db = Arc::new(init_db(&db_path, self.db.log_level)?.with_metrics());
info!(target: "reth::cli", "Database opened");
let mut provider_factory = ProviderFactory::new(Arc::clone(&db), Arc::clone(&self.chain));
// configure snapshotter
let snapshotter = reth_snapshot::Snapshotter::new(
db.clone(),
provider_factory.clone(),
data_dir.snapshots_path(),
self.chain.clone(),
self.chain.snapshot_block_interval,
)?;
let provider_factory = ProviderFactory::new(Arc::clone(&db), Arc::clone(&self.chain))
provider_factory = provider_factory
.with_snapshots(data_dir.snapshots_path(), snapshotter.highest_snapshot_receiver());
self.start_metrics_endpoint(prometheus_handle, Arc::clone(&db)).await?;
@ -309,7 +311,8 @@ impl<Ext: RethCliExt> NodeCommand<Ext> {
let head = self.lookup_head(Arc::clone(&db)).wrap_err("the head block is missing")?;
// setup the blockchain provider
let blockchain_db = BlockchainProvider::new(provider_factory, blockchain_tree.clone())?;
let blockchain_db =
BlockchainProvider::new(provider_factory.clone(), blockchain_tree.clone())?;
let blob_store = InMemoryBlobStore::default();
let validator = TransactionValidationTaskExecutor::eth_builder(Arc::clone(&self.chain))
.with_head_timestamp(head.timestamp)
@ -417,7 +420,7 @@ impl<Ext: RethCliExt> NodeCommand<Ext> {
&config,
client.clone(),
Arc::clone(&consensus),
db.clone(),
provider_factory,
&ctx.task_executor,
sync_metrics_tx,
prune_config.clone(),
@ -437,7 +440,7 @@ impl<Ext: RethCliExt> NodeCommand<Ext> {
&config,
network_client.clone(),
Arc::clone(&consensus),
db.clone(),
provider_factory,
&ctx.task_executor,
sync_metrics_tx,
prune_config.clone(),
@ -601,7 +604,7 @@ impl<Ext: RethCliExt> NodeCommand<Ext> {
config: &Config,
client: Client,
consensus: Arc<dyn Consensus>,
db: DB,
provider_factory: ProviderFactory<DB>,
task_executor: &TaskExecutor,
metrics_tx: reth_stages::MetricEventsSender,
prune_config: Option<PruneConfig>,
@ -617,16 +620,12 @@ impl<Ext: RethCliExt> NodeCommand<Ext> {
.into_task_with(task_executor);
let body_downloader = BodiesDownloaderBuilder::from(config.stages.bodies)
.build(
client,
Arc::clone(&consensus),
ProviderFactory::new(db.clone(), self.chain.clone()),
)
.build(client, Arc::clone(&consensus), provider_factory.clone())
.into_task_with(task_executor);
let pipeline = self
.build_pipeline(
db,
provider_factory,
config,
header_downloader,
body_downloader,
@ -848,7 +847,7 @@ impl<Ext: RethCliExt> NodeCommand<Ext> {
#[allow(clippy::too_many_arguments)]
async fn build_pipeline<DB, H, B>(
&self,
db: DB,
provider_factory: ProviderFactory<DB>,
config: &Config,
header_downloader: H,
body_downloader: B,
@ -900,7 +899,7 @@ impl<Ext: RethCliExt> NodeCommand<Ext> {
.with_metrics_tx(metrics_tx.clone())
.add_stages(
DefaultStages::new(
ProviderFactory::new(db.clone(), self.chain.clone()),
provider_factory.clone(),
header_mode,
Arc::clone(&consensus),
header_downloader,
@ -953,7 +952,7 @@ impl<Ext: RethCliExt> NodeCommand<Ext> {
prune_modes.storage_history,
)),
)
.build(db, self.chain.clone());
.build(provider_factory);
Ok(pipeline)
}

View File

@ -124,7 +124,7 @@ impl Command {
let db = Arc::new(init_db(db_path, self.db.log_level)?);
info!(target: "reth::cli", "Database opened");
let factory = ProviderFactory::new(&db, self.chain.clone());
let factory = ProviderFactory::new(Arc::clone(&db), self.chain.clone());
let mut provider_rw = factory.provider_rw()?;
if let Some(listen_addr) = self.metrics {

View File

@ -403,7 +403,10 @@ mod tests {
constants::ETHEREUM_BLOCK_GAS_LIMIT, stage::StageCheckpoint, BlockBody, ChainSpec,
ChainSpecBuilder, Header, SealedHeader, MAINNET,
};
use reth_provider::{test_utils::TestExecutorFactory, BundleStateWithReceipts};
use reth_provider::{
test_utils::{create_test_provider_factory_with_chain_spec, TestExecutorFactory},
BundleStateWithReceipts, ProviderFactory,
};
use reth_stages::{test_utils::TestStages, ExecOutput, StageError};
use reth_tasks::TokioTaskExecutor;
use std::{collections::VecDeque, future::poll_fn, sync::Arc};
@ -451,7 +454,6 @@ mod tests {
/// Builds the pipeline.
fn build(self, chain_spec: Arc<ChainSpec>) -> Pipeline<Arc<TempDatabase<DatabaseEnv>>> {
reth_tracing::init_test_tracing();
let db = create_test_rw_db();
let executor_factory = TestExecutorFactory::new(chain_spec.clone());
executor_factory.extend(self.executor_results);
@ -466,7 +468,7 @@ mod tests {
pipeline = pipeline.with_max_block(max_block);
}
pipeline.build(db, chain_spec)
pipeline.build(create_test_provider_factory_with_chain_spec(chain_spec))
}
}

View File

@ -516,7 +516,7 @@ where
pipeline = pipeline.with_max_block(max_block);
}
let pipeline = pipeline.build(db.clone(), self.base_config.chain_spec.clone());
let pipeline = pipeline.build(provider_factory.clone());
// Setup blockchain tree
let externals = TreeExternals::new(provider_factory.clone(), consensus, executor_factory);

View File

@ -1,5 +1,5 @@
use super::headers::client::HeadersRequest;
use crate::{consensus::ConsensusError, provider::ProviderError};
use crate::{consensus::ConsensusError, db::DatabaseError, provider::ProviderError};
use reth_network_api::ReputationChangeKind;
use reth_primitives::{
BlockHashOrNumber, BlockNumber, GotExpected, GotExpectedBoxed, Header, WithPeerId, B256,
@ -182,6 +182,12 @@ pub enum DownloadError {
Provider(#[from] ProviderError),
}
impl From<DatabaseError> for DownloadError {
fn from(error: DatabaseError) -> Self {
Self::Provider(ProviderError::Database(error))
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -90,16 +90,16 @@ mod tests {
};
use reth_primitives::{BlockNumber, PruneCheckpoint, PruneMode, PruneSegment, B256};
use reth_provider::PruneCheckpointReader;
use reth_stages::test_utils::TestTransaction;
use reth_stages::test_utils::TestStageDB;
use std::{collections::BTreeMap, ops::AddAssign};
#[test]
fn prune() {
let tx = TestTransaction::default();
let db = TestStageDB::default();
let mut rng = generators::rng();
let blocks = random_block_range(&mut rng, 1..=5000, B256::ZERO, 0..1);
tx.insert_blocks(blocks.iter(), None).expect("insert blocks");
db.insert_blocks(blocks.iter(), None).expect("insert blocks");
let accounts =
random_eoa_account_range(&mut rng, 0..2).into_iter().collect::<BTreeMap<_, _>>();
@ -111,10 +111,10 @@ mod tests {
0..0,
0..0,
);
tx.insert_changesets(changesets.clone(), None).expect("insert changesets");
tx.insert_history(changesets.clone(), None).expect("insert history");
db.insert_changesets(changesets.clone(), None).expect("insert changesets");
db.insert_history(changesets.clone(), None).expect("insert history");
let account_occurrences = tx.table::<tables::AccountHistory>().unwrap().into_iter().fold(
let account_occurrences = db.table::<tables::AccountHistory>().unwrap().into_iter().fold(
BTreeMap::<_, usize>::new(),
|mut map, (key, _)| {
map.entry(key.key).or_default().add_assign(1);
@ -124,17 +124,19 @@ mod tests {
assert!(account_occurrences.into_iter().any(|(_, occurrences)| occurrences > 1));
assert_eq!(
tx.table::<tables::AccountChangeSet>().unwrap().len(),
db.table::<tables::AccountChangeSet>().unwrap().len(),
changesets.iter().flatten().count()
);
let original_shards = tx.table::<tables::AccountHistory>().unwrap();
let original_shards = db.table::<tables::AccountHistory>().unwrap();
let test_prune = |to_block: BlockNumber, run: usize, expected_result: (bool, usize)| {
let prune_mode = PruneMode::Before(to_block);
let input = PruneInput {
previous_checkpoint: tx
.inner()
previous_checkpoint: db
.factory
.provider()
.unwrap()
.get_prune_checkpoint(PruneSegment::AccountHistory)
.unwrap(),
to_block,
@ -142,7 +144,7 @@ mod tests {
};
let segment = AccountHistory::new(prune_mode);
let provider = tx.inner_rw();
let provider = db.factory.provider_rw().unwrap();
let result = segment.prune(&provider, input).unwrap();
assert_matches!(
result,
@ -200,11 +202,11 @@ mod tests {
);
assert_eq!(
tx.table::<tables::AccountChangeSet>().unwrap().len(),
db.table::<tables::AccountChangeSet>().unwrap().len(),
pruned_changesets.values().flatten().count()
);
let actual_shards = tx.table::<tables::AccountHistory>().unwrap();
let actual_shards = db.table::<tables::AccountHistory>().unwrap();
let expected_shards = original_shards
.iter()
@ -221,7 +223,11 @@ mod tests {
assert_eq!(actual_shards, expected_shards);
assert_eq!(
tx.inner().get_prune_checkpoint(PruneSegment::AccountHistory).unwrap(),
db.factory
.provider()
.unwrap()
.get_prune_checkpoint(PruneSegment::AccountHistory)
.unwrap(),
Some(PruneCheckpoint {
block_number: Some(last_pruned_block_number),
tx_number: None,

View File

@ -116,25 +116,27 @@ mod tests {
use reth_interfaces::test_utils::{generators, generators::random_header_range};
use reth_primitives::{BlockNumber, PruneCheckpoint, PruneMode, PruneSegment, B256};
use reth_provider::PruneCheckpointReader;
use reth_stages::test_utils::TestTransaction;
use reth_stages::test_utils::TestStageDB;
#[test]
fn prune() {
let tx = TestTransaction::default();
let db = TestStageDB::default();
let mut rng = generators::rng();
let headers = random_header_range(&mut rng, 0..100, B256::ZERO);
tx.insert_headers_with_td(headers.iter()).expect("insert headers");
db.insert_headers_with_td(headers.iter()).expect("insert headers");
assert_eq!(tx.table::<tables::CanonicalHeaders>().unwrap().len(), headers.len());
assert_eq!(tx.table::<tables::Headers>().unwrap().len(), headers.len());
assert_eq!(tx.table::<tables::HeaderTD>().unwrap().len(), headers.len());
assert_eq!(db.table::<tables::CanonicalHeaders>().unwrap().len(), headers.len());
assert_eq!(db.table::<tables::Headers>().unwrap().len(), headers.len());
assert_eq!(db.table::<tables::HeaderTD>().unwrap().len(), headers.len());
let test_prune = |to_block: BlockNumber, expected_result: (bool, usize)| {
let prune_mode = PruneMode::Before(to_block);
let input = PruneInput {
previous_checkpoint: tx
.inner()
previous_checkpoint: db
.factory
.provider()
.unwrap()
.get_prune_checkpoint(PruneSegment::Headers)
.unwrap(),
to_block,
@ -142,15 +144,17 @@ mod tests {
};
let segment = Headers::new(prune_mode);
let next_block_number_to_prune = tx
.inner()
let next_block_number_to_prune = db
.factory
.provider()
.unwrap()
.get_prune_checkpoint(PruneSegment::Headers)
.unwrap()
.and_then(|checkpoint| checkpoint.block_number)
.map(|block_number| block_number + 1)
.unwrap_or_default();
let provider = tx.inner_rw();
let provider = db.factory.provider_rw().unwrap();
let result = segment.prune(&provider, input).unwrap();
assert_matches!(
result,
@ -169,19 +173,19 @@ mod tests {
.min(next_block_number_to_prune + input.delete_limit as BlockNumber / 3 - 1);
assert_eq!(
tx.table::<tables::CanonicalHeaders>().unwrap().len(),
db.table::<tables::CanonicalHeaders>().unwrap().len(),
headers.len() - (last_pruned_block_number + 1) as usize
);
assert_eq!(
tx.table::<tables::Headers>().unwrap().len(),
db.table::<tables::Headers>().unwrap().len(),
headers.len() - (last_pruned_block_number + 1) as usize
);
assert_eq!(
tx.table::<tables::HeaderTD>().unwrap().len(),
db.table::<tables::HeaderTD>().unwrap().len(),
headers.len() - (last_pruned_block_number + 1) as usize
);
assert_eq!(
tx.inner().get_prune_checkpoint(PruneSegment::Headers).unwrap(),
db.factory.provider().unwrap().get_prune_checkpoint(PruneSegment::Headers).unwrap(),
Some(PruneCheckpoint {
block_number: Some(last_pruned_block_number),
tx_number: None,
@ -196,7 +200,7 @@ mod tests {
#[test]
fn prune_cannot_be_done() {
let tx = TestTransaction::default();
let db = TestStageDB::default();
let input = PruneInput {
previous_checkpoint: None,
@ -206,7 +210,7 @@ mod tests {
};
let segment = Headers::new(PruneMode::Full);
let provider = tx.inner_rw();
let provider = db.factory.provider_rw().unwrap();
let result = segment.prune(&provider, input).unwrap();
assert_eq!(result, PruneOutput::not_done());
}

View File

@ -99,16 +99,16 @@ mod tests {
};
use reth_primitives::{BlockNumber, PruneCheckpoint, PruneMode, PruneSegment, TxNumber, B256};
use reth_provider::PruneCheckpointReader;
use reth_stages::test_utils::TestTransaction;
use reth_stages::test_utils::TestStageDB;
use std::ops::Sub;
#[test]
fn prune() {
let tx = TestTransaction::default();
let db = TestStageDB::default();
let mut rng = generators::rng();
let blocks = random_block_range(&mut rng, 1..=10, B256::ZERO, 2..3);
tx.insert_blocks(blocks.iter(), None).expect("insert blocks");
db.insert_blocks(blocks.iter(), None).expect("insert blocks");
let mut receipts = Vec::new();
for block in &blocks {
@ -117,22 +117,24 @@ mod tests {
.push((receipts.len() as u64, random_receipt(&mut rng, transaction, Some(0))));
}
}
tx.insert_receipts(receipts.clone()).expect("insert receipts");
db.insert_receipts(receipts.clone()).expect("insert receipts");
assert_eq!(
tx.table::<tables::Transactions>().unwrap().len(),
db.table::<tables::Transactions>().unwrap().len(),
blocks.iter().map(|block| block.body.len()).sum::<usize>()
);
assert_eq!(
tx.table::<tables::Transactions>().unwrap().len(),
tx.table::<tables::Receipts>().unwrap().len()
db.table::<tables::Transactions>().unwrap().len(),
db.table::<tables::Receipts>().unwrap().len()
);
let test_prune = |to_block: BlockNumber, expected_result: (bool, usize)| {
let prune_mode = PruneMode::Before(to_block);
let input = PruneInput {
previous_checkpoint: tx
.inner()
previous_checkpoint: db
.factory
.provider()
.unwrap()
.get_prune_checkpoint(PruneSegment::Receipts)
.unwrap(),
to_block,
@ -140,8 +142,10 @@ mod tests {
};
let segment = Receipts::new(prune_mode);
let next_tx_number_to_prune = tx
.inner()
let next_tx_number_to_prune = db
.factory
.provider()
.unwrap()
.get_prune_checkpoint(PruneSegment::Receipts)
.unwrap()
.and_then(|checkpoint| checkpoint.tx_number)
@ -156,7 +160,7 @@ mod tests {
.min(next_tx_number_to_prune as usize + input.delete_limit)
.sub(1);
let provider = tx.inner_rw();
let provider = db.factory.provider_rw().unwrap();
let result = segment.prune(&provider, input).unwrap();
assert_matches!(
result,
@ -187,11 +191,15 @@ mod tests {
.checked_sub(if result.done { 0 } else { 1 });
assert_eq!(
tx.table::<tables::Receipts>().unwrap().len(),
db.table::<tables::Receipts>().unwrap().len(),
receipts.len() - (last_pruned_tx_number + 1)
);
assert_eq!(
tx.inner().get_prune_checkpoint(PruneSegment::Receipts).unwrap(),
db.factory
.provider()
.unwrap()
.get_prune_checkpoint(PruneSegment::Receipts)
.unwrap(),
Some(PruneCheckpoint {
block_number: last_pruned_block_number,
tx_number: Some(last_pruned_tx_number as TxNumber),

View File

@ -216,12 +216,12 @@ mod tests {
};
use reth_primitives::{PruneMode, PruneSegment, ReceiptsLogPruneConfig, B256};
use reth_provider::{PruneCheckpointReader, TransactionsProvider};
use reth_stages::test_utils::TestTransaction;
use reth_stages::test_utils::TestStageDB;
use std::collections::BTreeMap;
#[test]
fn prune_receipts_by_logs() {
let tx = TestTransaction::default();
let db = TestStageDB::default();
let mut rng = generators::rng();
let tip = 20000;
@ -231,7 +231,7 @@ mod tests {
random_block_range(&mut rng, (tip - 100 + 1)..=tip, B256::ZERO, 1..5),
]
.concat();
tx.insert_blocks(blocks.iter(), None).expect("insert blocks");
db.insert_blocks(blocks.iter(), None).expect("insert blocks");
let mut receipts = Vec::new();
@ -247,19 +247,19 @@ mod tests {
receipts.push((receipts.len() as u64, receipt));
}
}
tx.insert_receipts(receipts).expect("insert receipts");
db.insert_receipts(receipts).expect("insert receipts");
assert_eq!(
tx.table::<tables::Transactions>().unwrap().len(),
db.table::<tables::Transactions>().unwrap().len(),
blocks.iter().map(|block| block.body.len()).sum::<usize>()
);
assert_eq!(
tx.table::<tables::Transactions>().unwrap().len(),
tx.table::<tables::Receipts>().unwrap().len()
db.table::<tables::Transactions>().unwrap().len(),
db.table::<tables::Receipts>().unwrap().len()
);
let run_prune = || {
let provider = tx.inner_rw();
let provider = db.factory.provider_rw().unwrap();
let prune_before_block: usize = 20;
let prune_mode = PruneMode::Before(prune_before_block as u64);
@ -269,8 +269,10 @@ mod tests {
let result = ReceiptsByLogs::new(receipts_log_filter).prune(
&provider,
PruneInput {
previous_checkpoint: tx
.inner()
previous_checkpoint: db
.factory
.provider()
.unwrap()
.get_prune_checkpoint(PruneSegment::ContractLogs)
.unwrap(),
to_block: tip,
@ -282,8 +284,10 @@ mod tests {
assert_matches!(result, Ok(_));
let output = result.unwrap();
let (pruned_block, pruned_tx) = tx
.inner()
let (pruned_block, pruned_tx) = db
.factory
.provider()
.unwrap()
.get_prune_checkpoint(PruneSegment::ContractLogs)
.unwrap()
.map(|checkpoint| (checkpoint.block_number.unwrap(), checkpoint.tx_number.unwrap()))
@ -293,7 +297,7 @@ mod tests {
let unprunable = pruned_block.saturating_sub(prune_before_block as u64 - 1);
assert_eq!(
tx.table::<tables::Receipts>().unwrap().len(),
db.table::<tables::Receipts>().unwrap().len(),
blocks.iter().map(|block| block.body.len()).sum::<usize>() -
((pruned_tx + 1) - unprunable) as usize
);
@ -303,7 +307,7 @@ mod tests {
while !run_prune() {}
let provider = tx.inner();
let provider = db.factory.provider().unwrap();
let mut cursor = provider.tx_ref().cursor_read::<tables::Receipts>().unwrap();
let walker = cursor.walk(None).unwrap();
for receipt in walker {

View File

@ -81,16 +81,16 @@ mod tests {
use reth_interfaces::test_utils::{generators, generators::random_block_range};
use reth_primitives::{BlockNumber, PruneCheckpoint, PruneMode, PruneSegment, TxNumber, B256};
use reth_provider::PruneCheckpointReader;
use reth_stages::test_utils::TestTransaction;
use reth_stages::test_utils::TestStageDB;
use std::ops::Sub;
#[test]
fn prune() {
let tx = TestTransaction::default();
let db = TestStageDB::default();
let mut rng = generators::rng();
let blocks = random_block_range(&mut rng, 1..=10, B256::ZERO, 2..3);
tx.insert_blocks(blocks.iter(), None).expect("insert blocks");
db.insert_blocks(blocks.iter(), None).expect("insert blocks");
let mut transaction_senders = Vec::new();
for block in &blocks {
@ -101,23 +101,25 @@ mod tests {
));
}
}
tx.insert_transaction_senders(transaction_senders.clone())
db.insert_transaction_senders(transaction_senders.clone())
.expect("insert transaction senders");
assert_eq!(
tx.table::<tables::Transactions>().unwrap().len(),
db.table::<tables::Transactions>().unwrap().len(),
blocks.iter().map(|block| block.body.len()).sum::<usize>()
);
assert_eq!(
tx.table::<tables::Transactions>().unwrap().len(),
tx.table::<tables::TxSenders>().unwrap().len()
db.table::<tables::Transactions>().unwrap().len(),
db.table::<tables::TxSenders>().unwrap().len()
);
let test_prune = |to_block: BlockNumber, expected_result: (bool, usize)| {
let prune_mode = PruneMode::Before(to_block);
let input = PruneInput {
previous_checkpoint: tx
.inner()
previous_checkpoint: db
.factory
.provider()
.unwrap()
.get_prune_checkpoint(PruneSegment::SenderRecovery)
.unwrap(),
to_block,
@ -125,8 +127,10 @@ mod tests {
};
let segment = SenderRecovery::new(prune_mode);
let next_tx_number_to_prune = tx
.inner()
let next_tx_number_to_prune = db
.factory
.provider()
.unwrap()
.get_prune_checkpoint(PruneSegment::SenderRecovery)
.unwrap()
.and_then(|checkpoint| checkpoint.tx_number)
@ -155,7 +159,7 @@ mod tests {
.into_inner()
.0;
let provider = tx.inner_rw();
let provider = db.factory.provider_rw().unwrap();
let result = segment.prune(&provider, input).unwrap();
assert_matches!(
result,
@ -174,11 +178,15 @@ mod tests {
last_pruned_block_number.checked_sub(if result.done { 0 } else { 1 });
assert_eq!(
tx.table::<tables::TxSenders>().unwrap().len(),
db.table::<tables::TxSenders>().unwrap().len(),
transaction_senders.len() - (last_pruned_tx_number + 1)
);
assert_eq!(
tx.inner().get_prune_checkpoint(PruneSegment::SenderRecovery).unwrap(),
db.factory
.provider()
.unwrap()
.get_prune_checkpoint(PruneSegment::SenderRecovery)
.unwrap(),
Some(PruneCheckpoint {
block_number: last_pruned_block_number,
tx_number: Some(last_pruned_tx_number as TxNumber),

View File

@ -94,16 +94,16 @@ mod tests {
};
use reth_primitives::{BlockNumber, PruneCheckpoint, PruneMode, PruneSegment, B256};
use reth_provider::PruneCheckpointReader;
use reth_stages::test_utils::TestTransaction;
use reth_stages::test_utils::TestStageDB;
use std::{collections::BTreeMap, ops::AddAssign};
#[test]
fn prune() {
let tx = TestTransaction::default();
let db = TestStageDB::default();
let mut rng = generators::rng();
let blocks = random_block_range(&mut rng, 0..=5000, B256::ZERO, 0..1);
tx.insert_blocks(blocks.iter(), None).expect("insert blocks");
db.insert_blocks(blocks.iter(), None).expect("insert blocks");
let accounts =
random_eoa_account_range(&mut rng, 0..2).into_iter().collect::<BTreeMap<_, _>>();
@ -115,10 +115,10 @@ mod tests {
2..3,
1..2,
);
tx.insert_changesets(changesets.clone(), None).expect("insert changesets");
tx.insert_history(changesets.clone(), None).expect("insert history");
db.insert_changesets(changesets.clone(), None).expect("insert changesets");
db.insert_history(changesets.clone(), None).expect("insert history");
let storage_occurrences = tx.table::<tables::StorageHistory>().unwrap().into_iter().fold(
let storage_occurrences = db.table::<tables::StorageHistory>().unwrap().into_iter().fold(
BTreeMap::<_, usize>::new(),
|mut map, (key, _)| {
map.entry((key.address, key.sharded_key.key)).or_default().add_assign(1);
@ -128,17 +128,19 @@ mod tests {
assert!(storage_occurrences.into_iter().any(|(_, occurrences)| occurrences > 1));
assert_eq!(
tx.table::<tables::StorageChangeSet>().unwrap().len(),
db.table::<tables::StorageChangeSet>().unwrap().len(),
changesets.iter().flatten().flat_map(|(_, _, entries)| entries).count()
);
let original_shards = tx.table::<tables::StorageHistory>().unwrap();
let original_shards = db.table::<tables::StorageHistory>().unwrap();
let test_prune = |to_block: BlockNumber, run: usize, expected_result: (bool, usize)| {
let prune_mode = PruneMode::Before(to_block);
let input = PruneInput {
previous_checkpoint: tx
.inner()
previous_checkpoint: db
.factory
.provider()
.unwrap()
.get_prune_checkpoint(PruneSegment::StorageHistory)
.unwrap(),
to_block,
@ -146,7 +148,7 @@ mod tests {
};
let segment = StorageHistory::new(prune_mode);
let provider = tx.inner_rw();
let provider = db.factory.provider_rw().unwrap();
let result = segment.prune(&provider, input).unwrap();
assert_matches!(
result,
@ -206,11 +208,11 @@ mod tests {
);
assert_eq!(
tx.table::<tables::StorageChangeSet>().unwrap().len(),
db.table::<tables::StorageChangeSet>().unwrap().len(),
pruned_changesets.values().flatten().count()
);
let actual_shards = tx.table::<tables::StorageHistory>().unwrap();
let actual_shards = db.table::<tables::StorageHistory>().unwrap();
let expected_shards = original_shards
.iter()
@ -227,7 +229,11 @@ mod tests {
assert_eq!(actual_shards, expected_shards);
assert_eq!(
tx.inner().get_prune_checkpoint(PruneSegment::StorageHistory).unwrap(),
db.factory
.provider()
.unwrap()
.get_prune_checkpoint(PruneSegment::StorageHistory)
.unwrap(),
Some(PruneCheckpoint {
block_number: Some(last_pruned_block_number),
tx_number: None,

View File

@ -104,16 +104,16 @@ mod tests {
use reth_interfaces::test_utils::{generators, generators::random_block_range};
use reth_primitives::{BlockNumber, PruneCheckpoint, PruneMode, PruneSegment, TxNumber, B256};
use reth_provider::PruneCheckpointReader;
use reth_stages::test_utils::TestTransaction;
use reth_stages::test_utils::TestStageDB;
use std::ops::Sub;
#[test]
fn prune() {
let tx = TestTransaction::default();
let db = TestStageDB::default();
let mut rng = generators::rng();
let blocks = random_block_range(&mut rng, 1..=10, B256::ZERO, 2..3);
tx.insert_blocks(blocks.iter(), None).expect("insert blocks");
db.insert_blocks(blocks.iter(), None).expect("insert blocks");
let mut tx_hash_numbers = Vec::new();
for block in &blocks {
@ -121,22 +121,24 @@ mod tests {
tx_hash_numbers.push((transaction.hash, tx_hash_numbers.len() as u64));
}
}
tx.insert_tx_hash_numbers(tx_hash_numbers.clone()).expect("insert tx hash numbers");
db.insert_tx_hash_numbers(tx_hash_numbers.clone()).expect("insert tx hash numbers");
assert_eq!(
tx.table::<tables::Transactions>().unwrap().len(),
db.table::<tables::Transactions>().unwrap().len(),
blocks.iter().map(|block| block.body.len()).sum::<usize>()
);
assert_eq!(
tx.table::<tables::Transactions>().unwrap().len(),
tx.table::<tables::TxHashNumber>().unwrap().len()
db.table::<tables::Transactions>().unwrap().len(),
db.table::<tables::TxHashNumber>().unwrap().len()
);
let test_prune = |to_block: BlockNumber, expected_result: (bool, usize)| {
let prune_mode = PruneMode::Before(to_block);
let input = PruneInput {
previous_checkpoint: tx
.inner()
previous_checkpoint: db
.factory
.provider()
.unwrap()
.get_prune_checkpoint(PruneSegment::TransactionLookup)
.unwrap(),
to_block,
@ -144,8 +146,10 @@ mod tests {
};
let segment = TransactionLookup::new(prune_mode);
let next_tx_number_to_prune = tx
.inner()
let next_tx_number_to_prune = db
.factory
.provider()
.unwrap()
.get_prune_checkpoint(PruneSegment::TransactionLookup)
.unwrap()
.and_then(|checkpoint| checkpoint.tx_number)
@ -174,7 +178,7 @@ mod tests {
.into_inner()
.0;
let provider = tx.inner_rw();
let provider = db.factory.provider_rw().unwrap();
let result = segment.prune(&provider, input).unwrap();
assert_matches!(
result,
@ -193,11 +197,15 @@ mod tests {
last_pruned_block_number.checked_sub(if result.done { 0 } else { 1 });
assert_eq!(
tx.table::<tables::TxHashNumber>().unwrap().len(),
db.table::<tables::TxHashNumber>().unwrap().len(),
tx_hash_numbers.len() - (last_pruned_tx_number + 1)
);
assert_eq!(
tx.inner().get_prune_checkpoint(PruneSegment::TransactionLookup).unwrap(),
db.factory
.provider()
.unwrap()
.get_prune_checkpoint(PruneSegment::TransactionLookup)
.unwrap(),
Some(PruneCheckpoint {
block_number: last_pruned_block_number,
tx_number: Some(last_pruned_tx_number as TxNumber),

View File

@ -80,26 +80,28 @@ mod tests {
use reth_interfaces::test_utils::{generators, generators::random_block_range};
use reth_primitives::{BlockNumber, PruneCheckpoint, PruneMode, PruneSegment, TxNumber, B256};
use reth_provider::PruneCheckpointReader;
use reth_stages::test_utils::TestTransaction;
use reth_stages::test_utils::TestStageDB;
use std::ops::Sub;
#[test]
fn prune() {
let tx = TestTransaction::default();
let db = TestStageDB::default();
let mut rng = generators::rng();
let blocks = random_block_range(&mut rng, 1..=100, B256::ZERO, 2..3);
tx.insert_blocks(blocks.iter(), None).expect("insert blocks");
db.insert_blocks(blocks.iter(), None).expect("insert blocks");
let transactions = blocks.iter().flat_map(|block| &block.body).collect::<Vec<_>>();
assert_eq!(tx.table::<tables::Transactions>().unwrap().len(), transactions.len());
assert_eq!(db.table::<tables::Transactions>().unwrap().len(), transactions.len());
let test_prune = |to_block: BlockNumber, expected_result: (bool, usize)| {
let prune_mode = PruneMode::Before(to_block);
let input = PruneInput {
previous_checkpoint: tx
.inner()
previous_checkpoint: db
.factory
.provider()
.unwrap()
.get_prune_checkpoint(PruneSegment::Transactions)
.unwrap(),
to_block,
@ -107,15 +109,17 @@ mod tests {
};
let segment = Transactions::new(prune_mode);
let next_tx_number_to_prune = tx
.inner()
let next_tx_number_to_prune = db
.factory
.provider()
.unwrap()
.get_prune_checkpoint(PruneSegment::Transactions)
.unwrap()
.and_then(|checkpoint| checkpoint.tx_number)
.map(|tx_number| tx_number + 1)
.unwrap_or_default();
let provider = tx.inner_rw();
let provider = db.factory.provider_rw().unwrap();
let result = segment.prune(&provider, input).unwrap();
assert_matches!(
result,
@ -154,11 +158,15 @@ mod tests {
.checked_sub(if result.done { 0 } else { 1 });
assert_eq!(
tx.table::<tables::Transactions>().unwrap().len(),
db.table::<tables::Transactions>().unwrap().len(),
transactions.len() - (last_pruned_tx_number + 1)
);
assert_eq!(
tx.inner().get_prune_checkpoint(PruneSegment::Transactions).unwrap(),
db.factory
.provider()
.unwrap()
.get_prune_checkpoint(PruneSegment::Transactions)
.unwrap(),
Some(PruneCheckpoint {
block_number: last_pruned_block_number,
tx_number: Some(last_pruned_tx_number as TxNumber),

View File

@ -5,14 +5,13 @@ use reth_db::database::Database;
use reth_interfaces::{RethError, RethResult};
use reth_primitives::{
snapshot::{iter_snapshots, HighestSnapshots},
BlockNumber, ChainSpec, TxNumber,
BlockNumber, TxNumber,
};
use reth_provider::{BlockReader, DatabaseProviderRO, ProviderFactory, TransactionsProviderExt};
use std::{
collections::HashMap,
ops::RangeInclusive,
path::{Path, PathBuf},
sync::Arc,
};
use tokio::sync::watch;
use tracing::warn;
@ -94,15 +93,14 @@ impl SnapshotTargets {
impl<DB: Database> Snapshotter<DB> {
/// Creates a new [Snapshotter].
pub fn new(
db: DB,
provider_factory: ProviderFactory<DB>,
snapshots_path: impl AsRef<Path>,
chain_spec: Arc<ChainSpec>,
block_interval: u64,
) -> RethResult<Self> {
let (highest_snapshots_notifier, highest_snapshots_tracker) = watch::channel(None);
let mut snapshotter = Self {
provider_factory: ProviderFactory::new(db, chain_spec),
provider_factory,
snapshots_path: snapshots_path.as_ref().into(),
highest_snapshots: HighestSnapshots::default(),
highest_snapshots_notifier,
@ -329,16 +327,14 @@ mod tests {
test_utils::{generators, generators::random_block_range},
RethError,
};
use reth_primitives::{snapshot::HighestSnapshots, B256, MAINNET};
use reth_stages::test_utils::TestTransaction;
use reth_primitives::{snapshot::HighestSnapshots, B256};
use reth_stages::test_utils::TestStageDB;
#[test]
fn new() {
let tx = TestTransaction::default();
let db = TestStageDB::default();
let snapshots_dir = tempfile::TempDir::new().unwrap();
let snapshotter =
Snapshotter::new(tx.inner_raw(), snapshots_dir.into_path(), MAINNET.clone(), 2)
.unwrap();
let snapshotter = Snapshotter::new(db.factory, snapshots_dir.into_path(), 2).unwrap();
assert_eq!(
*snapshotter.highest_snapshot_receiver().borrow(),
@ -348,16 +344,14 @@ mod tests {
#[test]
fn get_snapshot_targets() {
let tx = TestTransaction::default();
let db = TestStageDB::default();
let snapshots_dir = tempfile::TempDir::new().unwrap();
let mut rng = generators::rng();
let blocks = random_block_range(&mut rng, 0..=3, B256::ZERO, 2..3);
tx.insert_blocks(blocks.iter(), None).expect("insert blocks");
db.insert_blocks(blocks.iter(), None).expect("insert blocks");
let mut snapshotter =
Snapshotter::new(tx.inner_raw(), snapshots_dir.into_path(), MAINNET.clone(), 2)
.unwrap();
let mut snapshotter = Snapshotter::new(db.factory, snapshots_dir.into_path(), 2).unwrap();
// Snapshot targets has data per part up to the passed finalized block number,
// respecting the block interval

View File

@ -9,7 +9,7 @@ use reth_primitives::{stage::StageCheckpoint, MAINNET};
use reth_provider::ProviderFactory;
use reth_stages::{
stages::{MerkleStage, SenderRecoveryStage, TotalDifficultyStage, TransactionLookupStage},
test_utils::TestTransaction,
test_utils::TestStageDB,
ExecInput, Stage, StageExt, UnwindInput,
};
use std::{path::PathBuf, sync::Arc};
@ -123,9 +123,9 @@ fn measure_stage_with_path<F, S>(
label: String,
) where
S: Clone + Stage<DatabaseEnv>,
F: Fn(S, &TestTransaction, StageRange),
F: Fn(S, &TestStageDB, StageRange),
{
let tx = TestTransaction::new(&path);
let tx = TestStageDB::new(&path);
let (input, _) = stage_range;
group.bench_function(label, move |b| {
@ -136,7 +136,7 @@ fn measure_stage_with_path<F, S>(
},
|_| async {
let mut stage = stage.clone();
let factory = ProviderFactory::new(tx.tx.db(), MAINNET.clone());
let factory = ProviderFactory::new(tx.factory.db(), MAINNET.clone());
let provider = factory.provider_rw().unwrap();
stage
.execute_ready(input)
@ -157,7 +157,7 @@ fn measure_stage<F, S>(
label: String,
) where
S: Clone + Stage<DatabaseEnv>,
F: Fn(S, &TestTransaction, StageRange),
F: Fn(S, &TestStageDB, StageRange),
{
let path = setup::txs_testdata(block_interval.end);

View File

@ -5,7 +5,7 @@ use reth_db::{
use reth_primitives::stage::StageCheckpoint;
use reth_stages::{
stages::{AccountHashingStage, SeedOpts},
test_utils::TestTransaction,
test_utils::TestStageDB,
ExecInput, UnwindInput,
};
use std::path::{Path, PathBuf};
@ -31,8 +31,8 @@ pub fn prepare_account_hashing(num_blocks: u64) -> (PathBuf, AccountHashingStage
fn find_stage_range(db: &Path) -> StageRange {
let mut stage_range = None;
TestTransaction::new(db)
.tx
TestStageDB::new(db)
.factory
.view(|tx| {
let mut cursor = tx.cursor_read::<tables::BlockBodyIndices>()?;
let from = cursor.first()?.unwrap().0;
@ -62,8 +62,8 @@ fn generate_testdata_db(num_blocks: u64) -> (PathBuf, StageRange) {
// create the dirs
std::fs::create_dir_all(&path).unwrap();
println!("Account Hashing testdata not found, generating to {:?}", path.display());
let tx = TestTransaction::new(&path);
let provider = tx.inner_rw();
let tx = TestStageDB::new(&path);
let provider = tx.provider_rw();
let _accounts = AccountHashingStage::seed(&provider, opts);
provider.commit().expect("failed to commit");
}

View File

@ -16,7 +16,7 @@ use reth_primitives::{Account, Address, SealedBlock, B256, MAINNET};
use reth_provider::ProviderFactory;
use reth_stages::{
stages::{AccountHashingStage, StorageHashingStage},
test_utils::TestTransaction,
test_utils::TestStageDB,
ExecInput, Stage, UnwindInput,
};
use reth_trie::StateRoot;
@ -34,14 +34,14 @@ pub(crate) type StageRange = (ExecInput, UnwindInput);
pub(crate) fn stage_unwind<S: Clone + Stage<DatabaseEnv>>(
stage: S,
tx: &TestTransaction,
db: &TestStageDB,
range: StageRange,
) {
let (_, unwind) = range;
tokio::runtime::Runtime::new().unwrap().block_on(async {
let mut stage = stage.clone();
let factory = ProviderFactory::new(tx.tx.db(), MAINNET.clone());
let factory = ProviderFactory::new(db.factory.db(), MAINNET.clone());
let provider = factory.provider_rw().unwrap();
// Clear previous run
@ -50,7 +50,7 @@ pub(crate) fn stage_unwind<S: Clone + Stage<DatabaseEnv>>(
.map_err(|e| {
format!(
"{e}\nMake sure your test database at `{}` isn't too old and incompatible with newer stage changes.",
tx.path.as_ref().unwrap().display()
db.path.as_ref().unwrap().display()
)
})
.unwrap();
@ -61,13 +61,13 @@ pub(crate) fn stage_unwind<S: Clone + Stage<DatabaseEnv>>(
pub(crate) fn unwind_hashes<S: Clone + Stage<DatabaseEnv>>(
stage: S,
tx: &TestTransaction,
db: &TestStageDB,
range: StageRange,
) {
let (input, unwind) = range;
let mut stage = stage.clone();
let factory = ProviderFactory::new(tx.tx.db(), MAINNET.clone());
let factory = ProviderFactory::new(db.factory.db(), MAINNET.clone());
let provider = factory.provider_rw().unwrap();
StorageHashingStage::default().unwind(&provider, unwind).unwrap();
@ -105,7 +105,7 @@ pub(crate) fn txs_testdata(num_blocks: u64) -> PathBuf {
// create the dirs
std::fs::create_dir_all(&path).unwrap();
println!("Transactions testdata not found, generating to {:?}", path.display());
let tx = TestTransaction::new(&path);
let tx = TestStageDB::new(&path);
let accounts: BTreeMap<Address, Account> = concat([
random_eoa_account_range(&mut rng, 0..n_eoa),
@ -127,7 +127,8 @@ pub(crate) fn txs_testdata(num_blocks: u64) -> PathBuf {
tx.insert_accounts_and_storages(start_state.clone()).unwrap();
// make first block after genesis have valid state root
let (root, updates) = StateRoot::new(tx.inner_rw().tx_ref()).root_with_updates().unwrap();
let (root, updates) =
StateRoot::new(tx.provider_rw().tx_ref()).root_with_updates().unwrap();
let second_block = blocks.get_mut(1).unwrap();
let cloned_second = second_block.clone();
let mut updated_header = cloned_second.header.unseal();
@ -153,7 +154,7 @@ pub(crate) fn txs_testdata(num_blocks: u64) -> PathBuf {
// make last block have valid state root
let root = {
let tx_mut = tx.inner_rw();
let tx_mut = tx.provider_rw();
let root = StateRoot::new(tx_mut.tx_ref()).root().unwrap();
tx_mut.commit().unwrap();
root

View File

@ -13,7 +13,6 @@
//!
//! ```
//! # use std::sync::Arc;
//! # use reth_db::test_utils::create_test_rw_db;
//! # use reth_downloaders::bodies::bodies::BodiesDownloaderBuilder;
//! # use reth_downloaders::headers::reverse_headers::ReverseHeadersDownloaderBuilder;
//! # use reth_interfaces::consensus::Consensus;
@ -25,6 +24,7 @@
//! # use tokio::sync::watch;
//! # use reth_provider::ProviderFactory;
//! # use reth_provider::HeaderSyncMode;
//! # use reth_provider::test_utils::create_test_provider_factory;
//! #
//! # let chain_spec = MAINNET.clone();
//! # let consensus: Arc<dyn Consensus> = Arc::new(TestConsensus::default());
@ -32,11 +32,11 @@
//! # Arc::new(TestHeadersClient::default()),
//! # consensus.clone()
//! # );
//! # let db = create_test_rw_db();
//! # let provider_factory = create_test_provider_factory();
//! # let bodies_downloader = BodiesDownloaderBuilder::default().build(
//! # Arc::new(TestBodiesClient { responder: |_| Ok((PeerId::ZERO, vec![]).into()) }),
//! # consensus.clone(),
//! # ProviderFactory::new(db.clone(), MAINNET.clone())
//! # provider_factory.clone()
//! # );
//! # let (tip_tx, tip_rx) = watch::channel(B256::default());
//! # let factory = Factory::new(chain_spec.clone());
@ -45,14 +45,14 @@
//! Pipeline::builder()
//! .with_tip_sender(tip_tx)
//! .add_stages(DefaultStages::new(
//! ProviderFactory::new(db.clone(), chain_spec.clone()),
//! provider_factory.clone(),
//! HeaderSyncMode::Tip(tip_rx),
//! consensus,
//! headers_downloader,
//! bodies_downloader,
//! factory,
//! ))
//! .build(db, chain_spec.clone());
//! .build(provider_factory);
//! ```
//!
//! ## Feature Flags

View File

@ -1,8 +1,7 @@
use std::sync::Arc;
use crate::{pipeline::BoxedStage, MetricEventsSender, Pipeline, Stage, StageSet};
use reth_db::database::Database;
use reth_primitives::{stage::StageId, BlockNumber, ChainSpec, B256};
use reth_primitives::{stage::StageId, BlockNumber, B256};
use reth_provider::ProviderFactory;
use tokio::sync::watch;
/// Builds a [`Pipeline`].
@ -68,13 +67,10 @@ where
}
/// Builds the final [`Pipeline`] using the given database.
///
/// Note: it's expected that this is either an [Arc] or an Arc wrapper type.
pub fn build(self, db: DB, chain_spec: Arc<ChainSpec>) -> Pipeline<DB> {
pub fn build(self, provider_factory: ProviderFactory<DB>) -> Pipeline<DB> {
let Self { stages, max_block, tip_tx, metrics_tx } = self;
Pipeline {
db,
chain_spec,
provider_factory,
stages,
max_block,
tip_tx,

View File

@ -7,11 +7,11 @@ use reth_db::database::Database;
use reth_primitives::{
constants::BEACON_CONSENSUS_REORG_UNWIND_DEPTH,
stage::{StageCheckpoint, StageId},
BlockNumber, ChainSpec, B256,
BlockNumber, B256,
};
use reth_provider::{ProviderFactory, StageCheckpointReader, StageCheckpointWriter};
use reth_tokio_util::EventListeners;
use std::{pin::Pin, sync::Arc};
use std::pin::Pin;
use tokio::sync::watch;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::*;
@ -93,10 +93,8 @@ pub type PipelineWithResult<DB> = (Pipeline<DB>, Result<ControlFlow, PipelineErr
///
/// The [DefaultStages](crate::sets::DefaultStages) are used to fully sync reth.
pub struct Pipeline<DB: Database> {
/// The Database
db: DB,
/// Chain spec
chain_spec: Arc<ChainSpec>,
/// Provider factory.
provider_factory: ProviderFactory<DB>,
/// All configured stages in the order they will be executed.
stages: Vec<BoxedStage<DB>>,
/// The maximum block number to sync to.
@ -141,8 +139,7 @@ where
/// Registers progress metrics for each registered stage
pub fn register_metrics(&mut self) -> Result<(), PipelineError> {
let Some(metrics_tx) = &mut self.metrics_tx else { return Ok(()) };
let factory = ProviderFactory::new(&self.db, self.chain_spec.clone());
let provider = factory.provider()?;
let provider = self.provider_factory.provider()?;
for stage in &self.stages {
let stage_id = stage.id();
@ -236,10 +233,8 @@ where
}
}
let factory = ProviderFactory::new(&self.db, self.chain_spec.clone());
previous_stage = Some(
factory
self.provider_factory
.provider()?
.get_stage_checkpoint(stage_id)?
.unwrap_or_default()
@ -261,8 +256,7 @@ where
// Unwind stages in reverse order of execution
let unwind_pipeline = self.stages.iter_mut().rev();
let factory = ProviderFactory::new(&self.db, self.chain_spec.clone());
let mut provider_rw = factory.provider_rw()?;
let mut provider_rw = self.provider_factory.provider_rw()?;
for stage in unwind_pipeline {
let stage_id = stage.id();
@ -319,7 +313,7 @@ where
.notify(PipelineEvent::Unwound { stage_id, result: unwind_output });
provider_rw.commit()?;
provider_rw = factory.provider_rw()?;
provider_rw = self.provider_factory.provider_rw()?;
}
Err(err) => {
self.listeners.notify(PipelineEvent::Error { stage_id });
@ -344,10 +338,8 @@ where
let mut made_progress = false;
let target = self.max_block.or(previous_stage);
let factory = ProviderFactory::new(&self.db, self.chain_spec.clone());
loop {
let prev_checkpoint = factory.get_stage_checkpoint(stage_id)?;
let prev_checkpoint = self.provider_factory.get_stage_checkpoint(stage_id)?;
let stage_reached_max_block = prev_checkpoint
.zip(self.max_block)
@ -372,7 +364,7 @@ where
if let Err(err) = stage.execute_ready(exec_input).await {
self.listeners.notify(PipelineEvent::Error { stage_id });
match on_stage_error(&factory, stage_id, prev_checkpoint, err)? {
match on_stage_error(&self.provider_factory, stage_id, prev_checkpoint, err)? {
Some(ctrl) => return Ok(ctrl),
None => continue,
};
@ -388,7 +380,7 @@ where
target,
});
let provider_rw = factory.provider_rw()?;
let provider_rw = self.provider_factory.provider_rw()?;
match stage.execute(&provider_rw, exec_input) {
Ok(out @ ExecOutput { checkpoint, done }) => {
made_progress |=
@ -426,7 +418,9 @@ where
Err(err) => {
drop(provider_rw);
self.listeners.notify(PipelineEvent::Error { stage_id });
if let Some(ctrl) = on_stage_error(&factory, stage_id, prev_checkpoint, err)? {
if let Some(ctrl) =
on_stage_error(&self.provider_factory, stage_id, prev_checkpoint, err)?
{
return Ok(ctrl)
}
}
@ -526,13 +520,13 @@ mod tests {
use super::*;
use crate::{test_utils::TestStage, UnwindOutput};
use assert_matches::assert_matches;
use reth_db::test_utils::create_test_rw_db;
use reth_interfaces::{
consensus,
provider::ProviderError,
test_utils::{generators, generators::random_header},
};
use reth_primitives::{stage::StageCheckpoint, MAINNET};
use reth_primitives::stage::StageCheckpoint;
use reth_provider::test_utils::create_test_provider_factory;
use tokio_stream::StreamExt;
#[test]
@ -565,7 +559,7 @@ mod tests {
/// Runs a simple pipeline.
#[tokio::test]
async fn run_pipeline() {
let db = create_test_rw_db();
let provider_factory = create_test_provider_factory();
let mut pipeline = Pipeline::builder()
.add_stage(
@ -577,7 +571,7 @@ mod tests {
.add_exec(Ok(ExecOutput { checkpoint: StageCheckpoint::new(10), done: true })),
)
.with_max_block(10)
.build(db, MAINNET.clone());
.build(provider_factory);
let events = pipeline.events();
// Run pipeline
@ -618,7 +612,7 @@ mod tests {
/// Unwinds a simple pipeline.
#[tokio::test]
async fn unwind_pipeline() {
let db = create_test_rw_db();
let provider_factory = create_test_provider_factory();
let mut pipeline = Pipeline::builder()
.add_stage(
@ -637,7 +631,7 @@ mod tests {
.add_unwind(Ok(UnwindOutput { checkpoint: StageCheckpoint::new(1) })),
)
.with_max_block(10)
.build(db, MAINNET.clone());
.build(provider_factory);
let events = pipeline.events();
// Run pipeline
@ -731,7 +725,7 @@ mod tests {
/// Unwinds a pipeline with intermediate progress.
#[tokio::test]
async fn unwind_pipeline_with_intermediate_progress() {
let db = create_test_rw_db();
let provider_factory = create_test_provider_factory();
let mut pipeline = Pipeline::builder()
.add_stage(
@ -744,7 +738,7 @@ mod tests {
.add_exec(Ok(ExecOutput { checkpoint: StageCheckpoint::new(10), done: true })),
)
.with_max_block(10)
.build(db, MAINNET.clone());
.build(provider_factory);
let events = pipeline.events();
// Run pipeline
@ -816,7 +810,7 @@ mod tests {
/// - The pipeline finishes
#[tokio::test]
async fn run_pipeline_with_unwind() {
let db = create_test_rw_db();
let provider_factory = create_test_provider_factory();
let mut pipeline = Pipeline::builder()
.add_stage(
@ -841,7 +835,7 @@ mod tests {
.add_exec(Ok(ExecOutput { checkpoint: StageCheckpoint::new(10), done: true })),
)
.with_max_block(10)
.build(db, MAINNET.clone());
.build(provider_factory);
let events = pipeline.events();
// Run pipeline
@ -913,7 +907,7 @@ mod tests {
#[tokio::test]
async fn pipeline_error_handling() {
// Non-fatal
let db = create_test_rw_db();
let provider_factory = create_test_provider_factory();
let mut pipeline = Pipeline::builder()
.add_stage(
TestStage::new(StageId::Other("NonFatal"))
@ -921,17 +915,17 @@ mod tests {
.add_exec(Ok(ExecOutput { checkpoint: StageCheckpoint::new(10), done: true })),
)
.with_max_block(10)
.build(db, MAINNET.clone());
.build(provider_factory);
let result = pipeline.run().await;
assert_matches!(result, Ok(()));
// Fatal
let db = create_test_rw_db();
let provider_factory = create_test_provider_factory();
let mut pipeline = Pipeline::builder()
.add_stage(TestStage::new(StageId::Other("Fatal")).add_exec(Err(
StageError::DatabaseIntegrity(ProviderError::BlockBodyIndicesNotFound(5)),
)))
.build(db, MAINNET.clone());
.build(provider_factory);
let result = pipeline.run().await;
assert_matches!(
result,

View File

@ -14,13 +14,12 @@
//! # use reth_stages::sets::{OfflineStages};
//! # use reth_revm::Factory;
//! # use reth_primitives::MAINNET;
//! use reth_db::test_utils::create_test_rw_db;
//! # use reth_provider::test_utils::create_test_provider_factory;
//!
//! # let factory = Factory::new(MAINNET.clone());
//! # let db = create_test_rw_db();
//! # let provider_factory = create_test_provider_factory();
//! // Build a pipeline with all offline stages.
//! # let pipeline =
//! Pipeline::builder().add_stages(OfflineStages::new(factory)).build(db, MAINNET.clone());
//! # let pipeline = Pipeline::builder().add_stages(OfflineStages::new(factory)).build(provider_factory);
//! ```
//!
//! ```ignore

View File

@ -234,14 +234,14 @@ pub trait Stage<DB: Database>: Send + Sync {
/// upon invoking this method.
fn execute(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError>;
/// Unwind the stage.
fn unwind(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError>;
}

View File

@ -98,7 +98,7 @@ impl<DB: Database, D: BodyDownloader> Stage<DB> for BodyStage<D> {
/// header, limited by the stage's batch size.
fn execute(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
if input.target_reached() {
@ -185,7 +185,7 @@ impl<DB: Database, D: BodyDownloader> Stage<DB> for BodyStage<D> {
/// Unwind the stage.
fn unwind(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
self.buffer.take();
@ -440,7 +440,7 @@ mod tests {
// Delete a transaction
runner
.tx()
.db()
.commit(|tx| {
let mut tx_cursor = tx.cursor_write::<tables::Transactions>()?;
tx_cursor.last()?.expect("Could not read last transaction");
@ -471,7 +471,7 @@ mod tests {
use crate::{
stages::bodies::BodyStage,
test_utils::{
ExecuteStageTestRunner, StageTestRunner, TestRunnerError, TestTransaction,
ExecuteStageTestRunner, StageTestRunner, TestRunnerError, TestStageDB,
UnwindStageTestRunner,
},
ExecInput, ExecOutput, UnwindInput,
@ -479,12 +479,11 @@ mod tests {
use futures_util::Stream;
use reth_db::{
cursor::DbCursorRO,
database::Database,
models::{StoredBlockBodyIndices, StoredBlockOmmers},
tables,
test_utils::TempDatabase,
transaction::{DbTx, DbTxMut},
DatabaseEnv, DatabaseError,
DatabaseEnv,
};
use reth_interfaces::{
p2p::{
@ -494,7 +493,7 @@ mod tests {
response::BlockResponse,
},
download::DownloadClient,
error::{DownloadError, DownloadResult},
error::DownloadResult,
priority::Priority,
},
test_utils::{
@ -503,6 +502,7 @@ mod tests {
},
};
use reth_primitives::{BlockBody, BlockNumber, SealedBlock, SealedHeader, TxNumber, B256};
use reth_provider::ProviderFactory;
use std::{
collections::{HashMap, VecDeque},
ops::RangeInclusive,
@ -529,17 +529,13 @@ mod tests {
/// A helper struct for running the [BodyStage].
pub(crate) struct BodyTestRunner {
responses: HashMap<B256, BlockBody>,
tx: TestTransaction,
db: TestStageDB,
batch_size: u64,
}
impl Default for BodyTestRunner {
fn default() -> Self {
Self {
responses: HashMap::default(),
tx: TestTransaction::default(),
batch_size: 1000,
}
Self { responses: HashMap::default(), db: TestStageDB::default(), batch_size: 1000 }
}
}
@ -556,13 +552,13 @@ mod tests {
impl StageTestRunner for BodyTestRunner {
type S = BodyStage<TestBodyDownloader>;
fn tx(&self) -> &TestTransaction {
&self.tx
fn db(&self) -> &TestStageDB {
&self.db
}
fn stage(&self) -> Self::S {
BodyStage::new(TestBodyDownloader::new(
self.tx.inner_raw(),
self.db.factory.clone(),
self.responses.clone(),
self.batch_size,
))
@ -578,10 +574,10 @@ mod tests {
let end = input.target();
let mut rng = generators::rng();
let blocks = random_block_range(&mut rng, start..=end, GENESIS_HASH, 0..2);
self.tx.insert_headers_with_td(blocks.iter().map(|block| &block.header))?;
self.db.insert_headers_with_td(blocks.iter().map(|block| &block.header))?;
if let Some(progress) = blocks.first() {
// Insert last progress data
self.tx.commit(|tx| {
self.db.commit(|tx| {
let body = StoredBlockBodyIndices {
first_tx_num: 0,
tx_count: progress.body.len() as u64,
@ -629,16 +625,16 @@ mod tests {
impl UnwindStageTestRunner for BodyTestRunner {
fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
self.tx.ensure_no_entry_above::<tables::BlockBodyIndices, _>(
self.db.ensure_no_entry_above::<tables::BlockBodyIndices, _>(
input.unwind_to,
|key| key,
)?;
self.tx
self.db
.ensure_no_entry_above::<tables::BlockOmmers, _>(input.unwind_to, |key| key)?;
if let Some(last_tx_id) = self.get_last_tx_id()? {
self.tx
self.db
.ensure_no_entry_above::<tables::Transactions, _>(last_tx_id, |key| key)?;
self.tx.ensure_no_entry_above::<tables::TransactionBlock, _>(
self.db.ensure_no_entry_above::<tables::TransactionBlock, _>(
last_tx_id,
|key| key,
)?;
@ -650,7 +646,7 @@ mod tests {
impl BodyTestRunner {
/// Get the last available tx id if any
pub(crate) fn get_last_tx_id(&self) -> Result<Option<TxNumber>, TestRunnerError> {
let last_body = self.tx.query(|tx| {
let last_body = self.db.query(|tx| {
let v = tx.cursor_read::<tables::BlockBodyIndices>()?.last()?;
Ok(v)
})?;
@ -668,7 +664,7 @@ mod tests {
prev_progress: BlockNumber,
highest_block: BlockNumber,
) -> Result<(), TestRunnerError> {
self.tx.query(|tx| {
self.db.query(|tx| {
// Acquire cursors on body related tables
let mut headers_cursor = tx.cursor_read::<tables::Headers>()?;
let mut bodies_cursor = tx.cursor_read::<tables::BlockBodyIndices>()?;
@ -759,7 +755,7 @@ mod tests {
/// A [BodyDownloader] that is backed by an internal [HashMap] for testing.
#[derive(Debug)]
pub(crate) struct TestBodyDownloader {
db: Arc<TempDatabase<DatabaseEnv>>,
provider_factory: ProviderFactory<Arc<TempDatabase<DatabaseEnv>>>,
responses: HashMap<B256, BlockBody>,
headers: VecDeque<SealedHeader>,
batch_size: u64,
@ -767,11 +763,11 @@ mod tests {
impl TestBodyDownloader {
pub(crate) fn new(
db: Arc<TempDatabase<DatabaseEnv>>,
provider_factory: ProviderFactory<Arc<TempDatabase<DatabaseEnv>>>,
responses: HashMap<B256, BlockBody>,
batch_size: u64,
) -> Self {
Self { db, responses, headers: VecDeque::default(), batch_size }
Self { provider_factory, responses, headers: VecDeque::default(), batch_size }
}
}
@ -780,27 +776,19 @@ mod tests {
&mut self,
range: RangeInclusive<BlockNumber>,
) -> DownloadResult<()> {
self.headers = VecDeque::from(
self.db
.view(|tx| -> Result<Vec<SealedHeader>, DatabaseError> {
let mut header_cursor = tx.cursor_read::<tables::Headers>()?;
let provider = self.provider_factory.provider()?;
let mut header_cursor = provider.tx_ref().cursor_read::<tables::Headers>()?;
let mut canonical_cursor =
tx.cursor_read::<tables::CanonicalHeaders>()?;
let walker = canonical_cursor.walk_range(range)?;
let mut canonical_cursor =
provider.tx_ref().cursor_read::<tables::CanonicalHeaders>()?;
let walker = canonical_cursor.walk_range(range)?;
for entry in walker {
let (num, hash) = entry?;
let (_, header) = header_cursor.seek_exact(num)?.expect("missing header");
self.headers.push_back(header.seal(hash));
}
let mut headers = Vec::default();
for entry in walker {
let (num, hash) = entry?;
let (_, header) =
header_cursor.seek_exact(num)?.expect("missing header");
headers.push(header.seal(hash));
}
Ok(headers)
})
.map_err(|err| DownloadError::Provider(err.into()))?
.map_err(|err| DownloadError::Provider(err.into()))?,
);
Ok(())
}
}

View File

@ -110,7 +110,7 @@ impl<EF: ExecutorFactory> ExecutionStage<EF> {
/// Execute the stage.
pub fn execute_inner<DB: Database>(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
if input.target_reached() {
@ -228,7 +228,7 @@ impl<EF: ExecutorFactory> ExecutionStage<EF> {
/// been previously executed.
fn adjust_prune_modes<DB: Database>(
&self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
start_block: u64,
max_block: u64,
) -> Result<PruneModes, StageError> {
@ -247,7 +247,7 @@ impl<EF: ExecutorFactory> ExecutionStage<EF> {
}
fn execution_checkpoint<DB: Database>(
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
start_block: BlockNumber,
max_block: BlockNumber,
checkpoint: StageCheckpoint,
@ -314,7 +314,7 @@ fn execution_checkpoint<DB: Database>(
}
fn calculate_gas_used_from_headers<DB: Database>(
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
range: RangeInclusive<BlockNumber>,
) -> Result<u64, DatabaseError> {
let mut gas_total = 0;
@ -340,7 +340,7 @@ impl<EF: ExecutorFactory, DB: Database> Stage<DB> for ExecutionStage<EF> {
/// Execute the stage
fn execute(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
self.execute_inner(provider, input)
@ -349,7 +349,7 @@ impl<EF: ExecutorFactory, DB: Database> Stage<DB> for ExecutionStage<EF> {
/// Unwind the stage.
fn unwind(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
let tx = provider.tx_ref();
@ -491,7 +491,7 @@ impl ExecutionStageThresholds {
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::TestTransaction;
use crate::test_utils::TestStageDB;
use alloy_rlp::Decodable;
use assert_matches::assert_matches;
use reth_db::{models::AccountBeforeTx, test_utils::create_test_rw_db};
@ -826,9 +826,8 @@ mod tests {
#[tokio::test]
async fn test_selfdestruct() {
let test_tx = TestTransaction::default();
let factory = ProviderFactory::new(test_tx.tx.as_ref(), MAINNET.clone());
let provider = factory.provider_rw().unwrap();
let test_db = TestStageDB::default();
let provider = test_db.factory.provider_rw().unwrap();
let input = ExecInput { target: Some(1), checkpoint: None };
let mut genesis_rlp = hex!("f901f8f901f3a00000000000000000000000000000000000000000000000000000000000000000a01dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347942adc25665018aa1fe0e6bc666dac8fc2697ff9baa0c9ceb8372c88cb461724d8d3d87e8b933f6fc5f679d4841800e662f4428ffd0da056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421a056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421b90100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000008302000080830f4240808000a00000000000000000000000000000000000000000000000000000000000000000880000000000000000c0c0").as_slice();
let genesis = SealedBlock::decode(&mut genesis_rlp).unwrap();
@ -853,7 +852,7 @@ mod tests {
Account { nonce: 0, balance: U256::ZERO, bytecode_hash: Some(code_hash) };
// set account
let provider = factory.provider_rw().unwrap();
let provider = test_db.factory.provider_rw().unwrap();
provider.tx_ref().put::<tables::PlainAccountState>(caller_address, caller_info).unwrap();
provider
.tx_ref()
@ -882,13 +881,13 @@ mod tests {
provider.commit().unwrap();
// execute
let provider = factory.provider_rw().unwrap();
let provider = test_db.factory.provider_rw().unwrap();
let mut execution_stage = stage();
let _ = execution_stage.execute(&provider, input).unwrap();
provider.commit().unwrap();
// assert unwind stage
let provider = factory.provider_rw().unwrap();
let provider = test_db.factory.provider_rw().unwrap();
assert_eq!(provider.basic_account(destroyed_address), Ok(None), "Account was destroyed");
assert_eq!(
@ -898,8 +897,8 @@ mod tests {
);
// drops tx so that it returns write privilege to test_tx
drop(provider);
let plain_accounts = test_tx.table::<tables::PlainAccountState>().unwrap();
let plain_storage = test_tx.table::<tables::PlainStorageState>().unwrap();
let plain_accounts = test_db.table::<tables::PlainAccountState>().unwrap();
let plain_storage = test_db.table::<tables::PlainStorageState>().unwrap();
assert_eq!(
plain_accounts,
@ -924,8 +923,8 @@ mod tests {
);
assert!(plain_storage.is_empty());
let account_changesets = test_tx.table::<tables::AccountChangeSet>().unwrap();
let storage_changesets = test_tx.table::<tables::StorageChangeSet>().unwrap();
let account_changesets = test_db.table::<tables::AccountChangeSet>().unwrap();
let storage_changesets = test_db.table::<tables::StorageChangeSet>().unwrap();
assert_eq!(
account_changesets,

View File

@ -18,7 +18,7 @@ impl<DB: Database> Stage<DB> for FinishStage {
fn execute(
&mut self,
_provider: &DatabaseProviderRW<&DB>,
_provider: &DatabaseProviderRW<DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
Ok(ExecOutput { checkpoint: StageCheckpoint::new(input.target()), done: true })
@ -26,7 +26,7 @@ impl<DB: Database> Stage<DB> for FinishStage {
fn unwind(
&mut self,
_provider: &DatabaseProviderRW<&DB>,
_provider: &DatabaseProviderRW<DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
Ok(UnwindOutput { checkpoint: StageCheckpoint::new(input.unwind_to) })
@ -38,7 +38,7 @@ mod tests {
use super::*;
use crate::test_utils::{
stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, TestRunnerError,
TestTransaction, UnwindStageTestRunner,
TestStageDB, UnwindStageTestRunner,
};
use reth_interfaces::test_utils::{
generators,
@ -50,14 +50,14 @@ mod tests {
#[derive(Default)]
struct FinishTestRunner {
tx: TestTransaction,
db: TestStageDB,
}
impl StageTestRunner for FinishTestRunner {
type S = FinishStage;
fn tx(&self) -> &TestTransaction {
&self.tx
fn db(&self) -> &TestStageDB {
&self.db
}
fn stage(&self) -> Self::S {
@ -72,7 +72,7 @@ mod tests {
let start = input.checkpoint().block_number;
let mut rng = generators::rng();
let head = random_header(&mut rng, start, None);
self.tx.insert_headers_with_td(std::iter::once(&head))?;
self.db.insert_headers_with_td(std::iter::once(&head))?;
// use previous progress as seed size
let end = input.target.unwrap_or_default() + 1;
@ -82,7 +82,7 @@ mod tests {
}
let mut headers = random_header_range(&mut rng, start + 1..end, head.hash());
self.tx.insert_headers_with_td(headers.iter())?;
self.db.insert_headers_with_td(headers.iter())?;
headers.insert(0, head);
Ok(headers)
}

View File

@ -134,7 +134,7 @@ impl<DB: Database> Stage<DB> for AccountHashingStage {
/// Execute the stage.
fn execute(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
if input.target_reached() {
@ -266,7 +266,7 @@ impl<DB: Database> Stage<DB> for AccountHashingStage {
/// Unwind the stage.
fn unwind(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
let (range, unwind_progress, _) =
@ -288,7 +288,7 @@ impl<DB: Database> Stage<DB> for AccountHashingStage {
}
fn stage_checkpoint_progress<DB: Database>(
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
) -> Result<EntitiesCheckpoint, DatabaseError> {
Ok(EntitiesCheckpoint {
processed: provider.tx_ref().entries::<tables::HashedAccount>()? as u64,
@ -341,7 +341,7 @@ mod tests {
done: true,
}) if block_number == previous_stage &&
processed == total &&
total == runner.tx.table::<tables::PlainAccountState>().unwrap().len() as u64
total == runner.db.table::<tables::PlainAccountState>().unwrap().len() as u64
);
// Validate the stage execution
@ -368,7 +368,7 @@ mod tests {
let result = rx.await.unwrap();
let fifth_address = runner
.tx
.db
.query(|tx| {
let (address, _) = tx
.cursor_read::<tables::PlainAccountState>()?
@ -398,9 +398,9 @@ mod tests {
},
done: false
}) if address == fifth_address &&
total == runner.tx.table::<tables::PlainAccountState>().unwrap().len() as u64
total == runner.db.table::<tables::PlainAccountState>().unwrap().len() as u64
);
assert_eq!(runner.tx.table::<tables::HashedAccount>().unwrap().len(), 5);
assert_eq!(runner.db.table::<tables::HashedAccount>().unwrap().len(), 5);
// second run, hash next five accounts.
input.checkpoint = Some(result.unwrap().checkpoint);
@ -425,9 +425,9 @@ mod tests {
},
done: true
}) if processed == total &&
total == runner.tx.table::<tables::PlainAccountState>().unwrap().len() as u64
total == runner.db.table::<tables::PlainAccountState>().unwrap().len() as u64
);
assert_eq!(runner.tx.table::<tables::HashedAccount>().unwrap().len(), 10);
assert_eq!(runner.db.table::<tables::HashedAccount>().unwrap().len(), 10);
// Validate the stage execution
assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation");
@ -437,14 +437,14 @@ mod tests {
use super::*;
use crate::{
stages::hashing_account::AccountHashingStage,
test_utils::{StageTestRunner, TestTransaction},
test_utils::{StageTestRunner, TestStageDB},
ExecInput, ExecOutput, UnwindInput,
};
use reth_db::{cursor::DbCursorRO, tables, transaction::DbTx};
use reth_primitives::Address;
pub(crate) struct AccountHashingTestRunner {
pub(crate) tx: TestTransaction,
pub(crate) db: TestStageDB,
commit_threshold: u64,
clean_threshold: u64,
}
@ -462,7 +462,7 @@ mod tests {
/// Iterates over PlainAccount table and checks that the accounts match the ones
/// in the HashedAccount table
pub(crate) fn check_hashed_accounts(&self) -> Result<(), TestRunnerError> {
self.tx.query(|tx| {
self.db.query(|tx| {
let mut acc_cursor = tx.cursor_read::<tables::PlainAccountState>()?;
let mut hashed_acc_cursor = tx.cursor_read::<tables::HashedAccount>()?;
@ -481,7 +481,7 @@ mod tests {
/// Same as check_hashed_accounts, only that checks with the old account state,
/// namely, the same account with nonce - 1 and balance - 1.
pub(crate) fn check_old_hashed_accounts(&self) -> Result<(), TestRunnerError> {
self.tx.query(|tx| {
self.db.query(|tx| {
let mut acc_cursor = tx.cursor_read::<tables::PlainAccountState>()?;
let mut hashed_acc_cursor = tx.cursor_read::<tables::HashedAccount>()?;
@ -506,19 +506,15 @@ mod tests {
impl Default for AccountHashingTestRunner {
fn default() -> Self {
Self {
tx: TestTransaction::default(),
commit_threshold: 1000,
clean_threshold: 1000,
}
Self { db: TestStageDB::default(), commit_threshold: 1000, clean_threshold: 1000 }
}
}
impl StageTestRunner for AccountHashingTestRunner {
type S = AccountHashingStage;
fn tx(&self) -> &TestTransaction {
&self.tx
fn db(&self) -> &TestStageDB {
&self.db
}
fn stage(&self) -> Self::S {
@ -534,7 +530,7 @@ mod tests {
type Seed = Vec<(Address, Account)>;
fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
let provider = self.tx.inner_rw();
let provider = self.db.factory.provider_rw()?;
let res = Ok(AccountHashingStage::seed(
&provider,
SeedOpts { blocks: 1..=input.target(), accounts: 0..10, txs: 0..3 },

View File

@ -53,7 +53,7 @@ impl<DB: Database> Stage<DB> for StorageHashingStage {
/// Execute the stage.
fn execute(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
let tx = provider.tx_ref();
@ -192,7 +192,7 @@ impl<DB: Database> Stage<DB> for StorageHashingStage {
/// Unwind the stage.
fn unwind(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
let (range, unwind_progress, _) =
@ -213,7 +213,7 @@ impl<DB: Database> Stage<DB> for StorageHashingStage {
}
fn stage_checkpoint_progress<DB: Database>(
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
) -> Result<EntitiesCheckpoint, DatabaseError> {
Ok(EntitiesCheckpoint {
processed: provider.tx_ref().entries::<tables::HashedStorage>()? as u64,
@ -226,7 +226,7 @@ mod tests {
use super::*;
use crate::test_utils::{
stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, TestRunnerError,
TestTransaction, UnwindStageTestRunner,
TestStageDB, UnwindStageTestRunner,
};
use assert_matches::assert_matches;
use rand::Rng;
@ -282,7 +282,7 @@ mod tests {
},
..
}) if processed == previous_checkpoint.progress.processed + 1 &&
total == runner.tx.table::<tables::PlainStorageState>().unwrap().len() as u64);
total == runner.db.table::<tables::PlainStorageState>().unwrap().len() as u64);
// Continue from checkpoint
input.checkpoint = Some(checkpoint);
@ -296,7 +296,7 @@ mod tests {
},
..
}) if processed == total &&
total == runner.tx.table::<tables::PlainStorageState>().unwrap().len() as u64);
total == runner.db.table::<tables::PlainStorageState>().unwrap().len() as u64);
// Validate the stage execution
assert!(
@ -331,7 +331,7 @@ mod tests {
let result = rx.await.unwrap();
let (progress_address, progress_key) = runner
.tx
.db
.query(|tx| {
let (address, entry) = tx
.cursor_read::<tables::PlainStorageState>()?
@ -363,9 +363,9 @@ mod tests {
},
done: false
}) if address == progress_address && storage == progress_key &&
total == runner.tx.table::<tables::PlainStorageState>().unwrap().len() as u64
total == runner.db.table::<tables::PlainStorageState>().unwrap().len() as u64
);
assert_eq!(runner.tx.table::<tables::HashedStorage>().unwrap().len(), 500);
assert_eq!(runner.db.table::<tables::HashedStorage>().unwrap().len(), 500);
// second run with commit threshold of 2 to check if subkey is set.
runner.set_commit_threshold(2);
@ -375,7 +375,7 @@ mod tests {
let result = rx.await.unwrap();
let (progress_address, progress_key) = runner
.tx
.db
.query(|tx| {
let (address, entry) = tx
.cursor_read::<tables::PlainStorageState>()?
@ -409,9 +409,9 @@ mod tests {
},
done: false
}) if address == progress_address && storage == progress_key &&
total == runner.tx.table::<tables::PlainStorageState>().unwrap().len() as u64
total == runner.db.table::<tables::PlainStorageState>().unwrap().len() as u64
);
assert_eq!(runner.tx.table::<tables::HashedStorage>().unwrap().len(), 502);
assert_eq!(runner.db.table::<tables::HashedStorage>().unwrap().len(), 502);
// third last run, hash rest of storages.
runner.set_commit_threshold(1000);
@ -441,11 +441,11 @@ mod tests {
},
done: true
}) if processed == total &&
total == runner.tx.table::<tables::PlainStorageState>().unwrap().len() as u64
total == runner.db.table::<tables::PlainStorageState>().unwrap().len() as u64
);
assert_eq!(
runner.tx.table::<tables::HashedStorage>().unwrap().len(),
runner.tx.table::<tables::PlainStorageState>().unwrap().len()
runner.db.table::<tables::HashedStorage>().unwrap().len(),
runner.db.table::<tables::PlainStorageState>().unwrap().len()
);
// Validate the stage execution
@ -453,22 +453,22 @@ mod tests {
}
struct StorageHashingTestRunner {
tx: TestTransaction,
db: TestStageDB,
commit_threshold: u64,
clean_threshold: u64,
}
impl Default for StorageHashingTestRunner {
fn default() -> Self {
Self { tx: TestTransaction::default(), commit_threshold: 1000, clean_threshold: 1000 }
Self { db: TestStageDB::default(), commit_threshold: 1000, clean_threshold: 1000 }
}
}
impl StageTestRunner for StorageHashingTestRunner {
type S = StorageHashingStage;
fn tx(&self) -> &TestTransaction {
&self.tx
fn db(&self) -> &TestStageDB {
&self.db
}
fn stage(&self) -> Self::S {
@ -493,7 +493,7 @@ mod tests {
let blocks = random_block_range(&mut rng, stage_progress..=end, B256::ZERO, 0..3);
self.tx.insert_headers(blocks.iter().map(|block| &block.header))?;
self.db.insert_headers(blocks.iter().map(|block| &block.header))?;
let iter = blocks.iter();
let mut next_tx_num = 0;
@ -501,7 +501,7 @@ mod tests {
for progress in iter {
// Insert last progress data
let block_number = progress.number;
self.tx.commit(|tx| {
self.db.commit(|tx| {
progress.body.iter().try_for_each(
|transaction| -> Result<(), reth_db::DatabaseError> {
tx.put::<tables::TxHashNumber>(transaction.hash(), next_tx_num)?;
@ -552,7 +552,8 @@ mod tests {
first_tx_num = next_tx_num;
tx.put::<tables::BlockBodyIndices>(progress.number, body)
tx.put::<tables::BlockBodyIndices>(progress.number, body)?;
Ok(())
})?;
}
@ -592,7 +593,7 @@ mod tests {
}
fn check_hashed_storage(&self) -> Result<(), TestRunnerError> {
self.tx
self.db
.query(|tx| {
let mut storage_cursor = tx.cursor_dup_read::<tables::PlainStorageState>()?;
let mut hashed_storage_cursor =
@ -661,7 +662,7 @@ mod tests {
fn unwind_storage(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
tracing::debug!("unwinding storage...");
let target_block = input.unwind_to;
self.tx.commit(|tx| {
self.db.commit(|tx| {
let mut storage_cursor = tx.cursor_dup_write::<tables::PlainStorageState>()?;
let mut changeset_cursor = tx.cursor_dup_read::<tables::StorageChangeSet>()?;

View File

@ -176,7 +176,7 @@ where
/// starting from the tip of the chain
fn execute(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
let current_checkpoint = input.checkpoint();
@ -279,7 +279,7 @@ where
/// Unwind the stage.
fn unwind(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
self.buffer.take();
@ -326,7 +326,7 @@ mod tests {
mod test_runner {
use super::*;
use crate::test_utils::{TestRunnerError, TestTransaction};
use crate::test_utils::{TestRunnerError, TestStageDB};
use reth_db::{test_utils::TempDatabase, DatabaseEnv};
use reth_downloaders::headers::reverse_headers::{
ReverseHeadersDownloader, ReverseHeadersDownloaderBuilder,
@ -344,7 +344,7 @@ mod tests {
pub(crate) client: TestHeadersClient,
channel: (watch::Sender<B256>, watch::Receiver<B256>),
downloader_factory: Box<dyn Fn() -> D + Send + Sync + 'static>,
tx: TestTransaction,
db: TestStageDB,
}
impl Default for HeadersTestRunner<TestHeaderDownloader> {
@ -361,7 +361,7 @@ mod tests {
1000,
)
}),
tx: TestTransaction::default(),
db: TestStageDB::default(),
}
}
}
@ -369,13 +369,13 @@ mod tests {
impl<D: HeaderDownloader + 'static> StageTestRunner for HeadersTestRunner<D> {
type S = HeaderStage<ProviderFactory<Arc<TempDatabase<DatabaseEnv>>>, D>;
fn tx(&self) -> &TestTransaction {
&self.tx
fn db(&self) -> &TestStageDB {
&self.db
}
fn stage(&self) -> Self::S {
HeaderStage::new(
self.tx.factory.clone(),
self.db.factory.clone(),
(*self.downloader_factory)(),
HeaderSyncMode::Tip(self.channel.1.clone()),
)
@ -390,9 +390,10 @@ mod tests {
let mut rng = generators::rng();
let start = input.checkpoint().block_number;
let head = random_header(&mut rng, start, None);
self.tx.insert_headers(std::iter::once(&head))?;
self.db.insert_headers(std::iter::once(&head))?;
// patch td table for `update_head` call
self.tx.commit(|tx| tx.put::<tables::HeaderTD>(head.number, U256::ZERO.into()))?;
self.db
.commit(|tx| Ok(tx.put::<tables::HeaderTD>(head.number, U256::ZERO.into())?))?;
// use previous checkpoint as seed size
let end = input.target.unwrap_or_default() + 1;
@ -415,7 +416,7 @@ mod tests {
let initial_checkpoint = input.checkpoint().block_number;
match output {
Some(output) if output.checkpoint.block_number > initial_checkpoint => {
let provider = self.tx.factory.provider()?;
let provider = self.db.factory.provider()?;
for block_num in (initial_checkpoint..output.checkpoint.block_number).rev()
{
// look up the header hash
@ -442,7 +443,7 @@ mod tests {
headers.last().unwrap().hash()
} else {
let tip = random_header(&mut generators::rng(), 0, None);
self.tx.insert_headers(std::iter::once(&tip))?;
self.db.insert_headers(std::iter::once(&tip))?;
tip.hash()
};
self.send_tip(tip);
@ -467,7 +468,7 @@ mod tests {
.stream_batch_size(500)
.build(client.clone(), Arc::new(TestConsensus::default()))
}),
tx: TestTransaction::default(),
db: TestStageDB::default(),
}
}
}
@ -477,10 +478,10 @@ mod tests {
&self,
block: BlockNumber,
) -> Result<(), TestRunnerError> {
self.tx
self.db
.ensure_no_entry_above_by_value::<tables::HeaderNumbers, _>(block, |val| val)?;
self.tx.ensure_no_entry_above::<tables::CanonicalHeaders, _>(block, |key| key)?;
self.tx.ensure_no_entry_above::<tables::Headers, _>(block, |key| key)?;
self.db.ensure_no_entry_above::<tables::CanonicalHeaders, _>(block, |key| key)?;
self.db.ensure_no_entry_above::<tables::Headers, _>(block, |key| key)?;
Ok(())
}

View File

@ -44,7 +44,7 @@ impl<DB: Database> Stage<DB> for IndexAccountHistoryStage {
/// Execute the stage.
fn execute(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
mut input: ExecInput,
) -> Result<ExecOutput, StageError> {
if let Some((target_prunable_block, prune_mode)) = self
@ -87,7 +87,7 @@ impl<DB: Database> Stage<DB> for IndexAccountHistoryStage {
/// Unwind the stage.
fn unwind(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
let (range, unwind_progress, _) =
@ -105,7 +105,7 @@ mod tests {
use super::*;
use crate::test_utils::{
stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, TestRunnerError,
TestTransaction, UnwindStageTestRunner,
TestStageDB, UnwindStageTestRunner,
};
use itertools::Itertools;
use reth_db::{
@ -122,8 +122,7 @@ mod tests {
generators,
generators::{random_block_range, random_changeset_range, random_contract_account_range},
};
use reth_primitives::{address, Address, BlockNumber, PruneMode, B256, MAINNET};
use reth_provider::ProviderFactory;
use reth_primitives::{address, Address, BlockNumber, PruneMode, B256};
use std::collections::BTreeMap;
const ADDRESS: Address = address!("0000000000000000000000000000000000000001");
@ -153,9 +152,9 @@ mod tests {
.collect()
}
fn partial_setup(tx: &TestTransaction) {
fn partial_setup(db: &TestStageDB) {
// setup
tx.commit(|tx| {
db.commit(|tx| {
// we just need first and last
tx.put::<tables::BlockBodyIndices>(
0,
@ -177,25 +176,23 @@ mod tests {
.unwrap()
}
fn run(tx: &TestTransaction, run_to: u64) {
fn run(db: &TestStageDB, run_to: u64) {
let input = ExecInput { target: Some(run_to), ..Default::default() };
let mut stage = IndexAccountHistoryStage::default();
let factory = ProviderFactory::new(tx.tx.as_ref(), MAINNET.clone());
let provider = factory.provider_rw().unwrap();
let provider = db.factory.provider_rw().unwrap();
let out = stage.execute(&provider, input).unwrap();
assert_eq!(out, ExecOutput { checkpoint: StageCheckpoint::new(5), done: true });
provider.commit().unwrap();
}
fn unwind(tx: &TestTransaction, unwind_from: u64, unwind_to: u64) {
fn unwind(db: &TestStageDB, unwind_from: u64, unwind_to: u64) {
let input = UnwindInput {
checkpoint: StageCheckpoint::new(unwind_from),
unwind_to,
..Default::default()
};
let mut stage = IndexAccountHistoryStage::default();
let factory = ProviderFactory::new(tx.tx.as_ref(), MAINNET.clone());
let provider = factory.provider_rw().unwrap();
let provider = db.factory.provider_rw().unwrap();
let out = stage.unwind(&provider, input).unwrap();
assert_eq!(out, UnwindOutput { checkpoint: StageCheckpoint::new(unwind_to) });
provider.commit().unwrap();
@ -204,116 +201,116 @@ mod tests {
#[tokio::test]
async fn insert_index_to_empty() {
// init
let tx = TestTransaction::default();
let db = TestStageDB::default();
// setup
partial_setup(&tx);
partial_setup(&db);
// run
run(&tx, 5);
run(&db, 5);
// verify
let table = cast(tx.table::<tables::AccountHistory>().unwrap());
let table = cast(db.table::<tables::AccountHistory>().unwrap());
assert_eq!(table, BTreeMap::from([(shard(u64::MAX), vec![4, 5])]));
// unwind
unwind(&tx, 5, 0);
unwind(&db, 5, 0);
// verify initial state
let table = tx.table::<tables::AccountHistory>().unwrap();
let table = db.table::<tables::AccountHistory>().unwrap();
assert!(table.is_empty());
}
#[tokio::test]
async fn insert_index_to_not_empty_shard() {
// init
let tx = TestTransaction::default();
let db = TestStageDB::default();
// setup
partial_setup(&tx);
tx.commit(|tx| {
partial_setup(&db);
db.commit(|tx| {
tx.put::<tables::AccountHistory>(shard(u64::MAX), list(&[1, 2, 3])).unwrap();
Ok(())
})
.unwrap();
// run
run(&tx, 5);
run(&db, 5);
// verify
let table = cast(tx.table::<tables::AccountHistory>().unwrap());
let table = cast(db.table::<tables::AccountHistory>().unwrap());
assert_eq!(table, BTreeMap::from([(shard(u64::MAX), vec![1, 2, 3, 4, 5]),]));
// unwind
unwind(&tx, 5, 0);
unwind(&db, 5, 0);
// verify initial state
let table = cast(tx.table::<tables::AccountHistory>().unwrap());
let table = cast(db.table::<tables::AccountHistory>().unwrap());
assert_eq!(table, BTreeMap::from([(shard(u64::MAX), vec![1, 2, 3]),]));
}
#[tokio::test]
async fn insert_index_to_full_shard() {
// init
let tx = TestTransaction::default();
let db = TestStageDB::default();
let full_list = vec![3; NUM_OF_INDICES_IN_SHARD];
// setup
partial_setup(&tx);
tx.commit(|tx| {
partial_setup(&db);
db.commit(|tx| {
tx.put::<tables::AccountHistory>(shard(u64::MAX), list(&full_list)).unwrap();
Ok(())
})
.unwrap();
// run
run(&tx, 5);
run(&db, 5);
// verify
let table = cast(tx.table::<tables::AccountHistory>().unwrap());
let table = cast(db.table::<tables::AccountHistory>().unwrap());
assert_eq!(
table,
BTreeMap::from([(shard(3), full_list.clone()), (shard(u64::MAX), vec![4, 5])])
);
// unwind
unwind(&tx, 5, 0);
unwind(&db, 5, 0);
// verify initial state
let table = cast(tx.table::<tables::AccountHistory>().unwrap());
let table = cast(db.table::<tables::AccountHistory>().unwrap());
assert_eq!(table, BTreeMap::from([(shard(u64::MAX), full_list)]));
}
#[tokio::test]
async fn insert_index_to_fill_shard() {
// init
let tx = TestTransaction::default();
let db = TestStageDB::default();
let mut close_full_list = vec![1; NUM_OF_INDICES_IN_SHARD - 2];
// setup
partial_setup(&tx);
tx.commit(|tx| {
partial_setup(&db);
db.commit(|tx| {
tx.put::<tables::AccountHistory>(shard(u64::MAX), list(&close_full_list)).unwrap();
Ok(())
})
.unwrap();
// run
run(&tx, 5);
run(&db, 5);
// verify
close_full_list.push(4);
close_full_list.push(5);
let table = cast(tx.table::<tables::AccountHistory>().unwrap());
let table = cast(db.table::<tables::AccountHistory>().unwrap());
assert_eq!(table, BTreeMap::from([(shard(u64::MAX), close_full_list.clone()),]));
// unwind
unwind(&tx, 5, 0);
unwind(&db, 5, 0);
// verify initial state
close_full_list.pop();
close_full_list.pop();
let table = cast(tx.table::<tables::AccountHistory>().unwrap());
let table = cast(db.table::<tables::AccountHistory>().unwrap());
assert_eq!(table, BTreeMap::from([(shard(u64::MAX), close_full_list),]));
// verify initial state
@ -322,46 +319,46 @@ mod tests {
#[tokio::test]
async fn insert_index_second_half_shard() {
// init
let tx = TestTransaction::default();
let db = TestStageDB::default();
let mut close_full_list = vec![1; NUM_OF_INDICES_IN_SHARD - 1];
// setup
partial_setup(&tx);
tx.commit(|tx| {
partial_setup(&db);
db.commit(|tx| {
tx.put::<tables::AccountHistory>(shard(u64::MAX), list(&close_full_list)).unwrap();
Ok(())
})
.unwrap();
// run
run(&tx, 5);
run(&db, 5);
// verify
close_full_list.push(4);
let table = cast(tx.table::<tables::AccountHistory>().unwrap());
let table = cast(db.table::<tables::AccountHistory>().unwrap());
assert_eq!(
table,
BTreeMap::from([(shard(4), close_full_list.clone()), (shard(u64::MAX), vec![5])])
);
// unwind
unwind(&tx, 5, 0);
unwind(&db, 5, 0);
// verify initial state
close_full_list.pop();
let table = cast(tx.table::<tables::AccountHistory>().unwrap());
let table = cast(db.table::<tables::AccountHistory>().unwrap());
assert_eq!(table, BTreeMap::from([(shard(u64::MAX), close_full_list),]));
}
#[tokio::test]
async fn insert_index_to_third_shard() {
// init
let tx = TestTransaction::default();
let db = TestStageDB::default();
let full_list = vec![1; NUM_OF_INDICES_IN_SHARD];
// setup
partial_setup(&tx);
tx.commit(|tx| {
partial_setup(&db);
db.commit(|tx| {
tx.put::<tables::AccountHistory>(shard(1), list(&full_list)).unwrap();
tx.put::<tables::AccountHistory>(shard(2), list(&full_list)).unwrap();
tx.put::<tables::AccountHistory>(shard(u64::MAX), list(&[2, 3])).unwrap();
@ -369,10 +366,10 @@ mod tests {
})
.unwrap();
run(&tx, 5);
run(&db, 5);
// verify
let table = cast(tx.table::<tables::AccountHistory>().unwrap());
let table = cast(db.table::<tables::AccountHistory>().unwrap());
assert_eq!(
table,
BTreeMap::from([
@ -383,10 +380,10 @@ mod tests {
);
// unwind
unwind(&tx, 5, 0);
unwind(&db, 5, 0);
// verify initial state
let table = cast(tx.table::<tables::AccountHistory>().unwrap());
let table = cast(db.table::<tables::AccountHistory>().unwrap());
assert_eq!(
table,
BTreeMap::from([
@ -400,10 +397,10 @@ mod tests {
#[tokio::test]
async fn insert_index_with_prune_mode() {
// init
let tx = TestTransaction::default();
let db = TestStageDB::default();
// setup
tx.commit(|tx| {
db.commit(|tx| {
// we just need first and last
tx.put::<tables::BlockBodyIndices>(
0,
@ -431,43 +428,42 @@ mod tests {
prune_mode: Some(PruneMode::Before(36)),
..Default::default()
};
let factory = ProviderFactory::new(tx.tx.as_ref(), MAINNET.clone());
let provider = factory.provider_rw().unwrap();
let provider = db.factory.provider_rw().unwrap();
let out = stage.execute(&provider, input).unwrap();
assert_eq!(out, ExecOutput { checkpoint: StageCheckpoint::new(20000), done: true });
provider.commit().unwrap();
// verify
let table = cast(tx.table::<tables::AccountHistory>().unwrap());
let table = cast(db.table::<tables::AccountHistory>().unwrap());
assert_eq!(table, BTreeMap::from([(shard(u64::MAX), vec![36, 100])]));
// unwind
unwind(&tx, 20000, 0);
unwind(&db, 20000, 0);
// verify initial state
let table = tx.table::<tables::AccountHistory>().unwrap();
let table = db.table::<tables::AccountHistory>().unwrap();
assert!(table.is_empty());
}
stage_test_suite_ext!(IndexAccountHistoryTestRunner, index_account_history);
struct IndexAccountHistoryTestRunner {
pub(crate) tx: TestTransaction,
pub(crate) db: TestStageDB,
commit_threshold: u64,
prune_mode: Option<PruneMode>,
}
impl Default for IndexAccountHistoryTestRunner {
fn default() -> Self {
Self { tx: TestTransaction::default(), commit_threshold: 1000, prune_mode: None }
Self { db: TestStageDB::default(), commit_threshold: 1000, prune_mode: None }
}
}
impl StageTestRunner for IndexAccountHistoryTestRunner {
type S = IndexAccountHistoryStage;
fn tx(&self) -> &TestTransaction {
&self.tx
fn db(&self) -> &TestStageDB {
&self.db
}
fn stage(&self) -> Self::S {
@ -500,7 +496,7 @@ mod tests {
);
// add block changeset from block 1.
self.tx.insert_changesets(transitions, Some(start))?;
self.db.insert_changesets(transitions, Some(start))?;
Ok(())
}
@ -522,7 +518,7 @@ mod tests {
ExecOutput { checkpoint: StageCheckpoint::new(input.target()), done: true }
);
let provider = self.tx.inner();
let provider = self.db.factory.provider()?;
let mut changeset_cursor =
provider.tx_ref().cursor_read::<tables::AccountChangeSet>()?;
@ -568,7 +564,7 @@ mod tests {
};
}
let table = cast(self.tx.table::<tables::AccountHistory>().unwrap());
let table = cast(self.db.table::<tables::AccountHistory>().unwrap());
assert_eq!(table, result);
}
Ok(())
@ -577,7 +573,7 @@ mod tests {
impl UnwindStageTestRunner for IndexAccountHistoryTestRunner {
fn validate_unwind(&self, _input: UnwindInput) -> Result<(), TestRunnerError> {
let table = self.tx.table::<tables::AccountHistory>().unwrap();
let table = self.db.table::<tables::AccountHistory>().unwrap();
assert!(table.is_empty());
Ok(())
}

View File

@ -43,7 +43,7 @@ impl<DB: Database> Stage<DB> for IndexStorageHistoryStage {
/// Execute the stage.
fn execute(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
mut input: ExecInput,
) -> Result<ExecOutput, StageError> {
if let Some((target_prunable_block, prune_mode)) = self
@ -85,7 +85,7 @@ impl<DB: Database> Stage<DB> for IndexStorageHistoryStage {
/// Unwind the stage.
fn unwind(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
let (range, unwind_progress, _) =
@ -102,7 +102,7 @@ mod tests {
use super::*;
use crate::test_utils::{
stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, TestRunnerError,
TestTransaction, UnwindStageTestRunner,
TestStageDB, UnwindStageTestRunner,
};
use itertools::Itertools;
use reth_db::{
@ -121,9 +121,8 @@ mod tests {
generators::{random_block_range, random_changeset_range, random_contract_account_range},
};
use reth_primitives::{
address, b256, Address, BlockNumber, PruneMode, StorageEntry, B256, MAINNET, U256,
address, b256, Address, BlockNumber, PruneMode, StorageEntry, B256, U256,
};
use reth_provider::ProviderFactory;
use std::collections::BTreeMap;
const ADDRESS: Address = address!("0000000000000000000000000000000000000001");
@ -163,9 +162,9 @@ mod tests {
.collect()
}
fn partial_setup(tx: &TestTransaction) {
fn partial_setup(db: &TestStageDB) {
// setup
tx.commit(|tx| {
db.commit(|tx| {
// we just need first and last
tx.put::<tables::BlockBodyIndices>(
0,
@ -187,25 +186,23 @@ mod tests {
.unwrap()
}
fn run(tx: &TestTransaction, run_to: u64) {
fn run(db: &TestStageDB, run_to: u64) {
let input = ExecInput { target: Some(run_to), ..Default::default() };
let mut stage = IndexStorageHistoryStage::default();
let factory = ProviderFactory::new(tx.tx.as_ref(), MAINNET.clone());
let provider = factory.provider_rw().unwrap();
let provider = db.factory.provider_rw().unwrap();
let out = stage.execute(&provider, input).unwrap();
assert_eq!(out, ExecOutput { checkpoint: StageCheckpoint::new(5), done: true });
provider.commit().unwrap();
}
fn unwind(tx: &TestTransaction, unwind_from: u64, unwind_to: u64) {
fn unwind(db: &TestStageDB, unwind_from: u64, unwind_to: u64) {
let input = UnwindInput {
checkpoint: StageCheckpoint::new(unwind_from),
unwind_to,
..Default::default()
};
let mut stage = IndexStorageHistoryStage::default();
let factory = ProviderFactory::new(tx.tx.as_ref(), MAINNET.clone());
let provider = factory.provider_rw().unwrap();
let provider = db.factory.provider_rw().unwrap();
let out = stage.unwind(&provider, input).unwrap();
assert_eq!(out, UnwindOutput { checkpoint: StageCheckpoint::new(unwind_to) });
provider.commit().unwrap();
@ -214,119 +211,119 @@ mod tests {
#[tokio::test]
async fn insert_index_to_empty() {
// init
let tx = TestTransaction::default();
let db = TestStageDB::default();
// setup
partial_setup(&tx);
partial_setup(&db);
// run
run(&tx, 5);
run(&db, 5);
// verify
let table = cast(tx.table::<tables::StorageHistory>().unwrap());
let table = cast(db.table::<tables::StorageHistory>().unwrap());
assert_eq!(table, BTreeMap::from([(shard(u64::MAX), vec![4, 5]),]));
// unwind
unwind(&tx, 5, 0);
unwind(&db, 5, 0);
// verify initial state
let table = tx.table::<tables::StorageHistory>().unwrap();
let table = db.table::<tables::StorageHistory>().unwrap();
assert!(table.is_empty());
}
#[tokio::test]
async fn insert_index_to_not_empty_shard() {
// init
let tx = TestTransaction::default();
let db = TestStageDB::default();
// setup
partial_setup(&tx);
tx.commit(|tx| {
partial_setup(&db);
db.commit(|tx| {
tx.put::<tables::StorageHistory>(shard(u64::MAX), list(&[1, 2, 3])).unwrap();
Ok(())
})
.unwrap();
// run
run(&tx, 5);
run(&db, 5);
// verify
let table = cast(tx.table::<tables::StorageHistory>().unwrap());
let table = cast(db.table::<tables::StorageHistory>().unwrap());
assert_eq!(table, BTreeMap::from([(shard(u64::MAX), vec![1, 2, 3, 4, 5]),]));
// unwind
unwind(&tx, 5, 0);
unwind(&db, 5, 0);
// verify initial state
let table = cast(tx.table::<tables::StorageHistory>().unwrap());
let table = cast(db.table::<tables::StorageHistory>().unwrap());
assert_eq!(table, BTreeMap::from([(shard(u64::MAX), vec![1, 2, 3]),]));
}
#[tokio::test]
async fn insert_index_to_full_shard() {
// init
let tx = TestTransaction::default();
let db = TestStageDB::default();
let _input = ExecInput { target: Some(5), ..Default::default() };
// change does not matter only that account is present in changeset.
let full_list = vec![3; NUM_OF_INDICES_IN_SHARD];
// setup
partial_setup(&tx);
tx.commit(|tx| {
partial_setup(&db);
db.commit(|tx| {
tx.put::<tables::StorageHistory>(shard(u64::MAX), list(&full_list)).unwrap();
Ok(())
})
.unwrap();
// run
run(&tx, 5);
run(&db, 5);
// verify
let table = cast(tx.table::<tables::StorageHistory>().unwrap());
let table = cast(db.table::<tables::StorageHistory>().unwrap());
assert_eq!(
table,
BTreeMap::from([(shard(3), full_list.clone()), (shard(u64::MAX), vec![4, 5])])
);
// unwind
unwind(&tx, 5, 0);
unwind(&db, 5, 0);
// verify initial state
let table = cast(tx.table::<tables::StorageHistory>().unwrap());
let table = cast(db.table::<tables::StorageHistory>().unwrap());
assert_eq!(table, BTreeMap::from([(shard(u64::MAX), full_list)]));
}
#[tokio::test]
async fn insert_index_to_fill_shard() {
// init
let tx = TestTransaction::default();
let db = TestStageDB::default();
let mut close_full_list = vec![1; NUM_OF_INDICES_IN_SHARD - 2];
// setup
partial_setup(&tx);
tx.commit(|tx| {
partial_setup(&db);
db.commit(|tx| {
tx.put::<tables::StorageHistory>(shard(u64::MAX), list(&close_full_list)).unwrap();
Ok(())
})
.unwrap();
// run
run(&tx, 5);
run(&db, 5);
// verify
close_full_list.push(4);
close_full_list.push(5);
let table = cast(tx.table::<tables::StorageHistory>().unwrap());
let table = cast(db.table::<tables::StorageHistory>().unwrap());
assert_eq!(table, BTreeMap::from([(shard(u64::MAX), close_full_list.clone()),]));
// unwind
unwind(&tx, 5, 0);
unwind(&db, 5, 0);
// verify initial state
close_full_list.pop();
close_full_list.pop();
let table = cast(tx.table::<tables::StorageHistory>().unwrap());
let table = cast(db.table::<tables::StorageHistory>().unwrap());
assert_eq!(table, BTreeMap::from([(shard(u64::MAX), close_full_list),]));
// verify initial state
@ -335,46 +332,46 @@ mod tests {
#[tokio::test]
async fn insert_index_second_half_shard() {
// init
let tx = TestTransaction::default();
let db = TestStageDB::default();
let mut close_full_list = vec![1; NUM_OF_INDICES_IN_SHARD - 1];
// setup
partial_setup(&tx);
tx.commit(|tx| {
partial_setup(&db);
db.commit(|tx| {
tx.put::<tables::StorageHistory>(shard(u64::MAX), list(&close_full_list)).unwrap();
Ok(())
})
.unwrap();
// run
run(&tx, 5);
run(&db, 5);
// verify
close_full_list.push(4);
let table = cast(tx.table::<tables::StorageHistory>().unwrap());
let table = cast(db.table::<tables::StorageHistory>().unwrap());
assert_eq!(
table,
BTreeMap::from([(shard(4), close_full_list.clone()), (shard(u64::MAX), vec![5])])
);
// unwind
unwind(&tx, 5, 0);
unwind(&db, 5, 0);
// verify initial state
close_full_list.pop();
let table = cast(tx.table::<tables::StorageHistory>().unwrap());
let table = cast(db.table::<tables::StorageHistory>().unwrap());
assert_eq!(table, BTreeMap::from([(shard(u64::MAX), close_full_list),]));
}
#[tokio::test]
async fn insert_index_to_third_shard() {
// init
let tx = TestTransaction::default();
let db = TestStageDB::default();
let full_list = vec![1; NUM_OF_INDICES_IN_SHARD];
// setup
partial_setup(&tx);
tx.commit(|tx| {
partial_setup(&db);
db.commit(|tx| {
tx.put::<tables::StorageHistory>(shard(1), list(&full_list)).unwrap();
tx.put::<tables::StorageHistory>(shard(2), list(&full_list)).unwrap();
tx.put::<tables::StorageHistory>(shard(u64::MAX), list(&[2, 3])).unwrap();
@ -382,10 +379,10 @@ mod tests {
})
.unwrap();
run(&tx, 5);
run(&db, 5);
// verify
let table = cast(tx.table::<tables::StorageHistory>().unwrap());
let table = cast(db.table::<tables::StorageHistory>().unwrap());
assert_eq!(
table,
BTreeMap::from([
@ -396,10 +393,10 @@ mod tests {
);
// unwind
unwind(&tx, 5, 0);
unwind(&db, 5, 0);
// verify initial state
let table = cast(tx.table::<tables::StorageHistory>().unwrap());
let table = cast(db.table::<tables::StorageHistory>().unwrap());
assert_eq!(
table,
BTreeMap::from([
@ -413,10 +410,10 @@ mod tests {
#[tokio::test]
async fn insert_index_with_prune_mode() {
// init
let tx = TestTransaction::default();
let db = TestStageDB::default();
// setup
tx.commit(|tx| {
db.commit(|tx| {
// we just need first and last
tx.put::<tables::BlockBodyIndices>(
0,
@ -444,43 +441,42 @@ mod tests {
prune_mode: Some(PruneMode::Before(36)),
..Default::default()
};
let factory = ProviderFactory::new(tx.tx.as_ref(), MAINNET.clone());
let provider = factory.provider_rw().unwrap();
let provider = db.factory.provider_rw().unwrap();
let out = stage.execute(&provider, input).unwrap();
assert_eq!(out, ExecOutput { checkpoint: StageCheckpoint::new(20000), done: true });
provider.commit().unwrap();
// verify
let table = cast(tx.table::<tables::StorageHistory>().unwrap());
let table = cast(db.table::<tables::StorageHistory>().unwrap());
assert_eq!(table, BTreeMap::from([(shard(u64::MAX), vec![36, 100]),]));
// unwind
unwind(&tx, 20000, 0);
unwind(&db, 20000, 0);
// verify initial state
let table = tx.table::<tables::StorageHistory>().unwrap();
let table = db.table::<tables::StorageHistory>().unwrap();
assert!(table.is_empty());
}
stage_test_suite_ext!(IndexStorageHistoryTestRunner, index_storage_history);
struct IndexStorageHistoryTestRunner {
pub(crate) tx: TestTransaction,
pub(crate) db: TestStageDB,
commit_threshold: u64,
prune_mode: Option<PruneMode>,
}
impl Default for IndexStorageHistoryTestRunner {
fn default() -> Self {
Self { tx: TestTransaction::default(), commit_threshold: 1000, prune_mode: None }
Self { db: TestStageDB::default(), commit_threshold: 1000, prune_mode: None }
}
}
impl StageTestRunner for IndexStorageHistoryTestRunner {
type S = IndexStorageHistoryStage;
fn tx(&self) -> &TestTransaction {
&self.tx
fn db(&self) -> &TestStageDB {
&self.db
}
fn stage(&self) -> Self::S {
@ -513,7 +509,7 @@ mod tests {
);
// add block changeset from block 1.
self.tx.insert_changesets(transitions, Some(start))?;
self.db.insert_changesets(transitions, Some(start))?;
Ok(())
}
@ -535,7 +531,7 @@ mod tests {
ExecOutput { checkpoint: StageCheckpoint::new(input.target()), done: true }
);
let provider = self.tx.inner();
let provider = self.db.factory.provider()?;
let mut changeset_cursor =
provider.tx_ref().cursor_read::<tables::StorageChangeSet>()?;
@ -586,7 +582,7 @@ mod tests {
};
}
let table = cast(self.tx.table::<tables::StorageHistory>().unwrap());
let table = cast(self.db.table::<tables::StorageHistory>().unwrap());
assert_eq!(table, result);
}
Ok(())
@ -595,7 +591,7 @@ mod tests {
impl UnwindStageTestRunner for IndexStorageHistoryTestRunner {
fn validate_unwind(&self, _input: UnwindInput) -> Result<(), TestRunnerError> {
let table = self.tx.table::<tables::StorageHistory>().unwrap();
let table = self.db.table::<tables::StorageHistory>().unwrap();
assert!(table.is_empty());
Ok(())
}

View File

@ -80,7 +80,7 @@ impl MerkleStage {
/// Gets the hashing progress
pub fn get_execution_checkpoint<DB: Database>(
&self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
) -> Result<Option<MerkleCheckpoint>, StageError> {
let buf =
provider.get_stage_checkpoint_progress(StageId::MerkleExecute)?.unwrap_or_default();
@ -96,7 +96,7 @@ impl MerkleStage {
/// Saves the hashing progress
pub fn save_execution_checkpoint<DB: Database>(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
checkpoint: Option<MerkleCheckpoint>,
) -> Result<(), StageError> {
let mut buf = vec![];
@ -127,7 +127,7 @@ impl<DB: Database> Stage<DB> for MerkleStage {
/// Execute the stage.
fn execute(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
let threshold = match self {
@ -261,7 +261,7 @@ impl<DB: Database> Stage<DB> for MerkleStage {
/// Unwind the stage.
fn unwind(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
let tx = provider.tx_ref();
@ -338,7 +338,7 @@ mod tests {
use super::*;
use crate::test_utils::{
stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, TestRunnerError,
TestTransaction, UnwindStageTestRunner,
TestStageDB, UnwindStageTestRunner,
};
use assert_matches::assert_matches;
use reth_db::{
@ -392,8 +392,8 @@ mod tests {
done: true
}) if block_number == previous_stage && processed == total &&
total == (
runner.tx.table::<tables::HashedAccount>().unwrap().len() +
runner.tx.table::<tables::HashedStorage>().unwrap().len()
runner.db.table::<tables::HashedAccount>().unwrap().len() +
runner.db.table::<tables::HashedStorage>().unwrap().len()
) as u64
);
@ -432,8 +432,8 @@ mod tests {
done: true
}) if block_number == previous_stage && processed == total &&
total == (
runner.tx.table::<tables::HashedAccount>().unwrap().len() +
runner.tx.table::<tables::HashedStorage>().unwrap().len()
runner.db.table::<tables::HashedAccount>().unwrap().len() +
runner.db.table::<tables::HashedStorage>().unwrap().len()
) as u64
);
@ -442,21 +442,21 @@ mod tests {
}
struct MerkleTestRunner {
tx: TestTransaction,
db: TestStageDB,
clean_threshold: u64,
}
impl Default for MerkleTestRunner {
fn default() -> Self {
Self { tx: TestTransaction::default(), clean_threshold: 10000 }
Self { db: TestStageDB::default(), clean_threshold: 10000 }
}
}
impl StageTestRunner for MerkleTestRunner {
type S = MerkleStage;
fn tx(&self) -> &TestTransaction {
&self.tx
fn db(&self) -> &TestStageDB {
&self.db
}
fn stage(&self) -> Self::S {
@ -479,7 +479,7 @@ mod tests {
.into_iter()
.collect::<BTreeMap<_, _>>();
self.tx.insert_accounts_and_storages(
self.db.insert_accounts_and_storages(
accounts.iter().map(|(addr, acc)| (*addr, (*acc, std::iter::empty()))),
)?;
@ -498,7 +498,7 @@ mod tests {
let head_hash = sealed_head.hash();
let mut blocks = vec![sealed_head];
blocks.extend(random_block_range(&mut rng, start..=end, head_hash, 0..3));
self.tx.insert_blocks(blocks.iter(), None)?;
self.db.insert_blocks(blocks.iter(), None)?;
let (transitions, final_state) = random_changeset_range(
&mut rng,
@ -508,11 +508,11 @@ mod tests {
0..256,
);
// add block changeset from block 1.
self.tx.insert_changesets(transitions, Some(start))?;
self.tx.insert_accounts_and_storages(final_state)?;
self.db.insert_changesets(transitions, Some(start))?;
self.db.insert_accounts_and_storages(final_state)?;
// Calculate state root
let root = self.tx.query(|tx| {
let root = self.db.query(|tx| {
let mut accounts = BTreeMap::default();
let mut accounts_cursor = tx.cursor_read::<tables::HashedAccount>()?;
let mut storage_cursor = tx.cursor_dup_read::<tables::HashedStorage>()?;
@ -536,10 +536,11 @@ mod tests {
})?;
let last_block_number = end;
self.tx.commit(|tx| {
self.db.commit(|tx| {
let mut last_header = tx.get::<tables::Headers>(last_block_number)?.unwrap();
last_header.state_root = root;
tx.put::<tables::Headers>(last_block_number, last_header)
tx.put::<tables::Headers>(last_block_number, last_header)?;
Ok(())
})?;
Ok(blocks)
@ -564,7 +565,7 @@ mod tests {
fn before_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
let target_block = input.unwind_to + 1;
self.tx
self.db
.commit(|tx| {
let mut storage_changesets_cursor =
tx.cursor_dup_read::<tables::StorageChangeSet>().unwrap();

View File

@ -42,7 +42,7 @@ mod tests {
use crate::{
stage::Stage,
stages::{ExecutionStage, IndexAccountHistoryStage, IndexStorageHistoryStage},
test_utils::TestTransaction,
test_utils::TestStageDB,
ExecInput,
};
use alloy_rlp::Decodable;
@ -50,17 +50,17 @@ mod tests {
cursor::DbCursorRO,
mdbx::{cursor::Cursor, RW},
tables,
test_utils::TempDatabase,
transaction::{DbTx, DbTxMut},
AccountHistory, DatabaseEnv,
};
use reth_interfaces::test_utils::generators::{self, random_block};
use reth_primitives::{
address, hex_literal::hex, keccak256, Account, Bytecode, ChainSpecBuilder, PruneMode,
PruneModes, SealedBlock, MAINNET, U256,
PruneModes, SealedBlock, U256,
};
use reth_provider::{
AccountExtReader, BlockWriter, DatabaseProviderRW, ProviderFactory, ReceiptProvider,
StorageReader,
AccountExtReader, BlockWriter, ProviderFactory, ReceiptProvider, StorageReader,
};
use reth_revm::Factory;
use std::sync::Arc;
@ -68,18 +68,17 @@ mod tests {
#[tokio::test]
#[ignore]
async fn test_prune() {
let test_tx = TestTransaction::default();
let factory = Arc::new(ProviderFactory::new(test_tx.tx.db(), MAINNET.clone()));
let test_db = TestStageDB::default();
let provider = factory.provider_rw().unwrap();
let provider_rw = test_db.factory.provider_rw().unwrap();
let tip = 66;
let input = ExecInput { target: Some(tip), checkpoint: None };
let mut genesis_rlp = hex!("f901faf901f5a00000000000000000000000000000000000000000000000000000000000000000a01dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347942adc25665018aa1fe0e6bc666dac8fc2697ff9baa045571b40ae66ca7480791bbb2887286e4e4c4b1b298b191c889d6959023a32eda056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421a056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421b901000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000083020000808502540be400808000a00000000000000000000000000000000000000000000000000000000000000000880000000000000000c0c0").as_slice();
let genesis = SealedBlock::decode(&mut genesis_rlp).unwrap();
let mut block_rlp = hex!("f90262f901f9a075c371ba45999d87f4542326910a11af515897aebce5265d3f6acd1f1161f82fa01dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347942adc25665018aa1fe0e6bc666dac8fc2697ff9baa098f2dcd87c8ae4083e7017a05456c14eea4b1db2032126e27b3b1563d57d7cc0a08151d548273f6683169524b66ca9fe338b9ce42bc3540046c828fd939ae23bcba03f4e5c2ec5b2170b711d97ee755c160457bb58d8daa338e835ec02ae6860bbabb901000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000083020000018502540be40082a8798203e800a00000000000000000000000000000000000000000000000000000000000000000880000000000000000f863f861800a8405f5e10094100000000000000000000000000000000000000080801ba07e09e26678ed4fac08a249ebe8ed680bf9051a5e14ad223e4b2b9d26e0208f37a05f6e3f188e3e6eab7d7d3b6568f5eac7d687b08d307d3154ccd8c87b4630509bc0").as_slice();
let block = SealedBlock::decode(&mut block_rlp).unwrap();
provider.insert_block(genesis, None, None).unwrap();
provider.insert_block(block.clone(), None, None).unwrap();
provider_rw.insert_block(genesis, None, None).unwrap();
provider_rw.insert_block(block.clone(), None, None).unwrap();
// Fill with bogus blocks to respect PruneMode distance.
let mut head = block.hash;
@ -87,22 +86,22 @@ mod tests {
for block_number in 2..=tip {
let nblock = random_block(&mut rng, block_number, Some(head), Some(0), Some(0));
head = nblock.hash;
provider.insert_block(nblock, None, None).unwrap();
provider_rw.insert_block(nblock, None, None).unwrap();
}
provider.commit().unwrap();
provider_rw.commit().unwrap();
// insert pre state
let provider = factory.provider_rw().unwrap();
let provider_rw = test_db.factory.provider_rw().unwrap();
let code = hex!("5a465a905090036002900360015500");
let code_hash = keccak256(hex!("5a465a905090036002900360015500"));
provider
provider_rw
.tx_ref()
.put::<tables::PlainAccountState>(
address!("1000000000000000000000000000000000000000"),
Account { nonce: 0, balance: U256::ZERO, bytecode_hash: Some(code_hash) },
)
.unwrap();
provider
provider_rw
.tx_ref()
.put::<tables::PlainAccountState>(
address!("a94f5374fce5edbc8e2a8697c15331677e6ebf0b"),
@ -113,18 +112,18 @@ mod tests {
},
)
.unwrap();
provider
provider_rw
.tx_ref()
.put::<tables::Bytecodes>(code_hash, Bytecode::new_raw(code.to_vec().into()))
.unwrap();
provider.commit().unwrap();
provider_rw.commit().unwrap();
let check_pruning = |factory: Arc<ProviderFactory<_>>,
let check_pruning = |factory: ProviderFactory<Arc<TempDatabase<DatabaseEnv>>>,
prune_modes: PruneModes,
expect_num_receipts: usize,
expect_num_acc_changesets: usize,
expect_num_storage_changesets: usize| async move {
let provider: DatabaseProviderRW<&DatabaseEnv> = factory.provider_rw().unwrap();
let provider = factory.provider_rw().unwrap();
// Check execution and create receipts and changesets according to the pruning
// configuration
@ -195,34 +194,34 @@ mod tests {
// In an unpruned configuration there is 1 receipt, 3 changed accounts and 1 changed
// storage.
let mut prune = PruneModes::none();
check_pruning(factory.clone(), prune.clone(), 1, 3, 1).await;
check_pruning(test_db.factory.clone(), prune.clone(), 1, 3, 1).await;
prune.receipts = Some(PruneMode::Full);
prune.account_history = Some(PruneMode::Full);
prune.storage_history = Some(PruneMode::Full);
// This will result in error for account_history and storage_history, which is caught.
check_pruning(factory.clone(), prune.clone(), 0, 0, 0).await;
check_pruning(test_db.factory.clone(), prune.clone(), 0, 0, 0).await;
prune.receipts = Some(PruneMode::Before(1));
prune.account_history = Some(PruneMode::Before(1));
prune.storage_history = Some(PruneMode::Before(1));
check_pruning(factory.clone(), prune.clone(), 1, 3, 1).await;
check_pruning(test_db.factory.clone(), prune.clone(), 1, 3, 1).await;
prune.receipts = Some(PruneMode::Before(2));
prune.account_history = Some(PruneMode::Before(2));
prune.storage_history = Some(PruneMode::Before(2));
// The one account is the miner
check_pruning(factory.clone(), prune.clone(), 0, 1, 0).await;
check_pruning(test_db.factory.clone(), prune.clone(), 0, 1, 0).await;
prune.receipts = Some(PruneMode::Distance(66));
prune.account_history = Some(PruneMode::Distance(66));
prune.storage_history = Some(PruneMode::Distance(66));
check_pruning(factory.clone(), prune.clone(), 1, 3, 1).await;
check_pruning(test_db.factory.clone(), prune.clone(), 1, 3, 1).await;
prune.receipts = Some(PruneMode::Distance(64));
prune.account_history = Some(PruneMode::Distance(64));
prune.storage_history = Some(PruneMode::Distance(64));
// The one account is the miner
check_pruning(factory.clone(), prune.clone(), 0, 1, 0).await;
check_pruning(test_db.factory.clone(), prune.clone(), 0, 1, 0).await;
}
}

View File

@ -56,7 +56,7 @@ impl<DB: Database> Stage<DB> for SenderRecoveryStage {
/// the [`TxSenders`][reth_db::tables::TxSenders] table.
fn execute(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
if input.target_reached() {
@ -168,7 +168,7 @@ impl<DB: Database> Stage<DB> for SenderRecoveryStage {
/// Unwind the stage.
fn unwind(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
let (_, unwind_to, _) = input.unwind_block_range_with_threshold(self.commit_threshold);
@ -207,7 +207,7 @@ fn recover_sender(
}
fn stage_checkpoint<DB: Database>(
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
) -> Result<EntitiesCheckpoint, StageError> {
let pruned_entries = provider
.get_prune_checkpoint(PruneSegment::SenderRecovery)?
@ -250,14 +250,14 @@ mod tests {
};
use reth_primitives::{
stage::StageUnitCheckpoint, BlockNumber, PruneCheckpoint, PruneMode, SealedBlock,
TransactionSigned, B256, MAINNET,
TransactionSigned, B256,
};
use reth_provider::{ProviderFactory, PruneCheckpointWriter, TransactionsProvider};
use reth_provider::{PruneCheckpointWriter, TransactionsProvider};
use super::*;
use crate::test_utils::{
stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, TestRunnerError,
TestTransaction, UnwindStageTestRunner,
TestStageDB, UnwindStageTestRunner,
};
stage_test_suite_ext!(SenderRecoveryTestRunner, sender_recovery);
@ -288,7 +288,7 @@ mod tests {
)
})
.collect::<Vec<_>>();
runner.tx.insert_blocks(blocks.iter(), None).expect("failed to insert blocks");
runner.db.insert_blocks(blocks.iter(), None).expect("failed to insert blocks");
let rx = runner.execute(input);
@ -322,9 +322,9 @@ mod tests {
// Manually seed once with full input range
let seed =
random_block_range(&mut rng, stage_progress + 1..=previous_stage, B256::ZERO, 0..4); // set tx count range high enough to hit the threshold
runner.tx.insert_blocks(seed.iter(), None).expect("failed to seed execution");
runner.db.insert_blocks(seed.iter(), None).expect("failed to seed execution");
let total_transactions = runner.tx.table::<tables::Transactions>().unwrap().len() as u64;
let total_transactions = runner.db.table::<tables::Transactions>().unwrap().len() as u64;
let first_input = ExecInput {
target: Some(previous_stage),
@ -348,7 +348,7 @@ mod tests {
ExecOutput {
checkpoint: StageCheckpoint::new(expected_progress).with_entities_stage_checkpoint(
EntitiesCheckpoint {
processed: runner.tx.table::<tables::TxSenders>().unwrap().len() as u64,
processed: runner.db.table::<tables::TxSenders>().unwrap().len() as u64,
total: total_transactions
}
),
@ -379,11 +379,11 @@ mod tests {
#[test]
fn stage_checkpoint_pruned() {
let tx = TestTransaction::default();
let db = TestStageDB::default();
let mut rng = generators::rng();
let blocks = random_block_range(&mut rng, 0..=100, B256::ZERO, 0..10);
tx.insert_blocks(blocks.iter(), None).expect("insert blocks");
db.insert_blocks(blocks.iter(), None).expect("insert blocks");
let max_pruned_block = 30;
let max_processed_block = 70;
@ -399,9 +399,9 @@ mod tests {
tx_number += 1;
}
}
tx.insert_transaction_senders(tx_senders).expect("insert tx hash numbers");
db.insert_transaction_senders(tx_senders).expect("insert tx hash numbers");
let provider = tx.inner_rw();
let provider = db.factory.provider_rw().unwrap();
provider
.save_prune_checkpoint(
PruneSegment::SenderRecovery,
@ -419,10 +419,7 @@ mod tests {
.expect("save stage checkpoint");
provider.commit().expect("commit");
let db = tx.inner_raw();
let factory = ProviderFactory::new(db.as_ref(), MAINNET.clone());
let provider = factory.provider_rw().expect("provider rw");
let provider = db.factory.provider_rw().unwrap();
assert_eq!(
stage_checkpoint(&provider).expect("stage checkpoint"),
EntitiesCheckpoint {
@ -436,13 +433,13 @@ mod tests {
}
struct SenderRecoveryTestRunner {
tx: TestTransaction,
db: TestStageDB,
threshold: u64,
}
impl Default for SenderRecoveryTestRunner {
fn default() -> Self {
Self { threshold: 1000, tx: TestTransaction::default() }
Self { threshold: 1000, db: TestStageDB::default() }
}
}
@ -459,16 +456,17 @@ mod tests {
/// not empty.
fn ensure_no_senders_by_block(&self, block: BlockNumber) -> Result<(), TestRunnerError> {
let body_result = self
.tx
.inner_rw()
.db
.factory
.provider_rw()?
.block_body_indices(block)?
.ok_or(ProviderError::BlockBodyIndicesNotFound(block));
match body_result {
Ok(body) => self
.tx
.db
.ensure_no_entry_above::<tables::TxSenders, _>(body.last_tx_num(), |key| key)?,
Err(_) => {
assert!(self.tx.table_is_empty::<tables::TxSenders>()?);
assert!(self.db.table_is_empty::<tables::TxSenders>()?);
}
};
@ -479,8 +477,8 @@ mod tests {
impl StageTestRunner for SenderRecoveryTestRunner {
type S = SenderRecoveryStage;
fn tx(&self) -> &TestTransaction {
&self.tx
fn db(&self) -> &TestStageDB {
&self.db
}
fn stage(&self) -> Self::S {
@ -497,7 +495,7 @@ mod tests {
let end = input.target();
let blocks = random_block_range(&mut rng, stage_progress..=end, B256::ZERO, 0..2);
self.tx.insert_blocks(blocks.iter(), None)?;
self.db.insert_blocks(blocks.iter(), None)?;
Ok(blocks)
}
@ -508,7 +506,7 @@ mod tests {
) -> Result<(), TestRunnerError> {
match output {
Some(output) => {
let provider = self.tx.inner();
let provider = self.db.factory.provider()?;
let start_block = input.next_block();
let end_block = output.checkpoint.block_number;

View File

@ -50,7 +50,7 @@ impl<DB: Database> Stage<DB> for TotalDifficultyStage {
/// Write total difficulty entries
fn execute(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
let tx = provider.tx_ref();
@ -100,7 +100,7 @@ impl<DB: Database> Stage<DB> for TotalDifficultyStage {
/// Unwind the stage.
fn unwind(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
let (_, unwind_to, _) = input.unwind_block_range_with_threshold(self.commit_threshold);
@ -138,7 +138,7 @@ mod tests {
use super::*;
use crate::test_utils::{
stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, TestRunnerError,
TestTransaction, UnwindStageTestRunner,
TestStageDB, UnwindStageTestRunner,
};
stage_test_suite_ext!(TotalDifficultyTestRunner, total_difficulty);
@ -171,7 +171,7 @@ mod tests {
total
}))
}, done: false }) if block_number == expected_progress && processed == 1 + threshold &&
total == runner.tx.table::<tables::Headers>().unwrap().len() as u64
total == runner.db.table::<tables::Headers>().unwrap().len() as u64
);
// Execute second time
@ -189,14 +189,14 @@ mod tests {
total
}))
}, done: true }) if block_number == previous_stage && processed == total &&
total == runner.tx.table::<tables::Headers>().unwrap().len() as u64
total == runner.db.table::<tables::Headers>().unwrap().len() as u64
);
assert!(runner.validate_execution(first_input, result.ok()).is_ok(), "validation failed");
}
struct TotalDifficultyTestRunner {
tx: TestTransaction,
db: TestStageDB,
consensus: Arc<TestConsensus>,
commit_threshold: u64,
}
@ -204,7 +204,7 @@ mod tests {
impl Default for TotalDifficultyTestRunner {
fn default() -> Self {
Self {
tx: Default::default(),
db: Default::default(),
consensus: Arc::new(TestConsensus::default()),
commit_threshold: 500,
}
@ -214,8 +214,8 @@ mod tests {
impl StageTestRunner for TotalDifficultyTestRunner {
type S = TotalDifficultyStage;
fn tx(&self) -> &TestTransaction {
&self.tx
fn db(&self) -> &TestStageDB {
&self.db
}
fn stage(&self) -> Self::S {
@ -234,15 +234,16 @@ mod tests {
let mut rng = generators::rng();
let start = input.checkpoint().block_number;
let head = random_header(&mut rng, start, None);
self.tx.insert_headers(std::iter::once(&head))?;
self.tx.commit(|tx| {
self.db.insert_headers(std::iter::once(&head))?;
self.db.commit(|tx| {
let td: U256 = tx
.cursor_read::<tables::HeaderTD>()?
.last()?
.map(|(_, v)| v)
.unwrap_or_default()
.into();
tx.put::<tables::HeaderTD>(head.number, (td + head.difficulty).into())
tx.put::<tables::HeaderTD>(head.number, (td + head.difficulty).into())?;
Ok(())
})?;
// use previous progress as seed size
@ -253,7 +254,7 @@ mod tests {
}
let mut headers = random_header_range(&mut rng, start + 1..end, head.hash());
self.tx.insert_headers(headers.iter())?;
self.db.insert_headers(headers.iter())?;
headers.insert(0, head);
Ok(headers)
}
@ -267,7 +268,7 @@ mod tests {
let initial_stage_progress = input.checkpoint().block_number;
match output {
Some(output) if output.checkpoint.block_number > initial_stage_progress => {
let provider = self.tx.inner();
let provider = self.db.factory.provider()?;
let mut header_cursor = provider.tx_ref().cursor_read::<tables::Headers>()?;
let (_, mut current_header) = header_cursor
@ -301,7 +302,7 @@ mod tests {
impl TotalDifficultyTestRunner {
fn check_no_td_above(&self, block: BlockNumber) -> Result<(), TestRunnerError> {
self.tx.ensure_no_entry_above::<tables::HeaderTD, _>(block, |num| num)?;
self.db.ensure_no_entry_above::<tables::HeaderTD, _>(block, |num| num)?;
Ok(())
}

View File

@ -51,7 +51,7 @@ impl<DB: Database> Stage<DB> for TransactionLookupStage {
/// Write transaction hash -> id entries
fn execute(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
mut input: ExecInput,
) -> Result<ExecOutput, StageError> {
if let Some((target_prunable_block, prune_mode)) = self
@ -129,7 +129,7 @@ impl<DB: Database> Stage<DB> for TransactionLookupStage {
/// Unwind the stage.
fn unwind(
&mut self,
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
let tx = provider.tx_ref();
@ -164,7 +164,7 @@ impl<DB: Database> Stage<DB> for TransactionLookupStage {
}
fn stage_checkpoint<DB: Database>(
provider: &DatabaseProviderRW<&DB>,
provider: &DatabaseProviderRW<DB>,
) -> Result<EntitiesCheckpoint, StageError> {
let pruned_entries = provider
.get_prune_checkpoint(PruneSegment::TransactionLookup)?
@ -186,7 +186,7 @@ mod tests {
use super::*;
use crate::test_utils::{
stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, TestRunnerError,
TestTransaction, UnwindStageTestRunner,
TestStageDB, UnwindStageTestRunner,
};
use assert_matches::assert_matches;
use reth_interfaces::test_utils::{
@ -195,11 +195,8 @@ mod tests {
};
use reth_primitives::{
stage::StageUnitCheckpoint, BlockNumber, PruneCheckpoint, PruneMode, SealedBlock, B256,
MAINNET,
};
use reth_provider::{
BlockReader, ProviderError, ProviderFactory, PruneCheckpointWriter, TransactionsProvider,
};
use reth_provider::{BlockReader, ProviderError, PruneCheckpointWriter, TransactionsProvider};
use std::ops::Sub;
// Implement stage test suite.
@ -230,7 +227,7 @@ mod tests {
)
})
.collect::<Vec<_>>();
runner.tx.insert_blocks(blocks.iter(), None).expect("failed to insert blocks");
runner.db.insert_blocks(blocks.iter(), None).expect("failed to insert blocks");
let rx = runner.execute(input);
@ -246,7 +243,7 @@ mod tests {
total
}))
}, done: true }) if block_number == previous_stage && processed == total &&
total == runner.tx.table::<tables::Transactions>().unwrap().len() as u64
total == runner.db.table::<tables::Transactions>().unwrap().len() as u64
);
// Validate the stage execution
@ -269,9 +266,9 @@ mod tests {
// Seed only once with full input range
let seed =
random_block_range(&mut rng, stage_progress + 1..=previous_stage, B256::ZERO, 0..4); // set tx count range high enough to hit the threshold
runner.tx.insert_blocks(seed.iter(), None).expect("failed to seed execution");
runner.db.insert_blocks(seed.iter(), None).expect("failed to seed execution");
let total_txs = runner.tx.table::<tables::Transactions>().unwrap().len() as u64;
let total_txs = runner.db.table::<tables::Transactions>().unwrap().len() as u64;
// Execute first time
let result = runner.execute(first_input).await.unwrap();
@ -290,7 +287,7 @@ mod tests {
ExecOutput {
checkpoint: StageCheckpoint::new(expected_progress).with_entities_stage_checkpoint(
EntitiesCheckpoint {
processed: runner.tx.table::<tables::TxHashNumber>().unwrap().len() as u64,
processed: runner.db.table::<tables::TxHashNumber>().unwrap().len() as u64,
total: total_txs
}
),
@ -334,7 +331,7 @@ mod tests {
// Seed only once with full input range
let seed =
random_block_range(&mut rng, stage_progress + 1..=previous_stage, B256::ZERO, 0..2);
runner.tx.insert_blocks(seed.iter(), None).expect("failed to seed execution");
runner.db.insert_blocks(seed.iter(), None).expect("failed to seed execution");
runner.set_prune_mode(PruneMode::Before(prune_target));
@ -352,7 +349,7 @@ mod tests {
total
}))
}, done: true }) if block_number == previous_stage && processed == total &&
total == runner.tx.table::<tables::Transactions>().unwrap().len() as u64
total == runner.db.table::<tables::Transactions>().unwrap().len() as u64
);
// Validate the stage execution
@ -361,11 +358,11 @@ mod tests {
#[test]
fn stage_checkpoint_pruned() {
let tx = TestTransaction::default();
let db = TestStageDB::default();
let mut rng = generators::rng();
let blocks = random_block_range(&mut rng, 0..=100, B256::ZERO, 0..10);
tx.insert_blocks(blocks.iter(), None).expect("insert blocks");
db.insert_blocks(blocks.iter(), None).expect("insert blocks");
let max_pruned_block = 30;
let max_processed_block = 70;
@ -380,9 +377,9 @@ mod tests {
tx_hash_number += 1;
}
}
tx.insert_tx_hash_numbers(tx_hash_numbers).expect("insert tx hash numbers");
db.insert_tx_hash_numbers(tx_hash_numbers).expect("insert tx hash numbers");
let provider = tx.inner_rw();
let provider = db.factory.provider_rw().unwrap();
provider
.save_prune_checkpoint(
PruneSegment::TransactionLookup,
@ -401,10 +398,7 @@ mod tests {
.expect("save stage checkpoint");
provider.commit().expect("commit");
let db = tx.inner_raw();
let factory = ProviderFactory::new(db.as_ref(), MAINNET.clone());
let provider = factory.provider_rw().expect("provider rw");
let provider = db.factory.provider_rw().unwrap();
assert_eq!(
stage_checkpoint(&provider).expect("stage checkpoint"),
EntitiesCheckpoint {
@ -418,14 +412,14 @@ mod tests {
}
struct TransactionLookupTestRunner {
tx: TestTransaction,
db: TestStageDB,
commit_threshold: u64,
prune_mode: Option<PruneMode>,
}
impl Default for TransactionLookupTestRunner {
fn default() -> Self {
Self { tx: TestTransaction::default(), commit_threshold: 1000, prune_mode: None }
Self { db: TestStageDB::default(), commit_threshold: 1000, prune_mode: None }
}
}
@ -447,17 +441,18 @@ mod tests {
/// not empty.
fn ensure_no_hash_by_block(&self, number: BlockNumber) -> Result<(), TestRunnerError> {
let body_result = self
.tx
.inner_rw()
.db
.factory
.provider_rw()?
.block_body_indices(number)?
.ok_or(ProviderError::BlockBodyIndicesNotFound(number));
match body_result {
Ok(body) => self.tx.ensure_no_entry_above_by_value::<tables::TxHashNumber, _>(
Ok(body) => self.db.ensure_no_entry_above_by_value::<tables::TxHashNumber, _>(
body.last_tx_num(),
|key| key,
)?,
Err(_) => {
assert!(self.tx.table_is_empty::<tables::TxHashNumber>()?);
assert!(self.db.table_is_empty::<tables::TxHashNumber>()?);
}
};
@ -468,8 +463,8 @@ mod tests {
impl StageTestRunner for TransactionLookupTestRunner {
type S = TransactionLookupStage;
fn tx(&self) -> &TestTransaction {
&self.tx
fn db(&self) -> &TestStageDB {
&self.db
}
fn stage(&self) -> Self::S {
@ -489,7 +484,7 @@ mod tests {
let mut rng = generators::rng();
let blocks = random_block_range(&mut rng, stage_progress + 1..=end, B256::ZERO, 0..2);
self.tx.insert_blocks(blocks.iter(), None)?;
self.db.insert_blocks(blocks.iter(), None)?;
Ok(blocks)
}
@ -500,7 +495,7 @@ mod tests {
) -> Result<(), TestRunnerError> {
match output {
Some(output) => {
let provider = self.tx.inner();
let provider = self.db.factory.provider()?;
if let Some((target_prunable_block, _)) = self
.prune_mode

View File

@ -10,7 +10,7 @@ pub(crate) use runner::{
};
mod test_db;
pub use test_db::TestTransaction;
pub use test_db::TestStageDB;
mod stage;
pub use stage::TestStage;

View File

@ -1,6 +1,6 @@
use super::TestTransaction;
use super::TestStageDB;
use crate::{ExecInput, ExecOutput, Stage, StageError, StageExt, UnwindInput, UnwindOutput};
use reth_db::DatabaseEnv;
use reth_db::{test_utils::TempDatabase, DatabaseEnv};
use reth_interfaces::db::DatabaseError;
use reth_primitives::MAINNET;
use reth_provider::{ProviderError, ProviderFactory};
@ -19,10 +19,10 @@ pub(crate) enum TestRunnerError {
/// A generic test runner for stages.
pub(crate) trait StageTestRunner {
type S: Stage<DatabaseEnv> + 'static;
type S: Stage<Arc<TempDatabase<DatabaseEnv>>> + 'static;
/// Return a reference to the database.
fn tx(&self) -> &TestTransaction;
fn db(&self) -> &TestStageDB;
/// Return an instance of a Stage.
fn stage(&self) -> Self::S;
@ -45,12 +45,10 @@ pub(crate) trait ExecuteStageTestRunner: StageTestRunner {
/// Run [Stage::execute] and return a receiver for the result.
fn execute(&self, input: ExecInput) -> oneshot::Receiver<Result<ExecOutput, StageError>> {
let (tx, rx) = oneshot::channel();
let (db, mut stage) = (self.tx().inner_raw(), self.stage());
let (db, mut stage) = (self.db().factory.clone(), self.stage());
tokio::spawn(async move {
let factory = ProviderFactory::new(db.db(), MAINNET.clone());
let result = stage.execute_ready(input).await.and_then(|_| {
let provider_rw = factory.provider_rw().unwrap();
let provider_rw = db.provider_rw().unwrap();
let result = stage.execute(&provider_rw, input);
provider_rw.commit().expect("failed to commit");
result
@ -74,11 +72,9 @@ pub(crate) trait UnwindStageTestRunner: StageTestRunner {
/// Run [Stage::unwind] and return a receiver for the result.
async fn unwind(&self, input: UnwindInput) -> Result<UnwindOutput, StageError> {
let (tx, rx) = oneshot::channel();
let (db, mut stage) = (self.tx().inner_raw(), self.stage());
let (db, mut stage) = (self.db().factory.clone(), self.stage());
tokio::spawn(async move {
let factory = ProviderFactory::new(db.db(), MAINNET.clone());
let provider = factory.provider_rw().unwrap();
let provider = db.provider_rw().unwrap();
let result = stage.unwind(&provider, input);
provider.commit().expect("failed to commit");
tx.send(result).expect("failed to send result");

View File

@ -47,7 +47,7 @@ impl<DB: Database> Stage<DB> for TestStage {
fn execute(
&mut self,
_: &DatabaseProviderRW<&DB>,
_: &DatabaseProviderRW<DB>,
_input: ExecInput,
) -> Result<ExecOutput, StageError> {
self.exec_outputs
@ -57,7 +57,7 @@ impl<DB: Database> Stage<DB> for TestStage {
fn unwind(
&mut self,
_: &DatabaseProviderRW<&DB>,
_: &DatabaseProviderRW<DB>,
_input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
self.unwind_outputs

View File

@ -9,7 +9,7 @@ use reth_db::{
transaction::{DbTx, DbTxMut},
DatabaseEnv, DatabaseError as DbError,
};
use reth_interfaces::{test_utils::generators::ChangeSet, RethResult};
use reth_interfaces::{provider::ProviderResult, test_utils::generators::ChangeSet, RethResult};
use reth_primitives::{
keccak256, Account, Address, BlockNumber, Receipt, SealedBlock, SealedHeader, StorageEntry,
TxHash, TxNumber, B256, MAINNET, U256,
@ -18,80 +18,50 @@ use reth_provider::{DatabaseProviderRO, DatabaseProviderRW, HistoryWriter, Provi
use std::{
borrow::Borrow,
collections::BTreeMap,
ops::RangeInclusive,
ops::{Deref, RangeInclusive},
path::{Path, PathBuf},
sync::Arc,
};
/// The [TestTransaction] is used as an internal
/// database for testing stage implementation.
///
/// ```rust,ignore
/// let tx = TestTransaction::default();
/// stage.execute(&mut tx.container(), input);
/// ```
/// Test database that is used for testing stage implementations.
#[derive(Debug)]
pub struct TestTransaction {
/// DB
pub tx: Arc<TempDatabase<DatabaseEnv>>,
pub path: Option<PathBuf>,
pub struct TestStageDB {
pub factory: ProviderFactory<Arc<TempDatabase<DatabaseEnv>>>,
}
impl Default for TestTransaction {
/// Create a new instance of [TestTransaction]
impl Default for TestStageDB {
/// Create a new instance of [TestStageDB]
fn default() -> Self {
let tx = create_test_rw_db();
Self { tx: tx.clone(), path: None, factory: ProviderFactory::new(tx, MAINNET.clone()) }
Self { factory: ProviderFactory::new(create_test_rw_db(), MAINNET.clone()) }
}
}
impl TestTransaction {
impl TestStageDB {
pub fn new(path: &Path) -> Self {
let tx = create_test_rw_db_with_path(path);
Self {
tx: tx.clone(),
path: Some(path.to_path_buf()),
factory: ProviderFactory::new(tx, MAINNET.clone()),
}
}
/// Return a database wrapped in [DatabaseProviderRW].
pub fn inner_rw(&self) -> DatabaseProviderRW<Arc<TempDatabase<DatabaseEnv>>> {
self.factory.provider_rw().expect("failed to create db container")
}
/// Return a database wrapped in [DatabaseProviderRO].
pub fn inner(&self) -> DatabaseProviderRO<Arc<TempDatabase<DatabaseEnv>>> {
self.factory.provider().expect("failed to create db container")
}
/// Get a pointer to an internal database.
pub fn inner_raw(&self) -> Arc<TempDatabase<DatabaseEnv>> {
self.tx.clone()
Self { factory: ProviderFactory::new(create_test_rw_db_with_path(path), MAINNET.clone()) }
}
/// Invoke a callback with transaction committing it afterwards
pub fn commit<F>(&self, f: F) -> Result<(), DbError>
pub fn commit<F>(&self, f: F) -> ProviderResult<()>
where
F: FnOnce(&<DatabaseEnv as Database>::TXMut) -> Result<(), DbError>,
F: FnOnce(&<DatabaseEnv as Database>::TXMut) -> ProviderResult<()>,
{
let mut tx = self.inner_rw();
let mut tx = self.factory.provider_rw()?;
f(tx.tx_ref())?;
tx.commit().expect("failed to commit");
Ok(())
}
/// Invoke a callback with a read transaction
pub fn query<F, R>(&self, f: F) -> Result<R, DbError>
pub fn query<F, Ok>(&self, f: F) -> ProviderResult<Ok>
where
F: FnOnce(&<DatabaseEnv as Database>::TX) -> Result<R, DbError>,
F: FnOnce(&<DatabaseEnv as Database>::TX) -> ProviderResult<Ok>,
{
f(self.inner().tx_ref())
f(self.factory.provider()?.tx_ref())
}
/// Check if the table is empty
pub fn table_is_empty<T: Table>(&self) -> Result<bool, DbError> {
pub fn table_is_empty<T: Table>(&self) -> ProviderResult<bool> {
self.query(|tx| {
let last = tx.cursor_read::<T>()?.last()?;
Ok(last.is_none())
@ -99,70 +69,21 @@ impl TestTransaction {
}
/// Return full table as Vec
pub fn table<T: Table>(&self) -> Result<Vec<KeyValue<T>>, DbError>
pub fn table<T: Table>(&self) -> ProviderResult<Vec<KeyValue<T>>>
where
T::Key: Default + Ord,
{
self.query(|tx| {
tx.cursor_read::<T>()?
Ok(tx
.cursor_read::<T>()?
.walk(Some(T::Key::default()))?
.collect::<Result<Vec<_>, DbError>>()
})
}
/// Map a collection of values and store them in the database.
/// This function commits the transaction before exiting.
///
/// ```rust,ignore
/// let tx = TestTransaction::default();
/// tx.map_put::<Table, _, _>(&items, |item| item)?;
/// ```
#[allow(dead_code)]
pub fn map_put<T, S, F>(&self, values: &[S], mut map: F) -> Result<(), DbError>
where
T: Table,
S: Clone,
F: FnMut(&S) -> TableRow<T>,
{
self.commit(|tx| {
values.iter().try_for_each(|src| {
let (k, v) = map(src);
tx.put::<T>(k, v)
})
})
}
/// Transform a collection of values using a callback and store
/// them in the database. The callback additionally accepts the
/// optional last element that was stored.
/// This function commits the transaction before exiting.
///
/// ```rust,ignore
/// let tx = TestTransaction::default();
/// tx.transform_append::<Table, _, _>(&items, |prev, item| prev.unwrap_or_default() + item)?;
/// ```
#[allow(dead_code)]
pub fn transform_append<T, S, F>(&self, values: &[S], mut transform: F) -> Result<(), DbError>
where
T: Table,
<T as Table>::Value: Clone,
S: Clone,
F: FnMut(&Option<<T as Table>::Value>, &S) -> TableRow<T>,
{
self.commit(|tx| {
let mut cursor = tx.cursor_write::<T>()?;
let mut last = cursor.last()?.map(|(_, v)| v);
values.iter().try_for_each(|src| {
let (k, v) = transform(&last, src);
last = Some(v.clone());
cursor.append(k, v)
})
.collect::<Result<Vec<_>, DbError>>()?)
})
}
/// Check that there is no table entry above a given
/// number by [Table::Key]
pub fn ensure_no_entry_above<T, F>(&self, num: u64, mut selector: F) -> Result<(), DbError>
pub fn ensure_no_entry_above<T, F>(&self, num: u64, mut selector: F) -> ProviderResult<()>
where
T: Table,
F: FnMut(T::Key) -> BlockNumber,
@ -182,7 +103,7 @@ impl TestTransaction {
&self,
num: u64,
mut selector: F,
) -> Result<(), DbError>
) -> ProviderResult<()>
where
T: Table,
F: FnMut(T::Value) -> BlockNumber,
@ -206,17 +127,19 @@ impl TestTransaction {
/// Insert ordered collection of [SealedHeader] into the corresponding tables
/// that are supposed to be populated by the headers stage.
pub fn insert_headers<'a, I>(&self, headers: I) -> Result<(), DbError>
pub fn insert_headers<'a, I>(&self, headers: I) -> ProviderResult<()>
where
I: Iterator<Item = &'a SealedHeader>,
{
self.commit(|tx| headers.into_iter().try_for_each(|header| Self::insert_header(tx, header)))
self.commit(|tx| {
Ok(headers.into_iter().try_for_each(|header| Self::insert_header(tx, header))?)
})
}
/// Inserts total difficulty of headers into the corresponding tables.
///
/// Superset functionality of [TestTransaction::insert_headers].
pub fn insert_headers_with_td<'a, I>(&self, headers: I) -> Result<(), DbError>
/// Superset functionality of [TestStageDB::insert_headers].
pub fn insert_headers_with_td<'a, I>(&self, headers: I) -> ProviderResult<()>
where
I: Iterator<Item = &'a SealedHeader>,
{
@ -225,16 +148,16 @@ impl TestTransaction {
headers.into_iter().try_for_each(|header| {
Self::insert_header(tx, header)?;
td += header.difficulty;
tx.put::<tables::HeaderTD>(header.number, td.into())
Ok(tx.put::<tables::HeaderTD>(header.number, td.into())?)
})
})
}
/// Insert ordered collection of [SealedBlock] into corresponding tables.
/// Superset functionality of [TestTransaction::insert_headers].
/// Superset functionality of [TestStageDB::insert_headers].
///
/// Assumes that there's a single transition for each transaction (i.e. no block rewards).
pub fn insert_blocks<'a, I>(&self, blocks: I, tx_offset: Option<u64>) -> Result<(), DbError>
pub fn insert_blocks<'a, I>(&self, blocks: I, tx_offset: Option<u64>) -> ProviderResult<()>
where
I: Iterator<Item = &'a SealedBlock>,
{
@ -266,45 +189,45 @@ impl TestTransaction {
})
}
pub fn insert_tx_hash_numbers<I>(&self, tx_hash_numbers: I) -> Result<(), DbError>
pub fn insert_tx_hash_numbers<I>(&self, tx_hash_numbers: I) -> ProviderResult<()>
where
I: IntoIterator<Item = (TxHash, TxNumber)>,
{
self.commit(|tx| {
tx_hash_numbers.into_iter().try_for_each(|(tx_hash, tx_num)| {
// Insert into tx hash numbers table.
tx.put::<tables::TxHashNumber>(tx_hash, tx_num)
Ok(tx.put::<tables::TxHashNumber>(tx_hash, tx_num)?)
})
})
}
/// Insert collection of ([TxNumber], [Receipt]) into the corresponding table.
pub fn insert_receipts<I>(&self, receipts: I) -> Result<(), DbError>
pub fn insert_receipts<I>(&self, receipts: I) -> ProviderResult<()>
where
I: IntoIterator<Item = (TxNumber, Receipt)>,
{
self.commit(|tx| {
receipts.into_iter().try_for_each(|(tx_num, receipt)| {
// Insert into receipts table.
tx.put::<tables::Receipts>(tx_num, receipt)
Ok(tx.put::<tables::Receipts>(tx_num, receipt)?)
})
})
}
pub fn insert_transaction_senders<I>(&self, transaction_senders: I) -> Result<(), DbError>
pub fn insert_transaction_senders<I>(&self, transaction_senders: I) -> ProviderResult<()>
where
I: IntoIterator<Item = (TxNumber, Address)>,
{
self.commit(|tx| {
transaction_senders.into_iter().try_for_each(|(tx_num, sender)| {
// Insert into receipts table.
tx.put::<tables::TxSenders>(tx_num, sender)
Ok(tx.put::<tables::TxSenders>(tx_num, sender)?)
})
})
}
/// Insert collection of ([Address], [Account]) into corresponding tables.
pub fn insert_accounts_and_storages<I, S>(&self, accounts: I) -> Result<(), DbError>
pub fn insert_accounts_and_storages<I, S>(&self, accounts: I) -> ProviderResult<()>
where
I: IntoIterator<Item = (Address, (Account, S))>,
S: IntoIterator<Item = StorageEntry>,
@ -350,7 +273,7 @@ impl TestTransaction {
&self,
changesets: I,
block_offset: Option<u64>,
) -> Result<(), DbError>
) -> ProviderResult<()>
where
I: IntoIterator<Item = ChangeSet>,
{
@ -369,14 +292,14 @@ impl TestTransaction {
// Insert into storage changeset.
old_storage.into_iter().try_for_each(|entry| {
tx.put::<tables::StorageChangeSet>(block_address, entry)
Ok(tx.put::<tables::StorageChangeSet>(block_address, entry)?)
})
})
})
})
}
pub fn insert_history<I>(&self, changesets: I, block_offset: Option<u64>) -> RethResult<()>
pub fn insert_history<I>(&self, changesets: I, block_offset: Option<u64>) -> ProviderResult<()>
where
I: IntoIterator<Item = ChangeSet>,
{
@ -392,10 +315,10 @@ impl TestTransaction {
}
}
let provider = self.factory.provider_rw()?;
provider.insert_account_history_index(accounts)?;
provider.insert_storage_history_index(storages)?;
provider.commit()?;
let provider_rw = self.factory.provider_rw()?;
provider_rw.insert_account_history_index(accounts)?;
provider_rw.insert_storage_history_index(storages)?;
provider_rw.commit()?;
Ok(())
}