chore: group trie crates in trie folder (#8492)

This commit is contained in:
Matthias Seitz
2024-05-29 22:15:56 +02:00
committed by GitHub
parent a6800771c6
commit fd495eb50b
32 changed files with 4 additions and 4 deletions

View File

@ -0,0 +1,66 @@
[package]
name = "reth-trie-parallel"
version.workspace = true
edition.workspace = true
rust-version.workspace = true
license.workspace = true
homepage.workspace = true
repository.workspace = true
description = "Parallel implementation of merkle root algorithm"
[lints]
workspace = true
[dependencies]
# reth
reth-primitives.workspace = true
reth-db.workspace = true
reth-trie.workspace = true
reth-execution-errors.workspace = true
reth-provider.workspace = true
# alloy
alloy-rlp.workspace = true
# tracing
tracing.workspace = true
# misc
thiserror.workspace = true
derive_more.workspace = true
# `async` feature
reth-tasks = { workspace = true, optional = true }
tokio = { workspace = true, optional = true, default-features = false }
itertools = { workspace = true, optional = true }
# `parallel` feature
rayon = { workspace = true, optional = true }
# `metrics` feature
reth-metrics = { workspace = true, optional = true }
metrics = { workspace = true, optional = true }
[dev-dependencies]
# reth
reth-primitives = { workspace = true, features = ["test-utils", "arbitrary"] }
reth-provider = { workspace = true, features = ["test-utils"] }
reth-trie = { workspace = true, features = ["test-utils"] }
# misc
rand.workspace = true
tokio = { workspace = true, default-features = false, features = ["sync", "rt", "macros"] }
rayon.workspace = true
criterion = { workspace = true, features = ["async_tokio"] }
proptest.workspace = true
[features]
default = ["metrics", "async", "parallel"]
metrics = ["reth-metrics", "dep:metrics", "reth-trie/metrics"]
async = ["reth-tasks/rayon", "tokio/sync", "itertools"]
parallel = ["rayon"]
[[bench]]
name = "root"
required-features = ["async", "parallel"]
harness = false

View File

@ -0,0 +1,135 @@
#![allow(missing_docs, unreachable_pub)]
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use proptest::{prelude::*, strategy::ValueTree, test_runner::TestRunner};
use rayon::ThreadPoolBuilder;
use reth_primitives::{Account, B256, U256};
use reth_provider::{
bundle_state::HashedStateChanges, providers::ConsistentDbView,
test_utils::create_test_provider_factory,
};
use reth_tasks::pool::BlockingTaskPool;
use reth_trie::{
hashed_cursor::HashedPostStateCursorFactory, HashedPostState, HashedStorage, StateRoot,
};
use reth_trie_parallel::{async_root::AsyncStateRoot, parallel_root::ParallelStateRoot};
use std::collections::HashMap;
pub fn calculate_state_root(c: &mut Criterion) {
let mut group = c.benchmark_group("Calculate State Root");
group.sample_size(20);
let runtime = tokio::runtime::Runtime::new().unwrap();
let blocking_pool = BlockingTaskPool::new(ThreadPoolBuilder::default().build().unwrap());
for size in [1_000, 3_000, 5_000, 10_000] {
let (db_state, updated_state) = generate_test_data(size);
let provider_factory = create_test_provider_factory();
{
let provider_rw = provider_factory.provider_rw().unwrap();
HashedStateChanges(db_state).write_to_db(provider_rw.tx_ref()).unwrap();
let (_, updates) =
StateRoot::from_tx(provider_rw.tx_ref()).root_with_updates().unwrap();
updates.flush(provider_rw.tx_ref()).unwrap();
provider_rw.commit().unwrap();
}
let view = ConsistentDbView::new(provider_factory.clone(), None);
// state root
group.bench_function(BenchmarkId::new("sync root", size), |b| {
b.to_async(&runtime).iter_with_setup(
|| {
let sorted_state = updated_state.clone().into_sorted();
let prefix_sets = updated_state.construct_prefix_sets();
let provider = provider_factory.provider().unwrap();
(provider, sorted_state, prefix_sets)
},
|(provider, sorted_state, prefix_sets)| async move {
StateRoot::from_tx(provider.tx_ref())
.with_hashed_cursor_factory(HashedPostStateCursorFactory::new(
provider.tx_ref(),
&sorted_state,
))
.with_prefix_sets(prefix_sets)
.root()
},
)
});
// parallel root
group.bench_function(BenchmarkId::new("parallel root", size), |b| {
b.to_async(&runtime).iter_with_setup(
|| ParallelStateRoot::new(view.clone(), updated_state.clone()),
|calculator| async { calculator.incremental_root() },
);
});
// async root
group.bench_function(BenchmarkId::new("async root", size), |b| {
b.to_async(&runtime).iter_with_setup(
|| AsyncStateRoot::new(view.clone(), blocking_pool.clone(), updated_state.clone()),
|calculator| calculator.incremental_root(),
);
});
}
}
fn generate_test_data(size: usize) -> (HashedPostState, HashedPostState) {
let storage_size = 1_000;
let mut runner = TestRunner::new(ProptestConfig::default());
use proptest::{collection::hash_map, sample::subsequence};
let db_state = hash_map(
any::<B256>(),
(
any::<Account>().prop_filter("non empty account", |a| !a.is_empty()),
hash_map(
any::<B256>(),
any::<U256>().prop_filter("non zero value", |v| !v.is_zero()),
storage_size,
),
),
size,
)
.new_tree(&mut runner)
.unwrap()
.current();
let keys = db_state.keys().cloned().collect::<Vec<_>>();
let keys_to_update = subsequence(keys, size / 2).new_tree(&mut runner).unwrap().current();
let updated_storages = keys_to_update
.into_iter()
.map(|address| {
let (_, storage) = db_state.get(&address).unwrap();
let slots = storage.keys().cloned().collect::<Vec<_>>();
let slots_to_update =
subsequence(slots, storage_size / 2).new_tree(&mut runner).unwrap().current();
(
address,
slots_to_update
.into_iter()
.map(|slot| (slot, any::<U256>().new_tree(&mut runner).unwrap().current()))
.collect::<HashMap<_, _>>(),
)
})
.collect::<HashMap<_, _>>();
(
HashedPostState::default()
.with_accounts(
db_state.iter().map(|(address, (account, _))| (*address, Some(*account))),
)
.with_storages(db_state.into_iter().map(|(address, (_, storage))| {
(address, HashedStorage::from_iter(false, storage))
})),
HashedPostState::default().with_storages(
updated_storages
.into_iter()
.map(|(address, storage)| (address, HashedStorage::from_iter(false, storage))),
),
)
}
criterion_group!(state_root, calculate_state_root);
criterion_main!(state_root);

View File

@ -0,0 +1,329 @@
use crate::{stats::ParallelTrieTracker, storage_root_targets::StorageRootTargets};
use alloy_rlp::{BufMut, Encodable};
use itertools::Itertools;
use reth_db::database::Database;
use reth_execution_errors::StorageRootError;
use reth_primitives::{
trie::{HashBuilder, Nibbles, TrieAccount},
B256,
};
use reth_provider::{providers::ConsistentDbView, DatabaseProviderFactory, ProviderError};
use reth_tasks::pool::BlockingTaskPool;
use reth_trie::{
hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
node_iter::{TrieElement, TrieNodeIter},
trie_cursor::TrieCursorFactory,
updates::TrieUpdates,
walker::TrieWalker,
HashedPostState, StorageRoot,
};
use std::{collections::HashMap, sync::Arc};
use thiserror::Error;
use tracing::*;
#[cfg(feature = "metrics")]
use crate::metrics::ParallelStateRootMetrics;
/// Async state root calculator.
///
/// The calculator starts off by launching tasks to compute storage roots.
/// Then, it immediately starts walking the state trie updating the necessary trie
/// nodes in the process. Upon encountering a leaf node, it will poll the storage root
/// task for the corresponding hashed address.
///
/// Internally, the calculator uses [ConsistentDbView] since
/// it needs to rely on database state saying the same until
/// the last transaction is open.
/// See docs of using [ConsistentDbView] for caveats.
///
/// For sync usage, take a look at `ParallelStateRoot`.
#[derive(Debug)]
pub struct AsyncStateRoot<DB, Provider> {
/// Consistent view of the database.
view: ConsistentDbView<DB, Provider>,
/// Blocking task pool.
blocking_pool: BlockingTaskPool,
/// Changed hashed state.
hashed_state: HashedPostState,
/// Parallel state root metrics.
#[cfg(feature = "metrics")]
metrics: ParallelStateRootMetrics,
}
impl<DB, Provider> AsyncStateRoot<DB, Provider> {
/// Create new async state root calculator.
pub fn new(
view: ConsistentDbView<DB, Provider>,
blocking_pool: BlockingTaskPool,
hashed_state: HashedPostState,
) -> Self {
Self {
view,
blocking_pool,
hashed_state,
#[cfg(feature = "metrics")]
metrics: ParallelStateRootMetrics::default(),
}
}
}
impl<DB, Provider> AsyncStateRoot<DB, Provider>
where
DB: Database + Clone + 'static,
Provider: DatabaseProviderFactory<DB> + Clone + Send + Sync + 'static,
{
/// Calculate incremental state root asynchronously.
pub async fn incremental_root(self) -> Result<B256, AsyncStateRootError> {
self.calculate(false).await.map(|(root, _)| root)
}
/// Calculate incremental state root with updates asynchronously.
pub async fn incremental_root_with_updates(
self,
) -> Result<(B256, TrieUpdates), AsyncStateRootError> {
self.calculate(true).await
}
async fn calculate(
self,
retain_updates: bool,
) -> Result<(B256, TrieUpdates), AsyncStateRootError> {
let mut tracker = ParallelTrieTracker::default();
let prefix_sets = self.hashed_state.construct_prefix_sets();
let storage_root_targets = StorageRootTargets::new(
self.hashed_state.accounts.keys().copied(),
prefix_sets.storage_prefix_sets,
);
let hashed_state_sorted = Arc::new(self.hashed_state.into_sorted());
// Pre-calculate storage roots async for accounts which were changed.
tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64);
debug!(target: "trie::async_state_root", len = storage_root_targets.len(), "pre-calculating storage roots");
let mut storage_roots = HashMap::with_capacity(storage_root_targets.len());
for (hashed_address, prefix_set) in
storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
{
let view = self.view.clone();
let hashed_state_sorted = hashed_state_sorted.clone();
#[cfg(feature = "metrics")]
let metrics = self.metrics.storage_trie.clone();
let handle =
self.blocking_pool.spawn_fifo(move || -> Result<_, AsyncStateRootError> {
let provider = view.provider_ro()?;
Ok(StorageRoot::new_hashed(
provider.tx_ref(),
HashedPostStateCursorFactory::new(provider.tx_ref(), &hashed_state_sorted),
hashed_address,
#[cfg(feature = "metrics")]
metrics,
)
.with_prefix_set(prefix_set)
.calculate(retain_updates)?)
});
storage_roots.insert(hashed_address, handle);
}
trace!(target: "trie::async_state_root", "calculating state root");
let mut trie_updates = TrieUpdates::default();
let provider_ro = self.view.provider_ro()?;
let tx = provider_ro.tx_ref();
let hashed_cursor_factory = HashedPostStateCursorFactory::new(tx, &hashed_state_sorted);
let trie_cursor_factory = tx;
let walker = TrieWalker::new(
trie_cursor_factory.account_trie_cursor().map_err(ProviderError::Database)?,
prefix_sets.account_prefix_set,
)
.with_updates(retain_updates);
let mut account_node_iter = TrieNodeIter::new(
walker,
hashed_cursor_factory.hashed_account_cursor().map_err(ProviderError::Database)?,
);
let mut hash_builder = HashBuilder::default().with_updates(retain_updates);
let mut account_rlp = Vec::with_capacity(128);
while let Some(node) = account_node_iter.try_next().map_err(ProviderError::Database)? {
match node {
TrieElement::Branch(node) => {
hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
}
TrieElement::Leaf(hashed_address, account) => {
let (storage_root, _, updates) = match storage_roots.remove(&hashed_address) {
Some(rx) => rx.await.map_err(|_| {
AsyncStateRootError::StorageRootChannelClosed { hashed_address }
})??,
// Since we do not store all intermediate nodes in the database, there might
// be a possibility of re-adding a non-modified leaf to the hash builder.
None => {
tracker.inc_missed_leaves();
StorageRoot::new_hashed(
trie_cursor_factory,
hashed_cursor_factory.clone(),
hashed_address,
#[cfg(feature = "metrics")]
self.metrics.storage_trie.clone(),
)
.calculate(retain_updates)?
}
};
if retain_updates {
trie_updates.extend(updates.into_iter());
}
account_rlp.clear();
let account = TrieAccount::from((account, storage_root));
account.encode(&mut account_rlp as &mut dyn BufMut);
hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
}
}
}
let root = hash_builder.root();
trie_updates.finalize_state_updates(
account_node_iter.walker,
hash_builder,
prefix_sets.destroyed_accounts,
);
let stats = tracker.finish();
#[cfg(feature = "metrics")]
self.metrics.record_state_trie(stats);
trace!(
target: "trie::async_state_root",
%root,
duration = ?stats.duration(),
branches_added = stats.branches_added(),
leaves_added = stats.leaves_added(),
missed_leaves = stats.missed_leaves(),
precomputed_storage_roots = stats.precomputed_storage_roots(),
"calculated state root"
);
Ok((root, trie_updates))
}
}
/// Error during async state root calculation.
#[derive(Error, Debug)]
pub enum AsyncStateRootError {
/// Storage root channel for a given address was closed.
#[error("storage root channel for {hashed_address} got closed")]
StorageRootChannelClosed {
/// The hashed address for which channel was closed.
hashed_address: B256,
},
/// Error while calculating storage root.
#[error(transparent)]
StorageRoot(#[from] StorageRootError),
/// Provider error.
#[error(transparent)]
Provider(#[from] ProviderError),
}
#[cfg(test)]
mod tests {
use super::*;
use rand::Rng;
use rayon::ThreadPoolBuilder;
use reth_primitives::{keccak256, Account, Address, StorageEntry, U256};
use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
use reth_trie::{test_utils, HashedStorage};
#[tokio::test]
async fn random_async_root() {
let blocking_pool = BlockingTaskPool::new(ThreadPoolBuilder::default().build().unwrap());
let factory = create_test_provider_factory();
let consistent_view = ConsistentDbView::new(factory.clone(), None);
let mut rng = rand::thread_rng();
let mut state = (0..100)
.map(|_| {
let address = Address::random();
let account =
Account { balance: U256::from(rng.gen::<u64>()), ..Default::default() };
let mut storage = HashMap::<B256, U256>::default();
let has_storage = rng.gen_bool(0.7);
if has_storage {
for _ in 0..100 {
storage.insert(
B256::from(U256::from(rng.gen::<u64>())),
U256::from(rng.gen::<u64>()),
);
}
}
(address, (account, storage))
})
.collect::<HashMap<_, _>>();
{
let provider_rw = factory.provider_rw().unwrap();
provider_rw
.insert_account_for_hashing(
state.iter().map(|(address, (account, _))| (*address, Some(*account))),
)
.unwrap();
provider_rw
.insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
(
*address,
storage
.iter()
.map(|(slot, value)| StorageEntry { key: *slot, value: *value }),
)
}))
.unwrap();
provider_rw.commit().unwrap();
}
assert_eq!(
AsyncStateRoot::new(
consistent_view.clone(),
blocking_pool.clone(),
HashedPostState::default()
)
.incremental_root()
.await
.unwrap(),
test_utils::state_root(state.clone())
);
let mut hashed_state = HashedPostState::default();
for (address, (account, storage)) in state.iter_mut() {
let hashed_address = keccak256(address);
let should_update_account = rng.gen_bool(0.5);
if should_update_account {
*account = Account { balance: U256::from(rng.gen::<u64>()), ..*account };
hashed_state.accounts.insert(hashed_address, Some(*account));
}
let should_update_storage = rng.gen_bool(0.3);
if should_update_storage {
for (slot, value) in storage.iter_mut() {
let hashed_slot = keccak256(slot);
*value = U256::from(rng.gen::<u64>());
hashed_state
.storages
.entry(hashed_address)
.or_insert_with(|| HashedStorage::new(false))
.storage
.insert(hashed_slot, *value);
}
}
}
assert_eq!(
AsyncStateRoot::new(consistent_view.clone(), blocking_pool.clone(), hashed_state)
.incremental_root()
.await
.unwrap(),
test_utils::state_root(state)
);
}
}

