test: reenable test_state_root_task test (#13911)

Co-authored-by: Federico Gimenez <federico.gimenez@gmail.com>
Co-authored-by: Federico Gimenez <fgimenez@users.noreply.github.com>
This commit is contained in:
Roman Krasiuk
2025-01-24 19:28:02 +01:00
committed by GitHub
parent 203fed0f64
commit 0cd63cdf4b
2 changed files with 172 additions and 116 deletions

View File

@ -594,10 +594,7 @@ where
) -> Self { ) -> Self {
let (incoming_tx, incoming) = std::sync::mpsc::channel(); let (incoming_tx, incoming) = std::sync::mpsc::channel();
// The thread pool requires at least 2 threads as it contains a long running sparse trie let num_threads = root::thread_pool_size();
// task.
let num_threads =
std::thread::available_parallelism().map_or(2, |num| (num.get() / 2).max(2));
let state_root_task_pool = Arc::new( let state_root_task_pool = Arc::new(
rayon::ThreadPoolBuilder::new() rayon::ThreadPoolBuilder::new()

View File

@ -27,7 +27,7 @@ use reth_trie_sparse::{
}; };
use revm_primitives::{keccak256, EvmState, B256}; use revm_primitives::{keccak256, EvmState, B256};
use std::{ use std::{
collections::BTreeMap, collections::{BTreeMap, VecDeque},
sync::{ sync::{
mpsc::{self, channel, Receiver, Sender}, mpsc::{self, channel, Receiver, Sender},
Arc, Arc,
@ -39,6 +39,16 @@ use tracing::{debug, error, trace};
/// The level below which the sparse trie hashes are calculated in [`update_sparse_trie`]. /// The level below which the sparse trie hashes are calculated in [`update_sparse_trie`].
const SPARSE_TRIE_INCREMENTAL_LEVEL: usize = 2; const SPARSE_TRIE_INCREMENTAL_LEVEL: usize = 2;
/// Determines the size of the thread pool to be used in [`StateRootTask`].
/// It should be at least three, one for multiproof calculations plus two to be
/// used internally in [`StateRootTask`].
///
/// NOTE: this value can be greater than the available cores in the host, it
/// represents the maximum number of threads that can be handled by the pool.
pub(crate) fn thread_pool_size() -> usize {
std::thread::available_parallelism().map_or(3, |num| (num.get() / 2).max(3))
}
/// Outcome of the state root computation, including the state root itself with /// Outcome of the state root computation, including the state root itself with
/// the trie updates and the total time spent. /// the trie updates and the total time spent.
#[derive(Debug)] #[derive(Debug)]
@ -296,6 +306,129 @@ fn evm_state_to_hashed_post_state(update: EvmState) -> HashedPostState {
hashed_state hashed_state
} }
/// Input parameters for spawning a multiproof calculation.
#[derive(Debug)]
struct MultiproofInput<Factory> {
config: StateRootConfig<Factory>,
hashed_state_update: HashedPostState,
proof_targets: MultiProofTargets,
proof_sequence_number: u64,
state_root_message_sender: Sender<StateRootMessage>,
source: ProofFetchSource,
}
/// Manages concurrent multiproof calculations.
/// Takes care of not having more calculations in flight than a given thread
/// pool size, further calculation requests are queued and spawn later, after
/// availability has been signaled.
#[derive(Debug)]
struct MultiproofManager<Factory> {
/// Maximum number of concurrent calculations.
max_concurrent: usize,
/// Currently running calculations.
inflight: usize,
/// Queued calculations.
pending: VecDeque<MultiproofInput<Factory>>,
/// Thread pool to spawn multiproof calculations.
thread_pool: Arc<rayon::ThreadPool>,
}
impl<Factory> MultiproofManager<Factory>
where
Factory: DatabaseProviderFactory<Provider: BlockReader>
+ StateCommitmentProvider
+ Clone
+ Send
+ Sync
+ 'static,
{
/// Creates a new [`MultiproofManager`].
fn new(thread_pool: Arc<rayon::ThreadPool>, thread_pool_size: usize) -> Self {
// we keep 2 threads to be used internally by [`StateRootTask`]
let max_concurrent = thread_pool_size.saturating_sub(2);
debug_assert!(max_concurrent != 0);
Self {
thread_pool,
max_concurrent,
inflight: 0,
pending: VecDeque::with_capacity(max_concurrent),
}
}
/// Spawns a new multiproof calculation or enqueues it for later if
/// `max_concurrent` are already inflight.
fn spawn_or_queue(&mut self, input: MultiproofInput<Factory>) {
if self.inflight >= self.max_concurrent {
self.pending.push_back(input);
return;
}
self.spawn_multiproof(input);
}
/// Signals that a multiproof calculation has finished and there's room to
/// spawn a new calculation if needed.
fn on_calculation_complete(&mut self) {
self.inflight = self.inflight.saturating_sub(1);
if let Some(input) = self.pending.pop_front() {
self.spawn_multiproof(input);
}
}
/// Spawns a multiproof calculation.
fn spawn_multiproof(&mut self, input: MultiproofInput<Factory>) {
let MultiproofInput {
config,
hashed_state_update,
proof_targets,
proof_sequence_number,
state_root_message_sender,
source,
} = input;
let thread_pool = self.thread_pool.clone();
self.thread_pool.spawn(move || {
trace!(
target: "engine::root",
proof_sequence_number,
?proof_targets,
"Starting multiproof calculation",
);
let start = Instant::now();
let result = calculate_multiproof(thread_pool, config, proof_targets.clone());
trace!(
target: "engine::root",
proof_sequence_number,
elapsed = ?start.elapsed(),
"Multiproof calculated",
);
match result {
Ok(proof) => {
let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated(
Box::new(ProofCalculated {
sequence_number: proof_sequence_number,
update: SparseTrieUpdate {
state: hashed_state_update,
targets: proof_targets,
multiproof: proof,
},
source,
}),
));
}
Err(error) => {
let _ = state_root_message_sender
.send(StateRootMessage::ProofCalculationError(error));
}
}
});
self.inflight += 1;
}
}
/// Standalone task that receives a transaction state stream and updates relevant /// Standalone task that receives a transaction state stream and updates relevant
/// data structures to calculate state root. /// data structures to calculate state root.
/// ///
@ -316,8 +449,10 @@ pub struct StateRootTask<Factory> {
fetched_proof_targets: MultiProofTargets, fetched_proof_targets: MultiProofTargets,
/// Proof sequencing handler. /// Proof sequencing handler.
proof_sequencer: ProofSequencer, proof_sequencer: ProofSequencer,
/// Reference to the shared thread pool for parallel proof generation /// Reference to the shared thread pool for parallel proof generation.
thread_pool: Arc<rayon::ThreadPool>, thread_pool: Arc<rayon::ThreadPool>,
/// Manages calculation of multiproofs.
multiproof_manager: MultiproofManager<Factory>,
} }
impl<Factory> StateRootTask<Factory> impl<Factory> StateRootTask<Factory>
@ -338,7 +473,8 @@ where
tx, tx,
fetched_proof_targets: Default::default(), fetched_proof_targets: Default::default(),
proof_sequencer: ProofSequencer::new(), proof_sequencer: ProofSequencer::new(),
thread_pool, thread_pool: thread_pool.clone(),
multiproof_manager: MultiproofManager::new(thread_pool, thread_pool_size()),
} }
} }
@ -397,99 +533,34 @@ where
} }
/// Handles request for proof prefetch. /// Handles request for proof prefetch.
fn on_prefetch_proof( fn on_prefetch_proof(&mut self, targets: MultiProofTargets) {
config: StateRootConfig<Factory>, extend_multi_proof_targets_ref(&mut self.fetched_proof_targets, &targets);
targets: MultiProofTargets,
fetched_proof_targets: &mut MultiProofTargets,
proof_sequence_number: u64,
state_root_message_sender: Sender<StateRootMessage>,
thread_pool: Arc<rayon::ThreadPool>,
) {
extend_multi_proof_targets_ref(fetched_proof_targets, &targets);
Self::spawn_multiproof( self.multiproof_manager.spawn_or_queue(MultiproofInput {
config, config: self.config.clone(),
Default::default(), hashed_state_update: Default::default(),
targets, proof_targets: targets,
proof_sequence_number, proof_sequence_number: self.proof_sequencer.next_sequence(),
state_root_message_sender, state_root_message_sender: self.tx.clone(),
thread_pool, source: ProofFetchSource::Prefetch,
ProofFetchSource::Prefetch, });
);
} }
/// Handles state updates. /// Handles state updates.
/// ///
/// Returns proof targets derived from the state update. /// Returns proof targets derived from the state update.
fn on_state_update( fn on_state_update(&mut self, update: EvmState, proof_sequence_number: u64) {
config: StateRootConfig<Factory>,
update: EvmState,
fetched_proof_targets: &mut MultiProofTargets,
proof_sequence_number: u64,
state_root_message_sender: Sender<StateRootMessage>,
thread_pool: Arc<rayon::ThreadPool>,
) {
let hashed_state_update = evm_state_to_hashed_post_state(update); let hashed_state_update = evm_state_to_hashed_post_state(update);
let proof_targets = get_proof_targets(&hashed_state_update, &self.fetched_proof_targets);
extend_multi_proof_targets_ref(&mut self.fetched_proof_targets, &proof_targets);
let proof_targets = get_proof_targets(&hashed_state_update, fetched_proof_targets); self.multiproof_manager.spawn_or_queue(MultiproofInput {
extend_multi_proof_targets_ref(fetched_proof_targets, &proof_targets); config: self.config.clone(),
Self::spawn_multiproof(
config,
hashed_state_update, hashed_state_update,
proof_targets, proof_targets,
proof_sequence_number, proof_sequence_number,
state_root_message_sender, state_root_message_sender: self.tx.clone(),
thread_pool, source: ProofFetchSource::StateUpdate,
ProofFetchSource::StateUpdate,
);
}
fn spawn_multiproof(
config: StateRootConfig<Factory>,
hashed_state_update: HashedPostState,
proof_targets: MultiProofTargets,
proof_sequence_number: u64,
state_root_message_sender: Sender<StateRootMessage>,
thread_pool: Arc<rayon::ThreadPool>,
source: ProofFetchSource,
) {
// Dispatch proof gathering for this state update
thread_pool.clone().spawn(move || {
trace!(
target: "engine::root",
proof_sequence_number,
?proof_targets,
"Starting multiproof calculation",
);
let start = Instant::now();
let result = calculate_multiproof(thread_pool, config, proof_targets.clone());
trace!(
target: "engine::root",
proof_sequence_number,
elapsed = ?start.elapsed(),
"Multiproof calculated",
);
match result {
Ok(proof) => {
let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated(
Box::new(ProofCalculated {
sequence_number: proof_sequence_number,
update: SparseTrieUpdate {
state: hashed_state_update,
targets: proof_targets,
multiproof: proof,
},
source,
}),
));
}
Err(error) => {
let _ = state_root_message_sender
.send(StateRootMessage::ProofCalculationError(error));
}
}
}); });
} }
@ -526,24 +597,20 @@ where
let mut last_update_time = None; let mut last_update_time = None;
loop { loop {
trace!(target: "engine::root", "entering main channel receiving loop");
match self.rx.recv() { match self.rx.recv() {
Ok(message) => match message { Ok(message) => match message {
StateRootMessage::PrefetchProofs(targets) => { StateRootMessage::PrefetchProofs(targets) => {
trace!(target: "engine::root", "processing StateRootMessage::PrefetchProofs");
debug!( debug!(
target: "engine::root", target: "engine::root",
len = targets.len(), len = targets.len(),
"Prefetching proofs" "Prefetching proofs"
); );
Self::on_prefetch_proof( self.on_prefetch_proof(targets);
self.config.clone(),
targets,
&mut self.fetched_proof_targets,
self.proof_sequencer.next_sequence(),
self.tx.clone(),
self.thread_pool.clone(),
);
} }
StateRootMessage::StateUpdate(update) => { StateRootMessage::StateUpdate(update) => {
trace!(target: "engine::root", "processing StateRootMessage::StateUpdate");
if updates_received == 0 { if updates_received == 0 {
first_update_time = Some(Instant::now()); first_update_time = Some(Instant::now());
debug!(target: "engine::root", "Started state root calculation"); debug!(target: "engine::root", "Started state root calculation");
@ -557,23 +624,19 @@ where
total_updates = updates_received, total_updates = updates_received,
"Received new state update" "Received new state update"
); );
Self::on_state_update( let next_sequence = self.proof_sequencer.next_sequence();
self.config.clone(), self.on_state_update(update, next_sequence);
update,
&mut self.fetched_proof_targets,
self.proof_sequencer.next_sequence(),
self.tx.clone(),
self.thread_pool.clone(),
);
} }
StateRootMessage::FinishedStateUpdates => { StateRootMessage::FinishedStateUpdates => {
trace!(target: "engine::root", "Finished state updates"); trace!(target: "engine::root", "processing StateRootMessage::FinishedStateUpdates");
updates_finished = true; updates_finished = true;
} }
StateRootMessage::ProofCalculated(proof_calculated) => { StateRootMessage::ProofCalculated(proof_calculated) => {
trace!(target: "engine::root", "processing StateRootMessage::ProofCalculated");
if proof_calculated.is_from_state_update() { if proof_calculated.is_from_state_update() {
proofs_processed += 1; proofs_processed += 1;
} }
debug!( debug!(
target: "engine::root", target: "engine::root",
sequence = proof_calculated.sequence_number, sequence = proof_calculated.sequence_number,
@ -581,6 +644,8 @@ where
"Processing calculated proof" "Processing calculated proof"
); );
self.multiproof_manager.on_calculation_complete();
if let Some(combined_update) = if let Some(combined_update) =
self.on_proof(proof_calculated.sequence_number, proof_calculated.update) self.on_proof(proof_calculated.sequence_number, proof_calculated.update)
{ {
@ -599,6 +664,7 @@ where
} }
} }
StateRootMessage::RootCalculated { state_root, trie_updates, iterations } => { StateRootMessage::RootCalculated { state_root, trie_updates, iterations } => {
trace!(target: "engine::root", "processing StateRootMessage::RootCalculated");
let total_time = let total_time =
first_update_time.expect("first update time should be set").elapsed(); first_update_time.expect("first update time should be set").elapsed();
let time_from_last_update = let time_from_last_update =
@ -694,7 +760,7 @@ where
let elapsed = update_sparse_trie(&mut trie, update).map_err(|e| { let elapsed = update_sparse_trie(&mut trie, update).map_err(|e| {
ParallelStateRootError::Other(format!("could not calculate state root: {e:?}")) ParallelStateRootError::Other(format!("could not calculate state root: {e:?}"))
})?; })?;
trace!(target: "engine::root", ?elapsed, "Root calculation completed"); trace!(target: "engine::root", ?elapsed, num_iterations, "Root calculation completed");
} }
debug!(target: "engine::root", num_iterations, "All proofs processed, ending calculation"); debug!(target: "engine::root", num_iterations, "All proofs processed, ending calculation");
@ -853,7 +919,6 @@ mod tests {
}; };
use std::sync::Arc; use std::sync::Arc;
#[allow(dead_code)]
fn convert_revm_to_reth_account(revm_account: &RevmAccount) -> RethAccount { fn convert_revm_to_reth_account(revm_account: &RevmAccount) -> RethAccount {
RethAccount { RethAccount {
balance: revm_account.info.balance, balance: revm_account.info.balance,
@ -866,7 +931,6 @@ mod tests {
} }
} }
#[allow(dead_code)]
fn create_mock_state_updates(num_accounts: usize, updates_per_account: usize) -> Vec<EvmState> { fn create_mock_state_updates(num_accounts: usize, updates_per_account: usize) -> Vec<EvmState> {
let mut rng = generators::rng(); let mut rng = generators::rng();
let all_addresses: Vec<Address> = (0..num_accounts).map(|_| rng.gen()).collect(); let all_addresses: Vec<Address> = (0..num_accounts).map(|_| rng.gen()).collect();
@ -910,9 +974,7 @@ mod tests {
updates updates
} }
// TODO: re-enable test once gh worker hang is figured out. #[test]
// #[test]
#[allow(dead_code)]
fn test_state_root_task() { fn test_state_root_task() {
reth_tracing::init_test_tracing(); reth_tracing::init_test_tracing();
@ -973,10 +1035,7 @@ mod tests {
prefix_sets: Arc::new(input.prefix_sets), prefix_sets: Arc::new(input.prefix_sets),
}; };
// The thread pool requires at least 2 threads as it contains a long running sparse trie let num_threads = thread_pool_size();
// task.
let num_threads =
std::thread::available_parallelism().map_or(2, |num| (num.get() / 2).max(2));
let state_root_task_pool = rayon::ThreadPoolBuilder::new() let state_root_task_pool = rayon::ThreadPoolBuilder::new()
.num_threads(num_threads) .num_threads(num_threads)