perf(engine): cache proof targets in proof sequencer for state root task (#13310)

This commit is contained in:
Alexey Shekhirin
2024-12-12 16:28:12 +00:00
committed by GitHub
parent 5ef21cdfec
commit 6ff2510ad9
8 changed files with 206 additions and 88 deletions

View File

@ -1,6 +1,6 @@
//! State root task related functionality.
use alloy_primitives::map::{HashMap, HashSet};
use alloy_primitives::map::HashSet;
use rayon::iter::{ParallelBridge, ParallelIterator};
use reth_evm::system_calls::OnStateHook;
use reth_execution_errors::StateProofError;
@ -75,14 +75,7 @@ pub enum StateRootMessage<BPF: BlindedProviderFactory> {
/// New state update from transaction execution
StateUpdate(EvmState),
/// Proof calculation completed for a specific state update
ProofCalculated {
/// The calculated proof
proof: MultiProof,
/// The state update that was used to calculate the proof
state_update: HashedPostState,
/// The index of this proof in the sequence of state updates
sequence_number: u64,
},
ProofCalculated(Box<ProofCalculated>),
/// Error during proof calculation
ProofCalculationError(StateProofError),
/// State root calculation completed
@ -98,6 +91,19 @@ pub enum StateRootMessage<BPF: BlindedProviderFactory> {
FinishedStateUpdates,
}
/// Message about completion of proof calculation for a specific state update
#[derive(Debug)]
pub struct ProofCalculated {
/// The state update that was used to calculate the proof
state_update: HashedPostState,
/// The proof targets
targets: MultiProofTargets,
/// The calculated proof
proof: MultiProof,
/// The index of this proof in the sequence of state updates
sequence_number: u64,
}
/// Handle to track proof calculation ordering
#[derive(Debug, Default)]
pub(crate) struct ProofSequencer {
@ -106,7 +112,7 @@ pub(crate) struct ProofSequencer {
/// The next sequence number expected to be delivered.
next_to_deliver: u64,
/// Buffer for out-of-order proofs and corresponding state updates
pending_proofs: BTreeMap<u64, (MultiProof, HashedPostState)>,
pending_proofs: BTreeMap<u64, (HashedPostState, MultiProofTargets, MultiProof)>,
}
impl ProofSequencer {
@ -127,11 +133,12 @@ impl ProofSequencer {
pub(crate) fn add_proof(
&mut self,
sequence: u64,
proof: MultiProof,
state_update: HashedPostState,
) -> Vec<(MultiProof, HashedPostState)> {
targets: MultiProofTargets,
proof: MultiProof,
) -> Vec<(HashedPostState, MultiProofTargets, MultiProof)> {
if sequence >= self.next_to_deliver {
self.pending_proofs.insert(sequence, (proof, state_update));
self.pending_proofs.insert(sequence, (state_update, targets, proof));
}
// return early if we don't have the next expected proof
@ -143,8 +150,8 @@ impl ProofSequencer {
let mut current_sequence = self.next_to_deliver;
// keep collecting proofs and state updates as long as we have consecutive sequence numbers
while let Some((proof, state_update)) = self.pending_proofs.remove(&current_sequence) {
consecutive_proofs.push((proof, state_update));
while let Some(pending) = self.pending_proofs.remove(&current_sequence) {
consecutive_proofs.push(pending);
current_sequence += 1;
// if we don't have the next number, stop collecting
@ -319,9 +326,7 @@ where
let hashed_state_update = evm_state_to_hashed_post_state(update);
let proof_targets = get_proof_targets(&hashed_state_update, fetched_proof_targets);
for (address, slots) in &proof_targets {
fetched_proof_targets.entry(*address).or_default().extend(slots)
}
fetched_proof_targets.extend_ref(&proof_targets);
// Dispatch proof gathering for this state update
scope.spawn(move |_| {
@ -338,15 +343,18 @@ where
provider.tx_ref(),
// TODO(alexey): this clone can be expensive, we should avoid it
input.as_ref().clone(),
proof_targets,
proof_targets.clone(),
);
match result {
Ok(proof) => {
let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated {
proof,
let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated(
Box::new(ProofCalculated {
state_update: hashed_state_update,
targets: proof_targets,
proof,
sequence_number: proof_sequence_number,
});
}),
));
}
Err(e) => {
let _ =
@ -360,18 +368,21 @@ where
fn on_proof(
&mut self,
sequence_number: u64,
proof: MultiProof,
state_update: HashedPostState,
) -> Option<(MultiProof, HashedPostState)> {
let ready_proofs = self.proof_sequencer.add_proof(sequence_number, proof, state_update);
targets: MultiProofTargets,
proof: MultiProof,
) -> Option<(HashedPostState, MultiProofTargets, MultiProof)> {
let ready_proofs =
self.proof_sequencer.add_proof(sequence_number, state_update, targets, proof);
if ready_proofs.is_empty() {
None
} else {
// Merge all ready proofs and state updates
ready_proofs.into_iter().reduce(|mut acc, (proof, state_update)| {
acc.0.extend(proof);
acc.1.extend(state_update);
ready_proofs.into_iter().reduce(|mut acc, (state_update, targets, proof)| {
acc.0.extend(state_update);
acc.1.extend(targets);
acc.2.extend(proof);
acc
})
}
@ -382,6 +393,7 @@ where
&mut self,
scope: &rayon::Scope<'env>,
state: HashedPostState,
targets: MultiProofTargets,
multiproof: MultiProof,
) {
let Some(trie) = self.sparse_trie.take() else { return };
@ -394,7 +406,7 @@ where
);
// TODO(alexey): store proof targets in `ProofSequecner` to avoid recomputing them
let targets = get_proof_targets(&state, &HashMap::default());
let targets = get_proof_targets(&state, &targets);
let tx = self.tx.clone();
scope.spawn(move |_| {
@ -417,6 +429,7 @@ where
fn run(mut self, scope: &rayon::Scope<'env>) -> StateRootResult {
let mut current_state_update = HashedPostState::default();
let mut current_proof_targets = MultiProofTargets::default();
let mut current_multiproof = MultiProof::default();
let mut updates_received = 0;
let mut proofs_processed = 0;
@ -447,27 +460,36 @@ where
StateRootMessage::FinishedStateUpdates => {
updates_finished = true;
}
StateRootMessage::ProofCalculated { proof, state_update, sequence_number } => {
StateRootMessage::ProofCalculated(proof_calculated) => {
proofs_processed += 1;
trace!(
target: "engine::root",
sequence = sequence_number,
sequence = proof_calculated.sequence_number,
total_proofs = proofs_processed,
"Processing calculated proof"
);
trace!(target: "engine::root", ?proof, "Proof calculated");
trace!(target: "engine::root", proof = ?proof_calculated.proof, "Proof calculated");
if let Some((combined_proof, combined_state_update)) =
self.on_proof(sequence_number, proof, state_update)
{
if let Some((
combined_state_update,
combined_proof_targets,
combined_proof,
)) = self.on_proof(
proof_calculated.sequence_number,
proof_calculated.state_update,
proof_calculated.targets,
proof_calculated.proof,
) {
if self.sparse_trie.is_none() {
current_multiproof.extend(combined_proof);
current_state_update.extend(combined_state_update);
current_proof_targets.extend(combined_proof_targets);
current_multiproof.extend(combined_proof);
} else {
self.spawn_root_calculation(
scope,
combined_state_update,
combined_proof_targets,
combined_proof,
);
}
@ -509,6 +531,7 @@ where
self.spawn_root_calculation(
scope,
std::mem::take(&mut current_state_update),
std::mem::take(&mut current_proof_targets),
std::mem::take(&mut current_multiproof),
);
} else if all_proofs_received && no_pending && updates_finished {
@ -564,7 +587,7 @@ fn get_proof_targets(
state_update: &HashedPostState,
fetched_proof_targets: &MultiProofTargets,
) -> MultiProofTargets {
let mut targets = HashMap::default();
let mut targets = MultiProofTargets::default();
// first collect all new accounts (not previously fetched)
for &hashed_address in state_update.accounts.keys() {
@ -830,11 +853,21 @@ mod tests {
let proof2 = MultiProof::default();
sequencer.next_sequence = 2;
let ready = sequencer.add_proof(0, proof1, HashedPostState::default());
let ready = sequencer.add_proof(
0,
HashedPostState::default(),
MultiProofTargets::default(),
proof1,
);
assert_eq!(ready.len(), 1);
assert!(!sequencer.has_pending());
let ready = sequencer.add_proof(1, proof2, HashedPostState::default());
let ready = sequencer.add_proof(
1,
HashedPostState::default(),
MultiProofTargets::default(),
proof2,
);
assert_eq!(ready.len(), 1);
assert!(!sequencer.has_pending());
}
@ -847,15 +880,30 @@ mod tests {
let proof3 = MultiProof::default();
sequencer.next_sequence = 3;
let ready = sequencer.add_proof(2, proof3, HashedPostState::default());
let ready = sequencer.add_proof(
2,
HashedPostState::default(),
MultiProofTargets::default(),
proof3,
);
assert_eq!(ready.len(), 0);
assert!(sequencer.has_pending());
let ready = sequencer.add_proof(0, proof1, HashedPostState::default());
let ready = sequencer.add_proof(
0,
HashedPostState::default(),
MultiProofTargets::default(),
proof1,
);
assert_eq!(ready.len(), 1);
assert!(sequencer.has_pending());
let ready = sequencer.add_proof(1, proof2, HashedPostState::default());
let ready = sequencer.add_proof(
1,
HashedPostState::default(),
MultiProofTargets::default(),
proof2,
);
assert_eq!(ready.len(), 2);
assert!(!sequencer.has_pending());
}
@ -867,10 +915,20 @@ mod tests {
let proof3 = MultiProof::default();
sequencer.next_sequence = 3;
let ready = sequencer.add_proof(0, proof1, HashedPostState::default());
let ready = sequencer.add_proof(
0,
HashedPostState::default(),
MultiProofTargets::default(),
proof1,
);
assert_eq!(ready.len(), 1);
let ready = sequencer.add_proof(2, proof3, HashedPostState::default());
let ready = sequencer.add_proof(
2,
HashedPostState::default(),
MultiProofTargets::default(),
proof3,
);
assert_eq!(ready.len(), 0);
assert!(sequencer.has_pending());
}
@ -881,10 +939,20 @@ mod tests {
let proof1 = MultiProof::default();
let proof2 = MultiProof::default();
let ready = sequencer.add_proof(0, proof1, HashedPostState::default());
let ready = sequencer.add_proof(
0,
HashedPostState::default(),
MultiProofTargets::default(),
proof1,
);
assert_eq!(ready.len(), 1);
let ready = sequencer.add_proof(0, proof2, HashedPostState::default());
let ready = sequencer.add_proof(
0,
HashedPostState::default(),
MultiProofTargets::default(),
proof2,
);
assert_eq!(ready.len(), 0);
assert!(!sequencer.has_pending());
}
@ -895,12 +963,37 @@ mod tests {
let proofs: Vec<_> = (0..5).map(|_| MultiProof::default()).collect();
sequencer.next_sequence = 5;
sequencer.add_proof(4, proofs[4].clone(), HashedPostState::default());
sequencer.add_proof(2, proofs[2].clone(), HashedPostState::default());
sequencer.add_proof(1, proofs[1].clone(), HashedPostState::default());
sequencer.add_proof(3, proofs[3].clone(), HashedPostState::default());
sequencer.add_proof(
4,
HashedPostState::default(),
MultiProofTargets::default(),
proofs[4].clone(),
);
sequencer.add_proof(
2,
HashedPostState::default(),
MultiProofTargets::default(),
proofs[2].clone(),
);
sequencer.add_proof(
1,
HashedPostState::default(),
MultiProofTargets::default(),
proofs[1].clone(),
);
sequencer.add_proof(
3,
HashedPostState::default(),
MultiProofTargets::default(),
proofs[3].clone(),
);
let ready = sequencer.add_proof(0, proofs[0].clone(), HashedPostState::default());
let ready = sequencer.add_proof(
0,
HashedPostState::default(),
MultiProofTargets::default(),
proofs[0].clone(),
);
assert_eq!(ready.len(), 5);
assert!(!sequencer.has_pending());
}
@ -926,7 +1019,7 @@ mod tests {
#[test]
fn test_get_proof_targets_new_account_targets() {
let state = create_get_proof_targets_state();
let fetched = HashMap::default();
let fetched = MultiProofTargets::default();
let targets = get_proof_targets(&state, &fetched);
@ -940,7 +1033,7 @@ mod tests {
#[test]
fn test_get_proof_targets_new_storage_targets() {
let state = create_get_proof_targets_state();
let fetched = HashMap::default();
let fetched = MultiProofTargets::default();
let targets = get_proof_targets(&state, &fetched);
@ -958,7 +1051,7 @@ mod tests {
#[test]
fn test_get_proof_targets_filter_already_fetched_accounts() {
let state = create_get_proof_targets_state();
let mut fetched = HashMap::default();
let mut fetched = MultiProofTargets::default();
// select an account that has no storage updates
let fetched_addr = state
@ -981,7 +1074,7 @@ mod tests {
#[test]
fn test_get_proof_targets_filter_already_fetched_storage() {
let state = create_get_proof_targets_state();
let mut fetched = HashMap::default();
let mut fetched = MultiProofTargets::default();
// mark one storage slot as already fetched
let (addr, storage) = state.storages.iter().next().unwrap();
@ -1001,7 +1094,7 @@ mod tests {
#[test]
fn test_get_proof_targets_empty_state() {
let state = HashedPostState::default();
let fetched = HashMap::default();
let fetched = MultiProofTargets::default();
let targets = get_proof_targets(&state, &fetched);
@ -1011,7 +1104,7 @@ mod tests {
#[test]
fn test_get_proof_targets_mixed_fetched_state() {
let mut state = HashedPostState::default();
let mut fetched = HashMap::default();
let mut fetched = MultiProofTargets::default();
let addr1 = B256::random();
let addr2 = B256::random();
@ -1040,7 +1133,7 @@ mod tests {
#[test]
fn test_get_proof_targets_unmodified_account_with_storage() {
let mut state = HashedPostState::default();
let fetched = HashMap::default();
let fetched = MultiProofTargets::default();
let addr = B256::random();
let slot1 = B256::random();

View File

@ -13,11 +13,29 @@ use alloy_trie::{
proof::{verify_proof, ProofNodes, ProofVerificationError},
TrieMask, EMPTY_ROOT_HASH,
};
use derive_more::derive::{Deref, DerefMut, From, Into, IntoIterator};
use itertools::Itertools;
use reth_primitives_traits::Account;
/// Proof targets map.
pub type MultiProofTargets = B256HashMap<B256HashSet>;
#[derive(Debug, Default, Clone, Deref, DerefMut, From, Into, IntoIterator)]
pub struct MultiProofTargets(B256HashMap<B256HashSet>);
impl MultiProofTargets {
/// Extends the proof targets map with another one.
pub fn extend(&mut self, other: Self) {
for (address, slots) in other.0 {
self.0.entry(address).or_default().extend(slots);
}
}
/// Extends the proof targets map with another one by reference.
pub fn extend_ref(&mut self, other: &Self) {
for (address, slots) in &other.0 {
self.0.entry(*address).or_default().extend(slots);
}
}
}
/// The state multiproof of target accounts and multiproofs of their storage tries.
/// Multiproof is effectively a state subtrie that only contains the nodes

View File

@ -1,16 +1,13 @@
use crate::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
use alloy_primitives::{
keccak256,
map::{B256HashMap, B256HashSet, HashMap},
Address, B256,
};
use alloy_primitives::{keccak256, map::HashMap, Address, B256};
use reth_db_api::transaction::DbTx;
use reth_execution_errors::StateProofError;
use reth_trie::{
hashed_cursor::HashedPostStateCursorFactory,
proof::{Proof, StorageProof},
trie_cursor::InMemoryTrieCursorFactory,
AccountProof, HashedPostStateSorted, HashedStorage, MultiProof, StorageMultiProof, TrieInput,
AccountProof, HashedPostStateSorted, HashedStorage, MultiProof, MultiProofTargets,
StorageMultiProof, TrieInput,
};
/// Extends [`Proof`] with operations specific for working with a database transaction.
@ -30,7 +27,7 @@ pub trait DatabaseProof<'a, TX> {
fn overlay_multiproof(
tx: &'a TX,
input: TrieInput,
targets: B256HashMap<B256HashSet>,
targets: MultiProofTargets,
) -> Result<MultiProof, StateProofError>;
}
@ -66,7 +63,7 @@ impl<'a, TX: DbTx> DatabaseProof<'a, TX>
fn overlay_multiproof(
tx: &'a TX,
input: TrieInput,
targets: B256HashMap<B256HashSet>,
targets: MultiProofTargets,
) -> Result<MultiProof, StateProofError> {
let nodes_sorted = input.nodes.into_sorted();
let state_sorted = input.state.into_sorted();

View File

@ -39,7 +39,9 @@ fn includes_empty_node_preimage() {
let state_root = StateRoot::from_tx(provider.tx_ref()).root().unwrap();
let multiproof = Proof::from_tx(provider.tx_ref())
.multiproof(HashMap::from_iter([(hashed_address, HashSet::from_iter([hashed_slot]))]))
.multiproof(
HashMap::from_iter([(hashed_address, HashSet::from_iter([hashed_slot]))]).into(),
)
.unwrap();
let witness = TrieWitness::from_tx(provider.tx_ref())
@ -77,7 +79,9 @@ fn includes_nodes_for_destroyed_storage_nodes() {
let state_root = StateRoot::from_tx(provider.tx_ref()).root().unwrap();
let multiproof = Proof::from_tx(provider.tx_ref())
.multiproof(HashMap::from_iter([(hashed_address, HashSet::from_iter([hashed_slot]))]))
.multiproof(
HashMap::from_iter([(hashed_address, HashSet::from_iter([hashed_slot]))]).into(),
)
.unwrap();
let witness =
@ -122,10 +126,13 @@ fn correctly_decodes_branch_node_values() {
let state_root = StateRoot::from_tx(provider.tx_ref()).root().unwrap();
let multiproof = Proof::from_tx(provider.tx_ref())
.multiproof(HashMap::from_iter([(
.multiproof(
HashMap::from_iter([(
hashed_address,
HashSet::from_iter([hashed_slot1, hashed_slot2]),
)]))
)])
.into(),
)
.unwrap();
let witness = TrieWitness::from_tx(provider.tx_ref())

View File

@ -15,7 +15,8 @@ use reth_primitives_traits::Account;
use reth_tracing::tracing::trace;
use reth_trie_common::{
updates::{StorageTrieUpdates, TrieUpdates},
MultiProof, Nibbles, TrieAccount, TrieNode, EMPTY_ROOT_HASH, TRIE_ACCOUNT_RLP_MAX_SIZE,
MultiProof, MultiProofTargets, Nibbles, TrieAccount, TrieNode, EMPTY_ROOT_HASH,
TRIE_ACCOUNT_RLP_MAX_SIZE,
};
use std::{fmt, iter::Peekable};
@ -206,7 +207,7 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
/// NOTE: This method does not extensively validate the proof.
pub fn reveal_multiproof(
&mut self,
targets: B256HashMap<B256HashSet>,
targets: MultiProofTargets,
multiproof: MultiProof,
) -> SparseStateTrieResult<()> {
let account_subtree = multiproof.account_subtree.into_nodes_sorted();
@ -559,7 +560,8 @@ mod tests {
HashMap::from_iter([
(address_1, HashSet::from_iter([slot_1, slot_2])),
(address_2, HashSet::from_iter([slot_1, slot_2])),
]),
])
.into(),
MultiProof {
account_subtree: proof_nodes,
branch_node_hash_masks: HashMap::from_iter([(

View File

@ -91,7 +91,7 @@ where
let proof =
Proof::new(self.trie_cursor_factory.clone(), self.hashed_cursor_factory.clone())
.with_prefix_sets_mut(self.prefix_sets.as_ref().clone())
.multiproof(targets)
.multiproof(targets.into())
.map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?;
Ok(proof.account_subtree.into_inner().remove(path))

View File

@ -14,7 +14,8 @@ use alloy_primitives::{
use alloy_rlp::{BufMut, Encodable};
use reth_execution_errors::trie::StateProofError;
use reth_trie_common::{
proof::ProofRetainer, AccountProof, MultiProof, StorageMultiProof, TrieAccount,
proof::ProofRetainer, AccountProof, MultiProof, MultiProofTargets, StorageMultiProof,
TrieAccount,
};
mod blinded;
@ -93,17 +94,17 @@ where
slots: &[B256],
) -> Result<AccountProof, StateProofError> {
Ok(self
.multiproof(HashMap::from_iter([(
keccak256(address),
slots.iter().map(keccak256).collect(),
)]))?
.multiproof(
HashMap::from_iter([(keccak256(address), slots.iter().map(keccak256).collect())])
.into(),
)?
.account_proof(address, slots)?)
}
/// Generate a state multiproof according to specified targets.
pub fn multiproof(
mut self,
mut targets: B256HashMap<B256HashSet>,
mut targets: MultiProofTargets,
) -> Result<MultiProof, StateProofError> {
let hashed_account_cursor = self.hashed_cursor_factory.hashed_account_cursor()?;
let trie_cursor = self.trie_cursor_factory.account_trie_cursor()?;

View File

@ -15,7 +15,7 @@ use reth_execution_errors::{
SparseStateTrieError, SparseStateTrieErrorKind, SparseTrieError, SparseTrieErrorKind,
StateProofError, TrieWitnessError,
};
use reth_trie_common::Nibbles;
use reth_trie_common::{MultiProofTargets, Nibbles};
use reth_trie_sparse::{
blinded::{BlindedProvider, BlindedProviderFactory},
SparseStateTrie,
@ -171,8 +171,8 @@ where
fn get_proof_targets(
&self,
state: &HashedPostState,
) -> Result<B256HashMap<B256HashSet>, StateProofError> {
let mut proof_targets = B256HashMap::default();
) -> Result<MultiProofTargets, StateProofError> {
let mut proof_targets = MultiProofTargets::default();
for hashed_address in state.accounts.keys() {
proof_targets.insert(*hashed_address, B256HashSet::default());
}