feat(trie): use global thread pool in async state root calculation (#11057)

This commit is contained in:
Federico Gimenez
2024-09-23 19:45:12 +02:00
committed by GitHub
parent b29ff1f6cc
commit efa5d45e4e
4 changed files with 21 additions and 45 deletions

1
Cargo.lock generated
View File

@ -9080,7 +9080,6 @@ dependencies = [
"reth-metrics", "reth-metrics",
"reth-primitives", "reth-primitives",
"reth-provider", "reth-provider",
"reth-tasks",
"reth-trie", "reth-trie",
"reth-trie-db", "reth-trie-db",
"thiserror", "thiserror",

View File

@ -33,7 +33,6 @@ thiserror.workspace = true
derive_more.workspace = true derive_more.workspace = true
# `async` feature # `async` feature
reth-tasks = { workspace = true, optional = true }
tokio = { workspace = true, optional = true, default-features = false } tokio = { workspace = true, optional = true, default-features = false }
itertools = { workspace = true, optional = true } itertools = { workspace = true, optional = true }
@ -61,7 +60,7 @@ proptest-arbitrary-interop.workspace = true
[features] [features]
default = ["metrics", "async", "parallel"] default = ["metrics", "async", "parallel"]
metrics = ["reth-metrics", "dep:metrics", "reth-trie/metrics"] metrics = ["reth-metrics", "dep:metrics", "reth-trie/metrics"]
async = ["reth-tasks/rayon", "tokio/sync", "itertools"] async = ["tokio/sync", "itertools"]
parallel = ["rayon"] parallel = ["rayon"]
[[bench]] [[bench]]

View File

@ -3,13 +3,11 @@ use alloy_primitives::{B256, U256};
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use proptest::{prelude::*, strategy::ValueTree, test_runner::TestRunner}; use proptest::{prelude::*, strategy::ValueTree, test_runner::TestRunner};
use proptest_arbitrary_interop::arb; use proptest_arbitrary_interop::arb;
use rayon::ThreadPoolBuilder;
use reth_primitives::Account; use reth_primitives::Account;
use reth_provider::{ use reth_provider::{
providers::ConsistentDbView, test_utils::create_test_provider_factory, StateChangeWriter, providers::ConsistentDbView, test_utils::create_test_provider_factory, StateChangeWriter,
TrieWriter, TrieWriter,
}; };
use reth_tasks::pool::BlockingTaskPool;
use reth_trie::{ use reth_trie::{
hashed_cursor::HashedPostStateCursorFactory, HashedPostState, HashedStorage, StateRoot, hashed_cursor::HashedPostStateCursorFactory, HashedPostState, HashedStorage, StateRoot,
TrieInput, TrieInput,
@ -23,7 +21,6 @@ pub fn calculate_state_root(c: &mut Criterion) {
group.sample_size(20); group.sample_size(20);
let runtime = tokio::runtime::Runtime::new().unwrap(); let runtime = tokio::runtime::Runtime::new().unwrap();
let blocking_pool = BlockingTaskPool::new(ThreadPoolBuilder::default().build().unwrap());
for size in [1_000, 3_000, 5_000, 10_000] { for size in [1_000, 3_000, 5_000, 10_000] {
let (db_state, updated_state) = generate_test_data(size); let (db_state, updated_state) = generate_test_data(size);
@ -77,13 +74,7 @@ pub fn calculate_state_root(c: &mut Criterion) {
// async root // async root
group.bench_function(BenchmarkId::new("async root", size), |b| { group.bench_function(BenchmarkId::new("async root", size), |b| {
b.to_async(&runtime).iter_with_setup( b.to_async(&runtime).iter_with_setup(
|| { || AsyncStateRoot::new(view.clone(), TrieInput::from_state(updated_state.clone())),
AsyncStateRoot::new(
view.clone(),
blocking_pool.clone(),
TrieInput::from_state(updated_state.clone()),
)
},
|calculator| calculator.incremental_root(), |calculator| calculator.incremental_root(),
); );
}); });

View File

@ -8,7 +8,6 @@ use reth_execution_errors::StorageRootError;
use reth_provider::{ use reth_provider::{
providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, ProviderError, providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, ProviderError,
}; };
use reth_tasks::pool::BlockingTaskPool;
use reth_trie::{ use reth_trie::{
hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory}, hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
node_iter::{TrieElement, TrieNodeIter}, node_iter::{TrieElement, TrieNodeIter},
@ -20,6 +19,7 @@ use reth_trie::{
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory}; use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use thiserror::Error; use thiserror::Error;
use tokio::sync::oneshot;
use tracing::*; use tracing::*;
/// Async state root calculator. /// Async state root calculator.
@ -39,8 +39,6 @@ use tracing::*;
pub struct AsyncStateRoot<Factory> { pub struct AsyncStateRoot<Factory> {
/// Consistent view of the database. /// Consistent view of the database.
view: ConsistentDbView<Factory>, view: ConsistentDbView<Factory>,
/// Blocking task pool.
blocking_pool: BlockingTaskPool,
/// Trie input. /// Trie input.
input: TrieInput, input: TrieInput,
/// Parallel state root metrics. /// Parallel state root metrics.
@ -50,14 +48,9 @@ pub struct AsyncStateRoot<Factory> {
impl<Factory> AsyncStateRoot<Factory> { impl<Factory> AsyncStateRoot<Factory> {
/// Create new async state root calculator. /// Create new async state root calculator.
pub fn new( pub fn new(view: ConsistentDbView<Factory>, input: TrieInput) -> Self {
view: ConsistentDbView<Factory>,
blocking_pool: BlockingTaskPool,
input: TrieInput,
) -> Self {
Self { Self {
view, view,
blocking_pool,
input, input,
#[cfg(feature = "metrics")] #[cfg(feature = "metrics")]
metrics: ParallelStateRootMetrics::default(), metrics: ParallelStateRootMetrics::default(),
@ -106,8 +99,11 @@ where
let trie_nodes_sorted = trie_nodes_sorted.clone(); let trie_nodes_sorted = trie_nodes_sorted.clone();
#[cfg(feature = "metrics")] #[cfg(feature = "metrics")]
let metrics = self.metrics.storage_trie.clone(); let metrics = self.metrics.storage_trie.clone();
let handle =
self.blocking_pool.spawn_fifo(move || -> Result<_, AsyncStateRootError> { let (tx, rx) = oneshot::channel();
rayon::spawn_fifo(move || {
let result = (|| -> Result<_, AsyncStateRootError> {
let provider_ro = view.provider_ro()?; let provider_ro = view.provider_ro()?;
let trie_cursor_factory = InMemoryTrieCursorFactory::new( let trie_cursor_factory = InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider_ro.tx_ref()), DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
@ -126,8 +122,10 @@ where
) )
.with_prefix_set(prefix_set) .with_prefix_set(prefix_set)
.calculate(retain_updates)?) .calculate(retain_updates)?)
}); })();
storage_roots.insert(hashed_address, handle); let _ = tx.send(result);
});
storage_roots.insert(hashed_address, rx);
} }
trace!(target: "trie::async_state_root", "calculating state root"); trace!(target: "trie::async_state_root", "calculating state root");
@ -242,15 +240,12 @@ mod tests {
use super::*; use super::*;
use alloy_primitives::{keccak256, Address, U256}; use alloy_primitives::{keccak256, Address, U256};
use rand::Rng; use rand::Rng;
use rayon::ThreadPoolBuilder;
use reth_primitives::{Account, StorageEntry}; use reth_primitives::{Account, StorageEntry};
use reth_provider::{test_utils::create_test_provider_factory, HashingWriter}; use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
use reth_trie::{test_utils, HashedPostState, HashedStorage}; use reth_trie::{test_utils, HashedPostState, HashedStorage};
#[tokio::test] #[tokio::test]
async fn random_async_root() { async fn random_async_root() {
let blocking_pool = BlockingTaskPool::new(ThreadPoolBuilder::default().build().unwrap());
let factory = create_test_provider_factory(); let factory = create_test_provider_factory();
let consistent_view = ConsistentDbView::new(factory.clone(), None); let consistent_view = ConsistentDbView::new(factory.clone(), None);
@ -295,14 +290,10 @@ mod tests {
} }
assert_eq!( assert_eq!(
AsyncStateRoot::new( AsyncStateRoot::new(consistent_view.clone(), Default::default(),)
consistent_view.clone(), .incremental_root()
blocking_pool.clone(), .await
Default::default(), .unwrap(),
)
.incremental_root()
.await
.unwrap(),
test_utils::state_root(state.clone()) test_utils::state_root(state.clone())
); );
@ -332,14 +323,10 @@ mod tests {
} }
assert_eq!( assert_eq!(
AsyncStateRoot::new( AsyncStateRoot::new(consistent_view.clone(), TrieInput::from_state(hashed_state))
consistent_view.clone(), .incremental_root()
blocking_pool.clone(), .await
TrieInput::from_state(hashed_state) .unwrap(),
)
.incremental_root()
.await
.unwrap(),
test_utils::state_root(state) test_utils::state_root(state)
); );
} }