From efa5d45e4ea500b3c4b21755899c22233eb6d61c Mon Sep 17 00:00:00 2001 From: Federico Gimenez Date: Mon, 23 Sep 2024 19:45:12 +0200 Subject: [PATCH] feat(trie): use global thread pool in async state root calculation (#11057) --- Cargo.lock | 1 - crates/trie/parallel/Cargo.toml | 3 +- crates/trie/parallel/benches/root.rs | 11 +----- crates/trie/parallel/src/async_root.rs | 51 ++++++++++---------------- 4 files changed, 21 insertions(+), 45 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b6426cc1a..60feb38ac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9080,7 +9080,6 @@ dependencies = [ "reth-metrics", "reth-primitives", "reth-provider", - "reth-tasks", "reth-trie", "reth-trie-db", "thiserror", diff --git a/crates/trie/parallel/Cargo.toml b/crates/trie/parallel/Cargo.toml index e53d15c14..80fa0a70d 100644 --- a/crates/trie/parallel/Cargo.toml +++ b/crates/trie/parallel/Cargo.toml @@ -33,7 +33,6 @@ thiserror.workspace = true derive_more.workspace = true # `async` feature -reth-tasks = { workspace = true, optional = true } tokio = { workspace = true, optional = true, default-features = false } itertools = { workspace = true, optional = true } @@ -61,7 +60,7 @@ proptest-arbitrary-interop.workspace = true [features] default = ["metrics", "async", "parallel"] metrics = ["reth-metrics", "dep:metrics", "reth-trie/metrics"] -async = ["reth-tasks/rayon", "tokio/sync", "itertools"] +async = ["tokio/sync", "itertools"] parallel = ["rayon"] [[bench]] diff --git a/crates/trie/parallel/benches/root.rs b/crates/trie/parallel/benches/root.rs index b8a4d25e5..470222e3e 100644 --- a/crates/trie/parallel/benches/root.rs +++ b/crates/trie/parallel/benches/root.rs @@ -3,13 +3,11 @@ use alloy_primitives::{B256, U256}; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use proptest::{prelude::*, strategy::ValueTree, test_runner::TestRunner}; use proptest_arbitrary_interop::arb; -use rayon::ThreadPoolBuilder; use reth_primitives::Account; use reth_provider::{ providers::ConsistentDbView, test_utils::create_test_provider_factory, StateChangeWriter, TrieWriter, }; -use reth_tasks::pool::BlockingTaskPool; use reth_trie::{ hashed_cursor::HashedPostStateCursorFactory, HashedPostState, HashedStorage, StateRoot, TrieInput, @@ -23,7 +21,6 @@ pub fn calculate_state_root(c: &mut Criterion) { group.sample_size(20); let runtime = tokio::runtime::Runtime::new().unwrap(); - let blocking_pool = BlockingTaskPool::new(ThreadPoolBuilder::default().build().unwrap()); for size in [1_000, 3_000, 5_000, 10_000] { let (db_state, updated_state) = generate_test_data(size); @@ -77,13 +74,7 @@ pub fn calculate_state_root(c: &mut Criterion) { // async root group.bench_function(BenchmarkId::new("async root", size), |b| { b.to_async(&runtime).iter_with_setup( - || { - AsyncStateRoot::new( - view.clone(), - blocking_pool.clone(), - TrieInput::from_state(updated_state.clone()), - ) - }, + || AsyncStateRoot::new(view.clone(), TrieInput::from_state(updated_state.clone())), |calculator| calculator.incremental_root(), ); }); diff --git a/crates/trie/parallel/src/async_root.rs b/crates/trie/parallel/src/async_root.rs index ed12accb4..74481f09e 100644 --- a/crates/trie/parallel/src/async_root.rs +++ b/crates/trie/parallel/src/async_root.rs @@ -8,7 +8,6 @@ use reth_execution_errors::StorageRootError; use reth_provider::{ providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, ProviderError, }; -use reth_tasks::pool::BlockingTaskPool; use reth_trie::{ hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory}, node_iter::{TrieElement, TrieNodeIter}, @@ -20,6 +19,7 @@ use reth_trie::{ use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory}; use std::{collections::HashMap, sync::Arc}; use thiserror::Error; +use tokio::sync::oneshot; use tracing::*; /// Async state root calculator. @@ -39,8 +39,6 @@ use tracing::*; pub struct AsyncStateRoot { /// Consistent view of the database. view: ConsistentDbView, - /// Blocking task pool. - blocking_pool: BlockingTaskPool, /// Trie input. input: TrieInput, /// Parallel state root metrics. @@ -50,14 +48,9 @@ pub struct AsyncStateRoot { impl AsyncStateRoot { /// Create new async state root calculator. - pub fn new( - view: ConsistentDbView, - blocking_pool: BlockingTaskPool, - input: TrieInput, - ) -> Self { + pub fn new(view: ConsistentDbView, input: TrieInput) -> Self { Self { view, - blocking_pool, input, #[cfg(feature = "metrics")] metrics: ParallelStateRootMetrics::default(), @@ -106,8 +99,11 @@ where let trie_nodes_sorted = trie_nodes_sorted.clone(); #[cfg(feature = "metrics")] let metrics = self.metrics.storage_trie.clone(); - let handle = - self.blocking_pool.spawn_fifo(move || -> Result<_, AsyncStateRootError> { + + let (tx, rx) = oneshot::channel(); + + rayon::spawn_fifo(move || { + let result = (|| -> Result<_, AsyncStateRootError> { let provider_ro = view.provider_ro()?; let trie_cursor_factory = InMemoryTrieCursorFactory::new( DatabaseTrieCursorFactory::new(provider_ro.tx_ref()), @@ -126,8 +122,10 @@ where ) .with_prefix_set(prefix_set) .calculate(retain_updates)?) - }); - storage_roots.insert(hashed_address, handle); + })(); + let _ = tx.send(result); + }); + storage_roots.insert(hashed_address, rx); } trace!(target: "trie::async_state_root", "calculating state root"); @@ -242,15 +240,12 @@ mod tests { use super::*; use alloy_primitives::{keccak256, Address, U256}; use rand::Rng; - use rayon::ThreadPoolBuilder; use reth_primitives::{Account, StorageEntry}; use reth_provider::{test_utils::create_test_provider_factory, HashingWriter}; use reth_trie::{test_utils, HashedPostState, HashedStorage}; #[tokio::test] async fn random_async_root() { - let blocking_pool = BlockingTaskPool::new(ThreadPoolBuilder::default().build().unwrap()); - let factory = create_test_provider_factory(); let consistent_view = ConsistentDbView::new(factory.clone(), None); @@ -295,14 +290,10 @@ mod tests { } assert_eq!( - AsyncStateRoot::new( - consistent_view.clone(), - blocking_pool.clone(), - Default::default(), - ) - .incremental_root() - .await - .unwrap(), + AsyncStateRoot::new(consistent_view.clone(), Default::default(),) + .incremental_root() + .await + .unwrap(), test_utils::state_root(state.clone()) ); @@ -332,14 +323,10 @@ mod tests { } assert_eq!( - AsyncStateRoot::new( - consistent_view.clone(), - blocking_pool.clone(), - TrieInput::from_state(hashed_state) - ) - .incremental_root() - .await - .unwrap(), + AsyncStateRoot::new(consistent_view.clone(), TrieInput::from_state(hashed_state)) + .incremental_root() + .await + .unwrap(), test_utils::state_root(state) ); }