View File

@ -0,0 +1,26 @@
//! Implementation of exotic state root computation approaches.
#![doc(
html_logo_url = "https://raw.githubusercontent.com/paradigmxyz/reth/main/assets/reth-docs.png",
html_favicon_url = "https://avatars0.githubusercontent.com/u/97369466?s=256",
issue_tracker_base_url = "https://github.com/paradigmxyz/reth/issues/"
)]
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
mod storage_root_targets;
pub use storage_root_targets::StorageRootTargets;
/// Parallel trie calculation stats.
pub mod stats;
/// Implementation of async state root computation.
#[cfg(feature = "async")]
pub mod async_root;
/// Implementation of parallel state root computation.
#[cfg(feature = "parallel")]
pub mod parallel_root;
/// Parallel state root metrics.
#[cfg(feature = "metrics")]
pub mod metrics;

View File

@ -0,0 +1,44 @@
use crate::stats::ParallelTrieStats;
use metrics::Histogram;
use reth_metrics::Metrics;
use reth_trie::metrics::{TrieRootMetrics, TrieType};
/// Parallel state root metrics.
#[derive(Debug)]
pub struct ParallelStateRootMetrics {
/// State trie metrics.
pub state_trie: TrieRootMetrics,
/// Parallel trie metrics.
pub parallel: ParallelTrieMetrics,
/// Storage trie metrics.
pub storage_trie: TrieRootMetrics,
}
impl Default for ParallelStateRootMetrics {
fn default() -> Self {
Self {
state_trie: TrieRootMetrics::new(TrieType::State),
parallel: ParallelTrieMetrics::default(),
storage_trie: TrieRootMetrics::new(TrieType::Storage),
}
}
}
impl ParallelStateRootMetrics {
/// Record state trie metrics
pub fn record_state_trie(&self, stats: ParallelTrieStats) {
self.state_trie.record(stats.trie_stats());
self.parallel.precomputed_storage_roots.record(stats.precomputed_storage_roots() as f64);
self.parallel.missed_leaves.record(stats.missed_leaves() as f64);
}
}
/// Parallel state root metrics.
#[derive(Metrics)]
#[metrics(scope = "trie_parallel")]
pub struct ParallelTrieMetrics {
/// The number of storage roots computed in parallel.
pub precomputed_storage_roots: Histogram,
/// The number of leaves for which we did not pre-compute the storage roots.
pub missed_leaves: Histogram,
}

