refactor: remove Transaction and add DatabaseProvider to stages (#3034)

Co-authored-by: Georgios Konstantopoulos <me@gakonst.com>
This commit is contained in:
joshieDo
2023-06-12 23:37:58 +01:00
committed by GitHub
parent cfdd88d392
commit f55d88b8c4
58 changed files with 2326 additions and 2109 deletions

View File

@ -22,7 +22,7 @@ use reth_provider::{
chain::{ChainSplit, SplitAt},
post_state::PostState,
BlockNumProvider, CanonStateNotification, CanonStateNotificationSender,
CanonStateNotifications, Chain, ExecutorFactory, HeaderProvider, Transaction,
CanonStateNotifications, Chain, DatabaseProvider, ExecutorFactory, HeaderProvider,
};
use std::{
collections::{BTreeMap, HashMap},
@ -993,14 +993,18 @@ impl<DB: Database, C: Consensus, EF: ExecutorFactory> BlockchainTree<DB, C, EF>
/// Canonicalize the given chain and commit it to the database.
fn commit_canonical(&mut self, chain: Chain) -> Result<(), Error> {
let mut tx = Transaction::new(&self.externals.db)?;
let mut provider = DatabaseProvider::new_rw(
self.externals.db.tx_mut()?,
self.externals.chain_spec.clone(),
);
let (blocks, state) = chain.into_inner();
tx.append_blocks_with_post_state(blocks.into_blocks().collect(), state)
provider
.append_blocks_with_post_state(blocks.into_blocks().collect(), state)
.map_err(|e| BlockExecutionError::CanonicalCommit { inner: e.to_string() })?;
tx.commit()?;
provider.commit()?;
Ok(())
}
@ -1030,17 +1034,20 @@ impl<DB: Database, C: Consensus, EF: ExecutorFactory> BlockchainTree<DB, C, EF>
fn revert_canonical(&mut self, revert_until: BlockNumber) -> Result<Option<Chain>, Error> {
// read data that is needed for new sidechain
let mut tx = Transaction::new(&self.externals.db)?;
let provider = DatabaseProvider::new_rw(
self.externals.db.tx_mut()?,
self.externals.chain_spec.clone(),
);
let tip = tx.tip_number()?;
let tip = provider.last_block_number()?;
let revert_range = (revert_until + 1)..=tip;
info!(target: "blockchain_tree", "Unwinding canonical chain blocks: {:?}", revert_range);
// read block and execution result from database. and remove traces of block from tables.
let blocks_and_execution = tx
let blocks_and_execution = provider
.take_block_and_execution_range(self.externals.chain_spec.as_ref(), revert_range)
.map_err(|e| BlockExecutionError::CanonicalRevert { inner: e.to_string() })?;
tx.commit()?;
provider.commit()?;
if blocks_and_execution.is_empty() {
Ok(None)

View File

@ -1287,7 +1287,6 @@ mod tests {
use reth_primitives::{stage::StageCheckpoint, ChainSpec, ChainSpecBuilder, H256, MAINNET};
use reth_provider::{
providers::BlockchainProvider, test_utils::TestExecutorFactory, ShareableDatabase,
Transaction,
};
use reth_stages::{test_utils::TestStages, ExecOutput, PipelineError, StageError};
use reth_tasks::TokioTaskExecutor;
@ -1384,7 +1383,7 @@ mod tests {
let pipeline = Pipeline::builder()
.add_stages(TestStages::new(pipeline_exec_outputs, Default::default()))
.with_tip_sender(tip_tx)
.build(db.clone());
.build(db.clone(), chain_spec.clone());
// Setup blockchain tree
let externals =
@ -1436,7 +1435,7 @@ mod tests {
.build(),
);
let (consensus_engine, env) = setup_consensus_engine(
chain_spec,
chain_spec.clone(),
VecDeque::from([Err(StageError::ChannelClosed)]),
Vec::default(),
);
@ -1465,7 +1464,7 @@ mod tests {
.build(),
);
let (consensus_engine, env) = setup_consensus_engine(
chain_spec,
chain_spec.clone(),
VecDeque::from([Err(StageError::ChannelClosed)]),
Vec::default(),
);
@ -1505,7 +1504,7 @@ mod tests {
.build(),
);
let (consensus_engine, env) = setup_consensus_engine(
chain_spec,
chain_spec.clone(),
VecDeque::from([
Ok(ExecOutput { checkpoint: StageCheckpoint::new(1), done: true }),
Err(StageError::ChannelClosed),
@ -1538,7 +1537,7 @@ mod tests {
.build(),
);
let (mut consensus_engine, env) = setup_consensus_engine(
chain_spec,
chain_spec.clone(),
VecDeque::from([Ok(ExecOutput {
checkpoint: StageCheckpoint::new(max_block),
done: true,
@ -1557,12 +1556,15 @@ mod tests {
assert_matches!(rx.await, Ok(Ok(())));
}
fn insert_blocks<'a, DB: Database>(db: &DB, mut blocks: impl Iterator<Item = &'a SealedBlock>) {
let mut transaction = Transaction::new(db).unwrap();
blocks
.try_for_each(|b| transaction.insert_block(b.clone(), None))
.expect("failed to insert");
transaction.commit().unwrap();
fn insert_blocks<'a, DB: Database>(
db: &DB,
chain: Arc<ChainSpec>,
mut blocks: impl Iterator<Item = &'a SealedBlock>,
) {
let factory = ShareableDatabase::new(db, chain);
let mut provider = factory.provider_rw().unwrap();
blocks.try_for_each(|b| provider.insert_block(b.clone(), None)).expect("failed to insert");
provider.commit().unwrap();
}
mod fork_choice_updated {
@ -1581,7 +1583,7 @@ mod tests {
.build(),
);
let (consensus_engine, env) = setup_consensus_engine(
chain_spec,
chain_spec.clone(),
VecDeque::from([Ok(ExecOutput {
done: true,
checkpoint: StageCheckpoint::new(0),
@ -1612,7 +1614,7 @@ mod tests {
.build(),
);
let (consensus_engine, env) = setup_consensus_engine(
chain_spec,
chain_spec.clone(),
VecDeque::from([Ok(ExecOutput {
done: true,
checkpoint: StageCheckpoint::new(0),
@ -1622,7 +1624,7 @@ mod tests {
let genesis = random_block(0, None, None, Some(0));
let block1 = random_block(1, Some(genesis.hash), None, Some(0));
insert_blocks(env.db.as_ref(), [&genesis, &block1].into_iter());
insert_blocks(env.db.as_ref(), chain_spec.clone(), [&genesis, &block1].into_iter());
env.db
.update(|tx| {
tx.put::<tables::SyncStage>(
@ -1660,7 +1662,7 @@ mod tests {
.build(),
);
let (consensus_engine, env) = setup_consensus_engine(
chain_spec,
chain_spec.clone(),
VecDeque::from([
Ok(ExecOutput { done: true, checkpoint: StageCheckpoint::new(0) }),
Ok(ExecOutput { done: true, checkpoint: StageCheckpoint::new(0) }),
@ -1670,7 +1672,7 @@ mod tests {
let genesis = random_block(0, None, None, Some(0));
let block1 = random_block(1, Some(genesis.hash), None, Some(0));
insert_blocks(env.db.as_ref(), [&genesis, &block1].into_iter());
insert_blocks(env.db.as_ref(), chain_spec.clone(), [&genesis, &block1].into_iter());
let mut engine_rx = spawn_consensus_engine(consensus_engine);
@ -1686,7 +1688,7 @@ mod tests {
let invalid_rx = env.send_forkchoice_updated(next_forkchoice_state).await;
// Insert next head immediately after sending forkchoice update
insert_blocks(env.db.as_ref(), [&next_head].into_iter());
insert_blocks(env.db.as_ref(), chain_spec.clone(), [&next_head].into_iter());
let expected_result = ForkchoiceUpdated::from_status(PayloadStatusEnum::Syncing);
assert_matches!(invalid_rx, Ok(result) => assert_eq!(result, expected_result));
@ -1709,7 +1711,7 @@ mod tests {
.build(),
);
let (consensus_engine, env) = setup_consensus_engine(
chain_spec,
chain_spec.clone(),
VecDeque::from([Ok(ExecOutput {
done: true,
checkpoint: StageCheckpoint::new(0),
@ -1719,7 +1721,7 @@ mod tests {
let genesis = random_block(0, None, None, Some(0));
let block1 = random_block(1, Some(genesis.hash), None, Some(0));
insert_blocks(env.db.as_ref(), [&genesis, &block1].into_iter());
insert_blocks(env.db.as_ref(), chain_spec.clone(), [&genesis, &block1].into_iter());
let engine = spawn_consensus_engine(consensus_engine);
@ -1746,7 +1748,7 @@ mod tests {
.build(),
);
let (consensus_engine, env) = setup_consensus_engine(
chain_spec,
chain_spec.clone(),
VecDeque::from([
Ok(ExecOutput { done: true, checkpoint: StageCheckpoint::new(0) }),
Ok(ExecOutput { done: true, checkpoint: StageCheckpoint::new(0) }),
@ -1766,7 +1768,11 @@ mod tests {
let mut block3 = random_block(1, Some(genesis.hash), None, Some(0));
block3.header.difficulty = U256::from(1);
insert_blocks(env.db.as_ref(), [&genesis, &block1, &block2, &block3].into_iter());
insert_blocks(
env.db.as_ref(),
chain_spec.clone(),
[&genesis, &block1, &block2, &block3].into_iter(),
);
let _engine = spawn_consensus_engine(consensus_engine);
@ -1795,7 +1801,7 @@ mod tests {
.build(),
);
let (consensus_engine, env) = setup_consensus_engine(
chain_spec,
chain_spec.clone(),
VecDeque::from([
Ok(ExecOutput { done: true, checkpoint: StageCheckpoint::new(0) }),
Ok(ExecOutput { done: true, checkpoint: StageCheckpoint::new(0) }),
@ -1806,7 +1812,7 @@ mod tests {
let genesis = random_block(0, None, None, Some(0));
let block1 = random_block(1, Some(genesis.hash), None, Some(0));
insert_blocks(env.db.as_ref(), [&genesis, &block1].into_iter());
insert_blocks(env.db.as_ref(), chain_spec.clone(), [&genesis, &block1].into_iter());
let _engine = spawn_consensus_engine(consensus_engine);
@ -1842,7 +1848,7 @@ mod tests {
.build(),
);
let (consensus_engine, env) = setup_consensus_engine(
chain_spec,
chain_spec.clone(),
VecDeque::from([Ok(ExecOutput {
done: true,
checkpoint: StageCheckpoint::new(0),
@ -1875,7 +1881,7 @@ mod tests {
.build(),
);
let (consensus_engine, env) = setup_consensus_engine(
chain_spec,
chain_spec.clone(),
VecDeque::from([Ok(ExecOutput {
done: true,
checkpoint: StageCheckpoint::new(0),
@ -1886,7 +1892,11 @@ mod tests {
let genesis = random_block(0, None, None, Some(0));
let block1 = random_block(1, Some(genesis.hash), None, Some(0));
let block2 = random_block(2, Some(block1.hash), None, Some(0));
insert_blocks(env.db.as_ref(), [&genesis, &block1, &block2].into_iter());
insert_blocks(
env.db.as_ref(),
chain_spec.clone(),
[&genesis, &block1, &block2].into_iter(),
);
let mut engine_rx = spawn_consensus_engine(consensus_engine);
@ -1921,7 +1931,7 @@ mod tests {
.build(),
);
let (consensus_engine, env) = setup_consensus_engine(
chain_spec,
chain_spec.clone(),
VecDeque::from([Ok(ExecOutput {
done: true,
checkpoint: StageCheckpoint::new(0),
@ -1931,7 +1941,7 @@ mod tests {
let genesis = random_block(0, None, None, Some(0));
insert_blocks(env.db.as_ref(), [&genesis].into_iter());
insert_blocks(env.db.as_ref(), chain_spec.clone(), [&genesis].into_iter());
let mut engine_rx = spawn_consensus_engine(consensus_engine);
@ -1978,7 +1988,7 @@ mod tests {
.build(),
);
let (consensus_engine, env) = setup_consensus_engine(
chain_spec,
chain_spec.clone(),
VecDeque::from([Ok(ExecOutput {
done: true,
checkpoint: StageCheckpoint::new(0),
@ -1986,7 +1996,11 @@ mod tests {
Vec::from([exec_result2]),
);
insert_blocks(env.db.as_ref(), [&data.genesis, &block1].into_iter());
insert_blocks(
env.db.as_ref(),
chain_spec.clone(),
[&data.genesis, &block1].into_iter(),
);
let mut engine_rx = spawn_consensus_engine(consensus_engine);

View File

@ -19,6 +19,7 @@ reth-primitives = { path = "../../crates/primitives" }
reth-provider = { path = "../../crates/storage/provider", features = ["test-utils"] }
reth-net-nat = { path = "../../crates/net/nat" }
reth-stages = { path = "../stages" }
reth-interfaces = { path = "../interfaces" }
# io
serde = "1.0"

View File

@ -6,7 +6,7 @@ use reth_db::{
transaction::{DbTx, DbTxMut},
};
use reth_primitives::{stage::StageId, Account, Bytecode, ChainSpec, H256, U256};
use reth_provider::{PostState, Transaction, TransactionError};
use reth_provider::{DatabaseProviderRW, PostState, ShareableDatabase, TransactionError};
use std::{path::Path, sync::Arc};
use tracing::debug;
@ -39,6 +39,10 @@ pub enum InitDatabaseError {
/// Low-level database error.
#[error(transparent)]
DBError(#[from] reth_db::DatabaseError),
/// Internal error.
#[error(transparent)]
InternalError(#[from] reth_interfaces::Error),
}
/// Write the genesis block if it has not already been written
@ -66,11 +70,11 @@ pub fn init_genesis<DB: Database>(
drop(tx);
debug!("Writing genesis block.");
let tx = db.tx_mut()?;
// use transaction to insert genesis header
let transaction = Transaction::new_raw(&db, tx);
insert_genesis_hashes(transaction, genesis)?;
let shareable_db = ShareableDatabase::new(&db, chain.clone());
let provider_rw = shareable_db.provider_rw()?;
insert_genesis_hashes(provider_rw, genesis)?;
// Insert header
let tx = db.tx_mut()?;
@ -123,20 +127,21 @@ pub fn insert_genesis_state<DB: Database>(
/// Inserts hashes for the genesis state.
pub fn insert_genesis_hashes<DB: Database>(
mut transaction: Transaction<'_, DB>,
provider: DatabaseProviderRW<'_, &DB>,
genesis: &reth_primitives::Genesis,
) -> Result<(), InitDatabaseError> {
// insert and hash accounts to hashing table
let alloc_accounts =
genesis.alloc.clone().into_iter().map(|(addr, account)| (addr, Some(account.into())));
transaction.insert_account_for_hashing(alloc_accounts)?;
provider.insert_account_for_hashing(alloc_accounts)?;
let alloc_storage = genesis.alloc.clone().into_iter().filter_map(|(addr, account)| {
// only return Some if there is storage
account.storage.map(|storage| (addr, storage.into_iter().map(|(k, v)| (k, v.into()))))
});
transaction.insert_storage_for_hashing(alloc_storage)?;
transaction.commit()?;
provider.insert_storage_for_hashing(alloc_storage)?;
provider.commit()?;
Ok(())
}

View File

@ -5,7 +5,8 @@ use criterion::{
use pprof::criterion::{Output, PProfProfiler};
use reth_db::mdbx::{Env, WriteMap};
use reth_interfaces::test_utils::TestConsensus;
use reth_primitives::stage::StageCheckpoint;
use reth_primitives::{stage::StageCheckpoint, MAINNET};
use reth_provider::ShareableDatabase;
use reth_stages::{
stages::{MerkleStage, SenderRecoveryStage, TotalDifficultyStage, TransactionLookupStage},
test_utils::TestTransaction,
@ -135,9 +136,10 @@ fn measure_stage_with_path<F, S>(
},
|_| async {
let mut stage = stage.clone();
let mut db_tx = tx.inner();
stage.execute(&mut db_tx, input).await.unwrap();
db_tx.commit().unwrap();
let factory = ShareableDatabase::new(tx.tx.as_ref(), MAINNET.clone());
let mut provider = factory.provider_rw().unwrap();
stage.execute(&mut provider, input).await.unwrap();
provider.commit().unwrap();
},
)
});

View File

@ -63,8 +63,9 @@ fn generate_testdata_db(num_blocks: u64) -> (PathBuf, StageRange) {
std::fs::create_dir_all(&path).unwrap();
println!("Account Hashing testdata not found, generating to {:?}", path.display());
let tx = TestTransaction::new(&path);
let mut tx = tx.inner();
let _accounts = AccountHashingStage::seed(&mut tx, opts);
let mut provider = tx.inner();
let _accounts = AccountHashingStage::seed(&mut provider, opts);
provider.commit().expect("failed to commit");
}
(path, (ExecInput { target: Some(num_blocks), ..Default::default() }, UnwindInput::default()))
}

View File

@ -9,7 +9,8 @@ use reth_interfaces::test_utils::generators::{
random_block_range, random_contract_account_range, random_eoa_account_range,
random_transition_range,
};
use reth_primitives::{Account, Address, SealedBlock, H256};
use reth_primitives::{Account, Address, SealedBlock, H256, MAINNET};
use reth_provider::ShareableDatabase;
use reth_stages::{
stages::{AccountHashingStage, StorageHashingStage},
test_utils::TestTransaction,
@ -18,7 +19,6 @@ use reth_stages::{
use reth_trie::StateRoot;
use std::{
collections::BTreeMap,
ops::Deref,
path::{Path, PathBuf},
};
@ -38,11 +38,12 @@ pub(crate) fn stage_unwind<S: Clone + Stage<Env<WriteMap>>>(
tokio::runtime::Runtime::new().unwrap().block_on(async {
let mut stage = stage.clone();
let mut db_tx = tx.inner();
let factory = ShareableDatabase::new(tx.tx.as_ref(), MAINNET.clone());
let mut provider = factory.provider_rw().unwrap();
// Clear previous run
stage
.unwind(&mut db_tx, unwind)
.unwind(&mut provider, unwind)
.await
.map_err(|e| {
format!(
@ -52,7 +53,7 @@ pub(crate) fn stage_unwind<S: Clone + Stage<Env<WriteMap>>>(
})
.unwrap();
db_tx.commit().unwrap();
provider.commit().unwrap();
});
}
@ -65,18 +66,19 @@ pub(crate) fn unwind_hashes<S: Clone + Stage<Env<WriteMap>>>(
tokio::runtime::Runtime::new().unwrap().block_on(async {
let mut stage = stage.clone();
let mut db_tx = tx.inner();
let factory = ShareableDatabase::new(tx.tx.as_ref(), MAINNET.clone());
let mut provider = factory.provider_rw().unwrap();
StorageHashingStage::default().unwind(&mut db_tx, unwind).await.unwrap();
AccountHashingStage::default().unwind(&mut db_tx, unwind).await.unwrap();
StorageHashingStage::default().unwind(&mut provider, unwind).await.unwrap();
AccountHashingStage::default().unwind(&mut provider, unwind).await.unwrap();
// Clear previous run
stage.unwind(&mut db_tx, unwind).await.unwrap();
stage.unwind(&mut provider, unwind).await.unwrap();
AccountHashingStage::default().execute(&mut db_tx, input).await.unwrap();
StorageHashingStage::default().execute(&mut db_tx, input).await.unwrap();
AccountHashingStage::default().execute(&mut provider, input).await.unwrap();
StorageHashingStage::default().execute(&mut provider, input).await.unwrap();
db_tx.commit().unwrap();
provider.commit().unwrap();
});
}
@ -121,7 +123,7 @@ pub(crate) fn txs_testdata(num_blocks: u64) -> PathBuf {
tx.insert_accounts_and_storages(start_state.clone()).unwrap();
// make first block after genesis have valid state root
let (root, updates) = StateRoot::new(tx.inner().deref()).root_with_updates().unwrap();
let (root, updates) = StateRoot::new(tx.inner().tx_ref()).root_with_updates().unwrap();
let second_block = blocks.get_mut(1).unwrap();
let cloned_second = second_block.clone();
let mut updated_header = cloned_second.header.unseal();
@ -142,8 +144,8 @@ pub(crate) fn txs_testdata(num_blocks: u64) -> PathBuf {
// make last block have valid state root
let root = {
let mut tx_mut = tx.inner();
let root = StateRoot::new(tx_mut.deref()).root().unwrap();
let tx_mut = tx.inner();
let root = StateRoot::new(tx_mut.tx_ref()).root().unwrap();
tx_mut.commit().unwrap();
root
};

View File

@ -66,6 +66,9 @@ pub enum StageError {
/// rely on external downloaders
#[error("Invalid download response: {0}")]
Download(#[from] DownloadError),
/// Internal error
#[error(transparent)]
Internal(#[from] reth_interfaces::Error),
/// The stage encountered a recoverable error.
///
/// These types of errors are caught by the [Pipeline][crate::Pipeline] and trigger a restart
@ -104,6 +107,9 @@ pub enum PipelineError {
/// The pipeline encountered a database error.
#[error("A database error occurred.")]
Database(#[from] DbError),
/// The pipeline encountered an irrecoverable error in one of the stages.
#[error("An interface error occurred.")]
Interface(#[from] reth_interfaces::Error),
/// The pipeline encountered an error while trying to send an event.
#[error("The pipeline encountered an error while trying to send an event.")]
Channel(#[from] SendError<PipelineEvent>),

View File

@ -20,7 +20,7 @@
//!
//! ```
//! # use std::sync::Arc;
//! use reth_db::mdbx::test_utils::create_test_rw_db;
//! # use reth_db::mdbx::test_utils::create_test_rw_db;
//! # use reth_downloaders::bodies::bodies::BodiesDownloaderBuilder;
//! # use reth_downloaders::headers::reverse_headers::ReverseHeadersDownloaderBuilder;
//! # use reth_interfaces::consensus::Consensus;
@ -51,7 +51,7 @@
//! .add_stages(
//! DefaultStages::new(HeaderSyncMode::Tip(tip_rx), consensus, headers_downloader, bodies_downloader, factory)
//! )
//! .build(db);
//! .build(db, MAINNET.clone());
//! ```
mod error;
mod pipeline;

View File

@ -1,6 +1,8 @@
use std::sync::Arc;
use crate::{pipeline::BoxedStage, Pipeline, Stage, StageSet};
use reth_db::database::Database;
use reth_primitives::{stage::StageId, BlockNumber, H256};
use reth_primitives::{stage::StageId, BlockNumber, ChainSpec, H256};
use tokio::sync::watch;
/// Builds a [`Pipeline`].
@ -61,10 +63,11 @@ where
/// Builds the final [`Pipeline`] using the given database.
///
/// Note: it's expected that this is either an [Arc](std::sync::Arc) or an Arc wrapper type.
pub fn build(self, db: DB) -> Pipeline<DB> {
pub fn build(self, db: DB, chain_spec: Arc<ChainSpec>) -> Pipeline<DB> {
let Self { stages, max_block, tip_tx } = self;
Pipeline {
db,
chain_spec,
stages,
max_block,
tip_tx,

View File

@ -4,10 +4,10 @@ use reth_db::database::Database;
use reth_interfaces::executor::BlockExecutionError;
use reth_primitives::{
constants::BEACON_CONSENSUS_REORG_UNWIND_DEPTH, listener::EventListeners, stage::StageId,
BlockNumber, H256,
BlockNumber, ChainSpec, H256,
};
use reth_provider::{providers::get_stage_checkpoint, Transaction};
use std::pin::Pin;
use reth_provider::{providers::get_stage_checkpoint, ShareableDatabase};
use std::{pin::Pin, sync::Arc};
use tokio::sync::watch;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::*;
@ -93,6 +93,8 @@ pub type PipelineWithResult<DB> = (Pipeline<DB>, Result<ControlFlow, PipelineErr
pub struct Pipeline<DB: Database> {
/// The Database
db: DB,
/// Chain spec
chain_spec: Arc<ChainSpec>,
/// All configured stages in the order they will be executed.
stages: Vec<BoxedStage<DB>>,
/// The maximum block number to sync to.
@ -245,14 +247,15 @@ where
// Unwind stages in reverse order of execution
let unwind_pipeline = self.stages.iter_mut().rev();
let mut tx = Transaction::new(&self.db)?;
let shareable_db = ShareableDatabase::new(&self.db, self.chain_spec.clone());
let mut provider_rw = shareable_db.provider_rw().map_err(PipelineError::Interface)?;
for stage in unwind_pipeline {
let stage_id = stage.id();
let span = info_span!("Unwinding", stage = %stage_id);
let _enter = span.enter();
let mut checkpoint = tx.get_stage_checkpoint(stage_id)?.unwrap_or_default();
let mut checkpoint = provider_rw.get_stage_checkpoint(stage_id)?.unwrap_or_default();
if checkpoint.block_number < to {
debug!(target: "sync::pipeline", from = %checkpoint, %to, "Unwind point too far for stage");
self.listeners.notify(PipelineEvent::Skipped { stage_id });
@ -264,7 +267,7 @@ where
let input = UnwindInput { checkpoint, unwind_to: to, bad_block };
self.listeners.notify(PipelineEvent::Unwinding { stage_id, input });
let output = stage.unwind(&mut tx, input).await;
let output = stage.unwind(&mut provider_rw, input).await;
match output {
Ok(unwind_output) => {
checkpoint = unwind_output.checkpoint;
@ -282,12 +285,14 @@ where
// doesn't change when we unwind.
None,
);
tx.save_stage_checkpoint(stage_id, checkpoint)?;
provider_rw.save_stage_checkpoint(stage_id, checkpoint)?;
self.listeners
.notify(PipelineEvent::Unwound { stage_id, result: unwind_output });
tx.commit()?;
provider_rw.commit()?;
provider_rw =
shareable_db.provider_rw().map_err(PipelineError::Interface)?;
}
Err(err) => {
self.listeners.notify(PipelineEvent::Error { stage_id });
@ -312,10 +317,11 @@ where
let mut made_progress = false;
let target = self.max_block.or(previous_stage);
loop {
let mut tx = Transaction::new(&self.db)?;
let shareable_db = ShareableDatabase::new(&self.db, self.chain_spec.clone());
let mut provider_rw = shareable_db.provider_rw().map_err(PipelineError::Interface)?;
let prev_checkpoint = tx.get_stage_checkpoint(stage_id)?;
loop {
let prev_checkpoint = provider_rw.get_stage_checkpoint(stage_id)?;
let stage_reached_max_block = prev_checkpoint
.zip(self.max_block)
@ -343,7 +349,10 @@ where
checkpoint: prev_checkpoint,
});
match stage.execute(&mut tx, ExecInput { target, checkpoint: prev_checkpoint }).await {
match stage
.execute(&mut provider_rw, ExecInput { target, checkpoint: prev_checkpoint })
.await
{
Ok(out @ ExecOutput { checkpoint, done }) => {
made_progress |=
checkpoint.block_number != prev_checkpoint.unwrap_or_default().block_number;
@ -356,7 +365,7 @@ where
"Stage committed progress"
);
self.metrics.stage_checkpoint(stage_id, checkpoint, target);
tx.save_stage_checkpoint(stage_id, checkpoint)?;
provider_rw.save_stage_checkpoint(stage_id, checkpoint)?;
self.listeners.notify(PipelineEvent::Ran {
pipeline_position: stage_index + 1,
@ -366,7 +375,8 @@ where
});
// TODO: Make the commit interval configurable
tx.commit()?;
provider_rw.commit()?;
provider_rw = shareable_db.provider_rw().map_err(PipelineError::Interface)?;
if done {
let stage_progress = checkpoint.block_number;
@ -466,7 +476,7 @@ mod tests {
use reth_interfaces::{
consensus, provider::ProviderError, test_utils::generators::random_header,
};
use reth_primitives::stage::StageCheckpoint;
use reth_primitives::{stage::StageCheckpoint, MAINNET};
use tokio_stream::StreamExt;
#[test]
@ -511,7 +521,7 @@ mod tests {
.add_exec(Ok(ExecOutput { checkpoint: StageCheckpoint::new(10), done: true })),
)
.with_max_block(10)
.build(db);
.build(db, MAINNET.clone());
let events = pipeline.events();
// Run pipeline
@ -573,7 +583,7 @@ mod tests {
.add_unwind(Ok(UnwindOutput { checkpoint: StageCheckpoint::new(1) })),
)
.with_max_block(10)
.build(db);
.build(db, MAINNET.clone());
let events = pipeline.events();
// Run pipeline
@ -683,7 +693,7 @@ mod tests {
.add_exec(Ok(ExecOutput { checkpoint: StageCheckpoint::new(10), done: true })),
)
.with_max_block(10)
.build(db);
.build(db, MAINNET.clone());
let events = pipeline.events();
// Run pipeline
@ -776,7 +786,7 @@ mod tests {
.add_exec(Ok(ExecOutput { checkpoint: StageCheckpoint::new(10), done: true })),
)
.with_max_block(10)
.build(db);
.build(db, MAINNET.clone());
let events = pipeline.events();
// Run pipeline
@ -859,7 +869,7 @@ mod tests {
.add_exec(Ok(ExecOutput { checkpoint: StageCheckpoint::new(10), done: true })),
)
.with_max_block(10)
.build(db);
.build(db, MAINNET.clone());
let result = pipeline.run().await;
assert_matches!(result, Ok(()));
@ -869,7 +879,7 @@ mod tests {
.add_stage(TestStage::new(StageId::Other("Fatal")).add_exec(Err(
StageError::DatabaseIntegrity(ProviderError::BlockBodyIndicesNotFound(5)),
)))
.build(db);
.build(db, MAINNET.clone());
let result = pipeline.run().await;
assert_matches!(
result,

View File

@ -20,7 +20,7 @@
//! # let db = create_test_rw_db();
//! // Build a pipeline with all offline stages.
//! # let pipeline =
//! Pipeline::builder().add_stages(OfflineStages::new(factory)).build(db);
//! Pipeline::builder().add_stages(OfflineStages::new(factory)).build(db, MAINNET.clone());
//! ```
//!
//! ```ignore

View File

@ -5,7 +5,7 @@ use reth_primitives::{
stage::{StageCheckpoint, StageId},
BlockNumber, TxNumber,
};
use reth_provider::{ProviderError, Transaction};
use reth_provider::{DatabaseProviderRW, ProviderError};
use std::{
cmp::{max, min},
ops::RangeInclusive,
@ -75,11 +75,12 @@ impl ExecInput {
/// the number of transactions exceeds the threshold.
pub fn next_block_range_with_transaction_threshold<DB: Database>(
&self,
tx: &Transaction<'_, DB>,
provider: &DatabaseProviderRW<'_, DB>,
tx_threshold: u64,
) -> Result<(RangeInclusive<TxNumber>, RangeInclusive<BlockNumber>, bool), StageError> {
let start_block = self.next_block();
let start_block_body = tx
let start_block_body = provider
.tx_ref()
.get::<tables::BlockBodyIndices>(start_block)?
.ok_or(ProviderError::BlockBodyIndicesNotFound(start_block))?;
@ -88,7 +89,8 @@ impl ExecInput {
let first_tx_number = start_block_body.first_tx_num();
let mut last_tx_number = start_block_body.last_tx_num();
let mut end_block_number = start_block;
let mut body_indices_cursor = tx.cursor_read::<tables::BlockBodyIndices>()?;
let mut body_indices_cursor =
provider.tx_ref().cursor_read::<tables::BlockBodyIndices>()?;
for entry in body_indices_cursor.walk_range(start_block..=target_block)? {
let (block, body) = entry?;
last_tx_number = body.last_tx_num();
@ -171,8 +173,7 @@ pub struct UnwindOutput {
///
/// Stages are executed as part of a pipeline where they are executed serially.
///
/// Stages receive [`Transaction`] which manages the lifecycle of a transaction,
/// such as when to commit / reopen a new one etc.
/// Stages receive [`DatabaseProviderRW`].
#[async_trait]
pub trait Stage<DB: Database>: Send + Sync {
/// Get the ID of the stage.
@ -183,14 +184,14 @@ pub trait Stage<DB: Database>: Send + Sync {
/// Execute the stage.
async fn execute(
&mut self,
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, &DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError>;
/// Unwind the stage.
async fn unwind(
&mut self,
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, &DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError>;
}

View File

@ -13,8 +13,8 @@ use reth_interfaces::{
p2p::bodies::{downloader::BodyDownloader, response::BlockResponse},
};
use reth_primitives::stage::{EntitiesCheckpoint, StageCheckpoint, StageId};
use reth_provider::Transaction;
use std::{ops::Deref, sync::Arc};
use reth_provider::DatabaseProviderRW;
use std::sync::Arc;
use tracing::*;
// TODO(onbjerg): Metrics and events (gradual status for e.g. CLI)
@ -67,7 +67,7 @@ impl<DB: Database, D: BodyDownloader> Stage<DB> for BodyStage<D> {
/// header, limited by the stage's batch size.
async fn execute(
&mut self,
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, &DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
if input.target_reached() {
@ -80,6 +80,7 @@ impl<DB: Database, D: BodyDownloader> Stage<DB> for BodyStage<D> {
let (from_block, to_block) = range.into_inner();
// Cursors used to write bodies, ommers and transactions
let tx = provider.tx_ref();
let mut block_indices_cursor = tx.cursor_write::<tables::BlockBodyIndices>()?;
let mut tx_cursor = tx.cursor_write::<tables::Transactions>()?;
let mut tx_block_cursor = tx.cursor_write::<tables::TransactionBlock>()?;
@ -154,7 +155,7 @@ impl<DB: Database, D: BodyDownloader> Stage<DB> for BodyStage<D> {
let done = highest_block == to_block;
Ok(ExecOutput {
checkpoint: StageCheckpoint::new(highest_block)
.with_entities_stage_checkpoint(stage_checkpoint(tx)?),
.with_entities_stage_checkpoint(stage_checkpoint(provider)?),
done,
})
}
@ -162,9 +163,10 @@ impl<DB: Database, D: BodyDownloader> Stage<DB> for BodyStage<D> {
/// Unwind the stage.
async fn unwind(
&mut self,
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, &DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
let tx = provider.tx_ref();
// Cursors to unwind bodies, ommers
let mut body_cursor = tx.cursor_write::<tables::BlockBodyIndices>()?;
let mut transaction_cursor = tx.cursor_write::<tables::Transactions>()?;
@ -210,7 +212,7 @@ impl<DB: Database, D: BodyDownloader> Stage<DB> for BodyStage<D> {
Ok(UnwindOutput {
checkpoint: StageCheckpoint::new(input.unwind_to)
.with_entities_stage_checkpoint(stage_checkpoint(tx)?),
.with_entities_stage_checkpoint(stage_checkpoint(provider)?),
})
}
}
@ -219,11 +221,11 @@ impl<DB: Database, D: BodyDownloader> Stage<DB> for BodyStage<D> {
// beforehand how many bytes we need to download. So the good solution would be to measure the
// progress in gas as a proxy to size. Execution stage uses a similar approach.
fn stage_checkpoint<DB: Database>(
tx: &Transaction<'_, DB>,
provider: &DatabaseProviderRW<'_, DB>,
) -> Result<EntitiesCheckpoint, DatabaseError> {
Ok(EntitiesCheckpoint {
processed: tx.deref().entries::<tables::BlockBodyIndices>()? as u64,
total: tx.deref().entries::<tables::Headers>()? as u64,
processed: provider.tx_ref().entries::<tables::BlockBodyIndices>()? as u64,
total: provider.tx_ref().entries::<tables::Headers>()? as u64,
})
}

View File

@ -19,7 +19,8 @@ use reth_primitives::{
Block, BlockNumber, BlockWithSenders, Header, TransactionSigned, U256,
};
use reth_provider::{
post_state::PostState, BlockExecutor, ExecutorFactory, LatestStateProviderRef, Transaction,
post_state::PostState, BlockExecutor, BlockProvider, DatabaseProviderRW, ExecutorFactory,
HeaderProvider, LatestStateProviderRef, ProviderError, WithdrawalsProvider,
};
use std::{ops::RangeInclusive, time::Instant};
use tracing::*;
@ -83,22 +84,26 @@ impl<EF: ExecutorFactory> ExecutionStage<EF> {
Self::new(executor_factory, ExecutionStageThresholds::default())
}
// TODO: This should be in the block provider trait once we consolidate
// SharedDatabase/Transaction
// TODO(joshie): This should be in the block provider trait once we consolidate
fn read_block_with_senders<DB: Database>(
tx: &Transaction<'_, DB>,
provider: &DatabaseProviderRW<'_, &DB>,
block_number: BlockNumber,
) -> Result<(BlockWithSenders, U256), StageError> {
let header = tx.get_header(block_number)?;
let td = tx.get_td(block_number)?;
let ommers = tx.get::<tables::BlockOmmers>(block_number)?.unwrap_or_default().ommers;
let withdrawals = tx.get::<tables::BlockWithdrawals>(block_number)?.map(|v| v.withdrawals);
let header = provider
.header_by_number(block_number)?
.ok_or_else(|| ProviderError::HeaderNotFound(block_number.into()))?;
let td = provider
.header_td_by_number(block_number)?
.ok_or_else(|| ProviderError::HeaderNotFound(block_number.into()))?;
let ommers = provider.ommers(block_number.into())?.unwrap_or_default();
let withdrawals = provider.withdrawals_by_block(block_number.into(), header.timestamp)?;
// Get the block body
let body = tx.get::<tables::BlockBodyIndices>(block_number)?.unwrap();
let body = provider.block_body_indices(block_number)?;
let tx_range = body.tx_num_range();
// Get the transactions in the body
let tx = provider.tx_ref();
let (transactions, senders) = if tx_range.is_empty() {
(Vec::new(), Vec::new())
} else {
@ -135,7 +140,7 @@ impl<EF: ExecutorFactory> ExecutionStage<EF> {
/// Execute the stage.
pub fn execute_inner<DB: Database>(
&self,
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, &DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
if input.target_reached() {
@ -146,17 +151,18 @@ impl<EF: ExecutorFactory> ExecutionStage<EF> {
let max_block = input.target();
// Build executor
let mut executor = self.executor_factory.with_sp(LatestStateProviderRef::new(&**tx));
let mut executor =
self.executor_factory.with_sp(LatestStateProviderRef::new(provider.tx_ref()));
// Progress tracking
let mut stage_progress = start_block;
let mut stage_checkpoint =
execution_checkpoint(tx, start_block, max_block, input.checkpoint())?;
execution_checkpoint(provider, start_block, max_block, input.checkpoint())?;
// Execute block range
let mut state = PostState::default();
for block_number in start_block..=max_block {
let (block, td) = Self::read_block_with_senders(tx, block_number)?;
let (block, td) = Self::read_block_with_senders(provider, block_number)?;
// Configure the executor to use the current state.
trace!(target: "sync::stages::execution", number = block_number, txs = block.body.len(), "Executing block");
@ -190,7 +196,7 @@ impl<EF: ExecutorFactory> ExecutionStage<EF> {
// Write remaining changes
trace!(target: "sync::stages::execution", accounts = state.accounts().len(), "Writing updated state to database");
let start = Instant::now();
state.write_to_db(&**tx)?;
state.write_to_db(provider.tx_ref())?;
trace!(target: "sync::stages::execution", took = ?start.elapsed(), "Wrote state");
let done = stage_progress == max_block;
@ -203,7 +209,7 @@ impl<EF: ExecutorFactory> ExecutionStage<EF> {
}
fn execution_checkpoint<DB: Database>(
tx: &Transaction<'_, DB>,
provider: &DatabaseProviderRW<'_, &DB>,
start_block: BlockNumber,
max_block: BlockNumber,
checkpoint: StageCheckpoint,
@ -225,7 +231,7 @@ fn execution_checkpoint<DB: Database>(
block_range: CheckpointBlockRange { from: start_block, to: max_block },
progress: EntitiesCheckpoint {
processed,
total: total + calculate_gas_used_from_headers(tx, start_block..=max_block)?,
total: total + calculate_gas_used_from_headers(provider, start_block..=max_block)?,
},
},
// If checkpoint block range ends on the same block as our range, we take the previously
@ -242,7 +248,7 @@ fn execution_checkpoint<DB: Database>(
// to be processed not including the checkpoint range.
Some(ExecutionCheckpoint { progress: EntitiesCheckpoint { processed, .. }, .. }) => {
let after_checkpoint_block_number =
calculate_gas_used_from_headers(tx, checkpoint.block_number + 1..=max_block)?;
calculate_gas_used_from_headers(provider, checkpoint.block_number + 1..=max_block)?;
ExecutionCheckpoint {
block_range: CheckpointBlockRange { from: start_block, to: max_block },
@ -255,14 +261,14 @@ fn execution_checkpoint<DB: Database>(
// Otherwise, we recalculate the whole stage checkpoint including the amount of gas
// already processed, if there's any.
_ => {
let processed = calculate_gas_used_from_headers(tx, 0..=start_block - 1)?;
let processed = calculate_gas_used_from_headers(provider, 0..=start_block - 1)?;
ExecutionCheckpoint {
block_range: CheckpointBlockRange { from: start_block, to: max_block },
progress: EntitiesCheckpoint {
processed,
total: processed +
calculate_gas_used_from_headers(tx, start_block..=max_block)?,
calculate_gas_used_from_headers(provider, start_block..=max_block)?,
},
}
}
@ -270,13 +276,13 @@ fn execution_checkpoint<DB: Database>(
}
fn calculate_gas_used_from_headers<DB: Database>(
tx: &Transaction<'_, DB>,
provider: &DatabaseProviderRW<'_, &DB>,
range: RangeInclusive<BlockNumber>,
) -> Result<u64, DatabaseError> {
let mut gas_total = 0;
let start = Instant::now();
for entry in tx.cursor_read::<tables::Headers>()?.walk_range(range.clone())? {
for entry in provider.tx_ref().cursor_read::<tables::Headers>()?.walk_range(range.clone())? {
let (_, Header { gas_used, .. }) = entry?;
gas_total += gas_used;
}
@ -304,7 +310,7 @@ impl<EF: ExecutorFactory, DB: Database> Stage<DB> for ExecutionStage<EF> {
/// Execute the stage
async fn execute(
&mut self,
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, &DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
// For Ethereum transactions that reaches the max call depth (1024) revm can use more stack
@ -321,7 +327,7 @@ impl<EF: ExecutorFactory, DB: Database> Stage<DB> for ExecutionStage<EF> {
.stack_size(BIG_STACK_SIZE)
.spawn_scoped(scope, || {
// execute and store output to results
self.execute_inner(tx, input)
self.execute_inner(provider, input)
})
.expect("Expects that thread name is not null");
handle.join().expect("Expects for thread to not panic")
@ -331,9 +337,10 @@ impl<EF: ExecutorFactory, DB: Database> Stage<DB> for ExecutionStage<EF> {
/// Unwind the stage.
async fn unwind(
&mut self,
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, &DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
let tx = provider.tx_ref();
// Acquire changeset cursors
let mut account_changeset = tx.cursor_dup_write::<tables::AccountChangeSet>()?;
let mut storage_changeset = tx.cursor_dup_write::<tables::StorageChangeSet>()?;
@ -382,7 +389,7 @@ impl<EF: ExecutorFactory, DB: Database> Stage<DB> for ExecutionStage<EF> {
}
// Discard unwinded changesets
tx.unwind_table_by_num::<tables::AccountChangeSet>(unwind_to)?;
provider.unwind_table_by_num::<tables::AccountChangeSet>(unwind_to)?;
let mut rev_storage_changeset_walker = storage_changeset.walk_back(None)?;
while let Some((key, _)) = rev_storage_changeset_walker.next().transpose()? {
@ -394,7 +401,7 @@ impl<EF: ExecutorFactory, DB: Database> Stage<DB> for ExecutionStage<EF> {
}
// Look up the start index for the transaction range
let first_tx_num = tx.block_body_indices(*range.start())?.first_tx_num();
let first_tx_num = provider.block_body_indices(*range.start())?.first_tx_num();
let mut stage_checkpoint = input.checkpoint.execution_stage_checkpoint();
@ -461,15 +468,12 @@ mod tests {
};
use reth_primitives::{
hex_literal::hex, keccak256, stage::StageUnitCheckpoint, Account, Bytecode,
ChainSpecBuilder, SealedBlock, StorageEntry, H160, H256, U256,
ChainSpecBuilder, SealedBlock, StorageEntry, H160, H256, MAINNET, U256,
};
use reth_provider::insert_canonical_block;
use reth_provider::{insert_canonical_block, ShareableDatabase};
use reth_revm::Factory;
use reth_rlp::Decodable;
use std::{
ops::{Deref, DerefMut},
sync::Arc,
};
use std::sync::Arc;
fn stage() -> ExecutionStage<Factory> {
let factory =
@ -483,7 +487,8 @@ mod tests {
#[test]
fn execution_checkpoint_matches() {
let state_db = create_test_db::<WriteMap>(EnvKind::RW);
let tx = Transaction::new(state_db.as_ref()).unwrap();
let db = ShareableDatabase::new(state_db.as_ref(), MAINNET.clone());
let tx = db.provider_rw().unwrap();
let previous_stage_checkpoint = ExecutionCheckpoint {
block_range: CheckpointBlockRange { from: 0, to: 0 },
@ -507,15 +512,16 @@ mod tests {
#[test]
fn execution_checkpoint_precedes() {
let state_db = create_test_db::<WriteMap>(EnvKind::RW);
let mut tx = Transaction::new(state_db.as_ref()).unwrap();
let db = ShareableDatabase::new(state_db.as_ref(), MAINNET.clone());
let mut provider = db.provider_rw().unwrap();
let mut genesis_rlp = hex!("f901faf901f5a00000000000000000000000000000000000000000000000000000000000000000a01dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347942adc25665018aa1fe0e6bc666dac8fc2697ff9baa045571b40ae66ca7480791bbb2887286e4e4c4b1b298b191c889d6959023a32eda056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421a056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421b901000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000083020000808502540be400808000a00000000000000000000000000000000000000000000000000000000000000000880000000000000000c0c0").as_slice();
let genesis = SealedBlock::decode(&mut genesis_rlp).unwrap();
let mut block_rlp = hex!("f90262f901f9a075c371ba45999d87f4542326910a11af515897aebce5265d3f6acd1f1161f82fa01dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347942adc25665018aa1fe0e6bc666dac8fc2697ff9baa098f2dcd87c8ae4083e7017a05456c14eea4b1db2032126e27b3b1563d57d7cc0a08151d548273f6683169524b66ca9fe338b9ce42bc3540046c828fd939ae23bcba03f4e5c2ec5b2170b711d97ee755c160457bb58d8daa338e835ec02ae6860bbabb901000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000083020000018502540be40082a8798203e800a00000000000000000000000000000000000000000000000000000000000000000880000000000000000f863f861800a8405f5e10094100000000000000000000000000000000000000080801ba07e09e26678ed4fac08a249ebe8ed680bf9051a5e14ad223e4b2b9d26e0208f37a05f6e3f188e3e6eab7d7d3b6568f5eac7d687b08d307d3154ccd8c87b4630509bc0").as_slice();
let block = SealedBlock::decode(&mut block_rlp).unwrap();
insert_canonical_block(tx.deref_mut(), genesis, None).unwrap();
insert_canonical_block(tx.deref_mut(), block.clone(), None).unwrap();
tx.commit().unwrap();
insert_canonical_block(provider.tx_mut(), genesis, None).unwrap();
insert_canonical_block(provider.tx_mut(), block.clone(), None).unwrap();
provider.commit().unwrap();
let previous_stage_checkpoint = ExecutionCheckpoint {
block_range: CheckpointBlockRange { from: 0, to: 0 },
@ -526,7 +532,8 @@ mod tests {
stage_checkpoint: Some(StageUnitCheckpoint::Execution(previous_stage_checkpoint)),
};
let stage_checkpoint = execution_checkpoint(&tx, 1, 1, previous_checkpoint);
let provider = db.provider_rw().unwrap();
let stage_checkpoint = execution_checkpoint(&provider, 1, 1, previous_checkpoint);
assert_matches!(stage_checkpoint, Ok(ExecutionCheckpoint {
block_range: CheckpointBlockRange { from: 1, to: 1 },
@ -541,15 +548,16 @@ mod tests {
#[test]
fn execution_checkpoint_recalculate_full_previous_some() {
let state_db = create_test_db::<WriteMap>(EnvKind::RW);
let mut tx = Transaction::new(state_db.as_ref()).unwrap();
let db = ShareableDatabase::new(state_db.as_ref(), MAINNET.clone());
let mut provider = db.provider_rw().unwrap();
let mut genesis_rlp = hex!("f901faf901f5a00000000000000000000000000000000000000000000000000000000000000000a01dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347942adc25665018aa1fe0e6bc666dac8fc2697ff9baa045571b40ae66ca7480791bbb2887286e4e4c4b1b298b191c889d6959023a32eda056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421a056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421b901000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000083020000808502540be400808000a00000000000000000000000000000000000000000000000000000000000000000880000000000000000c0c0").as_slice();
let genesis = SealedBlock::decode(&mut genesis_rlp).unwrap();
let mut block_rlp = hex!("f90262f901f9a075c371ba45999d87f4542326910a11af515897aebce5265d3f6acd1f1161f82fa01dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347942adc25665018aa1fe0e6bc666dac8fc2697ff9baa098f2dcd87c8ae4083e7017a05456c14eea4b1db2032126e27b3b1563d57d7cc0a08151d548273f6683169524b66ca9fe338b9ce42bc3540046c828fd939ae23bcba03f4e5c2ec5b2170b711d97ee755c160457bb58d8daa338e835ec02ae6860bbabb901000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000083020000018502540be40082a8798203e800a00000000000000000000000000000000000000000000000000000000000000000880000000000000000f863f861800a8405f5e10094100000000000000000000000000000000000000080801ba07e09e26678ed4fac08a249ebe8ed680bf9051a5e14ad223e4b2b9d26e0208f37a05f6e3f188e3e6eab7d7d3b6568f5eac7d687b08d307d3154ccd8c87b4630509bc0").as_slice();
let block = SealedBlock::decode(&mut block_rlp).unwrap();
insert_canonical_block(tx.deref_mut(), genesis, None).unwrap();
insert_canonical_block(tx.deref_mut(), block.clone(), None).unwrap();
tx.commit().unwrap();
insert_canonical_block(provider.tx_mut(), genesis, None).unwrap();
insert_canonical_block(provider.tx_mut(), block.clone(), None).unwrap();
provider.commit().unwrap();
let previous_stage_checkpoint = ExecutionCheckpoint {
block_range: CheckpointBlockRange { from: 0, to: 0 },
@ -560,7 +568,8 @@ mod tests {
stage_checkpoint: Some(StageUnitCheckpoint::Execution(previous_stage_checkpoint)),
};
let stage_checkpoint = execution_checkpoint(&tx, 1, 1, previous_checkpoint);
let provider = db.provider_rw().unwrap();
let stage_checkpoint = execution_checkpoint(&provider, 1, 1, previous_checkpoint);
assert_matches!(stage_checkpoint, Ok(ExecutionCheckpoint {
block_range: CheckpointBlockRange { from: 1, to: 1 },
@ -575,19 +584,21 @@ mod tests {
#[test]
fn execution_checkpoint_recalculate_full_previous_none() {
let state_db = create_test_db::<WriteMap>(EnvKind::RW);
let mut tx = Transaction::new(state_db.as_ref()).unwrap();
let db = ShareableDatabase::new(state_db.as_ref(), MAINNET.clone());
let mut provider = db.provider_rw().unwrap();
let mut genesis_rlp = hex!("f901faf901f5a00000000000000000000000000000000000000000000000000000000000000000a01dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347942adc25665018aa1fe0e6bc666dac8fc2697ff9baa045571b40ae66ca7480791bbb2887286e4e4c4b1b298b191c889d6959023a32eda056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421a056e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421b901000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000083020000808502540be400808000a00000000000000000000000000000000000000000000000000000000000000000880000000000000000c0c0").as_slice();
let genesis = SealedBlock::decode(&mut genesis_rlp).unwrap();
let mut block_rlp = hex!("f90262f901f9a075c371ba45999d87f4542326910a11af515897aebce5265d3f6acd1f1161f82fa01dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347942adc25665018aa1fe0e6bc666dac8fc2697ff9baa098f2dcd87c8ae4083e7017a05456c14eea4b1db2032126e27b3b1563d57d7cc0a08151d548273f6683169524b66ca9fe338b9ce42bc3540046c828fd939ae23bcba03f4e5c2ec5b2170b711d97ee755c160457bb58d8daa338e835ec02ae6860bbabb901000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000083020000018502540be40082a8798203e800a00000000000000000000000000000000000000000000000000000000000000000880000000000000000f863f861800a8405f5e10094100000000000000000000000000000000000000080801ba07e09e26678ed4fac08a249ebe8ed680bf9051a5e14ad223e4b2b9d26e0208f37a05f6e3f188e3e6eab7d7d3b6568f5eac7d687b08d307d3154ccd8c87b4630509bc0").as_slice();
let block = SealedBlock::decode(&mut block_rlp).unwrap();
insert_canonical_block(tx.deref_mut(), genesis, None).unwrap();
insert_canonical_block(tx.deref_mut(), block.clone(), None).unwrap();
tx.commit().unwrap();
insert_canonical_block(provider.tx_mut(), genesis, None).unwrap();
insert_canonical_block(provider.tx_mut(), block.clone(), None).unwrap();
provider.commit().unwrap();
let previous_checkpoint = StageCheckpoint { block_number: 1, stage_checkpoint: None };
let stage_checkpoint = execution_checkpoint(&tx, 1, 1, previous_checkpoint);
let provider = db.provider_rw().unwrap();
let stage_checkpoint = execution_checkpoint(&provider, 1, 1, previous_checkpoint);
assert_matches!(stage_checkpoint, Ok(ExecutionCheckpoint {
block_range: CheckpointBlockRange { from: 1, to: 1 },
@ -603,7 +614,8 @@ mod tests {
// TODO cleanup the setup after https://github.com/paradigmxyz/reth/issues/332
// is merged as it has similar framework
let state_db = create_test_db::<WriteMap>(EnvKind::RW);
let mut tx = Transaction::new(state_db.as_ref()).unwrap();
let db = ShareableDatabase::new(state_db.as_ref(), MAINNET.clone());
let mut provider = db.provider_rw().unwrap();
let input = ExecInput {
target: Some(1),
/// The progress of this stage the last time it was executed.
@ -613,12 +625,13 @@ mod tests {
let genesis = SealedBlock::decode(&mut genesis_rlp).unwrap();
let mut block_rlp = hex!("f90262f901f9a075c371ba45999d87f4542326910a11af515897aebce5265d3f6acd1f1161f82fa01dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347942adc25665018aa1fe0e6bc666dac8fc2697ff9baa098f2dcd87c8ae4083e7017a05456c14eea4b1db2032126e27b3b1563d57d7cc0a08151d548273f6683169524b66ca9fe338b9ce42bc3540046c828fd939ae23bcba03f4e5c2ec5b2170b711d97ee755c160457bb58d8daa338e835ec02ae6860bbabb901000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000083020000018502540be40082a8798203e800a00000000000000000000000000000000000000000000000000000000000000000880000000000000000f863f861800a8405f5e10094100000000000000000000000000000000000000080801ba07e09e26678ed4fac08a249ebe8ed680bf9051a5e14ad223e4b2b9d26e0208f37a05f6e3f188e3e6eab7d7d3b6568f5eac7d687b08d307d3154ccd8c87b4630509bc0").as_slice();
let block = SealedBlock::decode(&mut block_rlp).unwrap();
insert_canonical_block(tx.deref_mut(), genesis, None).unwrap();
insert_canonical_block(tx.deref_mut(), block.clone(), None).unwrap();
tx.commit().unwrap();
insert_canonical_block(provider.tx_mut(), genesis, None).unwrap();
insert_canonical_block(provider.tx_mut(), block.clone(), None).unwrap();
provider.commit().unwrap();
// insert pre state
let db_tx = tx.deref_mut();
let mut provider = db.provider_rw().unwrap();
let db_tx = provider.tx_mut();
let acc1 = H160(hex!("1000000000000000000000000000000000000000"));
let acc2 = H160(hex!("a94f5374fce5edbc8e2a8697c15331677e6ebf0b"));
let code = hex!("5a465a905090036002900360015500");
@ -637,11 +650,12 @@ mod tests {
)
.unwrap();
db_tx.put::<tables::Bytecodes>(code_hash, Bytecode::new_raw(code.to_vec().into())).unwrap();
tx.commit().unwrap();
provider.commit().unwrap();
let mut provider = db.provider_rw().unwrap();
let mut execution_stage = stage();
let output = execution_stage.execute(&mut tx, input).await.unwrap();
tx.commit().unwrap();
let output = execution_stage.execute(&mut provider, input).await.unwrap();
provider.commit().unwrap();
assert_matches!(output, ExecOutput {
checkpoint: StageCheckpoint {
block_number: 1,
@ -658,7 +672,8 @@ mod tests {
},
done: true
} if processed == total && total == block.gas_used);
let tx = tx.deref_mut();
let mut provider = db.provider_rw().unwrap();
let tx = provider.tx_mut();
// check post state
let account1 = H160(hex!("1000000000000000000000000000000000000000"));
let account1_info =
@ -707,7 +722,8 @@ mod tests {
// is merged as it has similar framework
let state_db = create_test_db::<WriteMap>(EnvKind::RW);
let mut tx = Transaction::new(state_db.as_ref()).unwrap();
let db = ShareableDatabase::new(state_db.as_ref(), MAINNET.clone());
let mut provider = db.provider_rw().unwrap();
let input = ExecInput {
target: Some(1),
/// The progress of this stage the last time it was executed.
@ -717,16 +733,17 @@ mod tests {
let genesis = SealedBlock::decode(&mut genesis_rlp).unwrap();
let mut block_rlp = hex!("f90262f901f9a075c371ba45999d87f4542326910a11af515897aebce5265d3f6acd1f1161f82fa01dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347942adc25665018aa1fe0e6bc666dac8fc2697ff9baa098f2dcd87c8ae4083e7017a05456c14eea4b1db2032126e27b3b1563d57d7cc0a08151d548273f6683169524b66ca9fe338b9ce42bc3540046c828fd939ae23bcba03f4e5c2ec5b2170b711d97ee755c160457bb58d8daa338e835ec02ae6860bbabb901000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000083020000018502540be40082a8798203e800a00000000000000000000000000000000000000000000000000000000000000000880000000000000000f863f861800a8405f5e10094100000000000000000000000000000000000000080801ba07e09e26678ed4fac08a249ebe8ed680bf9051a5e14ad223e4b2b9d26e0208f37a05f6e3f188e3e6eab7d7d3b6568f5eac7d687b08d307d3154ccd8c87b4630509bc0").as_slice();
let block = SealedBlock::decode(&mut block_rlp).unwrap();
insert_canonical_block(tx.deref_mut(), genesis, None).unwrap();
insert_canonical_block(tx.deref_mut(), block.clone(), None).unwrap();
tx.commit().unwrap();
insert_canonical_block(provider.tx_mut(), genesis, None).unwrap();
insert_canonical_block(provider.tx_mut(), block.clone(), None).unwrap();
provider.commit().unwrap();
// variables
let code = hex!("5a465a905090036002900360015500");
let balance = U256::from(0x3635c9adc5dea00000u128);
let code_hash = keccak256(code);
// pre state
let db_tx = tx.deref_mut();
let mut provider = db.provider_rw().unwrap();
let db_tx = provider.tx_mut();
let acc1 = H160(hex!("1000000000000000000000000000000000000000"));
let acc1_info = Account { nonce: 0, balance: U256::ZERO, bytecode_hash: Some(code_hash) };
let acc2 = H160(hex!("a94f5374fce5edbc8e2a8697c15331677e6ebf0b"));
@ -735,17 +752,19 @@ mod tests {
db_tx.put::<tables::PlainAccountState>(acc1, acc1_info).unwrap();
db_tx.put::<tables::PlainAccountState>(acc2, acc2_info).unwrap();
db_tx.put::<tables::Bytecodes>(code_hash, Bytecode::new_raw(code.to_vec().into())).unwrap();
tx.commit().unwrap();
provider.commit().unwrap();
// execute
let mut provider = db.provider_rw().unwrap();
let mut execution_stage = stage();
let result = execution_stage.execute(&mut tx, input).await.unwrap();
tx.commit().unwrap();
let result = execution_stage.execute(&mut provider, input).await.unwrap();
provider.commit().unwrap();
let mut provider = db.provider_rw().unwrap();
let mut stage = stage();
let result = stage
.unwind(
&mut tx,
&mut provider,
UnwindInput { checkpoint: result.checkpoint, unwind_to: 0, bad_block: None },
)
.await
@ -768,7 +787,7 @@ mod tests {
} if total == block.gas_used);
// assert unwind stage
let db_tx = tx.deref();
let db_tx = provider.tx_ref();
assert_eq!(
db_tx.get::<tables::PlainAccountState>(acc1),
Ok(Some(acc1_info)),
@ -793,7 +812,8 @@ mod tests {
#[tokio::test]
async fn test_selfdestruct() {
let test_tx = TestTransaction::default();
let mut tx = test_tx.inner();
let factory = ShareableDatabase::new(test_tx.tx.as_ref(), MAINNET.clone());
let mut provider = factory.provider_rw().unwrap();
let input = ExecInput {
target: Some(1),
/// The progress of this stage the last time it was executed.
@ -803,9 +823,9 @@ mod tests {
let genesis = SealedBlock::decode(&mut genesis_rlp).unwrap();
let mut block_rlp = hex!("f9025ff901f7a0c86e8cc0310ae7c531c758678ddbfd16fc51c8cef8cec650b032de9869e8b94fa01dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347942adc25665018aa1fe0e6bc666dac8fc2697ff9baa050554882fbbda2c2fd93fdc466db9946ea262a67f7a76cc169e714f105ab583da00967f09ef1dfed20c0eacfaa94d5cd4002eda3242ac47eae68972d07b106d192a0e3c8b47fbfc94667ef4cceb17e5cc21e3b1eebd442cebb27f07562b33836290db90100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000008302000001830f42408238108203e800a00000000000000000000000000000000000000000000000000000000000000000880000000000000000f862f860800a83061a8094095e7baea6a6c7c4c2dfeb977efac326af552d8780801ba072ed817487b84ba367d15d2f039b5fc5f087d0a8882fbdf73e8cb49357e1ce30a0403d800545b8fc544f92ce8124e2255f8c3c6af93f28243a120585d4c4c6a2a3c0").as_slice();
let block = SealedBlock::decode(&mut block_rlp).unwrap();
insert_canonical_block(tx.deref_mut(), genesis, None).unwrap();
insert_canonical_block(tx.deref_mut(), block.clone(), None).unwrap();
tx.commit().unwrap();
insert_canonical_block(provider.tx_mut(), genesis, None).unwrap();
insert_canonical_block(provider.tx_mut(), block.clone(), None).unwrap();
provider.commit().unwrap();
// variables
let caller_address = H160(hex!("a94f5374fce5edbc8e2a8697c15331677e6ebf0b"));
@ -817,50 +837,60 @@ mod tests {
let code_hash = keccak256(code);
// pre state
let db_tx = tx.deref_mut();
let caller_info = Account { nonce: 0, balance, bytecode_hash: None };
let destroyed_info =
Account { nonce: 0, balance: U256::ZERO, bytecode_hash: Some(code_hash) };
// set account
db_tx.put::<tables::PlainAccountState>(caller_address, caller_info).unwrap();
db_tx.put::<tables::PlainAccountState>(destroyed_address, destroyed_info).unwrap();
db_tx.put::<tables::Bytecodes>(code_hash, Bytecode::new_raw(code.to_vec().into())).unwrap();
let provider = factory.provider_rw().unwrap();
provider.tx_ref().put::<tables::PlainAccountState>(caller_address, caller_info).unwrap();
provider
.tx_ref()
.put::<tables::PlainAccountState>(destroyed_address, destroyed_info)
.unwrap();
provider
.tx_ref()
.put::<tables::Bytecodes>(code_hash, Bytecode::new_raw(code.to_vec().into()))
.unwrap();
// set storage to check when account gets destroyed.
db_tx
provider
.tx_ref()
.put::<tables::PlainStorageState>(
destroyed_address,
StorageEntry { key: H256::zero(), value: U256::ZERO },
)
.unwrap();
db_tx
provider
.tx_ref()
.put::<tables::PlainStorageState>(
destroyed_address,
StorageEntry { key: H256::from_low_u64_be(1), value: U256::from(1u64) },
)
.unwrap();
tx.commit().unwrap();
provider.commit().unwrap();
// execute
let mut provider = factory.provider_rw().unwrap();
let mut execution_stage = stage();
let _ = execution_stage.execute(&mut tx, input).await.unwrap();
tx.commit().unwrap();
let _ = execution_stage.execute(&mut provider, input).await.unwrap();
provider.commit().unwrap();
// assert unwind stage
let provider = factory.provider_rw().unwrap();
assert_eq!(
tx.deref().get::<tables::PlainAccountState>(destroyed_address),
provider.tx_ref().get::<tables::PlainAccountState>(destroyed_address),
Ok(None),
"Account was destroyed"
);
assert_eq!(
tx.deref().get::<tables::PlainStorageState>(destroyed_address),
provider.tx_ref().get::<tables::PlainStorageState>(destroyed_address),
Ok(None),
"There is storage for destroyed account"
);
// drops tx so that it returns write privilege to test_tx
drop(tx);
drop(provider);
let plain_accounts = test_tx.table::<tables::PlainAccountState>().unwrap();
let plain_storage = test_tx.table::<tables::PlainStorageState>().unwrap();

View File

@ -1,7 +1,7 @@
use crate::{ExecInput, ExecOutput, Stage, StageError, UnwindInput, UnwindOutput};
use reth_db::database::Database;
use reth_primitives::stage::{StageCheckpoint, StageId};
use reth_provider::Transaction;
use reth_provider::DatabaseProviderRW;
/// The finish stage.
///
@ -18,7 +18,7 @@ impl<DB: Database> Stage<DB> for FinishStage {
async fn execute(
&mut self,
_tx: &mut Transaction<'_, DB>,
_provider: &mut DatabaseProviderRW<'_, &DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
Ok(ExecOutput { checkpoint: StageCheckpoint::new(input.target()), done: true })
@ -26,7 +26,7 @@ impl<DB: Database> Stage<DB> for FinishStage {
async fn unwind(
&mut self,
_tx: &mut Transaction<'_, DB>,
_provider: &mut DatabaseProviderRW<'_, &DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
Ok(UnwindOutput { checkpoint: StageCheckpoint::new(input.unwind_to) })

View File

@ -16,11 +16,11 @@ use reth_primitives::{
StageId,
},
};
use reth_provider::Transaction;
use reth_provider::{AccountExtProvider, DatabaseProviderRW};
use std::{
cmp::max,
fmt::Debug,
ops::{Deref, Range, RangeInclusive},
ops::{Range, RangeInclusive},
};
use tokio::sync::mpsc;
use tracing::*;
@ -79,7 +79,7 @@ impl AccountHashingStage {
/// Proceeds to go to the `BlockTransitionIndex` end, go back `transitions` and change the
/// account state in the `AccountChangeSet` table.
pub fn seed<DB: Database>(
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, DB>,
opts: SeedOpts,
) -> Result<Vec<(reth_primitives::Address, reth_primitives::Account)>, StageError> {
use reth_db::models::AccountBeforeTx;
@ -92,18 +92,20 @@ impl AccountHashingStage {
let blocks = random_block_range(opts.blocks.clone(), H256::zero(), opts.txs);
for block in blocks {
insert_canonical_block(&**tx, block, None).unwrap();
insert_canonical_block(provider.tx_ref(), block, None).unwrap();
}
let mut accounts = random_eoa_account_range(opts.accounts);
{
// Account State generator
let mut account_cursor = tx.cursor_write::<tables::PlainAccountState>()?;
let mut account_cursor =
provider.tx_ref().cursor_write::<tables::PlainAccountState>()?;
accounts.sort_by(|a, b| a.0.cmp(&b.0));
for (addr, acc) in accounts.iter() {
account_cursor.append(*addr, *acc)?;
}
let mut acc_changeset_cursor = tx.cursor_write::<tables::AccountChangeSet>()?;
let mut acc_changeset_cursor =
provider.tx_ref().cursor_write::<tables::AccountChangeSet>()?;
for (t, (addr, acc)) in (opts.blocks).zip(&accounts) {
let Account { nonce, balance, .. } = acc;
let prev_acc = Account {
@ -116,8 +118,6 @@ impl AccountHashingStage {
}
}
tx.commit()?;
Ok(accounts)
}
}
@ -132,7 +132,7 @@ impl<DB: Database> Stage<DB> for AccountHashingStage {
/// Execute the stage.
async fn execute(
&mut self,
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, &DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
if input.target_reached() {
@ -146,6 +146,7 @@ impl<DB: Database> Stage<DB> for AccountHashingStage {
// AccountHashing table. Also, if we start from genesis, we need to hash from scratch, as
// genesis accounts are not in changeset.
if to_block - from_block > self.clean_threshold || from_block == 1 {
let tx = provider.tx_ref();
let stage_checkpoint = input
.checkpoint
.and_then(|checkpoint| checkpoint.account_hashing_stage_checkpoint());
@ -231,7 +232,7 @@ impl<DB: Database> Stage<DB> for AccountHashingStage {
AccountHashingCheckpoint {
address: Some(next_address.key().unwrap()),
block_range: CheckpointBlockRange { from: from_block, to: to_block },
progress: stage_checkpoint_progress(tx)?,
progress: stage_checkpoint_progress(provider)?,
},
);
@ -240,20 +241,20 @@ impl<DB: Database> Stage<DB> for AccountHashingStage {
} else {
// Aggregate all transition changesets and make a list of accounts that have been
// changed.
let lists = tx.get_addresses_of_changed_accounts(from_block..=to_block)?;
let lists = provider.changed_accounts_with_range(from_block..=to_block)?;
// Iterate over plain state and get newest value.
// Assumption we are okay to make is that plainstate represent
// `previous_stage_progress` state.
let accounts = tx.get_plainstate_accounts(lists)?;
let accounts = provider.basic_accounts(lists)?;
// Insert and hash accounts to hashing table
tx.insert_account_for_hashing(accounts.into_iter())?;
provider.insert_account_for_hashing(accounts.into_iter())?;
}
// We finished the hashing stage, no future iterations is expected for the same block range,
// so no checkpoint is needed.
let checkpoint = StageCheckpoint::new(input.target())
.with_account_hashing_stage_checkpoint(AccountHashingCheckpoint {
progress: stage_checkpoint_progress(tx)?,
progress: stage_checkpoint_progress(provider)?,
..Default::default()
});
@ -263,19 +264,19 @@ impl<DB: Database> Stage<DB> for AccountHashingStage {
/// Unwind the stage.
async fn unwind(
&mut self,
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, &DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
let (range, unwind_progress, _) =
input.unwind_block_range_with_threshold(self.commit_threshold);
// Aggregate all transition changesets and make a list of accounts that have been changed.
tx.unwind_account_hashing(range)?;
provider.unwind_account_hashing(range)?;
let mut stage_checkpoint =
input.checkpoint.account_hashing_stage_checkpoint().unwrap_or_default();
stage_checkpoint.progress = stage_checkpoint_progress(tx)?;
stage_checkpoint.progress = stage_checkpoint_progress(provider)?;
Ok(UnwindOutput {
checkpoint: StageCheckpoint::new(unwind_progress)
@ -285,11 +286,11 @@ impl<DB: Database> Stage<DB> for AccountHashingStage {
}
fn stage_checkpoint_progress<DB: Database>(
tx: &Transaction<'_, DB>,
provider: &DatabaseProviderRW<'_, &DB>,
) -> Result<EntitiesCheckpoint, DatabaseError> {
Ok(EntitiesCheckpoint {
processed: tx.deref().entries::<tables::HashedAccount>()? as u64,
total: tx.deref().entries::<tables::PlainAccountState>()? as u64,
processed: provider.tx_ref().entries::<tables::HashedAccount>()? as u64,
total: provider.tx_ref().entries::<tables::PlainAccountState>()? as u64,
})
}
@ -531,11 +532,14 @@ mod tests {
type Seed = Vec<(Address, Account)>;
fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
Ok(AccountHashingStage::seed(
&mut self.tx.inner(),
let mut provider = self.tx.inner();
let res = Ok(AccountHashingStage::seed(
&mut provider,
SeedOpts { blocks: 1..=input.target(), accounts: 0..10, txs: 0..3 },
)
.unwrap())
.unwrap());
provider.commit().expect("failed to commit");
res
}
fn validate_execution(

View File

@ -16,8 +16,8 @@ use reth_primitives::{
},
StorageEntry,
};
use reth_provider::Transaction;
use std::{collections::BTreeMap, fmt::Debug, ops::Deref};
use reth_provider::DatabaseProviderRW;
use std::{collections::BTreeMap, fmt::Debug};
use tracing::*;
/// Storage hashing stage hashes plain storage.
@ -54,9 +54,10 @@ impl<DB: Database> Stage<DB> for StorageHashingStage {
/// Execute the stage.
async fn execute(
&mut self,
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, &DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
let tx = provider.tx_ref();
if input.target_reached() {
return Ok(ExecOutput::done(input.checkpoint()))
}
@ -161,7 +162,7 @@ impl<DB: Database> Stage<DB> for StorageHashingStage {
address: current_key,
storage: current_subkey,
block_range: CheckpointBlockRange { from: from_block, to: to_block },
progress: stage_checkpoint_progress(tx)?,
progress: stage_checkpoint_progress(provider)?,
},
);
@ -170,19 +171,20 @@ impl<DB: Database> Stage<DB> for StorageHashingStage {
} else {
// Aggregate all changesets and and make list of storages that have been
// changed.
let lists = tx.get_addresses_and_keys_of_changed_storages(from_block..=to_block)?;
let lists =
provider.get_addresses_and_keys_of_changed_storages(from_block..=to_block)?;
// iterate over plain state and get newest storage value.
// Assumption we are okay with is that plain state represent
// `previous_stage_progress` state.
let storages = tx.get_plainstate_storages(lists)?;
tx.insert_storage_for_hashing(storages.into_iter())?;
let storages = provider.get_plainstate_storages(lists)?;
provider.insert_storage_for_hashing(storages.into_iter())?;
}
// We finished the hashing stage, no future iterations is expected for the same block range,
// so no checkpoint is needed.
let checkpoint = StageCheckpoint::new(input.target())
.with_storage_hashing_stage_checkpoint(StorageHashingCheckpoint {
progress: stage_checkpoint_progress(tx)?,
progress: stage_checkpoint_progress(provider)?,
..Default::default()
});
@ -192,18 +194,18 @@ impl<DB: Database> Stage<DB> for StorageHashingStage {
/// Unwind the stage.
async fn unwind(
&mut self,
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, &DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
let (range, unwind_progress, _) =
input.unwind_block_range_with_threshold(self.commit_threshold);
tx.unwind_storage_hashing(BlockNumberAddress::range(range))?;
provider.unwind_storage_hashing(BlockNumberAddress::range(range))?;
let mut stage_checkpoint =
input.checkpoint.storage_hashing_stage_checkpoint().unwrap_or_default();
stage_checkpoint.progress = stage_checkpoint_progress(tx)?;
stage_checkpoint.progress = stage_checkpoint_progress(provider)?;
Ok(UnwindOutput {
checkpoint: StageCheckpoint::new(unwind_progress)
@ -213,11 +215,11 @@ impl<DB: Database> Stage<DB> for StorageHashingStage {
}
fn stage_checkpoint_progress<DB: Database>(
tx: &Transaction<'_, DB>,
provider: &DatabaseProviderRW<'_, &DB>,
) -> Result<EntitiesCheckpoint, DatabaseError> {
Ok(EntitiesCheckpoint {
processed: tx.deref().entries::<tables::HashedStorage>()? as u64,
total: tx.deref().entries::<tables::PlainStorageState>()? as u64,
processed: provider.tx_ref().entries::<tables::HashedStorage>()? as u64,
total: provider.tx_ref().entries::<tables::PlainStorageState>()? as u64,
})
}

View File

@ -19,7 +19,8 @@ use reth_primitives::{
},
BlockHashOrNumber, BlockNumber, SealedHeader, H256,
};
use reth_provider::Transaction;
use reth_provider::DatabaseProviderRW;
use std::ops::Deref;
use tokio::sync::watch;
use tracing::*;
@ -68,7 +69,7 @@ where
fn is_stage_done<DB: Database>(
&self,
tx: &Transaction<'_, DB>,
tx: &<DB as reth_db::database::DatabaseGAT<'_>>::TXMut,
checkpoint: u64,
) -> Result<bool, StageError> {
let mut header_cursor = tx.cursor_read::<tables::CanonicalHeaders>()?;
@ -84,12 +85,12 @@ where
/// See also [SyncTarget]
async fn get_sync_gap<DB: Database>(
&mut self,
tx: &Transaction<'_, DB>,
provider: &DatabaseProviderRW<'_, &DB>,
checkpoint: u64,
) -> Result<SyncGap, StageError> {
// Create a cursor over canonical header hashes
let mut cursor = tx.cursor_read::<tables::CanonicalHeaders>()?;
let mut header_cursor = tx.cursor_read::<tables::Headers>()?;
let mut cursor = provider.tx_ref().cursor_read::<tables::CanonicalHeaders>()?;
let mut header_cursor = provider.tx_ref().cursor_read::<tables::Headers>()?;
// Get head hash and reposition the cursor
let (head_num, head_hash) = cursor
@ -149,7 +150,7 @@ where
/// Note: this writes the headers with rising block numbers.
fn write_headers<DB: Database>(
&self,
tx: &Transaction<'_, DB>,
tx: &<DB as reth_db::database::DatabaseGAT<'_>>::TXMut,
headers: Vec<SealedHeader>,
) -> Result<Option<BlockNumber>, StageError> {
trace!(target: "sync::stages::headers", len = headers.len(), "writing headers");
@ -195,13 +196,14 @@ where
/// starting from the tip of the chain
async fn execute(
&mut self,
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, &DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
let tx = provider.tx_ref();
let current_checkpoint = input.checkpoint();
// Lookup the head and tip of the sync range
let gap = self.get_sync_gap(tx, current_checkpoint.block_number).await?;
let gap = self.get_sync_gap(provider.deref(), current_checkpoint.block_number).await?;
let local_head = gap.local_head.number;
let tip = gap.target.tip();
@ -301,7 +303,7 @@ where
// Write the headers to db
self.write_headers::<DB>(tx, downloaded_headers)?.unwrap_or_default();
if self.is_stage_done(tx, current_checkpoint.block_number)? {
if self.is_stage_done::<DB>(tx, current_checkpoint.block_number)? {
let checkpoint = current_checkpoint.block_number.max(
tx.cursor_read::<tables::CanonicalHeaders>()?
.last()?
@ -324,15 +326,15 @@ where
/// Unwind the stage.
async fn unwind(
&mut self,
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, &DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
// TODO: handle bad block
tx.unwind_table_by_walker::<tables::CanonicalHeaders, tables::HeaderNumbers>(
provider.unwind_table_by_walker::<tables::CanonicalHeaders, tables::HeaderNumbers>(
input.unwind_to + 1,
)?;
tx.unwind_table_by_num::<tables::CanonicalHeaders>(input.unwind_to)?;
let unwound_headers = tx.unwind_table_by_num::<tables::Headers>(input.unwind_to)?;
provider.unwind_table_by_num::<tables::CanonicalHeaders>(input.unwind_to)?;
let unwound_headers = provider.unwind_table_by_num::<tables::Headers>(input.unwind_to)?;
let stage_checkpoint =
input.checkpoint.headers_stage_checkpoint().map(|stage_checkpoint| HeadersCheckpoint {
@ -380,13 +382,15 @@ impl SyncGap {
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::{
stage_test_suite, ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner,
};
use assert_matches::assert_matches;
use reth_interfaces::test_utils::generators::random_header;
use reth_primitives::{stage::StageUnitCheckpoint, H256};
use reth_primitives::{stage::StageUnitCheckpoint, H256, MAINNET};
use reth_provider::ShareableDatabase;
use test_runner::HeadersTestRunner;
mod test_runner {
@ -598,7 +602,9 @@ mod tests {
#[tokio::test]
async fn head_and_tip_lookup() {
let runner = HeadersTestRunner::default();
let tx = runner.tx().inner();
let factory = ShareableDatabase::new(runner.tx().tx.as_ref(), MAINNET.clone());
let provider = factory.provider_rw().unwrap();
let tx = provider.tx_ref();
let mut stage = runner.stage();
let consensus_tip = H256::random();
@ -612,7 +618,7 @@ mod tests {
// Empty database
assert_matches!(
stage.get_sync_gap(&tx, checkpoint).await,
stage.get_sync_gap(&provider, checkpoint).await,
Err(StageError::DatabaseIntegrity(ProviderError::HeaderNotFound(block_number)))
if block_number.as_number().unwrap() == checkpoint
);
@ -623,7 +629,7 @@ mod tests {
tx.put::<tables::Headers>(head.number, head.clone().unseal())
.expect("failed to write header");
let gap = stage.get_sync_gap(&tx, checkpoint).await.unwrap();
let gap = stage.get_sync_gap(&provider, checkpoint).await.unwrap();
assert_eq!(gap.local_head, head);
assert_eq!(gap.target.tip(), consensus_tip.into());
@ -633,7 +639,7 @@ mod tests {
tx.put::<tables::Headers>(gap_tip.number, gap_tip.clone().unseal())
.expect("failed to write header");
let gap = stage.get_sync_gap(&tx, checkpoint).await.unwrap();
let gap = stage.get_sync_gap(&provider, checkpoint).await.unwrap();
assert_eq!(gap.local_head, head);
assert_eq!(gap.target.tip(), gap_tip.parent_hash.into());
@ -644,7 +650,7 @@ mod tests {
.expect("failed to write header");
assert_matches!(
stage.get_sync_gap(&tx, checkpoint).await,
stage.get_sync_gap(&provider, checkpoint).await,
Err(StageError::StageCheckpoint(_checkpoint)) if _checkpoint == checkpoint
);
}

View File

@ -6,11 +6,8 @@ use reth_primitives::{
},
BlockNumber,
};
use reth_provider::Transaction;
use std::{
fmt::Debug,
ops::{Deref, RangeInclusive},
};
use reth_provider::DatabaseProviderRW;
use std::{fmt::Debug, ops::RangeInclusive};
/// Stage is indexing history the account changesets generated in
/// [`ExecutionStage`][crate::stages::ExecutionStage]. For more information
@ -38,7 +35,7 @@ impl<DB: Database> Stage<DB> for IndexAccountHistoryStage {
/// Execute the stage.
async fn execute(
&mut self,
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, &DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
if input.target_reached() {
@ -48,18 +45,18 @@ impl<DB: Database> Stage<DB> for IndexAccountHistoryStage {
let (range, is_final_range) = input.next_block_range_with_threshold(self.commit_threshold);
let mut stage_checkpoint = stage_checkpoint(
tx,
provider,
input.checkpoint(),
// It is important to provide the full block range into the checkpoint,
// not the one accounting for commit threshold, to get the correct range end.
&input.next_block_range(),
)?;
let indices = tx.get_account_transition_ids_from_changeset(range.clone())?;
let indices = provider.get_account_transition_ids_from_changeset(range.clone())?;
let changesets = indices.values().map(|blocks| blocks.len() as u64).sum::<u64>();
// Insert changeset to history index
tx.insert_account_history_index(indices)?;
provider.insert_account_history_index(indices)?;
stage_checkpoint.progress.processed += changesets;
@ -73,13 +70,13 @@ impl<DB: Database> Stage<DB> for IndexAccountHistoryStage {
/// Unwind the stage.
async fn unwind(
&mut self,
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, &DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
let (range, unwind_progress, _) =
input.unwind_block_range_with_threshold(self.commit_threshold);
let changesets = tx.unwind_account_history_indices(range)?;
let changesets = provider.unwind_account_history_indices(range)?;
let checkpoint =
if let Some(mut stage_checkpoint) = input.checkpoint.index_history_stage_checkpoint() {
@ -105,7 +102,7 @@ impl<DB: Database> Stage<DB> for IndexAccountHistoryStage {
/// given block range and calculates the progress by counting the number of processed entries in the
/// [tables::AccountChangeSet] table within the given block range.
fn stage_checkpoint<DB: Database>(
tx: &Transaction<'_, DB>,
provider: &DatabaseProviderRW<'_, &DB>,
checkpoint: StageCheckpoint,
range: &RangeInclusive<BlockNumber>,
) -> Result<IndexHistoryCheckpoint, DatabaseError> {
@ -122,18 +119,19 @@ fn stage_checkpoint<DB: Database>(
block_range: CheckpointBlockRange::from(range),
progress: EntitiesCheckpoint {
processed: progress.processed,
total: tx.deref().entries::<tables::AccountChangeSet>()? as u64,
total: provider.tx_ref().entries::<tables::AccountChangeSet>()? as u64,
},
}
}
_ => IndexHistoryCheckpoint {
block_range: CheckpointBlockRange::from(range),
progress: EntitiesCheckpoint {
processed: tx
processed: provider
.tx_ref()
.cursor_read::<tables::AccountChangeSet>()?
.walk_range(0..=checkpoint.block_number)?
.count() as u64,
total: tx.deref().entries::<tables::AccountChangeSet>()? as u64,
total: provider.tx_ref().entries::<tables::AccountChangeSet>()? as u64,
},
},
})
@ -142,6 +140,7 @@ fn stage_checkpoint<DB: Database>(
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use reth_provider::ShareableDatabase;
use std::collections::BTreeMap;
use super::*;
@ -155,7 +154,7 @@ mod tests {
transaction::DbTxMut,
BlockNumberList,
};
use reth_primitives::{hex_literal::hex, H160};
use reth_primitives::{hex_literal::hex, H160, MAINNET};
const ADDRESS: H160 = H160(hex!("0000000000000000000000000000000000000001"));
@ -211,8 +210,9 @@ mod tests {
async fn run(tx: &TestTransaction, run_to: u64) {
let input = ExecInput { target: Some(run_to), ..Default::default() };
let mut stage = IndexAccountHistoryStage::default();
let mut tx = tx.inner();
let out = stage.execute(&mut tx, input).await.unwrap();
let factory = ShareableDatabase::new(tx.tx.as_ref(), MAINNET.clone());
let mut provider = factory.provider_rw().unwrap();
let out = stage.execute(&mut provider, input).await.unwrap();
assert_eq!(
out,
ExecOutput {
@ -225,7 +225,7 @@ mod tests {
done: true
}
);
tx.commit().unwrap();
provider.commit().unwrap();
}
async fn unwind(tx: &TestTransaction, unwind_from: u64, unwind_to: u64) {
@ -235,10 +235,11 @@ mod tests {
..Default::default()
};
let mut stage = IndexAccountHistoryStage::default();
let mut tx = tx.inner();
let out = stage.unwind(&mut tx, input).await.unwrap();
let factory = ShareableDatabase::new(tx.tx.as_ref(), MAINNET.clone());
let mut provider = factory.provider_rw().unwrap();
let out = stage.unwind(&mut provider, input).await.unwrap();
assert_eq!(out, UnwindOutput { checkpoint: StageCheckpoint::new(unwind_to) });
tx.commit().unwrap();
provider.commit().unwrap();
}
#[tokio::test]
@ -448,10 +449,11 @@ mod tests {
// run
{
let mut stage = IndexAccountHistoryStage { commit_threshold: 4 }; // Two runs required
let mut tx = test_tx.inner();
let factory = ShareableDatabase::new(&test_tx.tx, MAINNET.clone());
let mut provider = factory.provider_rw().unwrap();
let mut input = ExecInput { target: Some(5), ..Default::default() };
let out = stage.execute(&mut tx, input).await.unwrap();
let out = stage.execute(&mut provider, input).await.unwrap();
assert_eq!(
out,
ExecOutput {
@ -466,7 +468,7 @@ mod tests {
);
input.checkpoint = Some(out.checkpoint);
let out = stage.execute(&mut tx, input).await.unwrap();
let out = stage.execute(&mut provider, input).await.unwrap();
assert_eq!(
out,
ExecOutput {
@ -480,7 +482,7 @@ mod tests {
}
);
tx.commit().unwrap();
provider.commit().unwrap();
}
// verify
@ -536,8 +538,11 @@ mod tests {
})
.unwrap();
let factory = ShareableDatabase::new(tx.tx.as_ref(), MAINNET.clone());
let provider = factory.provider_rw().unwrap();
assert_matches!(
stage_checkpoint(&tx.inner(), StageCheckpoint::new(1), &(1..=2)).unwrap(),
stage_checkpoint(&provider, StageCheckpoint::new(1), &(1..=2)).unwrap(),
IndexHistoryCheckpoint {
block_range: CheckpointBlockRange { from: 1, to: 2 },
progress: EntitiesCheckpoint { processed: 2, total: 4 }

View File

@ -9,11 +9,8 @@ use reth_primitives::{
},
BlockNumber,
};
use reth_provider::Transaction;
use std::{
fmt::Debug,
ops::{Deref, RangeInclusive},
};
use reth_provider::DatabaseProviderRW;
use std::{fmt::Debug, ops::RangeInclusive};
/// Stage is indexing history the account changesets generated in
/// [`ExecutionStage`][crate::stages::ExecutionStage]. For more information
@ -41,7 +38,7 @@ impl<DB: Database> Stage<DB> for IndexStorageHistoryStage {
/// Execute the stage.
async fn execute(
&mut self,
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, &DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
if input.target_reached() {
@ -51,17 +48,17 @@ impl<DB: Database> Stage<DB> for IndexStorageHistoryStage {
let (range, is_final_range) = input.next_block_range_with_threshold(self.commit_threshold);
let mut stage_checkpoint = stage_checkpoint(
tx,
provider,
input.checkpoint(),
// It is important to provide the full block range into the checkpoint,
// not the one accounting for commit threshold, to get the correct range end.
&input.next_block_range(),
)?;
let indices = tx.get_storage_transition_ids_from_changeset(range.clone())?;
let indices = provider.get_storage_transition_ids_from_changeset(range.clone())?;
let changesets = indices.values().map(|blocks| blocks.len() as u64).sum::<u64>();
tx.insert_storage_history_index(indices)?;
provider.insert_storage_history_index(indices)?;
stage_checkpoint.progress.processed += changesets;
@ -75,13 +72,14 @@ impl<DB: Database> Stage<DB> for IndexStorageHistoryStage {
/// Unwind the stage.
async fn unwind(
&mut self,
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, &DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
let (range, unwind_progress, _) =
input.unwind_block_range_with_threshold(self.commit_threshold);
let changesets = tx.unwind_storage_history_indices(BlockNumberAddress::range(range))?;
let changesets =
provider.unwind_storage_history_indices(BlockNumberAddress::range(range))?;
let checkpoint =
if let Some(mut stage_checkpoint) = input.checkpoint.index_history_stage_checkpoint() {
@ -106,7 +104,7 @@ impl<DB: Database> Stage<DB> for IndexStorageHistoryStage {
/// given block range and calculates the progress by counting the number of processed entries in the
/// [tables::StorageChangeSet] table within the given block range.
fn stage_checkpoint<DB: Database>(
tx: &Transaction<'_, DB>,
provider: &DatabaseProviderRW<'_, &DB>,
checkpoint: StageCheckpoint,
range: &RangeInclusive<BlockNumber>,
) -> Result<IndexHistoryCheckpoint, DatabaseError> {
@ -123,18 +121,19 @@ fn stage_checkpoint<DB: Database>(
block_range: CheckpointBlockRange::from(range),
progress: EntitiesCheckpoint {
processed: progress.processed,
total: tx.deref().entries::<tables::StorageChangeSet>()? as u64,
total: provider.tx_ref().entries::<tables::StorageChangeSet>()? as u64,
},
}
}
_ => IndexHistoryCheckpoint {
block_range: CheckpointBlockRange::from(range),
progress: EntitiesCheckpoint {
processed: tx
processed: provider
.tx_ref()
.cursor_read::<tables::StorageChangeSet>()?
.walk_range(BlockNumberAddress::range(0..=checkpoint.block_number))?
.count() as u64,
total: tx.deref().entries::<tables::StorageChangeSet>()? as u64,
total: provider.tx_ref().entries::<tables::StorageChangeSet>()? as u64,
},
},
})
@ -144,6 +143,7 @@ fn stage_checkpoint<DB: Database>(
mod tests {
use assert_matches::assert_matches;
use reth_provider::ShareableDatabase;
use std::collections::BTreeMap;
use super::*;
@ -157,7 +157,7 @@ mod tests {
transaction::DbTxMut,
BlockNumberList,
};
use reth_primitives::{hex_literal::hex, StorageEntry, H160, H256, U256};
use reth_primitives::{hex_literal::hex, StorageEntry, H160, H256, MAINNET, U256};
const ADDRESS: H160 = H160(hex!("0000000000000000000000000000000000000001"));
const STORAGE_KEY: H256 =
@ -223,8 +223,9 @@ mod tests {
async fn run(tx: &TestTransaction, run_to: u64) {
let input = ExecInput { target: Some(run_to), ..Default::default() };
let mut stage = IndexStorageHistoryStage::default();
let mut tx = tx.inner();
let out = stage.execute(&mut tx, input).await.unwrap();
let factory = ShareableDatabase::new(tx.tx.as_ref(), MAINNET.clone());
let mut provider = factory.provider_rw().unwrap();
let out = stage.execute(&mut provider, input).await.unwrap();
assert_eq!(
out,
ExecOutput {
@ -237,7 +238,7 @@ mod tests {
done: true
}
);
tx.commit().unwrap();
provider.commit().unwrap();
}
async fn unwind(tx: &TestTransaction, unwind_from: u64, unwind_to: u64) {
@ -247,10 +248,11 @@ mod tests {
..Default::default()
};
let mut stage = IndexStorageHistoryStage::default();
let mut tx = tx.inner();
let out = stage.unwind(&mut tx, input).await.unwrap();
let factory = ShareableDatabase::new(tx.tx.as_ref(), MAINNET.clone());
let mut provider = factory.provider_rw().unwrap();
let out = stage.unwind(&mut provider, input).await.unwrap();
assert_eq!(out, UnwindOutput { checkpoint: StageCheckpoint::new(unwind_to) });
tx.commit().unwrap();
provider.commit().unwrap();
}
#[tokio::test]
@ -463,10 +465,11 @@ mod tests {
// run
{
let mut stage = IndexStorageHistoryStage { commit_threshold: 4 }; // Two runs required
let mut tx = test_tx.inner();
let factory = ShareableDatabase::new(&test_tx.tx, MAINNET.clone());
let mut provider = factory.provider_rw().unwrap();
let mut input = ExecInput { target: Some(5), ..Default::default() };
let out = stage.execute(&mut tx, input).await.unwrap();
let out = stage.execute(&mut provider, input).await.unwrap();
assert_eq!(
out,
ExecOutput {
@ -481,7 +484,7 @@ mod tests {
);
input.checkpoint = Some(out.checkpoint);
let out = stage.execute(&mut tx, input).await.unwrap();
let out = stage.execute(&mut provider, input).await.unwrap();
assert_eq!(
out,
ExecOutput {
@ -495,7 +498,7 @@ mod tests {
}
);
tx.commit().unwrap();
provider.commit().unwrap();
}
// verify
@ -561,8 +564,11 @@ mod tests {
})
.unwrap();
let factory = ShareableDatabase::new(tx.tx.as_ref(), MAINNET.clone());
let provider = factory.provider_rw().unwrap();
assert_matches!(
stage_checkpoint(&tx.inner(), StageCheckpoint::new(1), &(1..=2)).unwrap(),
stage_checkpoint(&provider, StageCheckpoint::new(1), &(1..=2)).unwrap(),
IndexHistoryCheckpoint {
block_range: CheckpointBlockRange { from: 1, to: 2 },
progress: EntitiesCheckpoint { processed: 3, total: 6 }

View File

@ -12,12 +12,9 @@ use reth_primitives::{
trie::StoredSubNode,
BlockNumber, SealedHeader, H256,
};
use reth_provider::Transaction;
use reth_provider::{DatabaseProviderRW, HeaderProvider, ProviderError};
use reth_trie::{IntermediateStateRootState, StateRoot, StateRootProgress};
use std::{
fmt::Debug,
ops::{Deref, DerefMut},
};
use std::fmt::Debug;
use tracing::*;
/// The merkle hashing stage uses input from
@ -93,11 +90,10 @@ impl MerkleStage {
/// Gets the hashing progress
pub fn get_execution_checkpoint<DB: Database>(
&self,
tx: &Transaction<'_, DB>,
provider: &DatabaseProviderRW<'_, &DB>,
) -> Result<Option<MerkleCheckpoint>, StageError> {
let buf = tx
.get::<tables::SyncStageProgress>(StageId::MerkleExecute.to_string())?
.unwrap_or_default();
let buf =
provider.get_stage_checkpoint_progress(StageId::MerkleExecute)?.unwrap_or_default();
if buf.is_empty() {
return Ok(None)
@ -110,7 +106,7 @@ impl MerkleStage {
/// Saves the hashing progress
pub fn save_execution_checkpoint<DB: Database>(
&mut self,
tx: &Transaction<'_, DB>,
provider: &DatabaseProviderRW<'_, &DB>,
checkpoint: Option<MerkleCheckpoint>,
) -> Result<(), StageError> {
let mut buf = vec![];
@ -123,8 +119,7 @@ impl MerkleStage {
);
checkpoint.to_compact(&mut buf);
}
tx.put::<tables::SyncStageProgress>(StageId::MerkleExecute.to_string(), buf)?;
Ok(())
Ok(provider.save_stage_checkpoint_progress(StageId::MerkleExecute, buf)?)
}
}
@ -143,7 +138,7 @@ impl<DB: Database> Stage<DB> for MerkleStage {
/// Execute the stage.
async fn execute(
&mut self,
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, &DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
let threshold = match self {
@ -160,10 +155,12 @@ impl<DB: Database> Stage<DB> for MerkleStage {
let (from_block, to_block) = range.clone().into_inner();
let current_block = input.target();
let block = tx.get_header(current_block)?;
let block = provider
.header_by_number(current_block)?
.ok_or_else(|| ProviderError::HeaderNotFound(current_block.into()))?;
let block_root = block.state_root;
let mut checkpoint = self.get_execution_checkpoint(tx)?;
let mut checkpoint = self.get_execution_checkpoint(provider)?;
let (trie_root, entities_checkpoint) = if range.is_empty() {
(block_root, input.checkpoint().entities_stage_checkpoint().unwrap_or_default())
@ -192,25 +189,27 @@ impl<DB: Database> Stage<DB> for MerkleStage {
);
// Reset the checkpoint and clear trie tables
checkpoint = None;
self.save_execution_checkpoint(tx, None)?;
tx.clear::<tables::AccountsTrie>()?;
tx.clear::<tables::StoragesTrie>()?;
self.save_execution_checkpoint(provider, None)?;
provider.tx_ref().clear::<tables::AccountsTrie>()?;
provider.tx_ref().clear::<tables::StoragesTrie>()?;
None
}
.unwrap_or(EntitiesCheckpoint {
processed: 0,
total: (tx.deref().entries::<tables::HashedAccount>()? +
tx.deref().entries::<tables::HashedStorage>()?) as u64,
total: (provider.tx_ref().entries::<tables::HashedAccount>()? +
provider.tx_ref().entries::<tables::HashedStorage>()?)
as u64,
});
let progress = StateRoot::new(tx.deref_mut())
let tx = provider.tx_ref();
let progress = StateRoot::new(tx)
.with_intermediate_state(checkpoint.map(IntermediateStateRootState::from))
.root_with_progress()
.map_err(|e| StageError::Fatal(Box::new(e)))?;
match progress {
StateRootProgress::Progress(state, hashed_entries_walked, updates) => {
updates.flush(tx.deref_mut())?;
updates.flush(tx)?;
let checkpoint = MerkleCheckpoint::new(
to_block,
@ -219,7 +218,7 @@ impl<DB: Database> Stage<DB> for MerkleStage {
state.walker_stack.into_iter().map(StoredSubNode::from).collect(),
state.hash_builder.into(),
);
self.save_execution_checkpoint(tx, Some(checkpoint))?;
self.save_execution_checkpoint(provider, Some(checkpoint))?;
entities_checkpoint.processed += hashed_entries_walked as u64;
@ -231,7 +230,7 @@ impl<DB: Database> Stage<DB> for MerkleStage {
})
}
StateRootProgress::Complete(root, hashed_entries_walked, updates) => {
updates.flush(tx.deref_mut())?;
updates.flush(tx)?;
entities_checkpoint.processed += hashed_entries_walked as u64;
@ -240,12 +239,13 @@ impl<DB: Database> Stage<DB> for MerkleStage {
}
} else {
debug!(target: "sync::stages::merkle::exec", current = ?current_block, target = ?to_block, "Updating trie");
let (root, updates) = StateRoot::incremental_root_with_updates(tx.deref_mut(), range)
.map_err(|e| StageError::Fatal(Box::new(e)))?;
updates.flush(tx.deref_mut())?;
let (root, updates) =
StateRoot::incremental_root_with_updates(provider.tx_ref(), range)
.map_err(|e| StageError::Fatal(Box::new(e)))?;
updates.flush(provider.tx_ref())?;
let total_hashed_entries = (tx.deref().entries::<tables::HashedAccount>()? +
tx.deref().entries::<tables::HashedStorage>()?)
let total_hashed_entries = (provider.tx_ref().entries::<tables::HashedAccount>()? +
provider.tx_ref().entries::<tables::HashedStorage>()?)
as u64;
let entities_checkpoint = EntitiesCheckpoint {
@ -260,7 +260,7 @@ impl<DB: Database> Stage<DB> for MerkleStage {
};
// Reset the checkpoint
self.save_execution_checkpoint(tx, None)?;
self.save_execution_checkpoint(provider, None)?;
self.validate_state_root(trie_root, block.seal_slow(), to_block)?;
@ -274,9 +274,10 @@ impl<DB: Database> Stage<DB> for MerkleStage {
/// Unwind the stage.
async fn unwind(
&mut self,
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, &DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
let tx = provider.tx_ref();
let range = input.unwind_block_range();
if matches!(self, MerkleStage::Execution { .. }) {
info!(target: "sync::stages::merkle::unwind", "Stage is always skipped");
@ -286,8 +287,8 @@ impl<DB: Database> Stage<DB> for MerkleStage {
let mut entities_checkpoint =
input.checkpoint.entities_stage_checkpoint().unwrap_or(EntitiesCheckpoint {
processed: 0,
total: (tx.deref().entries::<tables::HashedAccount>()? +
tx.deref().entries::<tables::HashedStorage>()?) as u64,
total: (tx.entries::<tables::HashedAccount>()? +
tx.entries::<tables::HashedStorage>()?) as u64,
});
if input.unwind_to == 0 {
@ -304,16 +305,17 @@ impl<DB: Database> Stage<DB> for MerkleStage {
// Unwind trie only if there are transitions
if !range.is_empty() {
let (block_root, updates) =
StateRoot::incremental_root_with_updates(tx.deref_mut(), range)
.map_err(|e| StageError::Fatal(Box::new(e)))?;
let (block_root, updates) = StateRoot::incremental_root_with_updates(tx, range)
.map_err(|e| StageError::Fatal(Box::new(e)))?;
// Validate the calulated state root
let target = tx.get_header(input.unwind_to)?;
let target = provider
.header_by_number(input.unwind_to)?
.ok_or_else(|| ProviderError::HeaderNotFound(input.unwind_to.into()))?;
self.validate_state_root(block_root, target.seal_slow(), input.unwind_to)?;
// Validation passed, apply unwind changes to the database.
updates.flush(tx.deref_mut())?;
updates.flush(provider.tx_ref())?;
// TODO(alexey): update entities checkpoint
} else {

View File

@ -13,8 +13,8 @@ use reth_primitives::{
stage::{EntitiesCheckpoint, StageCheckpoint, StageId},
TransactionSignedNoHash, TxNumber, H160,
};
use reth_provider::{ProviderError, Transaction};
use std::{fmt::Debug, ops::Deref};
use reth_provider::{DatabaseProviderRW, HeaderProvider, ProviderError};
use std::fmt::Debug;
use thiserror::Error;
use tokio::sync::mpsc;
use tracing::*;
@ -56,7 +56,7 @@ impl<DB: Database> Stage<DB> for SenderRecoveryStage {
/// the [`TxSenders`][reth_db::tables::TxSenders] table.
async fn execute(
&mut self,
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, &DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
if input.target_reached() {
@ -64,7 +64,7 @@ impl<DB: Database> Stage<DB> for SenderRecoveryStage {
}
let (tx_range, block_range, is_final_range) =
input.next_block_range_with_transaction_threshold(tx, self.commit_threshold)?;
input.next_block_range_with_transaction_threshold(provider, self.commit_threshold)?;
let end_block = *block_range.end();
// No transactions to walk over
@ -72,11 +72,13 @@ impl<DB: Database> Stage<DB> for SenderRecoveryStage {
info!(target: "sync::stages::sender_recovery", ?tx_range, "Target transaction already reached");
return Ok(ExecOutput {
checkpoint: StageCheckpoint::new(end_block)
.with_entities_stage_checkpoint(stage_checkpoint(tx)?),
.with_entities_stage_checkpoint(stage_checkpoint(provider)?),
done: is_final_range,
})
}
let tx = provider.tx_ref();
// Acquire the cursor for inserting elements
let mut senders_cursor = tx.cursor_write::<tables::TxSenders>()?;
@ -133,7 +135,9 @@ impl<DB: Database> Stage<DB> for SenderRecoveryStage {
// fetch the sealed header so we can use it in the sender recovery
// unwind
let sealed_header = tx.get_sealed_header(block_number)?;
let sealed_header = provider
.sealed_header(block_number)?
.ok_or(ProviderError::HeaderNotFound(block_number.into()))?;
return Err(StageError::Validation {
block: sealed_header,
error:
@ -150,7 +154,7 @@ impl<DB: Database> Stage<DB> for SenderRecoveryStage {
Ok(ExecOutput {
checkpoint: StageCheckpoint::new(end_block)
.with_entities_stage_checkpoint(stage_checkpoint(tx)?),
.with_entities_stage_checkpoint(stage_checkpoint(provider)?),
done: is_final_range,
})
}
@ -158,18 +162,18 @@ impl<DB: Database> Stage<DB> for SenderRecoveryStage {
/// Unwind the stage.
async fn unwind(
&mut self,
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, &DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
let (_, unwind_to, _) = input.unwind_block_range_with_threshold(self.commit_threshold);
// Lookup latest tx id that we should unwind to
let latest_tx_id = tx.block_body_indices(unwind_to)?.last_tx_num();
tx.unwind_table_by_num::<tables::TxSenders>(latest_tx_id)?;
let latest_tx_id = provider.block_body_indices(unwind_to)?.last_tx_num();
provider.unwind_table_by_num::<tables::TxSenders>(latest_tx_id)?;
Ok(UnwindOutput {
checkpoint: StageCheckpoint::new(unwind_to)
.with_entities_stage_checkpoint(stage_checkpoint(tx)?),
.with_entities_stage_checkpoint(stage_checkpoint(provider)?),
})
}
}
@ -194,11 +198,11 @@ fn recover_sender(
}
fn stage_checkpoint<DB: Database>(
tx: &Transaction<'_, DB>,
provider: &DatabaseProviderRW<'_, &DB>,
) -> Result<EntitiesCheckpoint, DatabaseError> {
Ok(EntitiesCheckpoint {
processed: tx.deref().entries::<tables::TxSenders>()? as u64,
total: tx.deref().entries::<tables::Transactions>()? as u64,
processed: provider.tx_ref().entries::<tables::TxSenders>()? as u64,
total: provider.tx_ref().entries::<tables::Transactions>()? as u64,
})
}

View File

@ -11,8 +11,8 @@ use reth_primitives::{
stage::{EntitiesCheckpoint, StageCheckpoint, StageId},
U256,
};
use reth_provider::Transaction;
use std::{ops::Deref, sync::Arc};
use reth_provider::DatabaseProviderRW;
use std::sync::Arc;
use tracing::*;
/// The total difficulty stage.
@ -51,9 +51,10 @@ impl<DB: Database> Stage<DB> for TotalDifficultyStage {
/// Write total difficulty entries
async fn execute(
&mut self,
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, &DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
let tx = provider.tx_ref();
if input.target_reached() {
return Ok(ExecOutput::done(input.checkpoint()))
}
@ -89,7 +90,7 @@ impl<DB: Database> Stage<DB> for TotalDifficultyStage {
Ok(ExecOutput {
checkpoint: StageCheckpoint::new(end_block)
.with_entities_stage_checkpoint(stage_checkpoint(tx)?),
.with_entities_stage_checkpoint(stage_checkpoint(provider)?),
done: is_final_range,
})
}
@ -97,26 +98,26 @@ impl<DB: Database> Stage<DB> for TotalDifficultyStage {
/// Unwind the stage.
async fn unwind(
&mut self,
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, &DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
let (_, unwind_to, _) = input.unwind_block_range_with_threshold(self.commit_threshold);
tx.unwind_table_by_num::<tables::HeaderTD>(unwind_to)?;
provider.unwind_table_by_num::<tables::HeaderTD>(unwind_to)?;
Ok(UnwindOutput {
checkpoint: StageCheckpoint::new(unwind_to)
.with_entities_stage_checkpoint(stage_checkpoint(tx)?),
.with_entities_stage_checkpoint(stage_checkpoint(provider)?),
})
}
}
fn stage_checkpoint<DB: Database>(
tx: &Transaction<'_, DB>,
provider: &DatabaseProviderRW<'_, DB>,
) -> Result<EntitiesCheckpoint, DatabaseError> {
Ok(EntitiesCheckpoint {
processed: tx.deref().entries::<tables::HeaderTD>()? as u64,
total: tx.deref().entries::<tables::Headers>()? as u64,
processed: provider.tx_ref().entries::<tables::HeaderTD>()? as u64,
total: provider.tx_ref().entries::<tables::Headers>()? as u64,
})
}

View File

@ -13,8 +13,7 @@ use reth_primitives::{
stage::{EntitiesCheckpoint, StageCheckpoint, StageId},
TransactionSignedNoHash, TxNumber, H256,
};
use reth_provider::Transaction;
use std::ops::Deref;
use reth_provider::DatabaseProviderRW;
use tokio::sync::mpsc;
use tracing::*;
@ -52,19 +51,19 @@ impl<DB: Database> Stage<DB> for TransactionLookupStage {
/// Write transaction hash -> id entries
async fn execute(
&mut self,
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, &DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
if input.target_reached() {
return Ok(ExecOutput::done(input.checkpoint()))
}
let (tx_range, block_range, is_final_range) =
input.next_block_range_with_transaction_threshold(tx, self.commit_threshold)?;
input.next_block_range_with_transaction_threshold(provider, self.commit_threshold)?;
let end_block = *block_range.end();
debug!(target: "sync::stages::transaction_lookup", ?tx_range, "Updating transaction lookup");
let tx = provider.tx_ref();
let mut tx_cursor = tx.cursor_read::<tables::Transactions>()?;
let tx_walker = tx_cursor.walk_range(tx_range)?;
@ -138,7 +137,7 @@ impl<DB: Database> Stage<DB> for TransactionLookupStage {
Ok(ExecOutput {
checkpoint: StageCheckpoint::new(end_block)
.with_entities_stage_checkpoint(stage_checkpoint(tx)?),
.with_entities_stage_checkpoint(stage_checkpoint(provider)?),
done: is_final_range,
})
}
@ -146,9 +145,10 @@ impl<DB: Database> Stage<DB> for TransactionLookupStage {
/// Unwind the stage.
async fn unwind(
&mut self,
tx: &mut Transaction<'_, DB>,
provider: &mut DatabaseProviderRW<'_, &DB>,
input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
let tx = provider.tx_ref();
let (range, unwind_to, _) = input.unwind_block_range_with_threshold(self.commit_threshold);
// Cursors to unwind tx hash to number
@ -174,17 +174,17 @@ impl<DB: Database> Stage<DB> for TransactionLookupStage {
Ok(UnwindOutput {
checkpoint: StageCheckpoint::new(unwind_to)
.with_entities_stage_checkpoint(stage_checkpoint(tx)?),
.with_entities_stage_checkpoint(stage_checkpoint(provider)?),
})
}
}
fn stage_checkpoint<DB: Database>(
tx: &Transaction<'_, DB>,
provider: &DatabaseProviderRW<'_, &DB>,
) -> Result<EntitiesCheckpoint, DatabaseError> {
Ok(EntitiesCheckpoint {
processed: tx.deref().entries::<tables::TxHashNumber>()? as u64,
total: tx.deref().entries::<tables::Transactions>()? as u64,
processed: provider.tx_ref().entries::<tables::TxHashNumber>()? as u64,
total: provider.tx_ref().entries::<tables::Transactions>()? as u64,
})
}

View File

@ -1,8 +1,9 @@
use super::TestTransaction;
use crate::{ExecInput, ExecOutput, Stage, StageError, UnwindInput, UnwindOutput};
use reth_db::mdbx::{Env, WriteMap};
use reth_provider::Transaction;
use std::borrow::Borrow;
use reth_primitives::MAINNET;
use reth_provider::ShareableDatabase;
use std::{borrow::Borrow, sync::Arc};
use tokio::sync::oneshot;
#[derive(thiserror::Error, Debug)]
@ -44,9 +45,11 @@ pub(crate) trait ExecuteStageTestRunner: StageTestRunner {
let (tx, rx) = oneshot::channel();
let (db, mut stage) = (self.tx().inner_raw(), self.stage());
tokio::spawn(async move {
let mut db = Transaction::new(db.borrow()).expect("failed to create db container");
let result = stage.execute(&mut db, input).await;
db.commit().expect("failed to commit");
let factory = ShareableDatabase::new(db.as_ref(), MAINNET.clone());
let mut provider = factory.provider_rw().unwrap();
let result = stage.execute(&mut provider, input).await;
provider.commit().expect("failed to commit");
tx.send(result).expect("failed to send message")
});
rx
@ -68,9 +71,11 @@ pub(crate) trait UnwindStageTestRunner: StageTestRunner {
let (tx, rx) = oneshot::channel();
let (db, mut stage) = (self.tx().inner_raw(), self.stage());
tokio::spawn(async move {
let mut db = Transaction::new(db.borrow()).expect("failed to create db container");
let result = stage.unwind(&mut db, input).await;
db.commit().expect("failed to commit");
let factory = ShareableDatabase::new(db.as_ref(), MAINNET.clone());
let mut provider = factory.provider_rw().unwrap();
let result = stage.unwind(&mut provider, input).await;
provider.commit().expect("failed to commit");
tx.send(result).expect("failed to send result");
});
Box::pin(rx).await.unwrap()

View File

@ -1,7 +1,7 @@
use crate::{ExecInput, ExecOutput, Stage, StageError, UnwindInput, UnwindOutput};
use reth_db::database::Database;
use reth_primitives::stage::StageId;
use reth_provider::Transaction;
use reth_provider::DatabaseProviderRW;
use std::collections::VecDeque;
#[derive(Debug)]
@ -48,7 +48,7 @@ impl<DB: Database> Stage<DB> for TestStage {
async fn execute(
&mut self,
_: &mut Transaction<'_, DB>,
_: &mut DatabaseProviderRW<'_, &DB>,
_input: ExecInput,
) -> Result<ExecOutput, StageError> {
self.exec_outputs
@ -58,7 +58,7 @@ impl<DB: Database> Stage<DB> for TestStage {
async fn unwind(
&mut self,
_: &mut Transaction<'_, DB>,
_: &mut DatabaseProviderRW<'_, &DB>,
_input: UnwindInput,
) -> Result<UnwindOutput, StageError> {
self.unwind_outputs

View File

@ -13,9 +13,10 @@ use reth_db::{
DatabaseError as DbError,
};
use reth_primitives::{
keccak256, Account, Address, BlockNumber, SealedBlock, SealedHeader, StorageEntry, H256, U256,
keccak256, Account, Address, BlockNumber, SealedBlock, SealedHeader, StorageEntry, H256,
MAINNET, U256,
};
use reth_provider::Transaction;
use reth_provider::{DatabaseProviderRW, ShareableDatabase};
use std::{
borrow::Borrow,
collections::BTreeMap,
@ -36,26 +37,30 @@ pub struct TestTransaction {
/// WriteMap DB
pub tx: Arc<Env<WriteMap>>,
pub path: Option<PathBuf>,
factory: ShareableDatabase<Arc<Env<WriteMap>>>,
}
impl Default for TestTransaction {
/// Create a new instance of [TestTransaction]
fn default() -> Self {
Self { tx: create_test_db::<WriteMap>(EnvKind::RW), path: None }
let tx = create_test_db::<WriteMap>(EnvKind::RW);
Self { tx: tx.clone(), path: None, factory: ShareableDatabase::new(tx, MAINNET.clone()) }
}
}
impl TestTransaction {
pub fn new(path: &Path) -> Self {
let tx = create_test_db::<WriteMap>(EnvKind::RW);
Self {
tx: Arc::new(create_test_db_with_path::<WriteMap>(EnvKind::RW, path)),
tx: tx.clone(),
path: Some(path.to_path_buf()),
factory: ShareableDatabase::new(tx, MAINNET.clone()),
}
}
/// Return a database wrapped in [Transaction].
pub fn inner(&self) -> Transaction<'_, Env<WriteMap>> {
Transaction::new(self.tx.borrow()).expect("failed to create db container")
/// Return a database wrapped in [DatabaseProviderRW].
pub fn inner(&self) -> DatabaseProviderRW<'_, Arc<Env<WriteMap>>> {
self.factory.provider_rw().expect("failed to create db container")
}
/// Get a pointer to an internal database.
@ -69,8 +74,8 @@ impl TestTransaction {
F: FnOnce(&mut Tx<'_, RW, WriteMap>) -> Result<(), DbError>,
{
let mut tx = self.inner();
f(&mut tx)?;
tx.commit()?;
f(tx.tx_mut())?;
tx.commit().expect("failed to commit");
Ok(())
}
@ -79,7 +84,7 @@ impl TestTransaction {
where
F: FnOnce(&Tx<'_, RW, WriteMap>) -> Result<R, DbError>,
{
f(&self.inner())
f(self.inner().tx_ref())
}
/// Check if the table is empty

View File

@ -4,7 +4,7 @@ use crate::{
transaction::{DbTx, DbTxMut},
DatabaseError,
};
use std::sync::Arc;
use std::{fmt::Debug, sync::Arc};
/// Implements the GAT method from:
/// <https://sabrinajewson.org/blog/the-better-alternative-to-lifetime-gats#the-better-gats>.
@ -12,9 +12,9 @@ use std::sync::Arc;
/// Sealed trait which cannot be implemented by 3rd parties, exposed only for implementers
pub trait DatabaseGAT<'a, __ImplicitBounds: Sealed = Bounds<&'a Self>>: Send + Sync {
/// RO database transaction
type TX: DbTx<'a> + Send + Sync;
type TX: DbTx<'a> + Send + Sync + Debug;
/// RW database transaction
type TXMut: DbTxMut<'a> + DbTx<'a> + TableImporter<'a> + Send + Sync;
type TXMut: DbTxMut<'a> + DbTx<'a> + TableImporter<'a> + Send + Sync + Debug;
}
/// Main Database trait that spawns transactions to be executed.

View File

@ -38,7 +38,7 @@ impl<'a> DatabaseGAT<'a> for DatabaseMock {
}
/// Mock read only tx
#[derive(Clone, Default)]
#[derive(Debug, Clone, Default)]
pub struct TxMock {
/// Table representation
_table: BTreeMap<Vec<u8>, Vec<u8>>,

View File

@ -11,11 +11,11 @@
/// Various provider traits.
mod traits;
pub use traits::{
AccountProvider, BlockExecutor, BlockHashProvider, BlockIdProvider, BlockNumProvider,
BlockProvider, BlockProviderIdExt, BlockSource, BlockchainTreePendingStateProvider,
CanonChainTracker, CanonStateNotification, CanonStateNotificationSender,
CanonStateNotifications, CanonStateSubscriptions, EvmEnvProvider, ExecutorFactory,
HeaderProvider, PostStateDataProvider, ReceiptProvider, ReceiptProviderIdExt,
AccountExtProvider, AccountProvider, BlockExecutor, BlockHashProvider, BlockIdProvider,
BlockNumProvider, BlockProvider, BlockProviderIdExt, BlockSource,
BlockchainTreePendingStateProvider, CanonChainTracker, CanonStateNotification,
CanonStateNotificationSender, CanonStateNotifications, CanonStateSubscriptions, EvmEnvProvider,
ExecutorFactory, HeaderProvider, PostStateDataProvider, ReceiptProvider, ReceiptProviderIdExt,
StageCheckpointProvider, StateProvider, StateProviderBox, StateProviderFactory,
StateRootProvider, TransactionsProvider, WithdrawalsProvider,
};
@ -23,8 +23,8 @@ pub use traits::{
/// Provider trait implementations.
pub mod providers;
pub use providers::{
HistoricalStateProvider, HistoricalStateProviderRef, LatestStateProvider,
LatestStateProviderRef, ShareableDatabase,
DatabaseProvider, DatabaseProviderRO, DatabaseProviderRW, HistoricalStateProvider,
HistoricalStateProviderRef, LatestStateProvider, LatestStateProviderRef, ShareableDatabase,
};
/// Execution result
@ -33,7 +33,7 @@ pub use post_state::PostState;
/// Helper types for interacting with the database
mod transaction;
pub use transaction::{Transaction, TransactionError};
pub use transaction::TransactionError;
/// Common database utilities.
mod utils;

View File

@ -18,7 +18,7 @@ use std::{ops::RangeBounds, sync::Arc};
use tracing::trace;
mod provider;
use provider::{DatabaseProvider, DatabaseProviderRO, DatabaseProviderRW};
pub use provider::{DatabaseProvider, DatabaseProviderRO, DatabaseProviderRW};
/// A common provider that fetches data from a database.
///
@ -34,16 +34,17 @@ pub struct ShareableDatabase<DB> {
impl<DB: Database> ShareableDatabase<DB> {
/// Returns a provider with a created `DbTx` inside, which allows fetching data from the
/// database using different types of providers. Example: [`HeaderProvider`]
/// [`BlockHashProvider`]
/// [`BlockHashProvider`]. This may fail if the inner read database transaction fails to open.
pub fn provider(&self) -> Result<DatabaseProviderRO<'_, DB>> {
Ok(DatabaseProvider::new(self.db.tx()?, self.chain_spec.clone()))
}
/// Returns a provider with a created `DbTxMut` inside, which allows fetching and updating
/// data from the database using different types of providers. Example: [`HeaderProvider`]
/// [`BlockHashProvider`]
/// [`BlockHashProvider`]. This may fail if the inner read/write database transaction fails to
/// open.
pub fn provider_rw(&self) -> Result<DatabaseProviderRW<'_, DB>> {
Ok(DatabaseProvider::new_rw(self.db.tx_mut()?, self.chain_spec.clone()))
Ok(DatabaseProviderRW(DatabaseProvider::new_rw(self.db.tx_mut()?, self.chain_spec.clone())))
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
//! Dummy blocks and data for tests
use crate::{post_state::PostState, Transaction};
use crate::{post_state::PostState, DatabaseProviderRW};
use reth_db::{database::Database, models::StoredBlockBodyIndices, tables};
use reth_primitives::{
hex_literal::hex, Account, BlockNumber, Bytes, Header, Log, Receipt, SealedBlock,
@ -10,9 +10,11 @@ use reth_rlp::Decodable;
use std::collections::BTreeMap;
/// Assert genesis block
pub fn assert_genesis_block<DB: Database>(tx: &Transaction<'_, DB>, g: SealedBlock) {
pub fn assert_genesis_block<DB: Database>(provider: &DatabaseProviderRW<'_, DB>, g: SealedBlock) {
let n = g.number;
let h = H256::zero();
let tx = provider;
// check if all tables are empty
assert_eq!(tx.table::<tables::Headers>().unwrap(), vec![(g.number, g.header.clone().unseal())]);

View File

@ -1,6 +1,7 @@
use auto_impl::auto_impl;
use reth_interfaces::Result;
use reth_primitives::{Account, Address};
use reth_primitives::{Account, Address, BlockNumber};
use std::{collections::BTreeSet, ops::RangeBounds};
/// Account provider
#[auto_impl(&, Arc, Box)]
@ -10,3 +11,22 @@ pub trait AccountProvider: Send + Sync {
/// Returns `None` if the account doesn't exist.
fn basic_account(&self, address: Address) -> Result<Option<Account>>;
}
/// Account provider
#[auto_impl(&, Arc, Box)]
pub trait AccountExtProvider: Send + Sync {
/// Iterate over account changesets and return all account address that were changed.
fn changed_accounts_with_range(
&self,
_range: impl RangeBounds<BlockNumber>,
) -> Result<BTreeSet<Address>>;
/// Get basic account information for multiple accounts. A more efficient version than calling
/// [`AccountProvider::basic_account`] repeatedly.
///
/// Returns `None` if the account doesn't exist.
fn basic_accounts(
&self,
_iter: impl IntoIterator<Item = Address>,
) -> Result<Vec<(Address, Option<Account>)>>;
}

View File

@ -1,7 +1,7 @@
//! Collection of common provider traits.
mod account;
pub use account::AccountProvider;
pub use account::{AccountExtProvider, AccountProvider};
mod block;
pub use block::{BlockProvider, BlockProviderIdExt, BlockSource};

File diff suppressed because it is too large Load Diff

View File

@ -523,14 +523,10 @@ mod tests {
keccak256,
proofs::KeccakHasher,
trie::{BranchNodeCompact, TrieMask},
Account, Address, H256, U256,
};
use reth_provider::Transaction;
use std::{
collections::BTreeMap,
ops::{Deref, DerefMut, Mul},
str::FromStr,
Account, Address, H256, MAINNET, U256,
};
use reth_provider::{DatabaseProviderRW, ShareableDatabase};
use std::{collections::BTreeMap, ops::Mul, str::FromStr};
fn insert_account<'a, TX: DbTxMut<'a>>(
tx: &mut TX,
@ -559,10 +555,12 @@ mod tests {
fn incremental_vs_full_root(inputs: &[&str], modified: &str) {
let db = create_test_rw_db();
let mut tx = Transaction::new(db.as_ref()).unwrap();
let factory = ShareableDatabase::new(db.as_ref(), MAINNET.clone());
let mut tx = factory.provider_rw().unwrap();
let hashed_address = H256::from_low_u64_be(1);
let mut hashed_storage_cursor = tx.cursor_dup_write::<tables::HashedStorage>().unwrap();
let mut hashed_storage_cursor =
tx.tx_ref().cursor_dup_write::<tables::HashedStorage>().unwrap();
let data = inputs.iter().map(|x| H256::from_str(x).unwrap());
let value = U256::from(0);
for key in data {
@ -571,7 +569,7 @@ mod tests {
// Generate the intermediate nodes on the receiving end of the channel
let (_, _, trie_updates) =
StorageRoot::new_hashed(tx.deref(), hashed_address).root_with_updates().unwrap();
StorageRoot::new_hashed(tx.tx_ref(), hashed_address).root_with_updates().unwrap();
// 1. Some state transition happens, update the hashed storage to the new value
let modified_key = H256::from_str(modified).unwrap();
@ -585,16 +583,16 @@ mod tests {
.unwrap();
// 2. Calculate full merkle root
let loader = StorageRoot::new_hashed(tx.deref(), hashed_address);
let loader = StorageRoot::new_hashed(tx.tx_ref(), hashed_address);
let modified_root = loader.root().unwrap();
// Update the intermediate roots table so that we can run the incremental verification
trie_updates.flush(tx.deref()).unwrap();
trie_updates.flush(tx.tx_ref()).unwrap();
// 3. Calculate the incremental root
let mut storage_changes = PrefixSet::default();
storage_changes.insert(Nibbles::unpack(modified_key));
let loader = StorageRoot::new_hashed(tx.deref_mut(), hashed_address)
let loader = StorageRoot::new_hashed(tx.tx_mut(), hashed_address)
.with_changed_prefixes(storage_changes);
let incremental_root = loader.root().unwrap();
@ -624,9 +622,10 @@ mod tests {
let hashed_address = keccak256(address);
let db = create_test_rw_db();
let mut tx = Transaction::new(db.as_ref()).unwrap();
let factory = ShareableDatabase::new(db.as_ref(), MAINNET.clone());
let tx = factory.provider_rw().unwrap();
for (key, value) in &storage {
tx.put::<tables::HashedStorage>(
tx.tx_ref().put::<tables::HashedStorage>(
hashed_address,
StorageEntry { key: keccak256(key), value: *value },
)
@ -634,7 +633,8 @@ mod tests {
}
tx.commit().unwrap();
let got = StorageRoot::new(tx.deref_mut(), address).root().unwrap();
let mut tx = factory.provider_rw().unwrap();
let got = StorageRoot::new(tx.tx_mut(), address).root().unwrap();
let expected = storage_root(storage.into_iter());
assert_eq!(expected, got);
});
@ -680,7 +680,8 @@ mod tests {
// This ensures we return an empty root when there are no storage entries
fn test_empty_storage_root() {
let db = create_test_rw_db();
let mut tx = Transaction::new(db.as_ref()).unwrap();
let factory = ShareableDatabase::new(db.as_ref(), MAINNET.clone());
let mut tx = factory.provider_rw().unwrap();
let address = Address::random();
let code = "el buen fla";
@ -689,10 +690,11 @@ mod tests {
balance: U256::from(414241124u32),
bytecode_hash: Some(keccak256(code)),
};
insert_account(&mut *tx, address, account, &Default::default());
insert_account(tx.tx_mut(), address, account, &Default::default());
tx.commit().unwrap();
let got = StorageRoot::new(tx.deref_mut(), address).root().unwrap();
let mut tx = factory.provider_rw().unwrap();
let got = StorageRoot::new(tx.tx_mut(), address).root().unwrap();
assert_eq!(got, EMPTY_ROOT);
}
@ -700,7 +702,8 @@ mod tests {
// This ensures that the walker goes over all the storage slots
fn test_storage_root() {
let db = create_test_rw_db();
let mut tx = Transaction::new(db.as_ref()).unwrap();
let factory = ShareableDatabase::new(db.as_ref(), MAINNET.clone());
let mut tx = factory.provider_rw().unwrap();
let address = Address::random();
let storage = BTreeMap::from([
@ -715,10 +718,11 @@ mod tests {
bytecode_hash: Some(keccak256(code)),
};
insert_account(&mut *tx, address, account, &storage);
insert_account(tx.tx_mut(), address, account, &storage);
tx.commit().unwrap();
let got = StorageRoot::new(tx.deref_mut(), address).root().unwrap();
let mut tx = factory.provider_rw().unwrap();
let got = StorageRoot::new(tx.tx_mut(), address).root().unwrap();
assert_eq!(storage_root(storage.into_iter()), got);
}
@ -742,12 +746,15 @@ mod tests {
state.values().map(|(_, slots)| slots.len()).sum::<usize>();
let db = create_test_rw_db();
let mut tx = Transaction::new(db.as_ref()).unwrap();
let factory = ShareableDatabase::new(db.as_ref(), MAINNET.clone());
let mut tx = factory.provider_rw().unwrap();
for (address, (account, storage)) in &state {
insert_account(&mut *tx, *address, *account, storage)
insert_account(tx.tx_mut(), *address, *account, storage)
}
tx.commit().unwrap();
let mut tx = factory.provider_rw().unwrap();
let expected = state_root(state.into_iter());
let threshold = 10;
@ -756,7 +763,7 @@ mod tests {
let mut intermediate_state: Option<Box<IntermediateStateRootState>> = None;
while got.is_none() {
let calculator = StateRoot::new(tx.deref_mut())
let calculator = StateRoot::new(tx.tx_mut())
.with_threshold(threshold)
.with_intermediate_state(intermediate_state.take().map(|state| *state));
match calculator.root_with_progress().unwrap() {
@ -778,15 +785,17 @@ mod tests {
fn test_state_root_with_state(state: State) {
let db = create_test_rw_db();
let mut tx = Transaction::new(db.as_ref()).unwrap();
let factory = ShareableDatabase::new(db.as_ref(), MAINNET.clone());
let mut tx = factory.provider_rw().unwrap();
for (address, (account, storage)) in &state {
insert_account(&mut *tx, *address, *account, storage)
insert_account(tx.tx_mut(), *address, *account, storage)
}
tx.commit().unwrap();
let expected = state_root(state.into_iter());
let got = StateRoot::new(tx.deref_mut()).root().unwrap();
let mut tx = factory.provider_rw().unwrap();
let got = StateRoot::new(tx.tx_mut()).root().unwrap();
assert_eq!(expected, got);
}
@ -803,7 +812,8 @@ mod tests {
#[test]
fn storage_root_regression() {
let db = create_test_rw_db();
let mut tx = Transaction::new(db.as_ref()).unwrap();
let factory = ShareableDatabase::new(db.as_ref(), MAINNET.clone());
let tx = factory.provider_rw().unwrap();
// Some address whose hash starts with 0xB041
let address3 = Address::from_str("16b07afd1c635f77172e842a000ead9a2a222459").unwrap();
let key3 = keccak256(address3);
@ -820,13 +830,15 @@ mod tests {
.map(|(slot, val)| (H256::from_str(slot).unwrap(), U256::from(val))),
);
let mut hashed_storage_cursor = tx.cursor_dup_write::<tables::HashedStorage>().unwrap();
let mut hashed_storage_cursor =
tx.tx_ref().cursor_dup_write::<tables::HashedStorage>().unwrap();
for (hashed_slot, value) in storage.clone() {
hashed_storage_cursor.upsert(key3, StorageEntry { key: hashed_slot, value }).unwrap();
}
tx.commit().unwrap();
let mut tx = factory.provider_rw().unwrap();
let account3_storage_root = StorageRoot::new(tx.deref_mut(), address3).root().unwrap();
let account3_storage_root = StorageRoot::new(tx.tx_mut(), address3).root().unwrap();
let expected_root = storage_root_prehashed(storage.into_iter());
assert_eq!(expected_root, account3_storage_root);
}
@ -845,10 +857,13 @@ mod tests {
);
let db = create_test_rw_db();
let mut tx = Transaction::new(db.as_ref()).unwrap();
let factory = ShareableDatabase::new(db.as_ref(), MAINNET.clone());
let mut tx = factory.provider_rw().unwrap();
let mut hashed_account_cursor = tx.cursor_write::<tables::HashedAccount>().unwrap();
let mut hashed_storage_cursor = tx.cursor_dup_write::<tables::HashedStorage>().unwrap();
let mut hashed_account_cursor =
tx.tx_ref().cursor_write::<tables::HashedAccount>().unwrap();
let mut hashed_storage_cursor =
tx.tx_ref().cursor_dup_write::<tables::HashedStorage>().unwrap();
let mut hash_builder = HashBuilder::default();
@ -891,7 +906,7 @@ mod tests {
}
hashed_storage_cursor.upsert(key3, StorageEntry { key: hashed_slot, value }).unwrap();
}
let account3_storage_root = StorageRoot::new(tx.deref_mut(), address3).root().unwrap();
let account3_storage_root = StorageRoot::new(tx.tx_mut(), address3).root().unwrap();
hash_builder.add_leaf(
Nibbles::unpack(key3),
&encode_account(account3, Some(account3_storage_root)),
@ -940,7 +955,7 @@ mod tests {
assert_eq!(hash_builder.root(), computed_expected_root);
// Check state root calculation from scratch
let (root, trie_updates) = StateRoot::new(tx.deref()).root_with_updates().unwrap();
let (root, trie_updates) = StateRoot::new(tx.tx_ref()).root_with_updates().unwrap();
assert_eq!(root, computed_expected_root);
// Check account trie
@ -1005,7 +1020,7 @@ mod tests {
H256::from_str("8e263cd4eefb0c3cbbb14e5541a66a755cad25bcfab1e10dd9d706263e811b28")
.unwrap();
let (root, trie_updates) = StateRoot::new(tx.deref())
let (root, trie_updates) = StateRoot::new(tx.tx_ref())
.with_changed_account_prefixes(prefix_set)
.root_with_updates()
.unwrap();
@ -1035,9 +1050,11 @@ mod tests {
assert_eq!(nibbles2b.inner[..], [0xB, 0x0]);
assert_eq!(node2a, node2b);
tx.commit().unwrap();
let tx = factory.provider_rw().unwrap();
{
let mut hashed_account_cursor = tx.cursor_write::<tables::HashedAccount>().unwrap();
let mut hashed_account_cursor =
tx.tx_ref().cursor_write::<tables::HashedAccount>().unwrap();
let account = hashed_account_cursor.seek_exact(key2).unwrap().unwrap();
hashed_account_cursor.delete_current().unwrap();
@ -1055,7 +1072,7 @@ mod tests {
(key6, encode_account(account6, None)),
]);
let (root, trie_updates) = StateRoot::new(tx.deref())
let (root, trie_updates) = StateRoot::new(tx.tx_ref())
.with_changed_account_prefixes(account_prefix_set)
.root_with_updates()
.unwrap();
@ -1085,11 +1102,13 @@ mod tests {
assert_ne!(node1c.hashes[0], node1b.hashes[0]);
assert_eq!(node1c.hashes[1], node1b.hashes[1]);
assert_eq!(node1c.hashes[2], node1b.hashes[2]);
tx.drop().unwrap();
drop(tx);
}
let mut tx = factory.provider_rw().unwrap();
{
let mut hashed_account_cursor = tx.cursor_write::<tables::HashedAccount>().unwrap();
let mut hashed_account_cursor =
tx.tx_ref().cursor_write::<tables::HashedAccount>().unwrap();
let account2 = hashed_account_cursor.seek_exact(key2).unwrap().unwrap();
hashed_account_cursor.delete_current().unwrap();
@ -1110,7 +1129,7 @@ mod tests {
(key6, encode_account(account6, None)),
]);
let (root, trie_updates) = StateRoot::new(tx.deref_mut())
let (root, trie_updates) = StateRoot::new(tx.tx_mut())
.with_changed_account_prefixes(account_prefix_set)
.root_with_updates()
.unwrap();
@ -1145,11 +1164,12 @@ mod tests {
#[test]
fn account_trie_around_extension_node() {
let db = create_test_rw_db();
let mut tx = Transaction::new(db.as_ref()).unwrap();
let factory = ShareableDatabase::new(db.as_ref(), MAINNET.clone());
let mut tx = factory.provider_rw().unwrap();
let expected = extension_node_trie(&mut tx);
let (got, updates) = StateRoot::new(tx.deref_mut()).root_with_updates().unwrap();
let (got, updates) = StateRoot::new(tx.tx_mut()).root_with_updates().unwrap();
assert_eq!(expected, got);
// Check account trie
@ -1170,16 +1190,17 @@ mod tests {
fn account_trie_around_extension_node_with_dbtrie() {
let db = create_test_rw_db();
let mut tx = Transaction::new(db.as_ref()).unwrap();
let factory = ShareableDatabase::new(db.as_ref(), MAINNET.clone());
let mut tx = factory.provider_rw().unwrap();
let expected = extension_node_trie(&mut tx);
let (got, updates) = StateRoot::new(tx.deref_mut()).root_with_updates().unwrap();
let (got, updates) = StateRoot::new(tx.tx_mut()).root_with_updates().unwrap();
assert_eq!(expected, got);
updates.flush(tx.deref_mut()).unwrap();
updates.flush(tx.tx_mut()).unwrap();
// read the account updates from the db
let mut accounts_trie = tx.cursor_read::<tables::AccountsTrie>().unwrap();
let mut accounts_trie = tx.tx_ref().cursor_read::<tables::AccountsTrie>().unwrap();
let walker = accounts_trie.walk(None).unwrap();
let mut account_updates = HashMap::new();
for item in walker {
@ -1197,8 +1218,9 @@ mod tests {
tokio::runtime::Runtime::new().unwrap().block_on(async {
let db = create_test_rw_db();
let mut tx = Transaction::new(db.as_ref()).unwrap();
let mut hashed_account_cursor = tx.cursor_write::<tables::HashedAccount>().unwrap();
let factory = ShareableDatabase::new(db.as_ref(), MAINNET.clone());
let mut tx = factory.provider_rw().unwrap();
let mut hashed_account_cursor = tx.tx_ref().cursor_write::<tables::HashedAccount>().unwrap();
let mut state = BTreeMap::default();
for accounts in account_changes {
@ -1211,7 +1233,7 @@ mod tests {
}
}
let (state_root, trie_updates) = StateRoot::new(tx.deref_mut())
let (state_root, trie_updates) = StateRoot::new(tx.tx_mut())
.with_changed_account_prefixes(changes)
.root_with_updates()
.unwrap();
@ -1221,7 +1243,7 @@ mod tests {
state.clone().into_iter().map(|(key, balance)| (key, (Account { balance, ..Default::default() }, std::iter::empty())))
);
assert_eq!(expected_root, state_root);
trie_updates.flush(tx.deref_mut()).unwrap();
trie_updates.flush(tx.tx_mut()).unwrap();
}
});
}
@ -1230,14 +1252,15 @@ mod tests {
#[test]
fn storage_trie_around_extension_node() {
let db = create_test_rw_db();
let mut tx = Transaction::new(db.as_ref()).unwrap();
let factory = ShareableDatabase::new(db.as_ref(), MAINNET.clone());
let mut tx = factory.provider_rw().unwrap();
let hashed_address = H256::random();
let (expected_root, expected_updates) =
extension_node_storage_trie(&mut tx, hashed_address);
let (got, _, updates) =
StorageRoot::new_hashed(tx.deref_mut(), hashed_address).root_with_updates().unwrap();
StorageRoot::new_hashed(tx.tx_mut(), hashed_address).root_with_updates().unwrap();
assert_eq!(expected_root, got);
// Check account trie
@ -1256,12 +1279,12 @@ mod tests {
}
fn extension_node_storage_trie(
tx: &mut Transaction<'_, Env<WriteMap>>,
tx: &mut DatabaseProviderRW<'_, &Env<WriteMap>>,
hashed_address: H256,
) -> (H256, HashMap<Nibbles, BranchNodeCompact>) {
let value = U256::from(1);
let mut hashed_storage = tx.cursor_write::<tables::HashedStorage>().unwrap();
let mut hashed_storage = tx.tx_ref().cursor_write::<tables::HashedStorage>().unwrap();
let mut hb = HashBuilder::default().with_updates(true);
@ -1282,12 +1305,12 @@ mod tests {
(root, updates)
}
fn extension_node_trie(tx: &mut Transaction<'_, Env<WriteMap>>) -> H256 {
fn extension_node_trie(tx: &mut DatabaseProviderRW<'_, &Env<WriteMap>>) -> H256 {
let a =
Account { nonce: 0, balance: U256::from(1u64), bytecode_hash: Some(H256::random()) };
let val = encode_account(a, None);
let mut hashed_accounts = tx.cursor_write::<tables::HashedAccount>().unwrap();
let mut hashed_accounts = tx.tx_ref().cursor_write::<tables::HashedAccount>().unwrap();
let mut hb = HashBuilder::default();
for key in [

View File

@ -38,6 +38,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use reth_db::{
cursor::{DbCursorRO, DbCursorRW},
@ -45,14 +46,15 @@ mod tests {
tables,
transaction::DbTxMut,
};
use reth_primitives::hex_literal::hex;
use reth_provider::Transaction;
use reth_primitives::{hex_literal::hex, MAINNET};
use reth_provider::ShareableDatabase;
#[test]
fn test_account_trie_order() {
let db = create_test_rw_db();
let tx = Transaction::new(db.as_ref()).unwrap();
let mut cursor = tx.cursor_write::<tables::AccountsTrie>().unwrap();
let factory = ShareableDatabase::new(db.as_ref(), MAINNET.clone());
let provider = factory.provider_rw().unwrap();
let mut cursor = provider.tx_ref().cursor_write::<tables::AccountsTrie>().unwrap();
let data = vec![
hex!("0303040e").to_vec(),

View File

@ -55,19 +55,24 @@ where
#[cfg(test)]
mod tests {
use super::*;
use reth_db::{
cursor::DbCursorRW, mdbx::test_utils::create_test_rw_db, tables, transaction::DbTxMut,
};
use reth_primitives::trie::{BranchNodeCompact, StorageTrieEntry};
use reth_provider::Transaction;
use reth_primitives::{
trie::{BranchNodeCompact, StorageTrieEntry},
MAINNET,
};
use reth_provider::ShareableDatabase;
// tests that upsert and seek match on the storagetrie cursor
#[test]
fn test_storage_cursor_abstraction() {
let db = create_test_rw_db();
let tx = Transaction::new(db.as_ref()).unwrap();
let mut cursor = tx.cursor_dup_write::<tables::StoragesTrie>().unwrap();
let factory = ShareableDatabase::new(db.as_ref(), MAINNET.clone());
let provider = factory.provider_rw().unwrap();
let mut cursor = provider.tx_ref().cursor_dup_write::<tables::StoragesTrie>().unwrap();
let hashed_address = H256::random();
let key = vec![0x2, 0x3];

View File

@ -256,13 +256,14 @@ impl<'a, K: Key + From<Vec<u8>>, C: TrieCursor<K>> TrieWalker<'a, K, C> {
#[cfg(test)]
mod tests {
use super::*;
use crate::trie_cursor::{AccountTrieCursor, StorageTrieCursor};
use reth_db::{
cursor::DbCursorRW, mdbx::test_utils::create_test_rw_db, tables, transaction::DbTxMut,
};
use reth_primitives::trie::StorageTrieEntry;
use reth_provider::Transaction;
use reth_primitives::{trie::StorageTrieEntry, MAINNET};
use reth_provider::ShareableDatabase;
#[test]
fn walk_nodes_with_common_prefix() {
@ -288,8 +289,11 @@ mod tests {
];
let db = create_test_rw_db();
let tx = Transaction::new(db.as_ref()).unwrap();
let mut account_cursor = tx.cursor_write::<tables::AccountsTrie>().unwrap();
let factory = ShareableDatabase::new(db.as_ref(), MAINNET.clone());
let tx = factory.provider_rw().unwrap();
let mut account_cursor = tx.tx_ref().cursor_write::<tables::AccountsTrie>().unwrap();
for (k, v) in &inputs {
account_cursor.upsert(k.clone().into(), v.clone()).unwrap();
}
@ -297,7 +301,7 @@ mod tests {
test_cursor(account_trie, &expected);
let hashed_address = H256::random();
let mut storage_cursor = tx.cursor_dup_write::<tables::StoragesTrie>().unwrap();
let mut storage_cursor = tx.tx_ref().cursor_dup_write::<tables::StoragesTrie>().unwrap();
for (k, v) in &inputs {
storage_cursor
.upsert(
@ -332,8 +336,9 @@ mod tests {
#[test]
fn cursor_rootnode_with_changesets() {
let db = create_test_rw_db();
let tx = Transaction::new(db.as_ref()).unwrap();
let mut cursor = tx.cursor_dup_write::<tables::StoragesTrie>().unwrap();
let factory = ShareableDatabase::new(db.as_ref(), MAINNET.clone());
let tx = factory.provider_rw().unwrap();
let mut cursor = tx.tx_ref().cursor_dup_write::<tables::StoragesTrie>().unwrap();
let nodes = vec![
(