chore(engine): use Arc<rayon::ThreadPool> for StateRootTask (#13755)

This commit is contained in:
Federico Gimenez
2025-01-09 15:25:00 +01:00
committed by GitHub
parent bf65ed45c5
commit 4a8c88f4d0
4 changed files with 39 additions and 31 deletions

View File

@ -23,7 +23,7 @@ use revm_primitives::{
Account as RevmAccount, AccountInfo, AccountStatus, Address, EvmState, EvmStorageSlot, HashMap,
B256, KECCAK_EMPTY, U256,
};
use std::hint::black_box;
use std::{hint::black_box, sync::Arc};
#[derive(Debug, Clone)]
struct BenchParams {
@ -217,11 +217,13 @@ fn bench_state_root(c: &mut Criterion) {
let num_threads = std::thread::available_parallelism()
.map_or(1, |num| (num.get() / 2).max(1));
let state_root_task_pool = rayon::ThreadPoolBuilder::new()
let state_root_task_pool = Arc::new(
rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.thread_name(|i| format!("proof-worker-{}", i))
.build()
.expect("Failed to create proof worker thread pool");
.expect("Failed to create proof worker thread pool"),
);
(
config,
@ -258,7 +260,7 @@ fn bench_state_root(c: &mut Criterion) {
let task = StateRootTask::new(
config,
blinded_provider_factory,
&state_root_task_pool,
state_root_task_pool,
);
let mut hook = task.state_hook();
let handle = task.spawn(scope);

View File

@ -539,7 +539,7 @@ where
/// The engine API variant of this handler
engine_kind: EngineApiKind,
/// state root task thread pool
state_root_task_pool: rayon::ThreadPool,
state_root_task_pool: Arc<rayon::ThreadPool>,
}
impl<N, P: Debug, E: Debug, T: EngineTypes + Debug, V: Debug> std::fmt::Debug
@ -606,11 +606,13 @@ where
let num_threads =
std::thread::available_parallelism().map_or(1, |num| (num.get() / 2).max(1));
let state_root_task_pool = rayon::ThreadPoolBuilder::new()
let state_root_task_pool = Arc::new(
rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.thread_name(|i| format!("srt-worker-{}", i))
.build()
.expect("Failed to create proof worker thread pool");
.expect("Failed to create proof worker thread pool"),
);
Self {
provider,
@ -2313,7 +2315,7 @@ where
let state_root_task = StateRootTask::new(
state_root_config,
blinded_provider_factory,
&self.state_root_task_pool,
self.state_root_task_pool.clone(),
);
let state_hook = state_root_task.state_hook();
(Some(state_root_task.spawn(scope)), Box::new(state_hook) as Box<dyn OnStateHook>)

View File

@ -260,7 +260,7 @@ fn evm_state_to_hashed_post_state(update: EvmState) -> HashedPostState {
/// to the tree.
/// Then it updates relevant leaves according to the result of the transaction.
#[derive(Debug)]
pub struct StateRootTask<'env, Factory, BPF: BlindedProviderFactory> {
pub struct StateRootTask<Factory, BPF: BlindedProviderFactory> {
/// Task configuration.
config: StateRootConfig<Factory>,
/// Receiver for state root related messages.
@ -275,10 +275,10 @@ pub struct StateRootTask<'env, Factory, BPF: BlindedProviderFactory> {
/// progress.
sparse_trie: Option<Box<SparseStateTrie<BPF>>>,
/// Reference to the shared thread pool for parallel proof generation
thread_pool: &'env rayon::ThreadPool,
thread_pool: Arc<rayon::ThreadPool>,
}
impl<'env, Factory, BPF> StateRootTask<'env, Factory, BPF>
impl<'env, Factory, BPF> StateRootTask<Factory, BPF>
where
Factory: DatabaseProviderFactory<Provider: BlockReader>
+ StateCommitmentProvider
@ -294,7 +294,7 @@ where
pub fn new(
config: StateRootConfig<Factory>,
blinded_provider: BPF,
thread_pool: &'env rayon::ThreadPool,
thread_pool: Arc<rayon::ThreadPool>,
) -> Self {
let (tx, rx) = channel();
@ -344,7 +344,7 @@ where
fetched_proof_targets: &mut MultiProofTargets,
proof_sequence_number: u64,
state_root_message_sender: Sender<StateRootMessage<BPF>>,
thread_pool: &'env rayon::ThreadPool,
thread_pool: Arc<rayon::ThreadPool>,
) {
let proof_targets =
targets.into_iter().map(|address| (keccak256(address), Default::default())).collect();
@ -371,7 +371,7 @@ where
fetched_proof_targets: &mut MultiProofTargets,
proof_sequence_number: u64,
state_root_message_sender: Sender<StateRootMessage<BPF>>,
thread_pool: &'env rayon::ThreadPool,
thread_pool: Arc<rayon::ThreadPool>,
) {
let hashed_state_update = evm_state_to_hashed_post_state(update);
@ -396,7 +396,7 @@ where
proof_targets: MultiProofTargets,
proof_sequence_number: u64,
state_root_message_sender: Sender<StateRootMessage<BPF>>,
thread_pool: &'env rayon::ThreadPool,
thread_pool: Arc<rayon::ThreadPool>,
) {
// Dispatch proof gathering for this state update
scope.spawn(move |_| {
@ -533,7 +533,7 @@ where
&mut self.fetched_proof_targets,
self.proof_sequencer.next_sequence(),
self.tx.clone(),
self.thread_pool,
self.thread_pool.clone(),
);
}
StateRootMessage::StateUpdate(update) => {
@ -557,7 +557,7 @@ where
&mut self.fetched_proof_targets,
self.proof_sequencer.next_sequence(),
self.tx.clone(),
self.thread_pool,
self.thread_pool.clone(),
);
}
StateRootMessage::FinishedStateUpdates => {
@ -735,7 +735,7 @@ fn get_proof_targets(
/// Calculate multiproof for the targets.
#[inline]
fn calculate_multiproof<Factory>(
thread_pool: &rayon::ThreadPool,
thread_pool: Arc<rayon::ThreadPool>,
config: StateRootConfig<Factory>,
proof_targets: MultiProofTargets,
) -> ProviderResult<MultiProof>
@ -993,7 +993,11 @@ mod tests {
.expect("Failed to create proof worker thread pool");
let (root_from_task, _) = std::thread::scope(|std_scope| {
let task = StateRootTask::new(config, blinded_provider_factory, &state_root_task_pool);
let task = StateRootTask::new(
config,
blinded_provider_factory,
Arc::new(state_root_task_pool),
);
let mut state_hook = task.state_hook();
let handle = task.spawn(std_scope);

View File

@ -32,7 +32,7 @@ use crate::metrics::ParallelStateRootMetrics;
/// TODO:
#[derive(Debug)]
pub struct ParallelProof<'env, Factory> {
pub struct ParallelProof<Factory> {
/// Consistent view of the database.
view: ConsistentDbView<Factory>,
/// The sorted collection of cached in-memory intermediate trie nodes that
@ -47,20 +47,20 @@ pub struct ParallelProof<'env, Factory> {
/// Flag indicating whether to include branch node hash masks in the proof.
collect_branch_node_hash_masks: bool,
/// Thread pool for local tasks
thread_pool: &'env rayon::ThreadPool,
thread_pool: Arc<rayon::ThreadPool>,
/// Parallel state root metrics.
#[cfg(feature = "metrics")]
metrics: ParallelStateRootMetrics,
}
impl<'env, Factory> ParallelProof<'env, Factory> {
impl<Factory> ParallelProof<Factory> {
/// Create new state proof generator.
pub fn new(
view: ConsistentDbView<Factory>,
nodes_sorted: Arc<TrieUpdatesSorted>,
state_sorted: Arc<HashedPostStateSorted>,
prefix_sets: Arc<TriePrefixSetsMut>,
thread_pool: &'env rayon::ThreadPool,
thread_pool: Arc<rayon::ThreadPool>,
) -> Self {
Self {
view,
@ -81,7 +81,7 @@ impl<'env, Factory> ParallelProof<'env, Factory> {
}
}
impl<Factory> ParallelProof<'_, Factory>
impl<Factory> ParallelProof<Factory>
where
Factory: DatabaseProviderFactory<Provider: BlockReader>
+ StateCommitmentProvider
@ -407,7 +407,7 @@ mod tests {
Default::default(),
Default::default(),
Default::default(),
&state_root_task_pool
Arc::new(state_root_task_pool)
)
.multiproof(targets.clone())
.unwrap(),