View File

@ -0,0 +1,305 @@
use crate::{stats::ParallelTrieTracker, storage_root_targets::StorageRootTargets};
use alloy_rlp::{BufMut, Encodable};
use rayon::prelude::*;
use reth_db::database::Database;
use reth_execution_errors::StorageRootError;
use reth_primitives::{
trie::{HashBuilder, Nibbles, TrieAccount},
B256,
};
use reth_provider::{providers::ConsistentDbView, DatabaseProviderFactory, ProviderError};
use reth_trie::{
hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
node_iter::{TrieElement, TrieNodeIter},
trie_cursor::TrieCursorFactory,
updates::TrieUpdates,
walker::TrieWalker,
HashedPostState, StorageRoot,
};
use std::collections::HashMap;
use thiserror::Error;
use tracing::*;
#[cfg(feature = "metrics")]
use crate::metrics::ParallelStateRootMetrics;
/// Parallel incremental state root calculator.
///
/// The calculator starts off by pre-computing storage roots of changed
/// accounts in parallel. Once that's done, it proceeds to walking the state
/// trie retrieving the pre-computed storage roots when needed.
///
/// Internally, the calculator uses [ConsistentDbView] since
/// it needs to rely on database state saying the same until
/// the last transaction is open.
/// See docs of using [ConsistentDbView] for caveats.
///
/// If possible, use more optimized `AsyncStateRoot` instead.
#[derive(Debug)]
pub struct ParallelStateRoot<DB, Provider> {
/// Consistent view of the database.
view: ConsistentDbView<DB, Provider>,
/// Changed hashed state.
hashed_state: HashedPostState,
/// Parallel state root metrics.
#[cfg(feature = "metrics")]
metrics: ParallelStateRootMetrics,
}
impl<DB, Provider> ParallelStateRoot<DB, Provider> {
/// Create new parallel state root calculator.
pub fn new(view: ConsistentDbView<DB, Provider>, hashed_state: HashedPostState) -> Self {
Self {
view,
hashed_state,
#[cfg(feature = "metrics")]
metrics: ParallelStateRootMetrics::default(),
}
}
}
impl<DB, Provider> ParallelStateRoot<DB, Provider>
where
DB: Database,
Provider: DatabaseProviderFactory<DB> + Send + Sync,
{
/// Calculate incremental state root in parallel.
pub fn incremental_root(self) -> Result<B256, ParallelStateRootError> {
self.calculate(false).map(|(root, _)| root)
}
/// Calculate incremental state root with updates in parallel.
pub fn incremental_root_with_updates(
self,
) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
self.calculate(true)
}
fn calculate(
self,
retain_updates: bool,
) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
let mut tracker = ParallelTrieTracker::default();
let prefix_sets = self.hashed_state.construct_prefix_sets();
let storage_root_targets = StorageRootTargets::new(
self.hashed_state.accounts.keys().copied(),
prefix_sets.storage_prefix_sets,
);
let hashed_state_sorted = self.hashed_state.into_sorted();
// Pre-calculate storage roots in parallel for accounts which were changed.
tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64);
debug!(target: "trie::parallel_state_root", len = storage_root_targets.len(), "pre-calculating storage roots");
let mut storage_roots = storage_root_targets
.into_par_iter()
.map(|(hashed_address, prefix_set)| {
let provider_ro = self.view.provider_ro()?;
let storage_root_result = StorageRoot::new_hashed(
provider_ro.tx_ref(),
HashedPostStateCursorFactory::new(provider_ro.tx_ref(), &hashed_state_sorted),
hashed_address,
#[cfg(feature = "metrics")]
self.metrics.storage_trie.clone(),
)
.with_prefix_set(prefix_set)
.calculate(retain_updates);
Ok((hashed_address, storage_root_result?))
})
.collect::<Result<HashMap<_, _>, ParallelStateRootError>>()?;
trace!(target: "trie::parallel_state_root", "calculating state root");
let mut trie_updates = TrieUpdates::default();
let provider_ro = self.view.provider_ro()?;
let hashed_cursor_factory =
HashedPostStateCursorFactory::new(provider_ro.tx_ref(), &hashed_state_sorted);
let trie_cursor_factory = provider_ro.tx_ref();
let walker = TrieWalker::new(
trie_cursor_factory.account_trie_cursor().map_err(ProviderError::Database)?,
prefix_sets.account_prefix_set,
)
.with_updates(retain_updates);
let mut account_node_iter = TrieNodeIter::new(
walker,
hashed_cursor_factory.hashed_account_cursor().map_err(ProviderError::Database)?,
);
let mut hash_builder = HashBuilder::default().with_updates(retain_updates);
let mut account_rlp = Vec::with_capacity(128);
while let Some(node) = account_node_iter.try_next().map_err(ProviderError::Database)? {
match node {
TrieElement::Branch(node) => {
hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
}
TrieElement::Leaf(hashed_address, account) => {
let (storage_root, _, updates) = match storage_roots.remove(&hashed_address) {
Some(result) => result,
// Since we do not store all intermediate nodes in the database, there might
// be a possibility of re-adding a non-modified leaf to the hash builder.
None => {
tracker.inc_missed_leaves();
StorageRoot::new_hashed(
trie_cursor_factory,
hashed_cursor_factory.clone(),
hashed_address,
#[cfg(feature = "metrics")]
self.metrics.storage_trie.clone(),
)
.calculate(retain_updates)?
}
};
if retain_updates {
trie_updates.extend(updates.into_iter());
}
account_rlp.clear();
let account = TrieAccount::from((account, storage_root));
account.encode(&mut account_rlp as &mut dyn BufMut);
hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
}
}
}
let root = hash_builder.root();
trie_updates.finalize_state_updates(
account_node_iter.walker,
hash_builder,
prefix_sets.destroyed_accounts,
);
let stats = tracker.finish();
#[cfg(feature = "metrics")]
self.metrics.record_state_trie(stats);
trace!(
target: "trie::parallel_state_root",
%root,
duration = ?stats.duration(),
branches_added = stats.branches_added(),
leaves_added = stats.leaves_added(),
missed_leaves = stats.missed_leaves(),
precomputed_storage_roots = stats.precomputed_storage_roots(),
"calculated state root"
);
Ok((root, trie_updates))
}
}
/// Error during parallel state root calculation.
#[derive(Error, Debug)]
pub enum ParallelStateRootError {
/// Error while calculating storage root.
#[error(transparent)]
StorageRoot(#[from] StorageRootError),
/// Provider error.
#[error(transparent)]
Provider(#[from] ProviderError),
}
impl From<ParallelStateRootError> for ProviderError {
fn from(error: ParallelStateRootError) -> Self {
match error {
ParallelStateRootError::Provider(error) => error,
ParallelStateRootError::StorageRoot(StorageRootError::DB(error)) => {
Self::Database(error)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::Rng;
use reth_primitives::{keccak256, Account, Address, StorageEntry, U256};
use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
use reth_trie::{test_utils, HashedStorage};
#[tokio::test]
async fn random_parallel_root() {
let factory = create_test_provider_factory();
let consistent_view = ConsistentDbView::new(factory.clone(), None);
let mut rng = rand::thread_rng();
let mut state = (0..100)
.map(|_| {
let address = Address::random();
let account =
Account { balance: U256::from(rng.gen::<u64>()), ..Default::default() };
let mut storage = HashMap::<B256, U256>::default();
let has_storage = rng.gen_bool(0.7);
if has_storage {
for _ in 0..100 {
storage.insert(
B256::from(U256::from(rng.gen::<u64>())),
U256::from(rng.gen::<u64>()),
);
}
}
(address, (account, storage))
})
.collect::<HashMap<_, _>>();
{
let provider_rw = factory.provider_rw().unwrap();
provider_rw
.insert_account_for_hashing(
state.iter().map(|(address, (account, _))| (*address, Some(*account))),
)
.unwrap();
provider_rw
.insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
(
*address,
storage
.iter()
.map(|(slot, value)| StorageEntry { key: *slot, value: *value }),
)
}))
.unwrap();
provider_rw.commit().unwrap();
}
assert_eq!(
ParallelStateRoot::new(consistent_view.clone(), HashedPostState::default())
.incremental_root()
.unwrap(),
test_utils::state_root(state.clone())
);
let mut hashed_state = HashedPostState::default();
for (address, (account, storage)) in state.iter_mut() {
let hashed_address = keccak256(address);
let should_update_account = rng.gen_bool(0.5);
if should_update_account {
*account = Account { balance: U256::from(rng.gen::<u64>()), ..*account };
hashed_state.accounts.insert(hashed_address, Some(*account));
}
let should_update_storage = rng.gen_bool(0.3);
if should_update_storage {
for (slot, value) in storage.iter_mut() {
let hashed_slot = keccak256(slot);
*value = U256::from(rng.gen::<u64>());
hashed_state
.storages
.entry(hashed_address)
.or_insert_with(|| HashedStorage::new(false))
.storage
.insert(hashed_slot, *value);
}
}
}
assert_eq!(
ParallelStateRoot::new(consistent_view, hashed_state).incremental_root().unwrap(),
test_utils::state_root(state)
);
}
}

View File

@ -0,0 +1,68 @@
use derive_more::Deref;
use reth_trie::stats::{TrieStats, TrieTracker};
/// Trie stats.
#[derive(Deref, Clone, Copy, Debug)]
pub struct ParallelTrieStats {
#[deref]
trie: TrieStats,
precomputed_storage_roots: u64,
missed_leaves: u64,
}
impl ParallelTrieStats {
/// Return general trie stats.
pub fn trie_stats(&self) -> TrieStats {
self.trie
}
/// The number of pre-computed storage roots.
pub fn precomputed_storage_roots(&self) -> u64 {
self.precomputed_storage_roots
}
/// The number of added leaf nodes for which we did not precompute the storage root.
pub fn missed_leaves(&self) -> u64 {
self.missed_leaves
}
}
/// Trie metrics tracker.
#[derive(Deref, Default, Debug)]
pub struct ParallelTrieTracker {
#[deref]
trie: TrieTracker,
precomputed_storage_roots: u64,
missed_leaves: u64,
}
impl ParallelTrieTracker {
/// Set the number of precomputed storage roots.
pub fn set_precomputed_storage_roots(&mut self, count: u64) {
self.precomputed_storage_roots = count;
}
/// Increment the number of branches added to the hash builder during the calculation.
pub fn inc_branch(&mut self) {
self.trie.inc_branch();
}
/// Increment the number of leaves added to the hash builder during the calculation.
pub fn inc_leaf(&mut self) {
self.trie.inc_leaf();
}
/// Increment the number of added leaf nodes for which we did not precompute the storage root.
pub fn inc_missed_leaves(&mut self) {
self.missed_leaves += 1;
}
/// Called when root calculation is finished to return trie statistics.
pub fn finish(self) -> ParallelTrieStats {
ParallelTrieStats {
trie: self.trie.finish(),
precomputed_storage_roots: self.precomputed_storage_roots,
missed_leaves: self.missed_leaves,
}
}
}

View File

@ -0,0 +1,47 @@
use derive_more::{Deref, DerefMut};
use reth_primitives::B256;
use reth_trie::prefix_set::PrefixSet;
use std::collections::HashMap;
/// Target accounts with corresponding prefix sets for storage root calculation.
#[derive(Deref, DerefMut, Debug)]
pub struct StorageRootTargets(HashMap<B256, PrefixSet>);
impl StorageRootTargets {
/// Create new storage root targets from updated post state accounts
/// and storage prefix sets.
///
/// NOTE: Since updated accounts and prefix sets always overlap,
/// it's important that iterator over storage prefix sets takes precedence.
pub fn new(
changed_accounts: impl IntoIterator<Item = B256>,
storage_prefix_sets: impl IntoIterator<Item = (B256, PrefixSet)>,
) -> Self {
Self(
changed_accounts
.into_iter()
.map(|address| (address, PrefixSet::default()))
.chain(storage_prefix_sets)
.collect(),
)
}
}
impl IntoIterator for StorageRootTargets {
type Item = (B256, PrefixSet);
type IntoIter = std::collections::hash_map::IntoIter<B256, PrefixSet>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
#[cfg(feature = "parallel")]
impl rayon::iter::IntoParallelIterator for StorageRootTargets {
type Iter = rayon::collections::hash_map::IntoIter<B256, PrefixSet>;
type Item = (B256, PrefixSet);
fn into_par_iter(self) -> Self::Iter {
self.0.into_par_iter()
}
}