fix(trie): sparse trie tree masks (#13760)

This commit is contained in:
Alexey Shekhirin
2025-01-10 11:28:54 +00:00
committed by GitHub
parent 986c75434a
commit 69f9e1628a
7 changed files with 245 additions and 157 deletions

View File

@ -748,7 +748,7 @@ where
config.prefix_sets,
thread_pool,
)
.with_branch_node_hash_masks(true)
.with_branch_node_masks(true)
.multiproof(proof_targets)?)
}

View File

@ -198,7 +198,7 @@ fn branch_nodes_equal(
) -> bool {
if let (Some(task), Some(regular)) = (task.as_ref(), regular.as_ref()) {
task.state_mask == regular.state_mask &&
// We do not compare the tree mask because it is known to be mismatching
task.tree_mask == regular.tree_mask &&
task.hash_mask == regular.hash_mask &&
task.hashes == regular.hashes &&
task.root_hash == regular.root_hash

View File

@ -29,6 +29,8 @@ pub struct MultiProof {
pub account_subtree: ProofNodes,
/// The hash masks of the branch nodes in the account proof.
pub branch_node_hash_masks: HashMap<Nibbles, TrieMask>,
/// The tree masks of the branch nodes in the account proof.
pub branch_node_tree_masks: HashMap<Nibbles, TrieMask>,
/// Storage trie multiproofs.
pub storages: B256HashMap<StorageMultiProof>,
}
@ -115,6 +117,7 @@ impl MultiProof {
self.account_subtree.extend_from(other.account_subtree);
self.branch_node_hash_masks.extend(other.branch_node_hash_masks);
self.branch_node_tree_masks.extend(other.branch_node_tree_masks);
for (hashed_address, storage) in other.storages {
match self.storages.entry(hashed_address) {
@ -123,6 +126,7 @@ impl MultiProof {
let entry = entry.get_mut();
entry.subtree.extend_from(storage.subtree);
entry.branch_node_hash_masks.extend(storage.branch_node_hash_masks);
entry.branch_node_tree_masks.extend(storage.branch_node_tree_masks);
}
hash_map::Entry::Vacant(entry) => {
entry.insert(storage);
@ -141,6 +145,8 @@ pub struct StorageMultiProof {
pub subtree: ProofNodes,
/// The hash masks of the branch nodes in the storage proof.
pub branch_node_hash_masks: HashMap<Nibbles, TrieMask>,
/// The tree masks of the branch nodes in the storage proof.
pub branch_node_tree_masks: HashMap<Nibbles, TrieMask>,
}
impl StorageMultiProof {
@ -153,6 +159,7 @@ impl StorageMultiProof {
Bytes::from([EMPTY_STRING_CODE]),
)]),
branch_node_hash_masks: HashMap::default(),
branch_node_tree_masks: HashMap::default(),
}
}
@ -398,6 +405,7 @@ mod tests {
root,
subtree: subtree1,
branch_node_hash_masks: HashMap::default(),
branch_node_tree_masks: HashMap::default(),
},
);
@ -412,6 +420,7 @@ mod tests {
root,
subtree: subtree2,
branch_node_hash_masks: HashMap::default(),
branch_node_tree_masks: HashMap::default(),
},
);

View File

@ -44,8 +44,8 @@ pub struct ParallelProof<Factory> {
/// invalidate the in-memory nodes, not all keys from `state_sorted` might be present here,
/// if we have cached nodes for them.
pub prefix_sets: Arc<TriePrefixSetsMut>,
/// Flag indicating whether to include branch node hash masks in the proof.
collect_branch_node_hash_masks: bool,
/// Flag indicating whether to include branch node masks in the proof.
collect_branch_node_masks: bool,
/// Thread pool for local tasks
thread_pool: Arc<rayon::ThreadPool>,
/// Parallel state root metrics.
@ -67,16 +67,16 @@ impl<Factory> ParallelProof<Factory> {
nodes_sorted,
state_sorted,
prefix_sets,
collect_branch_node_hash_masks: false,
collect_branch_node_masks: false,
thread_pool,
#[cfg(feature = "metrics")]
metrics: ParallelStateRootMetrics::default(),
}
}
/// Set the flag indicating whether to include branch node hash masks in the proof.
pub const fn with_branch_node_hash_masks(mut self, branch_node_hash_masks: bool) -> Self {
self.collect_branch_node_hash_masks = branch_node_hash_masks;
/// Set the flag indicating whether to include branch node masks in the proof.
pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
self.collect_branch_node_masks = branch_node_masks;
self
}
}
@ -137,7 +137,7 @@ where
let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default();
let trie_nodes_sorted = self.nodes_sorted.clone();
let hashed_state_sorted = self.state_sorted.clone();
let collect_masks = self.collect_branch_node_hash_masks;
let collect_masks = self.collect_branch_node_masks;
let (tx, rx) = std::sync::mpsc::sync_channel(1);
@ -182,7 +182,7 @@ where
hashed_address,
)
.with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().cloned()))
.with_branch_node_hash_masks(collect_masks)
.with_branch_node_masks(collect_masks)
.storage_multiproof(target_slots)
.map_err(|e| ParallelStateRootError::Other(e.to_string()));
@ -233,7 +233,7 @@ where
let retainer: ProofRetainer = targets.keys().map(Nibbles::unpack).collect();
let mut hash_builder = HashBuilder::default()
.with_proof_retainer(retainer)
.with_updates(self.collect_branch_node_hash_masks);
.with_updates(self.collect_branch_node_masks);
// Initialize all storage multiproofs as empty.
// Storage multiproofs for non empty tries will be overwritten if necessary.
@ -301,18 +301,23 @@ where
self.metrics.record_state_trie(tracker.finish());
let account_subtree = hash_builder.take_proof_nodes();
let branch_node_hash_masks = if self.collect_branch_node_hash_masks {
hash_builder
.updated_branch_nodes
.unwrap_or_default()
.into_iter()
.map(|(path, node)| (path, node.hash_mask))
.collect()
let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
(
updated_branch_nodes
.iter()
.map(|(path, node)| (path.clone(), node.hash_mask))
.collect(),
updated_branch_nodes
.into_iter()
.map(|(path, node)| (path, node.tree_mask))
.collect(),
)
} else {
HashMap::default()
(HashMap::default(), HashMap::default())
};
Ok(MultiProof { account_subtree, branch_node_hash_masks, storages })
Ok(MultiProof { account_subtree, branch_node_hash_masks, branch_node_tree_masks, storages })
}
}

