From 4a8c88f4d0e319d2ec61cffbbfb014ab0886a5a4 Mon Sep 17 00:00:00 2001 From: Federico Gimenez Date: Thu, 9 Jan 2025 15:25:00 +0100 Subject: [PATCH] chore(engine): use Arc for StateRootTask (#13755) --- crates/engine/tree/benches/state_root_task.rs | 16 +++++++----- crates/engine/tree/src/tree/mod.rs | 16 +++++++----- crates/engine/tree/src/tree/root.rs | 26 +++++++++++-------- crates/trie/parallel/src/proof.rs | 12 ++++----- 4 files changed, 39 insertions(+), 31 deletions(-) diff --git a/crates/engine/tree/benches/state_root_task.rs b/crates/engine/tree/benches/state_root_task.rs index 9958cf0ca..8c5b87138 100644 --- a/crates/engine/tree/benches/state_root_task.rs +++ b/crates/engine/tree/benches/state_root_task.rs @@ -23,7 +23,7 @@ use revm_primitives::{ Account as RevmAccount, AccountInfo, AccountStatus, Address, EvmState, EvmStorageSlot, HashMap, B256, KECCAK_EMPTY, U256, }; -use std::hint::black_box; +use std::{hint::black_box, sync::Arc}; #[derive(Debug, Clone)] struct BenchParams { @@ -217,11 +217,13 @@ fn bench_state_root(c: &mut Criterion) { let num_threads = std::thread::available_parallelism() .map_or(1, |num| (num.get() / 2).max(1)); - let state_root_task_pool = rayon::ThreadPoolBuilder::new() - .num_threads(num_threads) - .thread_name(|i| format!("proof-worker-{}", i)) - .build() - .expect("Failed to create proof worker thread pool"); + let state_root_task_pool = Arc::new( + rayon::ThreadPoolBuilder::new() + .num_threads(num_threads) + .thread_name(|i| format!("proof-worker-{}", i)) + .build() + .expect("Failed to create proof worker thread pool"), + ); ( config, @@ -258,7 +260,7 @@ fn bench_state_root(c: &mut Criterion) { let task = StateRootTask::new( config, blinded_provider_factory, - &state_root_task_pool, + state_root_task_pool, ); let mut hook = task.state_hook(); let handle = task.spawn(scope); diff --git a/crates/engine/tree/src/tree/mod.rs b/crates/engine/tree/src/tree/mod.rs index 1103c569f..429dda728 100644 --- a/crates/engine/tree/src/tree/mod.rs +++ b/crates/engine/tree/src/tree/mod.rs @@ -539,7 +539,7 @@ where /// The engine API variant of this handler engine_kind: EngineApiKind, /// state root task thread pool - state_root_task_pool: rayon::ThreadPool, + state_root_task_pool: Arc, } impl std::fmt::Debug @@ -606,11 +606,13 @@ where let num_threads = std::thread::available_parallelism().map_or(1, |num| (num.get() / 2).max(1)); - let state_root_task_pool = rayon::ThreadPoolBuilder::new() - .num_threads(num_threads) - .thread_name(|i| format!("srt-worker-{}", i)) - .build() - .expect("Failed to create proof worker thread pool"); + let state_root_task_pool = Arc::new( + rayon::ThreadPoolBuilder::new() + .num_threads(num_threads) + .thread_name(|i| format!("srt-worker-{}", i)) + .build() + .expect("Failed to create proof worker thread pool"), + ); Self { provider, @@ -2313,7 +2315,7 @@ where let state_root_task = StateRootTask::new( state_root_config, blinded_provider_factory, - &self.state_root_task_pool, + self.state_root_task_pool.clone(), ); let state_hook = state_root_task.state_hook(); (Some(state_root_task.spawn(scope)), Box::new(state_hook) as Box) diff --git a/crates/engine/tree/src/tree/root.rs b/crates/engine/tree/src/tree/root.rs index b41de299a..6bb213bce 100644 --- a/crates/engine/tree/src/tree/root.rs +++ b/crates/engine/tree/src/tree/root.rs @@ -260,7 +260,7 @@ fn evm_state_to_hashed_post_state(update: EvmState) -> HashedPostState { /// to the tree. /// Then it updates relevant leaves according to the result of the transaction. #[derive(Debug)] -pub struct StateRootTask<'env, Factory, BPF: BlindedProviderFactory> { +pub struct StateRootTask { /// Task configuration. config: StateRootConfig, /// Receiver for state root related messages. @@ -275,10 +275,10 @@ pub struct StateRootTask<'env, Factory, BPF: BlindedProviderFactory> { /// progress. sparse_trie: Option>>, /// Reference to the shared thread pool for parallel proof generation - thread_pool: &'env rayon::ThreadPool, + thread_pool: Arc, } -impl<'env, Factory, BPF> StateRootTask<'env, Factory, BPF> +impl<'env, Factory, BPF> StateRootTask where Factory: DatabaseProviderFactory + StateCommitmentProvider @@ -294,7 +294,7 @@ where pub fn new( config: StateRootConfig, blinded_provider: BPF, - thread_pool: &'env rayon::ThreadPool, + thread_pool: Arc, ) -> Self { let (tx, rx) = channel(); @@ -344,7 +344,7 @@ where fetched_proof_targets: &mut MultiProofTargets, proof_sequence_number: u64, state_root_message_sender: Sender>, - thread_pool: &'env rayon::ThreadPool, + thread_pool: Arc, ) { let proof_targets = targets.into_iter().map(|address| (keccak256(address), Default::default())).collect(); @@ -371,7 +371,7 @@ where fetched_proof_targets: &mut MultiProofTargets, proof_sequence_number: u64, state_root_message_sender: Sender>, - thread_pool: &'env rayon::ThreadPool, + thread_pool: Arc, ) { let hashed_state_update = evm_state_to_hashed_post_state(update); @@ -396,7 +396,7 @@ where proof_targets: MultiProofTargets, proof_sequence_number: u64, state_root_message_sender: Sender>, - thread_pool: &'env rayon::ThreadPool, + thread_pool: Arc, ) { // Dispatch proof gathering for this state update scope.spawn(move |_| { @@ -533,7 +533,7 @@ where &mut self.fetched_proof_targets, self.proof_sequencer.next_sequence(), self.tx.clone(), - self.thread_pool, + self.thread_pool.clone(), ); } StateRootMessage::StateUpdate(update) => { @@ -557,7 +557,7 @@ where &mut self.fetched_proof_targets, self.proof_sequencer.next_sequence(), self.tx.clone(), - self.thread_pool, + self.thread_pool.clone(), ); } StateRootMessage::FinishedStateUpdates => { @@ -735,7 +735,7 @@ fn get_proof_targets( /// Calculate multiproof for the targets. #[inline] fn calculate_multiproof( - thread_pool: &rayon::ThreadPool, + thread_pool: Arc, config: StateRootConfig, proof_targets: MultiProofTargets, ) -> ProviderResult @@ -993,7 +993,11 @@ mod tests { .expect("Failed to create proof worker thread pool"); let (root_from_task, _) = std::thread::scope(|std_scope| { - let task = StateRootTask::new(config, blinded_provider_factory, &state_root_task_pool); + let task = StateRootTask::new( + config, + blinded_provider_factory, + Arc::new(state_root_task_pool), + ); let mut state_hook = task.state_hook(); let handle = task.spawn(std_scope); diff --git a/crates/trie/parallel/src/proof.rs b/crates/trie/parallel/src/proof.rs index ce0c185e1..31df5f232 100644 --- a/crates/trie/parallel/src/proof.rs +++ b/crates/trie/parallel/src/proof.rs @@ -32,7 +32,7 @@ use crate::metrics::ParallelStateRootMetrics; /// TODO: #[derive(Debug)] -pub struct ParallelProof<'env, Factory> { +pub struct ParallelProof { /// Consistent view of the database. view: ConsistentDbView, /// The sorted collection of cached in-memory intermediate trie nodes that @@ -47,20 +47,20 @@ pub struct ParallelProof<'env, Factory> { /// Flag indicating whether to include branch node hash masks in the proof. collect_branch_node_hash_masks: bool, /// Thread pool for local tasks - thread_pool: &'env rayon::ThreadPool, + thread_pool: Arc, /// Parallel state root metrics. #[cfg(feature = "metrics")] metrics: ParallelStateRootMetrics, } -impl<'env, Factory> ParallelProof<'env, Factory> { +impl ParallelProof { /// Create new state proof generator. pub fn new( view: ConsistentDbView, nodes_sorted: Arc, state_sorted: Arc, prefix_sets: Arc, - thread_pool: &'env rayon::ThreadPool, + thread_pool: Arc, ) -> Self { Self { view, @@ -81,7 +81,7 @@ impl<'env, Factory> ParallelProof<'env, Factory> { } } -impl ParallelProof<'_, Factory> +impl ParallelProof where Factory: DatabaseProviderFactory + StateCommitmentProvider @@ -407,7 +407,7 @@ mod tests { Default::default(), Default::default(), Default::default(), - &state_root_task_pool + Arc::new(state_root_task_pool) ) .multiproof(targets.clone()) .unwrap(),