chore(engine): use Arc<rayon::ThreadPool> for StateRootTask (#13755)

This commit is contained in:
Federico Gimenez
2025-01-09 15:25:00 +01:00
committed by GitHub
parent bf65ed45c5
commit 4a8c88f4d0
4 changed files with 39 additions and 31 deletions

View File

@ -23,7 +23,7 @@ use revm_primitives::{
Account as RevmAccount, AccountInfo, AccountStatus, Address, EvmState, EvmStorageSlot, HashMap, Account as RevmAccount, AccountInfo, AccountStatus, Address, EvmState, EvmStorageSlot, HashMap,
B256, KECCAK_EMPTY, U256, B256, KECCAK_EMPTY, U256,
}; };
use std::hint::black_box; use std::{hint::black_box, sync::Arc};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct BenchParams { struct BenchParams {
@ -217,11 +217,13 @@ fn bench_state_root(c: &mut Criterion) {
let num_threads = std::thread::available_parallelism() let num_threads = std::thread::available_parallelism()
.map_or(1, |num| (num.get() / 2).max(1)); .map_or(1, |num| (num.get() / 2).max(1));
let state_root_task_pool = rayon::ThreadPoolBuilder::new() let state_root_task_pool = Arc::new(
.num_threads(num_threads) rayon::ThreadPoolBuilder::new()
.thread_name(|i| format!("proof-worker-{}", i)) .num_threads(num_threads)
.build() .thread_name(|i| format!("proof-worker-{}", i))
.expect("Failed to create proof worker thread pool"); .build()
.expect("Failed to create proof worker thread pool"),
);
( (
config, config,
@ -258,7 +260,7 @@ fn bench_state_root(c: &mut Criterion) {
let task = StateRootTask::new( let task = StateRootTask::new(
config, config,
blinded_provider_factory, blinded_provider_factory,
&state_root_task_pool, state_root_task_pool,
); );
let mut hook = task.state_hook(); let mut hook = task.state_hook();
let handle = task.spawn(scope); let handle = task.spawn(scope);

View File

@ -539,7 +539,7 @@ where
/// The engine API variant of this handler /// The engine API variant of this handler
engine_kind: EngineApiKind, engine_kind: EngineApiKind,
/// state root task thread pool /// state root task thread pool
state_root_task_pool: rayon::ThreadPool, state_root_task_pool: Arc<rayon::ThreadPool>,
} }
impl<N, P: Debug, E: Debug, T: EngineTypes + Debug, V: Debug> std::fmt::Debug impl<N, P: Debug, E: Debug, T: EngineTypes + Debug, V: Debug> std::fmt::Debug
@ -606,11 +606,13 @@ where
let num_threads = let num_threads =
std::thread::available_parallelism().map_or(1, |num| (num.get() / 2).max(1)); std::thread::available_parallelism().map_or(1, |num| (num.get() / 2).max(1));
let state_root_task_pool = rayon::ThreadPoolBuilder::new() let state_root_task_pool = Arc::new(
.num_threads(num_threads) rayon::ThreadPoolBuilder::new()
.thread_name(|i| format!("srt-worker-{}", i)) .num_threads(num_threads)
.build() .thread_name(|i| format!("srt-worker-{}", i))
.expect("Failed to create proof worker thread pool"); .build()
.expect("Failed to create proof worker thread pool"),
);
Self { Self {
provider, provider,
@ -2313,7 +2315,7 @@ where
let state_root_task = StateRootTask::new( let state_root_task = StateRootTask::new(
state_root_config, state_root_config,
blinded_provider_factory, blinded_provider_factory,
&self.state_root_task_pool, self.state_root_task_pool.clone(),
); );
let state_hook = state_root_task.state_hook(); let state_hook = state_root_task.state_hook();
(Some(state_root_task.spawn(scope)), Box::new(state_hook) as Box<dyn OnStateHook>) (Some(state_root_task.spawn(scope)), Box::new(state_hook) as Box<dyn OnStateHook>)

View File

@ -260,7 +260,7 @@ fn evm_state_to_hashed_post_state(update: EvmState) -> HashedPostState {
/// to the tree. /// to the tree.
/// Then it updates relevant leaves according to the result of the transaction. /// Then it updates relevant leaves according to the result of the transaction.
#[derive(Debug)] #[derive(Debug)]
pub struct StateRootTask<'env, Factory, BPF: BlindedProviderFactory> { pub struct StateRootTask<Factory, BPF: BlindedProviderFactory> {
/// Task configuration. /// Task configuration.
config: StateRootConfig<Factory>, config: StateRootConfig<Factory>,
/// Receiver for state root related messages. /// Receiver for state root related messages.
@ -275,10 +275,10 @@ pub struct StateRootTask<'env, Factory, BPF: BlindedProviderFactory> {
/// progress. /// progress.
sparse_trie: Option<Box<SparseStateTrie<BPF>>>, sparse_trie: Option<Box<SparseStateTrie<BPF>>>,
/// Reference to the shared thread pool for parallel proof generation /// Reference to the shared thread pool for parallel proof generation
thread_pool: &'env rayon::ThreadPool, thread_pool: Arc<rayon::ThreadPool>,
} }
impl<'env, Factory, BPF> StateRootTask<'env, Factory, BPF> impl<'env, Factory, BPF> StateRootTask<Factory, BPF>
where where
Factory: DatabaseProviderFactory<Provider: BlockReader> Factory: DatabaseProviderFactory<Provider: BlockReader>
+ StateCommitmentProvider + StateCommitmentProvider
@ -294,7 +294,7 @@ where
pub fn new( pub fn new(
config: StateRootConfig<Factory>, config: StateRootConfig<Factory>,
blinded_provider: BPF, blinded_provider: BPF,
thread_pool: &'env rayon::ThreadPool, thread_pool: Arc<rayon::ThreadPool>,
) -> Self { ) -> Self {
let (tx, rx) = channel(); let (tx, rx) = channel();
@ -344,7 +344,7 @@ where
fetched_proof_targets: &mut MultiProofTargets, fetched_proof_targets: &mut MultiProofTargets,
proof_sequence_number: u64, proof_sequence_number: u64,
state_root_message_sender: Sender<StateRootMessage<BPF>>, state_root_message_sender: Sender<StateRootMessage<BPF>>,
thread_pool: &'env rayon::ThreadPool, thread_pool: Arc<rayon::ThreadPool>,
) { ) {
let proof_targets = let proof_targets =
targets.into_iter().map(|address| (keccak256(address), Default::default())).collect(); targets.into_iter().map(|address| (keccak256(address), Default::default())).collect();
@ -371,7 +371,7 @@ where
fetched_proof_targets: &mut MultiProofTargets, fetched_proof_targets: &mut MultiProofTargets,
proof_sequence_number: u64, proof_sequence_number: u64,
state_root_message_sender: Sender<StateRootMessage<BPF>>, state_root_message_sender: Sender<StateRootMessage<BPF>>,
thread_pool: &'env rayon::ThreadPool, thread_pool: Arc<rayon::ThreadPool>,
) { ) {
let hashed_state_update = evm_state_to_hashed_post_state(update); let hashed_state_update = evm_state_to_hashed_post_state(update);
@ -396,7 +396,7 @@ where
proof_targets: MultiProofTargets, proof_targets: MultiProofTargets,
proof_sequence_number: u64, proof_sequence_number: u64,
state_root_message_sender: Sender<StateRootMessage<BPF>>, state_root_message_sender: Sender<StateRootMessage<BPF>>,
thread_pool: &'env rayon::ThreadPool, thread_pool: Arc<rayon::ThreadPool>,
) { ) {
// Dispatch proof gathering for this state update // Dispatch proof gathering for this state update
scope.spawn(move |_| { scope.spawn(move |_| {
@ -533,7 +533,7 @@ where
&mut self.fetched_proof_targets, &mut self.fetched_proof_targets,
self.proof_sequencer.next_sequence(), self.proof_sequencer.next_sequence(),
self.tx.clone(), self.tx.clone(),
self.thread_pool, self.thread_pool.clone(),
); );
} }
StateRootMessage::StateUpdate(update) => { StateRootMessage::StateUpdate(update) => {
@ -557,7 +557,7 @@ where
&mut self.fetched_proof_targets, &mut self.fetched_proof_targets,
self.proof_sequencer.next_sequence(), self.proof_sequencer.next_sequence(),
self.tx.clone(), self.tx.clone(),
self.thread_pool, self.thread_pool.clone(),
); );
} }
StateRootMessage::FinishedStateUpdates => { StateRootMessage::FinishedStateUpdates => {
@ -735,7 +735,7 @@ fn get_proof_targets(
/// Calculate multiproof for the targets. /// Calculate multiproof for the targets.
#[inline] #[inline]
fn calculate_multiproof<Factory>( fn calculate_multiproof<Factory>(
thread_pool: &rayon::ThreadPool, thread_pool: Arc<rayon::ThreadPool>,
config: StateRootConfig<Factory>, config: StateRootConfig<Factory>,
proof_targets: MultiProofTargets, proof_targets: MultiProofTargets,
) -> ProviderResult<MultiProof> ) -> ProviderResult<MultiProof>
@ -993,7 +993,11 @@ mod tests {
.expect("Failed to create proof worker thread pool"); .expect("Failed to create proof worker thread pool");
let (root_from_task, _) = std::thread::scope(|std_scope| { let (root_from_task, _) = std::thread::scope(|std_scope| {
let task = StateRootTask::new(config, blinded_provider_factory, &state_root_task_pool); let task = StateRootTask::new(
config,
blinded_provider_factory,
Arc::new(state_root_task_pool),
);
let mut state_hook = task.state_hook(); let mut state_hook = task.state_hook();
let handle = task.spawn(std_scope); let handle = task.spawn(std_scope);

View File

@ -32,7 +32,7 @@ use crate::metrics::ParallelStateRootMetrics;
/// TODO: /// TODO:
#[derive(Debug)] #[derive(Debug)]
pub struct ParallelProof<'env, Factory> { pub struct ParallelProof<Factory> {
/// Consistent view of the database. /// Consistent view of the database.
view: ConsistentDbView<Factory>, view: ConsistentDbView<Factory>,
/// The sorted collection of cached in-memory intermediate trie nodes that /// The sorted collection of cached in-memory intermediate trie nodes that
@ -47,20 +47,20 @@ pub struct ParallelProof<'env, Factory> {
/// Flag indicating whether to include branch node hash masks in the proof. /// Flag indicating whether to include branch node hash masks in the proof.
collect_branch_node_hash_masks: bool, collect_branch_node_hash_masks: bool,
/// Thread pool for local tasks /// Thread pool for local tasks
thread_pool: &'env rayon::ThreadPool, thread_pool: Arc<rayon::ThreadPool>,
/// Parallel state root metrics. /// Parallel state root metrics.
#[cfg(feature = "metrics")] #[cfg(feature = "metrics")]
metrics: ParallelStateRootMetrics, metrics: ParallelStateRootMetrics,
} }
impl<'env, Factory> ParallelProof<'env, Factory> { impl<Factory> ParallelProof<Factory> {
/// Create new state proof generator. /// Create new state proof generator.
pub fn new( pub fn new(
view: ConsistentDbView<Factory>, view: ConsistentDbView<Factory>,
nodes_sorted: Arc<TrieUpdatesSorted>, nodes_sorted: Arc<TrieUpdatesSorted>,
state_sorted: Arc<HashedPostStateSorted>, state_sorted: Arc<HashedPostStateSorted>,
prefix_sets: Arc<TriePrefixSetsMut>, prefix_sets: Arc<TriePrefixSetsMut>,
thread_pool: &'env rayon::ThreadPool, thread_pool: Arc<rayon::ThreadPool>,
) -> Self { ) -> Self {
Self { Self {
view, view,
@ -81,7 +81,7 @@ impl<'env, Factory> ParallelProof<'env, Factory> {
} }
} }
impl<Factory> ParallelProof<'_, Factory> impl<Factory> ParallelProof<Factory>
where where
Factory: DatabaseProviderFactory<Provider: BlockReader> Factory: DatabaseProviderFactory<Provider: BlockReader>
+ StateCommitmentProvider + StateCommitmentProvider
@ -407,7 +407,7 @@ mod tests {
Default::default(), Default::default(),
Default::default(), Default::default(),
Default::default(), Default::default(),
&state_root_task_pool Arc::new(state_root_task_pool)
) )
.multiproof(targets.clone()) .multiproof(targets.clone())
.unwrap(), .unwrap(),