diff --git a/crates/trie/common/src/proofs.rs b/crates/trie/common/src/proofs.rs index 517f9fb7c..99b315d24 100644 --- a/crates/trie/common/src/proofs.rs +++ b/crates/trie/common/src/proofs.rs @@ -11,7 +11,7 @@ use alloy_rlp::{encode_fixed_size, Decodable, EMPTY_STRING_CODE}; use alloy_trie::{ nodes::TrieNode, proof::{verify_proof, ProofNodes, ProofVerificationError}, - EMPTY_ROOT_HASH, + TrieMask, EMPTY_ROOT_HASH, }; use itertools::Itertools; use reth_primitives_traits::Account; @@ -23,6 +23,8 @@ use reth_primitives_traits::Account; pub struct MultiProof { /// State trie multiproof for requested accounts. pub account_subtree: ProofNodes, + /// The hash masks of the branch nodes in the account proof. + pub branch_node_hash_masks: HashMap, /// Storage trie multiproofs. pub storages: HashMap, } @@ -108,11 +110,15 @@ impl MultiProof { pub fn extend(&mut self, other: Self) { self.account_subtree.extend_from(other.account_subtree); + self.branch_node_hash_masks.extend(other.branch_node_hash_masks); + for (hashed_address, storage) in other.storages { match self.storages.entry(hashed_address) { hash_map::Entry::Occupied(mut entry) => { debug_assert_eq!(entry.get().root, storage.root); - entry.get_mut().subtree.extend_from(storage.subtree); + let entry = entry.get_mut(); + entry.subtree.extend_from(storage.subtree); + entry.branch_node_hash_masks.extend(storage.branch_node_hash_masks); } hash_map::Entry::Vacant(entry) => { entry.insert(storage); @@ -129,6 +135,8 @@ pub struct StorageMultiProof { pub root: B256, /// Storage multiproof for requested slots. pub subtree: ProofNodes, + /// The hash masks of the branch nodes in the storage proof. + pub branch_node_hash_masks: HashMap, } impl StorageMultiProof { @@ -140,6 +148,7 @@ impl StorageMultiProof { Nibbles::default(), Bytes::from([EMPTY_STRING_CODE]), )]), + branch_node_hash_masks: HashMap::default(), } } @@ -380,14 +389,28 @@ mod tests { Nibbles::from_nibbles(vec![0]), alloy_rlp::encode_fixed_size(&U256::from(42)).to_vec().into(), ); - proof1.storages.insert(addr, StorageMultiProof { root, subtree: subtree1 }); + proof1.storages.insert( + addr, + StorageMultiProof { + root, + subtree: subtree1, + branch_node_hash_masks: HashMap::default(), + }, + ); let mut subtree2 = ProofNodes::default(); subtree2.insert( Nibbles::from_nibbles(vec![1]), alloy_rlp::encode_fixed_size(&U256::from(43)).to_vec().into(), ); - proof2.storages.insert(addr, StorageMultiProof { root, subtree: subtree2 }); + proof2.storages.insert( + addr, + StorageMultiProof { + root, + subtree: subtree2, + branch_node_hash_masks: HashMap::default(), + }, + ); proof1.extend(proof2); diff --git a/crates/trie/parallel/src/proof.rs b/crates/trie/parallel/src/proof.rs index f90a53fa9..148f7cd5d 100644 --- a/crates/trie/parallel/src/proof.rs +++ b/crates/trie/parallel/src/proof.rs @@ -35,6 +35,8 @@ pub struct ParallelProof { view: ConsistentDbView, /// Trie input. input: Arc, + /// Flag indicating whether to include branch node hash masks in the proof. + collect_branch_node_hash_masks: bool, /// Parallel state root metrics. #[cfg(feature = "metrics")] metrics: ParallelStateRootMetrics, @@ -46,10 +48,17 @@ impl ParallelProof { Self { view, input, + collect_branch_node_hash_masks: false, #[cfg(feature = "metrics")] metrics: ParallelStateRootMetrics::default(), } } + + /// Set the flag indicating whether to include branch node hash masks in the proof. + pub const fn with_branch_node_hash_masks(mut self, branch_node_hash_masks: bool) -> Self { + self.collect_branch_node_hash_masks = branch_node_hash_masks; + self + } } impl ParallelProof @@ -125,6 +134,7 @@ where hashed_address, ) .with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().cloned())) + .with_branch_node_hash_masks(self.collect_branch_node_hash_masks) .storage_multiproof(target_slots) .map_err(|e| { ParallelStateRootError::StorageRoot(StorageRootError::Database( @@ -158,7 +168,9 @@ where // Create a hash builder to rebuild the root node since it is not available in the database. let retainer: ProofRetainer = targets.keys().map(Nibbles::unpack).collect(); - let mut hash_builder = HashBuilder::default().with_proof_retainer(retainer); + let mut hash_builder = HashBuilder::default() + .with_proof_retainer(retainer) + .with_updates(self.collect_branch_node_hash_masks); let mut storages = HashMap::default(); let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE); @@ -222,7 +234,19 @@ where #[cfg(feature = "metrics")] self.metrics.record_state_trie(tracker.finish()); - Ok(MultiProof { account_subtree: hash_builder.take_proof_nodes(), storages }) + let account_subtree = hash_builder.take_proof_nodes(); + let branch_node_hash_masks = if self.collect_branch_node_hash_masks { + hash_builder + .updated_branch_nodes + .unwrap_or_default() + .into_iter() + .map(|(path, node)| (path, node.hash_mask)) + .collect() + } else { + HashMap::default() + }; + + Ok(MultiProof { account_subtree, branch_node_hash_masks, storages }) } } diff --git a/crates/trie/trie/src/proof/mod.rs b/crates/trie/trie/src/proof/mod.rs index c344ec762..8e3d0aec2 100644 --- a/crates/trie/trie/src/proof/mod.rs +++ b/crates/trie/trie/src/proof/mod.rs @@ -33,6 +33,8 @@ pub struct Proof { hashed_cursor_factory: H, /// A set of prefix sets that have changes. prefix_sets: TriePrefixSetsMut, + /// Flag indicating whether to include branch node hash masks in the proof. + collect_branch_node_hash_masks: bool, } impl Proof { @@ -42,6 +44,7 @@ impl Proof { trie_cursor_factory: t, hashed_cursor_factory: h, prefix_sets: TriePrefixSetsMut::default(), + collect_branch_node_hash_masks: false, } } @@ -51,6 +54,7 @@ impl Proof { trie_cursor_factory, hashed_cursor_factory: self.hashed_cursor_factory, prefix_sets: self.prefix_sets, + collect_branch_node_hash_masks: self.collect_branch_node_hash_masks, } } @@ -60,6 +64,7 @@ impl Proof { trie_cursor_factory: self.trie_cursor_factory, hashed_cursor_factory, prefix_sets: self.prefix_sets, + collect_branch_node_hash_masks: self.collect_branch_node_hash_masks, } } @@ -68,6 +73,12 @@ impl Proof { self.prefix_sets = prefix_sets; self } + + /// Set the flag indicating whether to include branch node hash masks in the proof. + pub const fn with_branch_node_hash_masks(mut self, branch_node_hash_masks: bool) -> Self { + self.collect_branch_node_hash_masks = branch_node_hash_masks; + self + } } impl Proof @@ -104,7 +115,9 @@ where // Create a hash builder to rebuild the root node since it is not available in the database. let retainer = targets.keys().map(Nibbles::unpack).collect(); - let mut hash_builder = HashBuilder::default().with_proof_retainer(retainer); + let mut hash_builder = HashBuilder::default() + .with_proof_retainer(retainer) + .with_updates(self.collect_branch_node_hash_masks); // Initialize all storage multiproofs as empty. // Storage multiproofs for non empty tries will be overwritten if necessary. @@ -131,6 +144,7 @@ where hashed_address, ) .with_prefix_set_mut(storage_prefix_set) + .with_branch_node_hash_masks(self.collect_branch_node_hash_masks) .storage_multiproof(proof_targets.unwrap_or_default())?; // Encode account @@ -149,7 +163,19 @@ where } } let _ = hash_builder.root(); - Ok(MultiProof { account_subtree: hash_builder.take_proof_nodes(), storages }) + let account_subtree = hash_builder.take_proof_nodes(); + let branch_node_hash_masks = if self.collect_branch_node_hash_masks { + hash_builder + .updated_branch_nodes + .unwrap_or_default() + .into_iter() + .map(|(path, node)| (path, node.hash_mask)) + .collect() + } else { + HashMap::default() + }; + + Ok(MultiProof { account_subtree, branch_node_hash_masks, storages }) } } @@ -164,6 +190,8 @@ pub struct StorageProof { hashed_address: B256, /// The set of storage slot prefixes that have changed. prefix_set: PrefixSetMut, + /// Flag indicating whether to include branch node hash masks in the proof. + collect_branch_node_hash_masks: bool, } impl StorageProof { @@ -179,6 +207,7 @@ impl StorageProof { hashed_cursor_factory: h, hashed_address, prefix_set: PrefixSetMut::default(), + collect_branch_node_hash_masks: false, } } @@ -189,6 +218,7 @@ impl StorageProof { hashed_cursor_factory: self.hashed_cursor_factory, hashed_address: self.hashed_address, prefix_set: self.prefix_set, + collect_branch_node_hash_masks: self.collect_branch_node_hash_masks, } } @@ -199,6 +229,7 @@ impl StorageProof { hashed_cursor_factory, hashed_address: self.hashed_address, prefix_set: self.prefix_set, + collect_branch_node_hash_masks: self.collect_branch_node_hash_masks, } } @@ -207,6 +238,12 @@ impl StorageProof { self.prefix_set = prefix_set; self } + + /// Set the flag indicating whether to include branch node hash masks in the proof. + pub const fn with_branch_node_hash_masks(mut self, branch_node_hash_masks: bool) -> Self { + self.collect_branch_node_hash_masks = branch_node_hash_masks; + self + } } impl StorageProof @@ -243,7 +280,9 @@ where let walker = TrieWalker::new(trie_cursor, self.prefix_set.freeze()); let retainer = ProofRetainer::from_iter(target_nibbles); - let mut hash_builder = HashBuilder::default().with_proof_retainer(retainer); + let mut hash_builder = HashBuilder::default() + .with_proof_retainer(retainer) + .with_updates(self.collect_branch_node_hash_masks); let mut storage_node_iter = TrieNodeIter::new(walker, hashed_storage_cursor); while let Some(node) = storage_node_iter.try_next()? { match node { @@ -260,6 +299,18 @@ where } let root = hash_builder.root(); - Ok(StorageMultiProof { root, subtree: hash_builder.take_proof_nodes() }) + let subtree = hash_builder.take_proof_nodes(); + let branch_node_hash_masks = if self.collect_branch_node_hash_masks { + hash_builder + .updated_branch_nodes + .unwrap_or_default() + .into_iter() + .map(|(path, node)| (path, node.hash_mask)) + .collect() + } else { + HashMap::default() + }; + + Ok(StorageMultiProof { root, subtree, branch_node_hash_masks }) } }