View File

@ -155,13 +155,14 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
self.provider_factory.account_node_provider(),
root_node,
None,
None,
self.retain_updates,
)?;
// Reveal the remaining proof nodes.
for (path, bytes) in proof {
let node = TrieNode::decode(&mut &bytes[..])?;
trie.reveal_node(path, node, None)?;
trie.reveal_node(path, node, None, None)?;
}
// Mark leaf path as revealed.
@ -196,13 +197,14 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
self.provider_factory.storage_node_provider(account),
root_node,
None,
None,
self.retain_updates,
)?;
// Reveal the remaining proof nodes.
for (path, bytes) in proof {
let node = TrieNode::decode(&mut &bytes[..])?;
trie.reveal_node(path, node, None)?;
trie.reveal_node(path, node, None, None)?;
}
// Mark leaf path as revealed.
@ -227,20 +229,24 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
self.provider_factory.account_node_provider(),
root_node,
multiproof.branch_node_hash_masks.get(&Nibbles::default()).copied(),
multiproof.branch_node_tree_masks.get(&Nibbles::default()).copied(),
self.retain_updates,
)?;
// Reveal the remaining proof nodes.
for (path, bytes) in account_nodes {
let node = TrieNode::decode(&mut &bytes[..])?;
let hash_mask = if let TrieNode::Branch(_) = node {
multiproof.branch_node_hash_masks.get(&path).copied()
let (hash_mask, tree_mask) = if let TrieNode::Branch(_) = node {
(
multiproof.branch_node_hash_masks.get(&path).copied(),
multiproof.branch_node_tree_masks.get(&path).copied(),
)
} else {
None
(None, None)
};
trace!(target: "trie::sparse", ?path, ?node, ?hash_mask, "Revealing account node");
trie.reveal_node(path, node, hash_mask)?;
trace!(target: "trie::sparse", ?path, ?node, ?hash_mask, ?tree_mask, "Revealing account node");
trie.reveal_node(path, node, hash_mask, tree_mask)?;
}
}
@ -254,20 +260,24 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
self.provider_factory.storage_node_provider(account),
root_node,
storage_subtree.branch_node_hash_masks.get(&Nibbles::default()).copied(),
storage_subtree.branch_node_tree_masks.get(&Nibbles::default()).copied(),
self.retain_updates,
)?;
// Reveal the remaining proof nodes.
for (path, bytes) in nodes {
let node = TrieNode::decode(&mut &bytes[..])?;
let hash_mask = if let TrieNode::Branch(_) = node {
storage_subtree.branch_node_hash_masks.get(&path).copied()
let (hash_mask, tree_mask) = if let TrieNode::Branch(_) = node {
(
storage_subtree.branch_node_hash_masks.get(&path).copied(),
storage_subtree.branch_node_tree_masks.get(&path).copied(),
)
} else {
None
(None, None)
};
trace!(target: "trie::sparse", ?account, ?path, ?node, ?hash_mask, "Revealing storage node");
trie.reveal_node(path, node, hash_mask)?;
trace!(target: "trie::sparse", ?account, ?path, ?node, ?hash_mask, ?tree_mask, "Revealing storage node");
trie.reveal_node(path, node, hash_mask, tree_mask)?;
}
}
}
@ -348,6 +358,7 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
self.provider_factory.storage_node_provider(account),
trie_node,
None,
None,
self.retain_updates,
)?;
} else {
@ -355,7 +366,7 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
storage_trie_entry
.as_revealed_mut()
.ok_or(SparseTrieErrorKind::Blind)?
.reveal_node(path, trie_node, None)?;
.reveal_node(path, trie_node, None, None)?;
}
} else if path.is_empty() {
// Handle special state root node case.
@ -363,6 +374,7 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
self.provider_factory.account_node_provider(),
trie_node,
None,
None,
self.retain_updates,
)?;
} else {
@ -370,7 +382,7 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
self.state
.as_revealed_mut()
.ok_or(SparseTrieErrorKind::Blind)?
.reveal_node(path, trie_node, None)?;
.reveal_node(path, trie_node, None, None)?;
}
}
@ -668,6 +680,7 @@ mod tests {
Nibbles::from_nibbles([0x1]),
TrieMask::new(0b00),
)]),
branch_node_tree_masks: HashMap::default(),
storages: HashMap::from_iter([
(
address_1,
@ -675,6 +688,7 @@ mod tests {
root,
subtree: storage_proof_nodes.clone(),
branch_node_hash_masks: storage_branch_node_hash_masks.clone(),
branch_node_tree_masks: HashMap::default(),
},
),
(
@ -683,6 +697,7 @@ mod tests {
root,
subtree: storage_proof_nodes,
branch_node_hash_masks: storage_branch_node_hash_masks,
branch_node_tree_masks: HashMap::default(),
},
),
]),

