perf(root): untangle the state root task (#13898)

This commit is contained in:
Roman Krasiuk
2025-01-21 23:46:34 +01:00
committed by GitHub
parent 6c3b1b8bcd
commit d2b454236f
7 changed files with 343 additions and 555 deletions

View File

@ -14,11 +14,7 @@ use reth_provider::{
test_utils::{create_test_provider_factory, MockNodeTypesWithDB},
AccountReader, HashingWriter, ProviderFactory,
};
use reth_trie::{
hashed_cursor::HashedPostStateCursorFactory, proof::ProofBlindedProviderFactory,
trie_cursor::InMemoryTrieCursorFactory, TrieInput,
};
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
use reth_trie::TrieInput;
use revm_primitives::{
Account as RevmAccount, AccountInfo, AccountStatus, Address, EvmState, EvmStorageSlot, HashMap,
B256, KECCAK_EMPTY, U256,
@ -210,10 +206,6 @@ fn bench_state_root(c: &mut Criterion) {
ConsistentDbView::new(factory, None),
trie_input,
);
let provider = config.consistent_view.provider_ro().unwrap();
let nodes_sorted = config.nodes_sorted.clone();
let state_sorted = config.state_sorted.clone();
let prefix_sets = config.prefix_sets.clone();
let num_threads = std::thread::available_parallelism()
.map_or(1, |num| (num.get() / 2).max(1));
@ -225,45 +217,13 @@ fn bench_state_root(c: &mut Criterion) {
.expect("Failed to create proof worker thread pool"),
);
(
config,
state_updates,
provider,
nodes_sorted,
state_sorted,
prefix_sets,
state_root_task_pool,
)
(config, state_updates, state_root_task_pool)
},
|(
config,
state_updates,
provider,
nodes_sorted,
state_sorted,
prefix_sets,
state_root_task_pool,
)| {
let blinded_provider_factory = ProofBlindedProviderFactory::new(
InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider.tx_ref()),
&nodes_sorted,
),
HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(provider.tx_ref()),
&state_sorted,
),
prefix_sets,
);
black_box(std::thread::scope(|scope| {
let task = StateRootTask::new(
config,
blinded_provider_factory,
state_root_task_pool,
);
|(config, state_updates, state_root_task_pool)| {
black_box({
let task = StateRootTask::new(config, state_root_task_pool);
let mut hook = task.state_hook();
let handle = task.spawn(scope);
let handle = task.spawn();
for update in state_updates {
hook.on_state(&update)
@ -271,7 +231,7 @@ fn bench_state_root(c: &mut Criterion) {
drop(hook);
handle.wait_for_result().expect("task failed")
}));
});
},
)
},

View File

