feat(trie): collect branch node hash masks when calculating a proof (#13129)

This commit is contained in:
Alexey Shekhirin
2024-12-04 20:34:37 +00:00
committed by GitHub
parent 337272c88b
commit 27dab59ceb
3 changed files with 108 additions and 10 deletions

View File

@ -11,7 +11,7 @@ use alloy_rlp::{encode_fixed_size, Decodable, EMPTY_STRING_CODE};
use alloy_trie::{
nodes::TrieNode,
proof::{verify_proof, ProofNodes, ProofVerificationError},
EMPTY_ROOT_HASH,
TrieMask, EMPTY_ROOT_HASH,
};
use itertools::Itertools;
use reth_primitives_traits::Account;
@ -23,6 +23,8 @@ use reth_primitives_traits::Account;
pub struct MultiProof {
/// State trie multiproof for requested accounts.
pub account_subtree: ProofNodes,
/// The hash masks of the branch nodes in the account proof.
pub branch_node_hash_masks: HashMap<Nibbles, TrieMask>,
/// Storage trie multiproofs.
pub storages: HashMap<B256, StorageMultiProof>,
}
@ -108,11 +110,15 @@ impl MultiProof {
pub fn extend(&mut self, other: Self) {
self.account_subtree.extend_from(other.account_subtree);
self.branch_node_hash_masks.extend(other.branch_node_hash_masks);
for (hashed_address, storage) in other.storages {
match self.storages.entry(hashed_address) {
hash_map::Entry::Occupied(mut entry) => {
debug_assert_eq!(entry.get().root, storage.root);
entry.get_mut().subtree.extend_from(storage.subtree);
let entry = entry.get_mut();
entry.subtree.extend_from(storage.subtree);
entry.branch_node_hash_masks.extend(storage.branch_node_hash_masks);
}
hash_map::Entry::Vacant(entry) => {
entry.insert(storage);
@ -129,6 +135,8 @@ pub struct StorageMultiProof {
pub root: B256,
/// Storage multiproof for requested slots.
pub subtree: ProofNodes,
/// The hash masks of the branch nodes in the storage proof.
pub branch_node_hash_masks: HashMap<Nibbles, TrieMask>,
}
impl StorageMultiProof {
@ -140,6 +148,7 @@ impl StorageMultiProof {
Nibbles::default(),
Bytes::from([EMPTY_STRING_CODE]),
)]),
branch_node_hash_masks: HashMap::default(),
}
}
@ -380,14 +389,28 @@ mod tests {
Nibbles::from_nibbles(vec![0]),
alloy_rlp::encode_fixed_size(&U256::from(42)).to_vec().into(),
);
proof1.storages.insert(addr, StorageMultiProof { root, subtree: subtree1 });
proof1.storages.insert(
addr,
StorageMultiProof {
root,
subtree: subtree1,
branch_node_hash_masks: HashMap::default(),
},
);
let mut subtree2 = ProofNodes::default();
subtree2.insert(
Nibbles::from_nibbles(vec![1]),
alloy_rlp::encode_fixed_size(&U256::from(43)).to_vec().into(),
);
proof2.storages.insert(addr, StorageMultiProof { root, subtree: subtree2 });
proof2.storages.insert(
addr,
StorageMultiProof {
root,
subtree: subtree2,
branch_node_hash_masks: HashMap::default(),
},
);
proof1.extend(proof2);

View File

@ -35,6 +35,8 @@ pub struct ParallelProof<Factory> {
view: ConsistentDbView<Factory>,
/// Trie input.
input: Arc<TrieInput>,
/// Flag indicating whether to include branch node hash masks in the proof.
collect_branch_node_hash_masks: bool,
/// Parallel state root metrics.
#[cfg(feature = "metrics")]
metrics: ParallelStateRootMetrics,
@ -46,10 +48,17 @@ impl<Factory> ParallelProof<Factory> {
Self {
view,
input,
collect_branch_node_hash_masks: false,
#[cfg(feature = "metrics")]
metrics: ParallelStateRootMetrics::default(),
}
}
/// Set the flag indicating whether to include branch node hash masks in the proof.
pub const fn with_branch_node_hash_masks(mut self, branch_node_hash_masks: bool) -> Self {
self.collect_branch_node_hash_masks = branch_node_hash_masks;
self
}
}
impl<Factory> ParallelProof<Factory>
@ -125,6 +134,7 @@ where
hashed_address,
)
.with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().cloned()))
.with_branch_node_hash_masks(self.collect_branch_node_hash_masks)
.storage_multiproof(target_slots)
.map_err(|e| {
ParallelStateRootError::StorageRoot(StorageRootError::Database(
@ -158,7 +168,9 @@ where
// Create a hash builder to rebuild the root node since it is not available in the database.
let retainer: ProofRetainer = targets.keys().map(Nibbles::unpack).collect();
let mut hash_builder = HashBuilder::default().with_proof_retainer(retainer);
let mut hash_builder = HashBuilder::default()
.with_proof_retainer(retainer)
.with_updates(self.collect_branch_node_hash_masks);
let mut storages = HashMap::default();
let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
@ -222,7 +234,19 @@ where
#[cfg(feature = "metrics")]
self.metrics.record_state_trie(tracker.finish());
Ok(MultiProof { account_subtree: hash_builder.take_proof_nodes(), storages })
let account_subtree = hash_builder.take_proof_nodes();
let branch_node_hash_masks = if self.collect_branch_node_hash_masks {
hash_builder
.updated_branch_nodes
.unwrap_or_default()
.into_iter()
.map(|(path, node)| (path, node.hash_mask))
.collect()
} else {
HashMap::default()
};
Ok(MultiProof { account_subtree, branch_node_hash_masks, storages })
}
}

View File

@ -33,6 +33,8 @@ pub struct Proof<T, H> {
hashed_cursor_factory: H,
/// A set of prefix sets that have changes.
prefix_sets: TriePrefixSetsMut,
/// Flag indicating whether to include branch node hash masks in the proof.
collect_branch_node_hash_masks: bool,
}
impl<T, H> Proof<T, H> {
@ -42,6 +44,7 @@ impl<T, H> Proof<T, H> {
trie_cursor_factory: t,
hashed_cursor_factory: h,
prefix_sets: TriePrefixSetsMut::default(),
collect_branch_node_hash_masks: false,
}
}
@ -51,6 +54,7 @@ impl<T, H> Proof<T, H> {
trie_cursor_factory,
hashed_cursor_factory: self.hashed_cursor_factory,
prefix_sets: self.prefix_sets,
collect_branch_node_hash_masks: self.collect_branch_node_hash_masks,
}
}
@ -60,6 +64,7 @@ impl<T, H> Proof<T, H> {
trie_cursor_factory: self.trie_cursor_factory,
hashed_cursor_factory,
prefix_sets: self.prefix_sets,
collect_branch_node_hash_masks: self.collect_branch_node_hash_masks,
}
}
@ -68,6 +73,12 @@ impl<T, H> Proof<T, H> {
self.prefix_sets = prefix_sets;
self
}
/// Set the flag indicating whether to include branch node hash masks in the proof.
pub const fn with_branch_node_hash_masks(mut self, branch_node_hash_masks: bool) -> Self {
self.collect_branch_node_hash_masks = branch_node_hash_masks;
self
}
}
impl<T, H> Proof<T, H>
@ -104,7 +115,9 @@ where
// Create a hash builder to rebuild the root node since it is not available in the database.
let retainer = targets.keys().map(Nibbles::unpack).collect();
let mut hash_builder = HashBuilder::default().with_proof_retainer(retainer);
let mut hash_builder = HashBuilder::default()
.with_proof_retainer(retainer)
.with_updates(self.collect_branch_node_hash_masks);
// Initialize all storage multiproofs as empty.
// Storage multiproofs for non empty tries will be overwritten if necessary.
@ -131,6 +144,7 @@ where
hashed_address,
)
.with_prefix_set_mut(storage_prefix_set)
.with_branch_node_hash_masks(self.collect_branch_node_hash_masks)
.storage_multiproof(proof_targets.unwrap_or_default())?;
// Encode account
@ -149,7 +163,19 @@ where
}
}
let _ = hash_builder.root();
Ok(MultiProof { account_subtree: hash_builder.take_proof_nodes(), storages })
let account_subtree = hash_builder.take_proof_nodes();
let branch_node_hash_masks = if self.collect_branch_node_hash_masks {
hash_builder
.updated_branch_nodes
.unwrap_or_default()
.into_iter()
.map(|(path, node)| (path, node.hash_mask))
.collect()
} else {
HashMap::default()
};
Ok(MultiProof { account_subtree, branch_node_hash_masks, storages })
}
}
@ -164,6 +190,8 @@ pub struct StorageProof<T, H> {
hashed_address: B256,
/// The set of storage slot prefixes that have changed.
prefix_set: PrefixSetMut,
/// Flag indicating whether to include branch node hash masks in the proof.
collect_branch_node_hash_masks: bool,
}
impl<T, H> StorageProof<T, H> {
@ -179,6 +207,7 @@ impl<T, H> StorageProof<T, H> {
hashed_cursor_factory: h,
hashed_address,
prefix_set: PrefixSetMut::default(),
collect_branch_node_hash_masks: false,
}
}
@ -189,6 +218,7 @@ impl<T, H> StorageProof<T, H> {
hashed_cursor_factory: self.hashed_cursor_factory,
hashed_address: self.hashed_address,
prefix_set: self.prefix_set,
collect_branch_node_hash_masks: self.collect_branch_node_hash_masks,
}
}
@ -199,6 +229,7 @@ impl<T, H> StorageProof<T, H> {
hashed_cursor_factory,
hashed_address: self.hashed_address,
prefix_set: self.prefix_set,
collect_branch_node_hash_masks: self.collect_branch_node_hash_masks,
}
}
@ -207,6 +238,12 @@ impl<T, H> StorageProof<T, H> {
self.prefix_set = prefix_set;
self
}
/// Set the flag indicating whether to include branch node hash masks in the proof.
pub const fn with_branch_node_hash_masks(mut self, branch_node_hash_masks: bool) -> Self {
self.collect_branch_node_hash_masks = branch_node_hash_masks;
self
}
}
impl<T, H> StorageProof<T, H>
@ -243,7 +280,9 @@ where
let walker = TrieWalker::new(trie_cursor, self.prefix_set.freeze());
let retainer = ProofRetainer::from_iter(target_nibbles);
let mut hash_builder = HashBuilder::default().with_proof_retainer(retainer);
let mut hash_builder = HashBuilder::default()
.with_proof_retainer(retainer)
.with_updates(self.collect_branch_node_hash_masks);
let mut storage_node_iter = TrieNodeIter::new(walker, hashed_storage_cursor);
while let Some(node) = storage_node_iter.try_next()? {
match node {
@ -260,6 +299,18 @@ where
}
let root = hash_builder.root();
Ok(StorageMultiProof { root, subtree: hash_builder.take_proof_nodes() })
let subtree = hash_builder.take_proof_nodes();
let branch_node_hash_masks = if self.collect_branch_node_hash_masks {
hash_builder
.updated_branch_nodes
.unwrap_or_default()
.into_iter()
.map(|(path, node)| (path, node.hash_mask))
.collect()
} else {
HashMap::default()
};
Ok(StorageMultiProof { root, subtree, branch_node_hash_masks })
}
}