View File

@ -60,9 +60,16 @@ impl SparseTrie {
&mut self,
root: TrieNode,
hash_mask: Option<TrieMask>,
tree_mask: Option<TrieMask>,
retain_updates: bool,
) -> SparseTrieResult<&mut RevealedSparseTrie> {
self.reveal_root_with_provider(Default::default(), root, hash_mask, retain_updates)
self.reveal_root_with_provider(
Default::default(),
root,
hash_mask,
tree_mask,
retain_updates,
)
}
}
@ -100,6 +107,7 @@ impl<P> SparseTrie<P> {
provider: P,
root: TrieNode,
hash_mask: Option<TrieMask>,
tree_mask: Option<TrieMask>,
retain_updates: bool,
) -> SparseTrieResult<&mut RevealedSparseTrie<P>> {
if self.is_blind() {
@ -107,6 +115,7 @@ impl<P> SparseTrie<P> {
provider,
root,
hash_mask,
tree_mask,
retain_updates,
)?))
}
@ -163,6 +172,8 @@ pub struct RevealedSparseTrie<P = DefaultBlindedProvider> {
nodes: HashMap<Nibbles, SparseNode>,
/// All branch node hash masks.
branch_node_hash_masks: HashMap<Nibbles, TrieMask>,
/// All branch node tree masks.
branch_node_tree_masks: HashMap<Nibbles, TrieMask>,
/// All leaf values.
values: HashMap<Nibbles, Vec<u8>>,
/// Prefix set.
@ -178,6 +189,7 @@ impl<P> fmt::Debug for RevealedSparseTrie<P> {
f.debug_struct("RevealedSparseTrie")
.field("nodes", &self.nodes)
.field("branch_hash_masks", &self.branch_node_hash_masks)
.field("branch_tree_masks", &self.branch_node_tree_masks)
.field("values", &self.values)
.field("prefix_set", &self.prefix_set)
.field("updates", &self.updates)
@ -192,6 +204,7 @@ impl Default for RevealedSparseTrie {
provider: Default::default(),
nodes: HashMap::from_iter([(Nibbles::default(), SparseNode::Empty)]),
branch_node_hash_masks: HashMap::default(),
branch_node_tree_masks: HashMap::default(),
values: HashMap::default(),
prefix_set: PrefixSetMut::default(),
updates: None,
@ -205,19 +218,21 @@ impl RevealedSparseTrie {
pub fn from_root(
node: TrieNode,
hash_mask: Option<TrieMask>,
tree_mask: Option<TrieMask>,
retain_updates: bool,
) -> SparseTrieResult<Self> {
let mut this = Self {
provider: Default::default(),
nodes: HashMap::default(),
branch_node_hash_masks: HashMap::default(),
branch_node_tree_masks: HashMap::default(),
values: HashMap::default(),
prefix_set: PrefixSetMut::default(),
rlp_buf: Vec::new(),
updates: None,
}
.with_updates(retain_updates);
this.reveal_node(Nibbles::default(), node, hash_mask)?;
this.reveal_node(Nibbles::default(), node, hash_mask, tree_mask)?;
Ok(this)
}
}
@ -228,19 +243,21 @@ impl<P> RevealedSparseTrie<P> {
provider: P,
node: TrieNode,
hash_mask: Option<TrieMask>,
tree_mask: Option<TrieMask>,
retain_updates: bool,
) -> SparseTrieResult<Self> {
let mut this = Self {
provider,
nodes: HashMap::default(),
branch_node_hash_masks: HashMap::default(),
branch_node_tree_masks: HashMap::default(),
values: HashMap::default(),
prefix_set: PrefixSetMut::default(),
rlp_buf: Vec::new(),
updates: None,
}
.with_updates(retain_updates);
this.reveal_node(Nibbles::default(), node, hash_mask)?;
this.reveal_node(Nibbles::default(), node, hash_mask, tree_mask)?;
Ok(this)
}
@ -250,6 +267,7 @@ impl<P> RevealedSparseTrie<P> {
provider,
nodes: self.nodes,
branch_node_hash_masks: self.branch_node_hash_masks,
branch_node_tree_masks: self.branch_node_tree_masks,
values: self.values,
prefix_set: self.prefix_set,
updates: self.updates,
@ -286,6 +304,7 @@ impl<P> RevealedSparseTrie<P> {
path: Nibbles,
node: TrieNode,
hash_mask: Option<TrieMask>,
tree_mask: Option<TrieMask>,
) -> SparseTrieResult<()> {
// If the node is already revealed and it's not a hash node, do nothing.
if self.nodes.get(&path).is_some_and(|node| !node.is_hash()) {
@ -295,6 +314,9 @@ impl<P> RevealedSparseTrie<P> {
if let Some(hash_mask) = hash_mask {
self.branch_node_hash_masks.insert(path.clone(), hash_mask);
}
if let Some(tree_mask) = tree_mask {
self.branch_node_tree_masks.insert(path.clone(), tree_mask);
}
match node {
TrieNode::EmptyRoot => {
@ -321,7 +343,10 @@ impl<P> RevealedSparseTrie<P> {
// Memoize the hash of a previously blinded node in a new branch
// node.
hash: Some(*hash),
store_in_db_trie: None,
store_in_db_trie: Some(
hash_mask.is_some_and(|mask| !mask.is_empty()) ||
tree_mask.is_some_and(|mask| !mask.is_empty()),
),
});
}
// Branch node already exists, or an extension node was placed where a
@ -433,7 +458,7 @@ impl<P> RevealedSparseTrie<P> {
return Ok(())
}
self.reveal_node(path, TrieNode::decode(&mut &child[..])?, None)
self.reveal_node(path, TrieNode::decode(&mut &child[..])?, None, None)
}
/// Traverse trie nodes down to the leaf node and collect all nodes along the path.
@ -627,22 +652,20 @@ impl<P> RevealedSparseTrie<P> {
let mut prefix_set_contains =
|path: &Nibbles| *is_in_prefix_set.get_or_insert_with(|| prefix_set.contains(path));
let (rlp_node, calculated, node_type) = match self.nodes.get_mut(&path).unwrap() {
SparseNode::Empty => {
(RlpNode::word_rlp(&EMPTY_ROOT_HASH), false, SparseNodeType::Empty)
}
SparseNode::Hash(hash) => (RlpNode::word_rlp(hash), false, SparseNodeType::Hash),
let (rlp_node, node_type) = match self.nodes.get_mut(&path).unwrap() {
SparseNode::Empty => (RlpNode::word_rlp(&EMPTY_ROOT_HASH), SparseNodeType::Empty),
SparseNode::Hash(hash) => (RlpNode::word_rlp(hash), SparseNodeType::Hash),
SparseNode::Leaf { key, hash } => {
let mut path = path.clone();
path.extend_from_slice_unchecked(key);
if let Some(hash) = hash.filter(|_| !prefix_set_contains(&path)) {
(RlpNode::word_rlp(&hash), false, SparseNodeType::Leaf)
(RlpNode::word_rlp(&hash), SparseNodeType::Leaf)
} else {
let value = self.values.get(&path).unwrap();
self.rlp_buf.clear();
let rlp_node = LeafNodeRef { key, value }.rlp(&mut self.rlp_buf);
*hash = rlp_node.as_hash();
(rlp_node, true, SparseNodeType::Leaf)
(rlp_node, SparseNodeType::Leaf)
}
}
SparseNode::Extension { key, hash } => {
@ -651,22 +674,20 @@ impl<P> RevealedSparseTrie<P> {
if let Some(hash) = hash.filter(|_| !prefix_set_contains(&path)) {
(
RlpNode::word_rlp(&hash),
false,
SparseNodeType::Extension { store_in_db_trie: true },
)
} else if buffers.rlp_node_stack.last().is_some_and(|e| e.0 == child_path) {
let (_, child, _, node_type) = buffers.rlp_node_stack.pop().unwrap();
let (_, child, child_node_type) = buffers.rlp_node_stack.pop().unwrap();
self.rlp_buf.clear();
let rlp_node = ExtensionNodeRef::new(key, &child).rlp(&mut self.rlp_buf);
*hash = rlp_node.as_hash();
(
rlp_node,
true,
SparseNodeType::Extension {
// Inherit the `store_in_db_trie` flag from the child node, which is
// always the branch node
store_in_db_trie: node_type.store_in_db_trie(),
store_in_db_trie: child_node_type.store_in_db_trie(),
},
)
} else {
@ -682,7 +703,6 @@ impl<P> RevealedSparseTrie<P> {
buffers.rlp_node_stack.push((
path,
RlpNode::word_rlp(&hash),
false,
SparseNodeType::Branch { store_in_db_trie },
));
continue
@ -710,8 +730,7 @@ impl<P> RevealedSparseTrie<P> {
let mut hashes = Vec::new();
for (i, child_path) in buffers.branch_child_buf.iter().enumerate() {
if buffers.rlp_node_stack.last().is_some_and(|e| &e.0 == child_path) {
let (_, child, calculated, node_type) =
buffers.rlp_node_stack.pop().unwrap();
let (_, child, child_node_type) = buffers.rlp_node_stack.pop().unwrap();
// Update the masks only if we need to retain trie updates
if retain_updates {
@ -720,13 +739,16 @@ impl<P> RevealedSparseTrie<P> {
// Determine whether we need to set trie mask bit.
let should_set_tree_mask_bit =
// A blinded node has the tree mask bit set
(
child_node_type.is_hash() &&
self.branch_node_tree_masks
.get(&path)
.is_some_and(|mask| mask.is_bit_set(last_child_nibble))
) ||
// A branch or an extension node explicitly set the
// `store_in_db_trie` flag
node_type.store_in_db_trie() ||
// Set the flag according to whether a child node was
// pre-calculated (`calculated = false`), meaning that it wasn't
// in the database
!calculated;
child_node_type.store_in_db_trie();
if should_set_tree_mask_bit {
tree_mask.set_bit(last_child_nibble);
}
@ -735,8 +757,8 @@ impl<P> RevealedSparseTrie<P> {
// is a blinded node that has its hash mask bit set according to the
// database, set the hash mask bit and save the hash.
let hash = child.as_hash().filter(|_| {
node_type.is_branch() ||
(node_type.is_hash() &&
child_node_type.is_branch() ||
(child_node_type.is_hash() &&
self.branch_node_hash_masks
.get(&path)
.is_some_and(|mask| {
@ -806,14 +828,10 @@ impl<P> RevealedSparseTrie<P> {
};
*store_in_db_trie = Some(store_in_db_trie_value);
(
rlp_node,
true,
SparseNodeType::Branch { store_in_db_trie: store_in_db_trie_value },
)
(rlp_node, SparseNodeType::Branch { store_in_db_trie: store_in_db_trie_value })
}
};
buffers.rlp_node_stack.push((path, rlp_node, calculated, node_type));
buffers.rlp_node_stack.push((path, rlp_node, node_type));
}
debug_assert_eq!(buffers.rlp_node_stack.len(), 1);
@ -894,7 +912,7 @@ impl<P: BlindedProvider> RevealedSparseTrie<P> {
// remove or do nothing, so
// we can safely ignore the hash mask here and
// pass `None`.
self.reveal_node(current.clone(), decoded, None)?;
self.reveal_node(current.clone(), decoded, None, None)?;
}
}
}
@ -1046,7 +1064,7 @@ impl<P: BlindedProvider> RevealedSparseTrie<P> {
// We'll never have to update the revealed branch node, only remove
// or do nothing, so we can safely ignore the hash mask here and
// pass `None`.
self.reveal_node(child_path.clone(), decoded, None)?;
self.reveal_node(child_path.clone(), decoded, None, None)?;
}
}
@ -1251,7 +1269,7 @@ struct RlpNodeBuffers {
/// Stack of paths we need rlp nodes for and whether the path is in the prefix set.
path_stack: Vec<(Nibbles, Option<bool>)>,
/// Stack of rlp nodes
rlp_node_stack: Vec<(Nibbles, RlpNode, bool, SparseNodeType)>,
rlp_node_stack: Vec<(Nibbles, RlpNode, SparseNodeType)>,
/// Reusable branch child path
branch_child_buf: SmallVec<[Nibbles; 16]>,
/// Reusable branch value stack
@ -1336,7 +1354,8 @@ mod tests {
state: impl IntoIterator<Item = (Nibbles, Account)> + Clone,
destroyed_accounts: B256HashSet,
proof_targets: impl IntoIterator<Item = Nibbles>,
) -> (B256, TrieUpdates, ProofNodes, HashMap<Nibbles, TrieMask>) {
) -> (B256, TrieUpdates, ProofNodes, HashMap<Nibbles, TrieMask>, HashMap<Nibbles, TrieMask>)
{
let mut account_rlp = Vec::new();
let mut hash_builder = HashBuilder::default()
@ -1383,12 +1402,19 @@ mod tests {
.iter()
.map(|(path, node)| (path.clone(), node.hash_mask))
.collect();
let branch_node_tree_masks = hash_builder
.updated_branch_nodes
.clone()
.unwrap_or_default()
.iter()
.map(|(path, node)| (path.clone(), node.tree_mask))
.collect();
let mut trie_updates = TrieUpdates::default();
let removed_keys = node_iter.walker.take_removed_keys();
trie_updates.finalize(hash_builder, removed_keys, destroyed_accounts);
(root, trie_updates, proof_nodes, branch_node_hash_masks)
(root, trie_updates, proof_nodes, branch_node_hash_masks, branch_node_tree_masks)
}
/// Assert that the sparse trie nodes and the proof nodes from the hash builder are equal.
@ -1450,7 +1476,7 @@ mod tests {
account_rlp
};
let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _) =
let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _, _) =
run_hash_builder([(key.clone(), value())], Default::default(), [key.clone()]);
let mut sparse = RevealedSparseTrie::default().with_updates(true);
@ -1475,7 +1501,7 @@ mod tests {
account_rlp
};
let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _) =
let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _, _) =
run_hash_builder(
paths.iter().cloned().zip(std::iter::repeat_with(value)),
Default::default(),
@ -1504,7 +1530,7 @@ mod tests {
account_rlp
};
let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _) =
let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _, _) =
run_hash_builder(
paths.iter().cloned().zip(std::iter::repeat_with(value)),
Default::default(),
@ -1541,7 +1567,7 @@ mod tests {
account_rlp
};
let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _) =
let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _, _) =
run_hash_builder(
paths.iter().sorted_unstable().cloned().zip(std::iter::repeat_with(value)),
Default::default(),
@ -1579,7 +1605,7 @@ mod tests {
account_rlp
};
let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _) =
let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _, _) =
run_hash_builder(
paths.iter().cloned().zip(std::iter::repeat_with(|| old_value)),
Default::default(),
@ -1597,7 +1623,7 @@ mod tests {
assert_eq!(sparse_updates.updated_nodes, hash_builder_updates.account_nodes);
assert_eq_sparse_trie_proof_nodes(&sparse, hash_builder_proof_nodes);
let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _) =
let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _, _) =
run_hash_builder(
paths.iter().cloned().zip(std::iter::repeat_with(|| new_value)),
Default::default(),
@ -1871,7 +1897,7 @@ mod tests {
));
let mut sparse =
RevealedSparseTrie::from_root(branch.clone(), Some(TrieMask::new(0b01)), false)
RevealedSparseTrie::from_root(branch.clone(), Some(TrieMask::new(0b01)), None, false)
.unwrap();
// Reveal a branch node and one of its children
@ -1879,8 +1905,8 @@ mod tests {
// Branch (Mask = 11)
// ├── 0 -> Hash (Path = 0)
// └── 1 -> Leaf (Path = 1)
sparse.reveal_node(Nibbles::default(), branch, Some(TrieMask::new(0b01))).unwrap();
sparse.reveal_node(Nibbles::from_nibbles([0x1]), TrieNode::Leaf(leaf), None).unwrap();
sparse.reveal_node(Nibbles::default(), branch, Some(TrieMask::new(0b01)), None).unwrap();
sparse.reveal_node(Nibbles::from_nibbles([0x1]), TrieNode::Leaf(leaf), None, None).unwrap();
// Removing a blinded leaf should result in an error
assert_matches!(
@ -1904,7 +1930,7 @@ mod tests {
));
let mut sparse =
RevealedSparseTrie::from_root(branch.clone(), Some(TrieMask::new(0b01)), false)
RevealedSparseTrie::from_root(branch.clone(), Some(TrieMask::new(0b01)), None, false)
.unwrap();
// Reveal a branch node and one of its children
@ -1912,8 +1938,8 @@ mod tests {
// Branch (Mask = 11)
// ├── 0 -> Hash (Path = 0)
// └── 1 -> Leaf (Path = 1)
sparse.reveal_node(Nibbles::default(), branch, Some(TrieMask::new(0b01))).unwrap();
sparse.reveal_node(Nibbles::from_nibbles([0x1]), TrieNode::Leaf(leaf), None).unwrap();
sparse.reveal_node(Nibbles::default(), branch, Some(TrieMask::new(0b01)), None).unwrap();
sparse.reveal_node(Nibbles::from_nibbles([0x1]), TrieNode::Leaf(leaf), None, None).unwrap();
// Removing a non-existent leaf should be a noop
let sparse_old = sparse.clone();
@ -1951,7 +1977,7 @@ mod tests {
// Insert state updates into the hash builder and calculate the root
state.extend(update);
let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _) =
let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _, _) =
run_hash_builder(
state.clone(),
Default::default(),
@ -1982,7 +2008,7 @@ mod tests {
let sparse_root = updated_sparse.root();
let sparse_updates = updated_sparse.take_updates();
let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _) =
let (hash_builder_root, hash_builder_updates, hash_builder_proof_nodes, _, _) =
run_hash_builder(
state.clone(),
Default::default(),
@ -2063,24 +2089,29 @@ mod tests {
};
// Generate the proof for the root node and initialize the sparse trie with it
let (_, _, hash_builder_proof_nodes, branch_node_hash_masks) = run_hash_builder(
[(key1(), value()), (key3(), value())],
Default::default(),
[Nibbles::default()],
);
let (_, _, hash_builder_proof_nodes, branch_node_hash_masks, branch_node_tree_masks) =
run_hash_builder(
[(key1(), value()), (key3(), value())],
Default::default(),
[Nibbles::default()],
);
let mut sparse = RevealedSparseTrie::from_root(
TrieNode::decode(&mut &hash_builder_proof_nodes.nodes_sorted()[0].1[..]).unwrap(),
branch_node_hash_masks.get(&Nibbles::default()).copied(),
branch_node_tree_masks.get(&Nibbles::default()).copied(),
false,
)
.unwrap();
// Generate the proof for the first key and reveal it in the sparse trie
let (_, _, hash_builder_proof_nodes, branch_node_hash_masks) =
let (_, _, hash_builder_proof_nodes, branch_node_hash_masks, branch_node_tree_masks) =
run_hash_builder([(key1(), value()), (key3(), value())], Default::default(), [key1()]);
for (path, node) in hash_builder_proof_nodes.nodes_sorted() {
let hash_mask = branch_node_hash_masks.get(&path).copied();
sparse.reveal_node(path, TrieNode::decode(&mut &node[..]).unwrap(), hash_mask).unwrap();
let tree_mask = branch_node_tree_masks.get(&path).copied();
sparse
.reveal_node(path, TrieNode::decode(&mut &node[..]).unwrap(), hash_mask, tree_mask)
.unwrap();
}
// Check that the branch node exists with only two nibbles set
@ -2099,11 +2130,14 @@ mod tests {
);
// Generate the proof for the third key and reveal it in the sparse trie
let (_, _, hash_builder_proof_nodes, branch_node_hash_masks) =
let (_, _, hash_builder_proof_nodes, branch_node_hash_masks, branch_node_tree_masks) =
run_hash_builder([(key1(), value()), (key3(), value())], Default::default(), [key3()]);
for (path, node) in hash_builder_proof_nodes.nodes_sorted() {
let hash_mask = branch_node_hash_masks.get(&path).copied();
sparse.reveal_node(path, TrieNode::decode(&mut &node[..]).unwrap(), hash_mask).unwrap();
let tree_mask = branch_node_tree_masks.get(&path).copied();
sparse
.reveal_node(path, TrieNode::decode(&mut &node[..]).unwrap(), hash_mask, tree_mask)
.unwrap();
}
// Check that nothing changed in the branch node
@ -2114,7 +2148,7 @@ mod tests {
// Generate the nodes for the full trie with all three key using the hash builder, and
// compare them to the sparse trie
let (_, _, hash_builder_proof_nodes, _) = run_hash_builder(
let (_, _, hash_builder_proof_nodes, _, _) = run_hash_builder(
[(key1(), value()), (key2(), value()), (key3(), value())],
Default::default(),
[key1(), key2(), key3()],
@ -2141,28 +2175,34 @@ mod tests {
let value = || Account::default();
// Generate the proof for the root node and initialize the sparse trie with it
let (_, _, hash_builder_proof_nodes, branch_node_hash_masks) = run_hash_builder(
[(key1(), value()), (key2(), value()), (key3(), value())],
Default::default(),
[Nibbles::default()],
);
let (_, _, hash_builder_proof_nodes, branch_node_hash_masks, branch_node_tree_masks) =
run_hash_builder(
[(key1(), value()), (key2(), value()), (key3(), value())],
Default::default(),
[Nibbles::default()],
);
let mut sparse = RevealedSparseTrie::from_root(
TrieNode::decode(&mut &hash_builder_proof_nodes.nodes_sorted()[0].1[..]).unwrap(),
branch_node_hash_masks.get(&Nibbles::default()).copied(),
branch_node_tree_masks.get(&Nibbles::default()).copied(),
false,
)
.unwrap();
// Generate the proof for the children of the root branch node and reveal it in the sparse
// trie
let (_, _, hash_builder_proof_nodes, branch_node_hash_masks) = run_hash_builder(
[(key1(), value()), (key2(), value()), (key3(), value())],
Default::default(),
[key1(), Nibbles::from_nibbles_unchecked([0x01])],
);
let (_, _, hash_builder_proof_nodes, branch_node_hash_masks, branch_node_tree_masks) =
run_hash_builder(
[(key1(), value()), (key2(), value()), (key3(), value())],
Default::default(),
[key1(), Nibbles::from_nibbles_unchecked([0x01])],
);
for (path, node) in hash_builder_proof_nodes.nodes_sorted() {
let hash_mask = branch_node_hash_masks.get(&path).copied();
sparse.reveal_node(path, TrieNode::decode(&mut &node[..]).unwrap(), hash_mask).unwrap();
let tree_mask = branch_node_tree_masks.get(&path).copied();
sparse
.reveal_node(path, TrieNode::decode(&mut &node[..]).unwrap(), hash_mask, tree_mask)
.unwrap();
}
// Check that the branch node exists
@ -2181,14 +2221,18 @@ mod tests {
);
// Generate the proof for the third key and reveal it in the sparse trie
let (_, _, hash_builder_proof_nodes, branch_node_hash_masks) = run_hash_builder(
[(key1(), value()), (key2(), value()), (key3(), value())],
Default::default(),
[key2()],
);
let (_, _, hash_builder_proof_nodes, branch_node_hash_masks, branch_node_tree_masks) =
run_hash_builder(
[(key1(), value()), (key2(), value()), (key3(), value())],
Default::default(),
[key2()],
);
for (path, node) in hash_builder_proof_nodes.nodes_sorted() {
let hash_mask = branch_node_hash_masks.get(&path).copied();
sparse.reveal_node(path, TrieNode::decode(&mut &node[..]).unwrap(), hash_mask).unwrap();
let tree_mask = branch_node_tree_masks.get(&path).copied();
sparse
.reveal_node(path, TrieNode::decode(&mut &node[..]).unwrap(), hash_mask, tree_mask)
.unwrap();
}
// Check that nothing changed in the extension node
@ -2219,14 +2263,16 @@ mod tests {
};
// Generate the proof for the root node and initialize the sparse trie with it
let (_, _, hash_builder_proof_nodes, branch_node_hash_masks) = run_hash_builder(
[(key1(), value()), (key2(), value())],
Default::default(),
[Nibbles::default()],
);
let (_, _, hash_builder_proof_nodes, branch_node_hash_masks, branch_node_tree_masks) =
run_hash_builder(
[(key1(), value()), (key2(), value())],
Default::default(),
[Nibbles::default()],
);
let mut sparse = RevealedSparseTrie::from_root(
TrieNode::decode(&mut &hash_builder_proof_nodes.nodes_sorted()[0].1[..]).unwrap(),
branch_node_hash_masks.get(&Nibbles::default()).copied(),
branch_node_tree_masks.get(&Nibbles::default()).copied(),
false,
)
.unwrap();
@ -2247,11 +2293,14 @@ mod tests {
);
// Generate the proof for the first key and reveal it in the sparse trie
let (_, _, hash_builder_proof_nodes, branch_node_hash_masks) =
let (_, _, hash_builder_proof_nodes, branch_node_hash_masks, branch_node_tree_masks) =
run_hash_builder([(key1(), value()), (key2(), value())], Default::default(), [key1()]);
for (path, node) in hash_builder_proof_nodes.nodes_sorted() {
let hash_mask = branch_node_hash_masks.get(&path).copied();
sparse.reveal_node(path, TrieNode::decode(&mut &node[..]).unwrap(), hash_mask).unwrap();
let tree_mask = branch_node_tree_masks.get(&path).copied();
sparse
.reveal_node(path, TrieNode::decode(&mut &node[..]).unwrap(), hash_mask, tree_mask)
.unwrap();
}
// Check that the branch node wasn't overwritten by the extension node in the proof
@ -2345,7 +2394,7 @@ mod tests {
account_rlp
};
let (hash_builder_root, hash_builder_updates, _, _) = run_hash_builder(
let (hash_builder_root, hash_builder_updates, _, _, _) = run_hash_builder(
[(key1(), value()), (key2(), value())],
Default::default(),
[Nibbles::default()],

View File

@ -33,8 +33,8 @@ pub struct Proof<T, H> {
hashed_cursor_factory: H,
/// A set of prefix sets that have changes.
prefix_sets: TriePrefixSetsMut,
/// Flag indicating whether to include branch node hash masks in the proof.
collect_branch_node_hash_masks: bool,
/// Flag indicating whether to include branch node masks in the proof.
collect_branch_node_masks: bool,
}
impl<T, H> Proof<T, H> {
@ -44,7 +44,7 @@ impl<T, H> Proof<T, H> {
trie_cursor_factory: t,
hashed_cursor_factory: h,
prefix_sets: TriePrefixSetsMut::default(),
collect_branch_node_hash_masks: false,
collect_branch_node_masks: false,
}
}
@ -54,7 +54,7 @@ impl<T, H> Proof<T, H> {
trie_cursor_factory,
hashed_cursor_factory: self.hashed_cursor_factory,
prefix_sets: self.prefix_sets,
collect_branch_node_hash_masks: self.collect_branch_node_hash_masks,
collect_branch_node_masks: self.collect_branch_node_masks,
}
}
@ -64,7 +64,7 @@ impl<T, H> Proof<T, H> {
trie_cursor_factory: self.trie_cursor_factory,
hashed_cursor_factory,
prefix_sets: self.prefix_sets,
collect_branch_node_hash_masks: self.collect_branch_node_hash_masks,
collect_branch_node_masks: self.collect_branch_node_masks,
}
}
@ -74,9 +74,9 @@ impl<T, H> Proof<T, H> {
self
}
/// Set the flag indicating whether to include branch node hash masks in the proof.
pub const fn with_branch_node_hash_masks(mut self, branch_node_hash_masks: bool) -> Self {
self.collect_branch_node_hash_masks = branch_node_hash_masks;
/// Set the flag indicating whether to include branch node masks in the proof.
pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
self.collect_branch_node_masks = branch_node_masks;
self
}
}
@ -117,7 +117,7 @@ where
let retainer = targets.keys().map(Nibbles::unpack).collect();
let mut hash_builder = HashBuilder::default()
.with_proof_retainer(retainer)
.with_updates(self.collect_branch_node_hash_masks);
.with_updates(self.collect_branch_node_masks);
// Initialize all storage multiproofs as empty.
// Storage multiproofs for non empty tries will be overwritten if necessary.
@ -144,7 +144,7 @@ where
hashed_address,
)
.with_prefix_set_mut(storage_prefix_set)
.with_branch_node_hash_masks(self.collect_branch_node_hash_masks)
.with_branch_node_masks(self.collect_branch_node_masks)
.storage_multiproof(proof_targets.unwrap_or_default())?;
// Encode account
@ -164,18 +164,23 @@ where
}
let _ = hash_builder.root();
let account_subtree = hash_builder.take_proof_nodes();
let branch_node_hash_masks = if self.collect_branch_node_hash_masks {
hash_builder
.updated_branch_nodes
.unwrap_or_default()
.into_iter()
.map(|(path, node)| (path, node.hash_mask))
.collect()
let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
(
updated_branch_nodes
.iter()
.map(|(path, node)| (path.clone(), node.hash_mask))
.collect(),
updated_branch_nodes
.into_iter()
.map(|(path, node)| (path, node.tree_mask))
.collect(),
)
} else {
HashMap::default()
(HashMap::default(), HashMap::default())
};
Ok(MultiProof { account_subtree, branch_node_hash_masks, storages })
Ok(MultiProof { account_subtree, branch_node_hash_masks, branch_node_tree_masks, storages })
}
}
@ -190,8 +195,8 @@ pub struct StorageProof<T, H> {
hashed_address: B256,
/// The set of storage slot prefixes that have changed.
prefix_set: PrefixSetMut,
/// Flag indicating whether to include branch node hash masks in the proof.
collect_branch_node_hash_masks: bool,
/// Flag indicating whether to include branch node masks in the proof.
collect_branch_node_masks: bool,
}
impl<T, H> StorageProof<T, H> {
@ -207,7 +212,7 @@ impl<T, H> StorageProof<T, H> {
hashed_cursor_factory: h,
hashed_address,
prefix_set: PrefixSetMut::default(),
collect_branch_node_hash_masks: false,
collect_branch_node_masks: false,
}
}
@ -218,7 +223,7 @@ impl<T, H> StorageProof<T, H> {
hashed_cursor_factory: self.hashed_cursor_factory,
hashed_address: self.hashed_address,
prefix_set: self.prefix_set,
collect_branch_node_hash_masks: self.collect_branch_node_hash_masks,
collect_branch_node_masks: self.collect_branch_node_masks,
}
}
@ -229,7 +234,7 @@ impl<T, H> StorageProof<T, H> {
hashed_cursor_factory,
hashed_address: self.hashed_address,
prefix_set: self.prefix_set,
collect_branch_node_hash_masks: self.collect_branch_node_hash_masks,
collect_branch_node_masks: self.collect_branch_node_masks,
}
}
@ -239,9 +244,9 @@ impl<T, H> StorageProof<T, H> {
self
}
/// Set the flag indicating whether to include branch node hash masks in the proof.
pub const fn with_branch_node_hash_masks(mut self, branch_node_hash_masks: bool) -> Self {
self.collect_branch_node_hash_masks = branch_node_hash_masks;
/// Set the flag indicating whether to include branch node masks in the proof.
pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
self.collect_branch_node_masks = branch_node_masks;
self
}
}
@ -282,7 +287,7 @@ where
let retainer = ProofRetainer::from_iter(target_nibbles);
let mut hash_builder = HashBuilder::default()
.with_proof_retainer(retainer)
.with_updates(self.collect_branch_node_hash_masks);
.with_updates(self.collect_branch_node_masks);
let mut storage_node_iter = TrieNodeIter::new(walker, hashed_storage_cursor);
while let Some(node) = storage_node_iter.try_next()? {
match node {
@ -300,17 +305,22 @@ where
let root = hash_builder.root();
let subtree = hash_builder.take_proof_nodes();
let branch_node_hash_masks = if self.collect_branch_node_hash_masks {
hash_builder
.updated_branch_nodes
.unwrap_or_default()
.into_iter()
.map(|(path, node)| (path, node.hash_mask))
.collect()
let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
(
updated_branch_nodes
.iter()
.map(|(path, node)| (path.clone(), node.hash_mask))
.collect(),
updated_branch_nodes
.into_iter()
.map(|(path, node)| (path, node.tree_mask))
.collect(),
)
} else {
HashMap::default()
(HashMap::default(), HashMap::default())
};
Ok(StorageMultiProof { root, subtree, branch_node_hash_masks })
Ok(StorageMultiProof { root, subtree, branch_node_hash_masks, branch_node_tree_masks })
}
}