@ -31,7 +31,10 @@ use reth_engine_primitives::{
OnForkChoiceUpdated,
};
use reth_errors::{ConsensusError, ProviderResult};
use reth_evm::{execute::BlockExecutorProvider, system_calls::OnStateHook};
use reth_evm::{
execute::BlockExecutorProvider,
system_calls::{NoopHook, OnStateHook},
};
use reth_payload_builder::PayloadBuilderHandle;
use reth_payload_builder_primitives::PayloadBuilder;
use reth_payload_primitives::PayloadBuilderAttributes;
@ -47,16 +50,10 @@ use reth_provider::{
use reth_revm::database::StateProviderDatabase;
use reth_stages_api::ControlFlow;
use reth_trie::{
hashed_cursor::HashedPostStateCursorFactory,
prefix_set::TriePrefixSetsMut,
proof::ProofBlindedProviderFactory,
trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
updates::{TrieUpdates, TrieUpdatesSorted},
HashedPostState, HashedPostStateSorted, TrieInput,
trie_cursor::InMemoryTrieCursorFactory, updates::TrieUpdates, HashedPostState, TrieInput,
};
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
use reth_trie_db::DatabaseTrieCursorFactory;
use reth_trie_parallel::root::{ParallelStateRoot, ParallelStateRootError};
use revm_primitives::EvmState;
use root::{StateRootComputeOutcome, StateRootConfig, StateRootHandle, StateRootTask};
use std::{
cmp::Ordering,
@ -485,15 +482,6 @@ pub enum TreeAction {
},
}
/// Context used to keep alive the required values when returning a state hook
/// from a scoped thread.
struct StateHookContext<P> {
provider_ro: P,
nodes_sorted: Arc<TrieUpdatesSorted>,
state_sorted: Arc<HashedPostStateSorted>,
prefix_sets: Arc<TriePrefixSetsMut>,
}
/// The engine API tree handler implementation.
///
/// This type is responsible for processing engine API requests, maintaining the canonical state and
@ -607,8 +595,10 @@ where
) -> Self {
let (incoming_tx, incoming) = std::sync::mpsc::channel();
// The thread pool requires at least 2 threads as it contains a long running sparse trie
// task.
let num_threads =
std::thread::available_parallelism().map_or(1, |num| (num.get() / 2).max(1));
std::thread::available_parallelism().map_or(2, |num| (num.get() / 2).max(2));
let state_root_task_pool = Arc::new(
rayon::ThreadPoolBuilder::new()
@ -2281,144 +2271,90 @@ where
let persistence_not_in_progress = !self.persistence_state.in_progress();
let state_root_result = std::thread::scope(|scope| {
let (state_root_handle, in_memory_trie_cursor, state_hook) =
if persistence_not_in_progress && self.config.use_state_root_task() {
let consistent_view =
ConsistentDbView::new_with_latest_tip(self.provider.clone())?;
let (state_root_handle, state_root_task_config, state_hook) = if persistence_not_in_progress &&
self.config.use_state_root_task()
{
let consistent_view = ConsistentDbView::new_with_latest_tip(self.provider.clone())?;
let state_root_config = StateRootConfig::new_from_input(
consistent_view.clone(),
self.compute_trie_input(consistent_view, block.header().parent_hash())
.map_err(|e| InsertBlockErrorKind::Other(Box::new(e)))?,
);
let state_root_config = StateRootConfig::new_from_input(
consistent_view.clone(),
self.compute_trie_input(
consistent_view.clone(),
block.header().parent_hash(),
)
.map_err(|e| InsertBlockErrorKind::Other(Box::new(e)))?,
);
let state_root_task =
StateRootTask::new(state_root_config.clone(), self.state_root_task_pool.clone());
let state_hook = Box::new(state_root_task.state_hook()) as Box<dyn OnStateHook>;
(Some(state_root_task.spawn()), Some(state_root_config), state_hook)
} else {
(None, None, Box::new(NoopHook::default()) as Box<dyn OnStateHook>)
};
let provider_ro = consistent_view.provider_ro()?;
let nodes_sorted = state_root_config.nodes_sorted.clone();
let state_sorted = state_root_config.state_sorted.clone();
let prefix_sets = state_root_config.prefix_sets.clone();
let execution_start = Instant::now();
let output = self.metrics.executor.execute_metered(executor, &block, state_hook)?;
let execution_time = execution_start.elapsed();
trace!(target: "engine::tree", elapsed = ?execution_time, ?block_number, "Executed block");
// context will hold the values that need to be kept alive
let context =
StateHookContext { provider_ro, nodes_sorted, state_sorted, prefix_sets };
if let Err(err) = self.consensus.validate_block_post_execution(
&block,
PostExecutionInput::new(&output.receipts, &output.requests),
) {
// call post-block hook
self.invalid_block_hook.on_invalid_block(&parent_block, &block, &output, None);
return Err(err.into())
}
// it is ok to leak here because we are in a scoped thread, the
// memory will be freed when the thread completes
let context = Box::leak(Box::new(context));
let hashed_state = self.provider.hashed_post_state(&output.state);
let in_memory_trie_cursor = InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(context.provider_ro.tx_ref()),
&context.nodes_sorted,
);
let blinded_provider_factory = ProofBlindedProviderFactory::new(
in_memory_trie_cursor.clone(),
HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(context.provider_ro.tx_ref()),
&context.state_sorted,
),
context.prefix_sets.clone(),
);
trace!(target: "engine::tree", block=?sealed_block.num_hash(), "Calculating block state root");
let root_time = Instant::now();
let state_root_task = StateRootTask::new(
state_root_config,
blinded_provider_factory,
self.state_root_task_pool.clone(),
);
let state_hook = state_root_task.state_hook();
(
Some(state_root_task.spawn(scope)),
Some(in_memory_trie_cursor),
Box::new(state_hook) as Box<dyn OnStateHook>,
)
} else {
(None, None, Box::new(|_state: &EvmState| {}) as Box<dyn OnStateHook>)
};
// We attempt to compute state root in parallel if we are currently not persisting
// anything to database. This is safe, because the database state cannot
// change until we finish parallel computation. It is important that nothing
// is being persisted as we are computing in parallel, because we initialize
// a different database transaction per thread and it might end up with a
// different view of the database.
let (state_root, trie_output, root_elapsed) = if persistence_not_in_progress {
if self.config.use_state_root_task() {
let state_root_handle = state_root_handle
.expect("state root handle must exist if use_state_root_task is true");
let state_root_config = state_root_task_config.expect("task config is present");
let execution_start = Instant::now();
let output = self.metrics.executor.execute_metered(executor, &block, state_hook)?;
let execution_time = execution_start.elapsed();
trace!(target: "engine::tree", elapsed = ?execution_time, ?block_number, "Executed block");
if let Err(err) = self.consensus.validate_block_post_execution(
&block,
PostExecutionInput::new(&output.receipts, &output.requests),
) {
// call post-block hook
self.invalid_block_hook.on_invalid_block(&parent_block, &block, &output, None);
return Err(err.into())
}
let hashed_state = self.provider.hashed_post_state(&output.state);
trace!(target: "engine::tree", block=?sealed_block.num_hash(), "Calculating block state root");
let root_time = Instant::now();
// We attempt to compute state root in parallel if we are currently not persisting
// anything to database. This is safe, because the database state cannot
// change until we finish parallel computation. It is important that nothing
// is being persisted as we are computing in parallel, because we initialize
// a different database transaction per thread and it might end up with a
// different view of the database.
let (state_root, trie_updates, root_elapsed) = if persistence_not_in_progress {
if self.config.use_state_root_task() {
let state_root_handle = state_root_handle
.expect("state root handle must exist if use_state_root_task is true");
let in_memory_trie_cursor = in_memory_trie_cursor
.expect("in memory trie cursor must exist if use_state_root_task is true");
// Handle state root result from task using handle
self.handle_state_root_result(
state_root_handle,
sealed_block.as_ref(),
&hashed_state,
&state_provider,
in_memory_trie_cursor,
root_time,
)?
} else {
match self
.compute_state_root_parallel(block.header().parent_hash(), &hashed_state)
{
Ok(result) => {
info!(
target: "engine::tree",
block = ?sealed_block.num_hash(),
regular_state_root = ?result.0,
"Regular root task finished"
);
(result.0, result.1, root_time.elapsed())
}
Err(ParallelStateRootError::Provider(ProviderError::ConsistentView(
error,
))) => {
debug!(target: "engine", %error, "Parallel state root computation failed consistency check, falling back");
let (root, updates) =
state_provider.state_root_with_updates(hashed_state.clone())?;
(root, updates, root_time.elapsed())
}
Err(error) => return Err(InsertBlockErrorKind::Other(Box::new(error))),
}
}
// Handle state root result from task using handle
self.handle_state_root_result(
state_root_handle,
state_root_config,
sealed_block.as_ref(),
&hashed_state,
&state_provider,
root_time,
)?
} else {
debug!(target: "engine::tree", block=?sealed_block.num_hash(), ?persistence_not_in_progress, "Failed to compute state root in parallel");
let (root, updates) =
state_provider.state_root_with_updates(hashed_state.clone())?;
(root, updates, root_time.elapsed())
};
Result::<_, InsertBlockErrorKind>::Ok((
state_root,
trie_updates,
hashed_state,
output,
root_elapsed,
))
})?;
let (state_root, trie_output, hashed_state, output, root_elapsed) = state_root_result;
match self.compute_state_root_parallel(block.header().parent_hash(), &hashed_state)
{
Ok(result) => {
info!(
target: "engine::tree",
block = ?sealed_block.num_hash(),
regular_state_root = ?result.0,
"Regular root task finished"
);
(result.0, result.1, root_time.elapsed())
}
Err(ParallelStateRootError::Provider(ProviderError::ConsistentView(error))) => {
debug!(target: "engine", %error, "Parallel state root computation failed consistency check, falling back");
let (root, updates) =
state_provider.state_root_with_updates(hashed_state.clone())?;
(root, updates, root_time.elapsed())
}
Err(error) => return Err(InsertBlockErrorKind::Other(Box::new(error))),
}
}
} else {
debug!(target: "engine::tree", block=?sealed_block.num_hash(), ?persistence_not_in_progress, "Failed to compute state root in parallel");
let (root, updates) = state_provider.state_root_with_updates(hashed_state.clone())?;
(root, updates, root_time.elapsed())
};
if state_root != block.header().state_root() {
// call post-block hook
@ -2559,10 +2495,10 @@ where
fn handle_state_root_result(
&self,
state_root_handle: StateRootHandle,
state_root_task_config: StateRootConfig<P>,
sealed_block: &SealedBlock<N::Block>,
hashed_state: &HashedPostState,
state_provider: impl StateRootProvider,
in_memory_trie_cursor: impl TrieCursorFactory,
root_time: Instant,
) -> Result<(B256, TrieUpdates, Duration), InsertBlockErrorKind> {
match state_root_handle.wait_for_result() {
@ -2590,6 +2526,11 @@ where
state_provider.state_root_with_updates(hashed_state.clone())?;
if regular_root == sealed_block.header().state_root() {
let provider_ro = state_root_task_config.consistent_view.provider_ro()?;
let in_memory_trie_cursor = InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
&state_root_task_config.nodes_sorted,
);
compare_trie_updates(
in_memory_trie_cursor,
task_trie_updates.clone(),

View File

@ -6,18 +6,23 @@ use rayon::iter::{ParallelBridge, ParallelIterator};
use reth_errors::{ProviderError, ProviderResult};
use reth_evm::system_calls::OnStateHook;
use reth_provider::{
providers::ConsistentDbView, BlockReader, DatabaseProviderFactory, StateCommitmentProvider,
providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory,
StateCommitmentProvider,
};
use reth_trie::{
hashed_cursor::HashedPostStateCursorFactory,
prefix_set::TriePrefixSetsMut,
proof::ProofBlindedProviderFactory,
trie_cursor::InMemoryTrieCursorFactory,
updates::{TrieUpdates, TrieUpdatesSorted},
HashedPostState, HashedPostStateSorted, HashedStorage, MultiProof, MultiProofTargets, Nibbles,
TrieInput,
};
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
use reth_trie_parallel::{proof::ParallelProof, root::ParallelStateRootError};
use reth_trie_sparse::{
blinded::{BlindedProvider, BlindedProviderFactory},
errors::{SparseStateTrieError, SparseStateTrieResult, SparseTrieErrorKind},
errors::{SparseStateTrieResult, SparseTrieErrorKind},
SparseStateTrie,
};
use revm_primitives::{keccak256, EvmState, B256};
@ -27,7 +32,6 @@ use std::{
mpsc::{self, channel, Receiver, Sender},
Arc,
},
thread::{self},
time::{Duration, Instant},
};
use tracing::{debug, error, trace};
@ -47,6 +51,32 @@ pub struct StateRootComputeOutcome {
pub time_from_last_update: Duration,
}
/// A trie update that can be applied to sparse trie alongside the proofs for touched parts of the
/// state.
#[derive(Default, Debug)]
pub struct SparseTrieUpdate {
/// The state update that was used to calculate the proof
state: HashedPostState,
/// The proof targets
targets: MultiProofTargets,
/// The calculated multiproof
multiproof: MultiProof,
}
impl SparseTrieUpdate {
/// Construct update from multiproof.
pub fn from_multiproof(multiproof: MultiProof) -> Self {
Self { multiproof, ..Default::default() }
}
/// Extend update with contents of the other.
pub fn extend(&mut self, other: Self) {
self.state.extend(other.state);
extend_multi_proof_targets(&mut self.targets, other.targets);
self.multiproof.extend(other.multiproof);
}
}
/// Result of the state root calculation
pub(crate) type StateRootResult = Result<StateRootComputeOutcome, ParallelStateRootError>;
@ -99,7 +129,7 @@ impl<Factory> StateRootConfig<Factory> {
/// Messages used internally by the state root task
#[derive(Debug)]
pub enum StateRootMessage<BPF: BlindedProviderFactory> {
pub enum StateRootMessage {
/// Prefetch proof targets
PrefetchProofs(MultiProofTargets),
/// New state update from transaction execution
@ -110,13 +140,15 @@ pub enum StateRootMessage<BPF: BlindedProviderFactory> {
ProofCalculationError(ProviderError),
/// State root calculation completed
RootCalculated {
/// The updated sparse trie
trie: Box<SparseStateTrie<BPF>>,
/// Time taken to calculate the root
elapsed: Duration,
/// Final state root.
state_root: B256,
/// Trie updates.
trie_updates: TrieUpdates,
/// The number of time sparse trie was updated.
iterations: u64,
},
/// Error during state root calculation
RootCalculationError(SparseStateTrieError),
RootCalculationError(ParallelStateRootError),
/// Signals state update stream end.
FinishedStateUpdates,
}
@ -124,14 +156,10 @@ pub enum StateRootMessage<BPF: BlindedProviderFactory> {
/// Message about completion of proof calculation for a specific state update
#[derive(Debug)]
pub struct ProofCalculated {
/// The state update that was used to calculate the proof
state_update: HashedPostState,
/// The proof targets
targets: MultiProofTargets,
/// The calculated proof
proof: MultiProof,
/// The index of this proof in the sequence of state updates
sequence_number: u64,
/// Sparse trie update
update: SparseTrieUpdate,
}
/// Handle to track proof calculation ordering
@ -142,7 +170,7 @@ pub(crate) struct ProofSequencer {
/// The next sequence number expected to be delivered.
next_to_deliver: u64,
/// Buffer for out-of-order proofs and corresponding state updates
pending_proofs: BTreeMap<u64, (HashedPostState, MultiProofTargets, MultiProof)>,
pending_proofs: BTreeMap<u64, SparseTrieUpdate>,
}
impl ProofSequencer {
@ -163,12 +191,10 @@ impl ProofSequencer {
pub(crate) fn add_proof(
&mut self,
sequence: u64,
state_update: HashedPostState,
targets: MultiProofTargets,
proof: MultiProof,
) -> Vec<(HashedPostState, MultiProofTargets, MultiProof)> {
update: SparseTrieUpdate,
) -> Vec<SparseTrieUpdate> {
if sequence >= self.next_to_deliver {
self.pending_proofs.insert(sequence, (state_update, targets, proof));
self.pending_proofs.insert(sequence, update);
}
// return early if we don't have the next expected proof
@ -203,15 +229,15 @@ impl ProofSequencer {
/// A wrapper for the sender that signals completion when dropped
#[derive(Deref, Debug)]
pub struct StateHookSender<BPF: BlindedProviderFactory>(Sender<StateRootMessage<BPF>>);
pub struct StateHookSender(Sender<StateRootMessage>);
impl<BPF: BlindedProviderFactory> StateHookSender<BPF> {
pub(crate) const fn new(inner: Sender<StateRootMessage<BPF>>) -> Self {
impl StateHookSender {
pub(crate) const fn new(inner: Sender<StateRootMessage>) -> Self {
Self(inner)
}
}
impl<BPF: BlindedProviderFactory> Drop for StateHookSender<BPF> {
impl Drop for StateHookSender {
fn drop(&mut self) {
// Send completion signal when the sender is dropped
let _ = self.0.send(StateRootMessage::FinishedStateUpdates);
@ -260,25 +286,22 @@ fn evm_state_to_hashed_post_state(update: EvmState) -> HashedPostState {
/// to the tree.
/// Then it updates relevant leaves according to the result of the transaction.
#[derive(Debug)]
pub struct StateRootTask<Factory, BPF: BlindedProviderFactory> {
pub struct StateRootTask<Factory> {
/// Task configuration.
config: StateRootConfig<Factory>,
/// Receiver for state root related messages.
rx: Receiver<StateRootMessage<BPF>>,
rx: Receiver<StateRootMessage>,
/// Sender for state root related messages.
tx: Sender<StateRootMessage<BPF>>,
tx: Sender<StateRootMessage>,
/// Proof targets that have been already fetched.
fetched_proof_targets: MultiProofTargets,
/// Proof sequencing handler.
proof_sequencer: ProofSequencer,
/// The sparse trie used for the state root calculation. If [`None`], then update is in
/// progress.
sparse_trie: Option<Box<SparseStateTrie<BPF>>>,
/// Reference to the shared thread pool for parallel proof generation
thread_pool: Arc<rayon::ThreadPool>,
}
impl<'env, Factory, BPF> StateRootTask<Factory, BPF>
impl<Factory> StateRootTask<Factory>
where
Factory: DatabaseProviderFactory<Provider: BlockReader>
+ StateCommitmentProvider
@ -286,47 +309,22 @@ where
+ Send
+ Sync
+ 'static,
BPF: BlindedProviderFactory + Send + Sync + 'env,
BPF::AccountNodeProvider: BlindedProvider + Send + Sync + 'env,
BPF::StorageNodeProvider: BlindedProvider + Send + Sync + 'env,
{
/// Creates a new state root task with the unified message channel
pub fn new(
config: StateRootConfig<Factory>,
blinded_provider: BPF,
thread_pool: Arc<rayon::ThreadPool>,
) -> Self {
pub fn new(config: StateRootConfig<Factory>, thread_pool: Arc<rayon::ThreadPool>) -> Self {
let (tx, rx) = channel();
Self {
config,
rx,
tx,
fetched_proof_targets: Default::default(),
proof_sequencer: ProofSequencer::new(),
sparse_trie: Some(Box::new(SparseStateTrie::new(blinded_provider).with_updates(true))),
thread_pool,
}
}
/// Spawns the state root task and returns a handle to await its result.
pub fn spawn<'scope>(self, scope: &'scope thread::Scope<'scope, 'env>) -> StateRootHandle {
let (tx, rx) = mpsc::sync_channel(1);
std::thread::Builder::new()
.name("State Root Task".to_string())
.spawn_scoped(scope, move || {
debug!(target: "engine::tree", "Starting state root task");
let result = rayon::scope(|scope| self.run(scope));
let _ = tx.send(result);
})
.expect("failed to spawn state root thread");
StateRootHandle::new(rx)
}
/// Returns a [`StateHookSender`] that can be used to send state updates to this task.
pub fn state_hook_sender(&self) -> StateHookSender<BPF> {
pub fn state_hook_sender(&self) -> StateHookSender {
StateHookSender::new(self.tx.clone())
}
@ -341,20 +339,56 @@ where
}
}
/// Spawns the state root task and returns a handle to await its result.
pub fn spawn(self) -> StateRootHandle {
let sparse_trie_tx =
Self::spawn_sparse_trie(self.thread_pool.clone(), self.config.clone(), self.tx.clone());
let (tx, rx) = mpsc::sync_channel(1);
std::thread::Builder::new()
.name("State Root Task".to_string())
.spawn(move || {
debug!(target: "engine::tree", "Starting state root task");
let result = self.run(sparse_trie_tx);
let _ = tx.send(result);
})
.expect("failed to spawn state root thread");
StateRootHandle::new(rx)
}
/// Spawn long running sparse trie task that forwards the final result upon completion.
fn spawn_sparse_trie(
thread_pool: Arc<rayon::ThreadPool>,
config: StateRootConfig<Factory>,
task_tx: Sender<StateRootMessage>,
) -> Sender<SparseTrieUpdate> {
let (tx, rx) = mpsc::channel();
thread_pool.spawn(move || {
debug!(target: "engine::tree", "Starting sparse trie task");
let result = match run_sparse_trie(config, rx) {
Ok((state_root, trie_updates, iterations)) => {
StateRootMessage::RootCalculated { state_root, trie_updates, iterations }
}
Err(error) => StateRootMessage::RootCalculationError(error),
};
let _ = task_tx.send(result);
});
tx
}
/// Handles request for proof prefetch.
fn on_prefetch_proof(
scope: &rayon::Scope<'env>,
config: StateRootConfig<Factory>,
targets: MultiProofTargets,
fetched_proof_targets: &mut MultiProofTargets,
proof_sequence_number: u64,
state_root_message_sender: Sender<StateRootMessage<BPF>>,
state_root_message_sender: Sender<StateRootMessage>,
thread_pool: Arc<rayon::ThreadPool>,
) {
extend_multi_proof_targets_ref(fetched_proof_targets, &targets);
Self::spawn_multiproof(
scope,
config,
Default::default(),
targets,
@ -368,12 +402,11 @@ where
///
/// Returns proof targets derived from the state update.
fn on_state_update(
scope: &rayon::Scope<'env>,
config: StateRootConfig<Factory>,
update: EvmState,
fetched_proof_targets: &mut MultiProofTargets,
proof_sequence_number: u64,
state_root_message_sender: Sender<StateRootMessage<BPF>>,
state_root_message_sender: Sender<StateRootMessage>,
thread_pool: Arc<rayon::ThreadPool>,
) {
let hashed_state_update = evm_state_to_hashed_post_state(update);
@ -382,7 +415,6 @@ where
extend_multi_proof_targets_ref(fetched_proof_targets, &proof_targets);
Self::spawn_multiproof(
scope,
config,
hashed_state_update,
proof_targets,
@ -393,16 +425,15 @@ where
}
fn spawn_multiproof(
scope: &rayon::Scope<'env>,
config: StateRootConfig<Factory>,
hashed_state_update: HashedPostState,
proof_targets: MultiProofTargets,
proof_sequence_number: u64,
state_root_message_sender: Sender<StateRootMessage<BPF>>,
state_root_message_sender: Sender<StateRootMessage>,
thread_pool: Arc<rayon::ThreadPool>,
) {
// Dispatch proof gathering for this state update
scope.spawn(move |_| {
thread_pool.clone().spawn(move || {
trace!(
target: "engine::root",
proof_sequence_number,
@ -422,10 +453,12 @@ where
Ok(proof) => {
let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated(
Box::new(ProofCalculated {
state_update: hashed_state_update,
targets: proof_targets,
proof,
sequence_number: proof_sequence_number,
update: SparseTrieUpdate {
state: hashed_state_update,
targets: proof_targets,
multiproof: proof,
},
}),
));
}
@ -441,77 +474,26 @@ where
fn on_proof(
&mut self,
sequence_number: u64,
state_update: HashedPostState,
targets: MultiProofTargets,
proof: MultiProof,
) -> Option<(HashedPostState, MultiProofTargets, MultiProof)> {
let ready_proofs =
self.proof_sequencer.add_proof(sequence_number, state_update, targets, proof);
update: SparseTrieUpdate,
) -> Option<SparseTrieUpdate> {
let ready_proofs = self.proof_sequencer.add_proof(sequence_number, update);
if ready_proofs.is_empty() {
None
} else {
// Merge all ready proofs and state updates
ready_proofs.into_iter().reduce(
|(mut acc_state_update, mut acc_targets, mut acc_proof),
(state_update, targets, proof)| {
acc_state_update.extend(state_update);
extend_multi_proof_targets(&mut acc_targets, targets);
acc_proof.extend(proof);
(acc_state_update, acc_targets, acc_proof)
},
)
ready_proofs.into_iter().reduce(|mut acc_update, update| {
acc_update.extend(update);
acc_update
})
}
}
/// Spawns root calculation with the current state and proofs.
fn spawn_root_calculation(
&mut self,
scope: &rayon::Scope<'env>,
state: HashedPostState,
targets: MultiProofTargets,
multiproof: MultiProof,
) {
let Some(trie) = self.sparse_trie.take() else { return };
trace!(
target: "engine::root",
account_proofs = multiproof.account_subtree.len(),
storage_proofs = multiproof.storages.len(),
"Spawning root calculation"
);
// TODO(alexey): store proof targets in `ProofSequecner` to avoid recomputing them
let targets = get_proof_targets(&state, &targets);
let tx = self.tx.clone();
scope.spawn(move |_| {
let result = update_sparse_trie(trie, multiproof, targets, state);
match result {
Ok((trie, elapsed)) => {
trace!(
target: "engine::root",
?elapsed,
"Root calculation completed, sending result"
);
let _ = tx.send(StateRootMessage::RootCalculated { trie, elapsed });
}
Err(e) => {
let _ = tx.send(StateRootMessage::RootCalculationError(e));
}
}
});
}
fn run(mut self, scope: &rayon::Scope<'env>) -> StateRootResult {
let mut current_state_update = HashedPostState::default();
let mut current_proof_targets = MultiProofTargets::default();
let mut current_multiproof = MultiProof::default();
fn run(mut self, sparse_trie_tx: Sender<SparseTrieUpdate>) -> StateRootResult {
let mut sparse_trie_tx = Some(sparse_trie_tx);
let mut updates_received = 0;
let mut proofs_processed = 0;
let mut roots_calculated = 0;
let mut updates_finished = false;
@ -530,7 +512,6 @@ where
"Prefetching proofs"
);
Self::on_prefetch_proof(
scope,
self.config.clone(),
targets,
&mut self.fetched_proof_targets,
@ -554,7 +535,6 @@ where
"Received new state update"
);
Self::on_state_update(
scope,
self.config.clone(),
update,
&mut self.fetched_proof_targets,
@ -576,105 +556,44 @@ where
"Processing calculated proof"
);
if let Some((
combined_state_update,
combined_proof_targets,
combined_proof,
)) = self.on_proof(
proof_calculated.sequence_number,
proof_calculated.state_update,
proof_calculated.targets,
proof_calculated.proof,
) {
if self.sparse_trie.is_none() {
current_state_update.extend(combined_state_update);
extend_multi_proof_targets(
&mut current_proof_targets,
combined_proof_targets,
);
current_multiproof.extend(combined_proof);
} else {
self.spawn_root_calculation(
scope,
combined_state_update,
combined_proof_targets,
combined_proof,
);
}
if let Some(combined_update) =
self.on_proof(proof_calculated.sequence_number, proof_calculated.update)
{
let _ = sparse_trie_tx
.as_ref()
.expect("tx not dropped")
.send(combined_update);
}
}
StateRootMessage::RootCalculated { trie, elapsed } => {
roots_calculated += 1;
debug!(
target: "engine::root",
?elapsed,
roots_calculated,
proofs = proofs_processed,
updates = updates_received,
"Computed intermediate root"
);
self.sparse_trie = Some(trie);
let has_new_proofs = !current_multiproof.account_subtree.is_empty() ||
!current_multiproof.storages.is_empty();
let all_proofs_received = proofs_processed >= updates_received;
let no_pending = !self.proof_sequencer.has_pending();
trace!(
target: "engine::root",
has_new_proofs,
all_proofs_received,
no_pending,
?updates_finished,
"State check"
);
// only spawn new calculation if we have accumulated new proofs
if has_new_proofs {
debug!(
target: "engine::root",
account_proofs = current_multiproof.account_subtree.len(),
storage_proofs = current_multiproof.storages.len(),
"Spawning subsequent root calculation"
);
self.spawn_root_calculation(
scope,
std::mem::take(&mut current_state_update),
std::mem::take(&mut current_proof_targets),
std::mem::take(&mut current_multiproof),
);
} else if all_proofs_received && no_pending && updates_finished {
let total_time = first_update_time
.expect("first update time should be set")
.elapsed();
let time_from_last_update =
last_update_time.expect("last update time should be set").elapsed();
debug!(
target: "engine::root",
total_updates = updates_received,
total_proofs = proofs_processed,
roots_calculated,
?total_time,
?time_from_last_update,
"All proofs processed, ending calculation"
);
let mut trie = self
.sparse_trie
.take()
.expect("sparse trie update should not be in progress");
let root = trie.root().expect("sparse trie should be revealed");
let trie_updates = trie
.take_trie_updates()
.expect("sparse trie should have updates retention enabled");
return Ok(StateRootComputeOutcome {
state_root: (root, trie_updates),
total_time,
time_from_last_update,
});
if all_proofs_received && no_pending && updates_finished {
// drop the sender
sparse_trie_tx.take();
debug!(target: "engine::root", total_updates = updates_received, total_proofs = proofs_processed, "All proofs processed, ending calculation");
}
}
StateRootMessage::RootCalculated { state_root, trie_updates, iterations } => {
let total_time =
first_update_time.expect("first update time should be set").elapsed();
let time_from_last_update =
last_update_time.expect("last update time should be set").elapsed();
debug!(
target: "engine::root",
total_updates = updates_received,
total_proofs = proofs_processed,
roots_calculated = iterations,
?total_time,
?time_from_last_update,
"All proofs processed, ending calculation"
);
return Ok(StateRootComputeOutcome {
state_root: (state_root, trie_updates),
total_time,
time_from_last_update,
});
}
StateRootMessage::ProofCalculationError(e) => {
return Err(ParallelStateRootError::Other(format!(
"could not calculate multiproof: {e:?}"
@ -702,6 +621,63 @@ where
}
}
/// Listen to incoming sparse trie updates and update the sparse trie.
/// Returns final state root, trie updates and the number of update iterations.
fn run_sparse_trie<Factory>(
config: StateRootConfig<Factory>,
update_rx: mpsc::Receiver<SparseTrieUpdate>,
) -> Result<(B256, TrieUpdates, u64), ParallelStateRootError>
where
Factory: DatabaseProviderFactory<Provider: BlockReader> + StateCommitmentProvider,
{
let provider_ro = config.consistent_view.provider_ro()?;
let in_memory_trie_cursor = InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
&config.nodes_sorted,
);
let blinded_provider_factory = ProofBlindedProviderFactory::new(
in_memory_trie_cursor.clone(),
HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
&config.state_sorted,
),
config.prefix_sets.clone(),
);
let mut num_iterations = 0;
let mut trie = SparseStateTrie::new(blinded_provider_factory).with_updates(true);
while let Ok(mut update) = update_rx.recv() {
num_iterations += 1;
let mut num_updates = 1;
while let Ok(next) = update_rx.try_recv() {
update.extend(next);
num_updates += 1;
}
debug!(
target: "engine::root",
num_updates,
account_proofs = update.multiproof.account_subtree.len(),
storage_proofs = update.multiproof.storages.len(),
"Updating sparse trie"
);
// TODO: alexey to remind me why we are doing this
update.targets = get_proof_targets(&update.state, &update.targets);
let elapsed = update_sparse_trie(&mut trie, update).map_err(|e| {
ParallelStateRootError::Other(format!("could not calculate state root: {e:?}"))
})?;
trace!(target: "engine::root", ?elapsed, "Root calculation completed");
}
debug!(target: "engine::root", num_iterations, "All proofs processed, ending calculation");
let root = trie.root().expect("sparse trie should be revealed");
let trie_updates = trie.take_trie_updates().expect("retention must be enabled");
Ok((root, trie_updates, num_iterations))
}
/// Returns accounts only with those storages that were not already fetched, and
/// if there are no such storages and the account itself was already fetched, the
/// account shouldn't be included.
@ -757,14 +733,11 @@ where
.multiproof(proof_targets)?)
}
/// Updates the sparse trie with the given proofs and state, and returns the updated trie and the
/// time it took.
/// Updates the sparse trie with the given proofs and state, and returns the elapsed time.
fn update_sparse_trie<BPF>(
mut trie: Box<SparseStateTrie<BPF>>,
multiproof: MultiProof,
targets: MultiProofTargets,
state: HashedPostState,
) -> SparseStateTrieResult<(Box<SparseStateTrie<BPF>>, Duration)>
trie: &mut SparseStateTrie<BPF>,
SparseTrieUpdate { state, targets, multiproof }: SparseTrieUpdate,
) -> SparseStateTrieResult<Duration>
where
BPF: BlindedProviderFactory + Send + Sync,
BPF::AccountNodeProvider: BlindedProvider + Send + Sync,
@ -825,7 +798,7 @@ where
trie.calculate_below_level(SPARSE_TRIE_INCREMENTAL_LEVEL);
let elapsed = started_at.elapsed();
Ok((trie, elapsed))
Ok(elapsed)
}
fn extend_multi_proof_targets(targets: &mut MultiProofTargets, other: MultiProofTargets) {
@ -848,17 +821,14 @@ mod tests {
providers::ConsistentDbView, test_utils::create_test_provider_factory, HashingWriter,
};
use reth_testing_utils::generators::{self, Rng};
use reth_trie::{
hashed_cursor::HashedPostStateCursorFactory, proof::ProofBlindedProviderFactory,
test_utils::state_root, trie_cursor::InMemoryTrieCursorFactory, TrieInput,
};
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
use reth_trie::{test_utils::state_root, TrieInput};
use revm_primitives::{
Account as RevmAccount, AccountInfo, AccountStatus, Address, EvmState, EvmStorageSlot,
HashMap, B256, KECCAK_EMPTY, U256,
};
use std::sync::Arc;
#[allow(dead_code)]
fn convert_revm_to_reth_account(revm_account: &RevmAccount) -> RethAccount {
RethAccount {
balance: revm_account.info.balance,
@ -871,6 +841,7 @@ mod tests {
}
}
#[allow(dead_code)]
fn create_mock_state_updates(num_accounts: usize, updates_per_account: usize) -> Vec<EvmState> {
let mut rng = generators::rng();
let all_addresses: Vec<Address> = (0..num_accounts).map(|_| rng.gen()).collect();
@ -914,7 +885,9 @@ mod tests {
updates
}
#[test]
// TODO: re-enable test once gh worker hang is figured out.
// #[test]
#[allow(dead_code)]
fn test_state_root_task() {
reth_tracing::init_test_tracing();
@ -970,24 +943,15 @@ mod tests {
let state_sorted = Arc::new(input.state.clone().into_sorted());
let config = StateRootConfig {
consistent_view: ConsistentDbView::new(factory, None),
nodes_sorted: nodes_sorted.clone(),
state_sorted: state_sorted.clone(),
nodes_sorted,
state_sorted,
prefix_sets: Arc::new(input.prefix_sets),
};
let provider = config.consistent_view.provider_ro().unwrap();
let blinded_provider_factory = ProofBlindedProviderFactory::new(
InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider.tx_ref()),
&nodes_sorted,
),
HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(provider.tx_ref()),
&state_sorted,
),
config.prefix_sets.clone(),
);
// The thread pool requires at least 2 threads as it contains a long running sparse trie
// task.
let num_threads =
std::thread::available_parallelism().map_or(1, |num| (num.get() / 2).max(1));
std::thread::available_parallelism().map_or(2, |num| (num.get() / 2).max(2));
let state_root_task_pool = rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
@ -995,23 +959,16 @@ mod tests {
.build()
.expect("Failed to create proof worker thread pool");
let (root_from_task, _) = std::thread::scope(|std_scope| {
let task = StateRootTask::new(
config,
blinded_provider_factory,
Arc::new(state_root_task_pool),
);
let mut state_hook = task.state_hook();
let handle = task.spawn(std_scope);
let task = StateRootTask::new(config, Arc::new(state_root_task_pool));
let mut state_hook = task.state_hook();
let handle = task.spawn();
for update in state_updates {
state_hook.on_state(&update);
}
drop(state_hook);
for update in state_updates {
state_hook.on_state(&update);
}
drop(state_hook);
handle.wait_for_result().expect("task failed")
})
.state_root;
let (root_from_task, _) = handle.wait_for_result().expect("task failed").state_root;
let root_from_base = state_root(accumulated_state);
assert_eq!(
@ -1027,21 +984,11 @@ mod tests {
let proof2 = MultiProof::default();
sequencer.next_sequence = 2;
let ready = sequencer.add_proof(
0,
HashedPostState::default(),
MultiProofTargets::default(),
proof1,
);
let ready = sequencer.add_proof(0, SparseTrieUpdate::from_multiproof(proof1));
assert_eq!(ready.len(), 1);
assert!(!sequencer.has_pending());
let ready = sequencer.add_proof(
1,
HashedPostState::default(),
MultiProofTargets::default(),
proof2,
);
let ready = sequencer.add_proof(1, SparseTrieUpdate::from_multiproof(proof2));
assert_eq!(ready.len(), 1);
assert!(!sequencer.has_pending());
}
@ -1054,30 +1001,15 @@ mod tests {
let proof3 = MultiProof::default();
sequencer.next_sequence = 3;
let ready = sequencer.add_proof(
2,
HashedPostState::default(),
MultiProofTargets::default(),
proof3,
);
let ready = sequencer.add_proof(2, SparseTrieUpdate::from_multiproof(proof3));
assert_eq!(ready.len(), 0);
assert!(sequencer.has_pending());
let ready = sequencer.add_proof(
0,
HashedPostState::default(),
MultiProofTargets::default(),
proof1,
);
let ready = sequencer.add_proof(0, SparseTrieUpdate::from_multiproof(proof1));
assert_eq!(ready.len(), 1);
assert!(sequencer.has_pending());
let ready = sequencer.add_proof(
1,
HashedPostState::default(),
MultiProofTargets::default(),
proof2,
);
let ready = sequencer.add_proof(1, SparseTrieUpdate::from_multiproof(proof2));
assert_eq!(ready.len(), 2);
assert!(!sequencer.has_pending());
}
@ -1089,20 +1021,10 @@ mod tests {
let proof3 = MultiProof::default();
sequencer.next_sequence = 3;
let ready = sequencer.add_proof(
0,
HashedPostState::default(),
MultiProofTargets::default(),
proof1,
);
let ready = sequencer.add_proof(0, SparseTrieUpdate::from_multiproof(proof1));
assert_eq!(ready.len(), 1);
let ready = sequencer.add_proof(
2,
HashedPostState::default(),
MultiProofTargets::default(),
proof3,
);
let ready = sequencer.add_proof(2, SparseTrieUpdate::from_multiproof(proof3));
assert_eq!(ready.len(), 0);
assert!(sequencer.has_pending());
}
@ -1113,20 +1035,10 @@ mod tests {
let proof1 = MultiProof::default();
let proof2 = MultiProof::default();
let ready = sequencer.add_proof(
0,
HashedPostState::default(),
MultiProofTargets::default(),
proof1,
);
let ready = sequencer.add_proof(0, SparseTrieUpdate::from_multiproof(proof1));
assert_eq!(ready.len(), 1);
let ready = sequencer.add_proof(
0,
HashedPostState::default(),
MultiProofTargets::default(),
proof2,
);
let ready = sequencer.add_proof(0, SparseTrieUpdate::from_multiproof(proof2));
assert_eq!(ready.len(), 0);
assert!(!sequencer.has_pending());
}
@ -1137,37 +1049,12 @@ mod tests {
let proofs: Vec<_> = (0..5).map(|_| MultiProof::default()).collect();
sequencer.next_sequence = 5;
sequencer.add_proof(
4,
HashedPostState::default(),
MultiProofTargets::default(),
proofs[4].clone(),
);
sequencer.add_proof(
2,
HashedPostState::default(),
MultiProofTargets::default(),
proofs[2].clone(),
);
sequencer.add_proof(
1,
HashedPostState::default(),
MultiProofTargets::default(),
proofs[1].clone(),
);
sequencer.add_proof(
3,
HashedPostState::default(),
MultiProofTargets::default(),
proofs[3].clone(),
);
sequencer.add_proof(4, SparseTrieUpdate::from_multiproof(proofs[4].clone()));
sequencer.add_proof(2, SparseTrieUpdate::from_multiproof(proofs[2].clone()));
sequencer.add_proof(1, SparseTrieUpdate::from_multiproof(proofs[1].clone()));
sequencer.add_proof(3, SparseTrieUpdate::from_multiproof(proofs[3].clone()));
let ready = sequencer.add_proof(
0,
HashedPostState::default(),
MultiProofTargets::default(),
proofs[0].clone(),
);
let ready = sequencer.add_proof(0, SparseTrieUpdate::from_multiproof(proofs[0].clone()));
assert_eq!(ready.len(), 5);
assert!(!sequencer.has_pending());
}