refactor(trie): struct for passing hash and tree masks in sparse trie (#14468)

This commit is contained in:
Shourya Chaudhry
2025-02-14 18:33:58 +05:30
committed by GitHub
parent 1c09351a93
commit 8c2bcf11db
2 changed files with 127 additions and 73 deletions

View File

@ -1,6 +1,6 @@
use crate::{
blinded::{BlindedProvider, BlindedProviderFactory, DefaultBlindedProviderFactory},
RevealedSparseTrie, SparseTrie,
RevealedSparseTrie, SparseTrie, TrieMasks,
};
use alloy_primitives::{
hex,
@ -174,8 +174,7 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
let trie = self.state.reveal_root_with_provider(
self.provider_factory.account_node_provider(),
root_node,
None,
None,
TrieMasks::none(),
self.retain_updates,
)?;
@ -185,7 +184,7 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
continue
}
let node = TrieNode::decode(&mut &bytes[..])?;
trie.reveal_node(path.clone(), node, None, None)?;
trie.reveal_node(path.clone(), node, TrieMasks::none())?;
// Track the revealed path.
self.revealed_account_paths.insert(path);
@ -219,8 +218,7 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
let trie = self.storages.entry(account).or_default().reveal_root_with_provider(
self.provider_factory.storage_node_provider(account),
root_node,
None,
None,
TrieMasks::none(),
self.retain_updates,
)?;
@ -233,7 +231,7 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
continue
}
let node = TrieNode::decode(&mut &bytes[..])?;
trie.reveal_node(path.clone(), node, None, None)?;
trie.reveal_node(path.clone(), node, TrieMasks::none())?;
// Track the revealed path.
revealed_nodes.insert(path);
@ -253,8 +251,10 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
let trie = self.state.reveal_root_with_provider(
self.provider_factory.account_node_provider(),
root_node,
multiproof.branch_node_hash_masks.get(&Nibbles::default()).copied(),
multiproof.branch_node_tree_masks.get(&Nibbles::default()).copied(),
TrieMasks {
hash_mask: multiproof.branch_node_hash_masks.get(&Nibbles::default()).copied(),
tree_mask: multiproof.branch_node_tree_masks.get(&Nibbles::default()).copied(),
},
self.retain_updates,
)?;
@ -275,7 +275,7 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
};
trace!(target: "trie::sparse", ?path, ?node, ?hash_mask, ?tree_mask, "Revealing account node");
trie.reveal_node(path.clone(), node, tree_mask, hash_mask)?;
trie.reveal_node(path.clone(), node, TrieMasks { hash_mask, tree_mask })?;
// Track the revealed path.
self.revealed_account_paths.insert(path);
@ -291,8 +291,16 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
let trie = self.storages.entry(account).or_default().reveal_root_with_provider(
self.provider_factory.storage_node_provider(account),
root_node,
storage_subtree.branch_node_hash_masks.get(&Nibbles::default()).copied(),
storage_subtree.branch_node_tree_masks.get(&Nibbles::default()).copied(),
TrieMasks {
hash_mask: storage_subtree
.branch_node_hash_masks
.get(&Nibbles::default())
.copied(),
tree_mask: storage_subtree
.branch_node_tree_masks
.get(&Nibbles::default())
.copied(),
},
self.retain_updates,
)?;
let revealed_nodes = self.revealed_storage_paths.entry(account).or_default();
@ -314,7 +322,7 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
};
trace!(target: "trie::sparse", ?account, ?path, ?node, ?hash_mask, ?tree_mask, "Revealing storage node");
trie.reveal_node(path.clone(), node, tree_mask, hash_mask)?;
trie.reveal_node(path.clone(), node, TrieMasks { hash_mask, tree_mask })?;
// Track the revealed path.
revealed_nodes.insert(path);
@ -392,8 +400,7 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
storage_trie_entry.reveal_root_with_provider(
self.provider_factory.storage_node_provider(account),
trie_node,
None,
None,
TrieMasks::none(),
self.retain_updates,
)?;
} else {
@ -401,7 +408,7 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
storage_trie_entry
.as_revealed_mut()
.ok_or(SparseTrieErrorKind::Blind)?
.reveal_node(path.clone(), trie_node, None, None)?;
.reveal_node(path.clone(), trie_node, TrieMasks::none())?;
}
// Track the revealed path.
@ -415,8 +422,7 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
self.state.reveal_root_with_provider(
self.provider_factory.account_node_provider(),
trie_node,
None,
None,
TrieMasks::none(),
self.retain_updates,
)?;
} else {
@ -424,8 +430,7 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
self.state.as_revealed_mut().ok_or(SparseTrieErrorKind::Blind)?.reveal_node(
path.clone(),
trie_node,
None,
None,
TrieMasks::none(),
)?;
}
@ -503,8 +508,7 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
.reveal_root_with_provider(
self.provider_factory.account_node_provider(),
root_node,
hash_mask,
tree_mask,
TrieMasks { hash_mask, tree_mask },
self.retain_updates,
)
.map_err(Into::into)

View File

@ -15,6 +15,22 @@ use reth_trie_common::{
use smallvec::SmallVec;
use std::{borrow::Cow, fmt};
/// Struct for passing around `hash_mask` and `tree_mask`
#[derive(Debug)]
pub struct TrieMasks {
/// Branch node hash mask, if any.
pub hash_mask: Option<TrieMask>,
/// Branch node tree mask, if any.
pub tree_mask: Option<TrieMask>,
}
impl TrieMasks {
/// Helper function, returns both fields `hash_mask` and `tree_mask` as [`None`]
pub fn none() -> Self {
Self { hash_mask: None, tree_mask: None }
}
}
/// Inner representation of the sparse trie.
/// Sparse trie is blind by default until nodes are revealed.
#[derive(PartialEq, Eq)]
@ -59,17 +75,10 @@ impl SparseTrie {
pub fn reveal_root(
&mut self,
root: TrieNode,
hash_mask: Option<TrieMask>,
tree_mask: Option<TrieMask>,
masks: TrieMasks,
retain_updates: bool,
) -> SparseTrieResult<&mut RevealedSparseTrie> {
self.reveal_root_with_provider(
Default::default(),
root,
hash_mask,
tree_mask,
retain_updates,
)
self.reveal_root_with_provider(Default::default(), root, masks, retain_updates)
}
}
@ -106,16 +115,14 @@ impl<P> SparseTrie<P> {
&mut self,
provider: P,
root: TrieNode,
hash_mask: Option<TrieMask>,
tree_mask: Option<TrieMask>,
masks: TrieMasks,
retain_updates: bool,
) -> SparseTrieResult<&mut RevealedSparseTrie<P>> {
if self.is_blind() {
*self = Self::Revealed(Box::new(RevealedSparseTrie::from_provider_and_root(
provider,
root,
hash_mask,
tree_mask,
masks,
retain_updates,
)?))
}
@ -218,8 +225,7 @@ impl RevealedSparseTrie {
/// Create new revealed sparse trie from the given root node.
pub fn from_root(
node: TrieNode,
hash_mask: Option<TrieMask>,
tree_mask: Option<TrieMask>,
masks: TrieMasks,
retain_updates: bool,
) -> SparseTrieResult<Self> {
let mut this = Self {
@ -233,7 +239,7 @@ impl RevealedSparseTrie {
updates: None,
}
.with_updates(retain_updates);
this.reveal_node(Nibbles::default(), node, tree_mask, hash_mask)?;
this.reveal_node(Nibbles::default(), node, masks)?;
Ok(this)
}
}
@ -243,8 +249,7 @@ impl<P> RevealedSparseTrie<P> {
pub fn from_provider_and_root(
provider: P,
node: TrieNode,
hash_mask: Option<TrieMask>,
tree_mask: Option<TrieMask>,
masks: TrieMasks,
retain_updates: bool,
) -> SparseTrieResult<Self> {
let mut this = Self {
@ -258,7 +263,7 @@ impl<P> RevealedSparseTrie<P> {
updates: None,
}
.with_updates(retain_updates);
this.reveal_node(Nibbles::default(), node, tree_mask, hash_mask)?;
this.reveal_node(Nibbles::default(), node, masks)?;
Ok(this)
}
@ -309,18 +314,17 @@ impl<P> RevealedSparseTrie<P> {
&mut self,
path: Nibbles,
node: TrieNode,
tree_mask: Option<TrieMask>,
hash_mask: Option<TrieMask>,
masks: TrieMasks,
) -> SparseTrieResult<()> {
// If the node is already revealed and it's not a hash node, do nothing.
if self.nodes.get(&path).is_some_and(|node| !node.is_hash()) {
return Ok(())
}
if let Some(tree_mask) = tree_mask {
if let Some(tree_mask) = masks.tree_mask {
self.branch_node_tree_masks.insert(path.clone(), tree_mask);
}
if let Some(hash_mask) = hash_mask {
if let Some(hash_mask) = masks.hash_mask {
self.branch_node_hash_masks.insert(path.clone(), hash_mask);
}
@ -350,8 +354,8 @@ impl<P> RevealedSparseTrie<P> {
// node.
hash: Some(*hash),
store_in_db_trie: Some(
hash_mask.is_some_and(|mask| !mask.is_empty()) ||
tree_mask.is_some_and(|mask| !mask.is_empty()),
masks.hash_mask.is_some_and(|mask| !mask.is_empty()) ||
masks.tree_mask.is_some_and(|mask| !mask.is_empty()),
),
});
}
@ -465,7 +469,7 @@ impl<P> RevealedSparseTrie<P> {
return Ok(())
}
self.reveal_node(path, TrieNode::decode(&mut &child[..])?, None, None)
self.reveal_node(path, TrieNode::decode(&mut &child[..])?, TrieMasks::none())
}
/// Traverse trie nodes down to the leaf node and collect all nodes along the path.
@ -1057,8 +1061,7 @@ impl<P: BlindedProvider> RevealedSparseTrie<P> {
self.reveal_node(
current.clone(),
decoded,
tree_mask,
hash_mask,
TrieMasks { hash_mask, tree_mask },
)?;
}
}
@ -1221,8 +1224,7 @@ impl<P: BlindedProvider> RevealedSparseTrie<P> {
self.reveal_node(
child_path.clone(),
decoded,
tree_mask,
hash_mask,
TrieMasks { hash_mask, tree_mask },
)?;
}
}
@ -2107,17 +2109,28 @@ mod tests {
TrieMask::new(0b11),
));
let mut sparse =
RevealedSparseTrie::from_root(branch.clone(), Some(TrieMask::new(0b01)), None, false)
.unwrap();
let mut sparse = RevealedSparseTrie::from_root(
branch.clone(),
TrieMasks { hash_mask: Some(TrieMask::new(0b01)), tree_mask: None },
false,
)
.unwrap();
// Reveal a branch node and one of its children
//
// Branch (Mask = 11)
// ├── 0 -> Hash (Path = 0)
// └── 1 -> Leaf (Path = 1)
sparse.reveal_node(Nibbles::default(), branch, Some(TrieMask::new(0b01)), None).unwrap();
sparse.reveal_node(Nibbles::from_nibbles([0x1]), TrieNode::Leaf(leaf), None, None).unwrap();
sparse
.reveal_node(
Nibbles::default(),
branch,
TrieMasks { hash_mask: None, tree_mask: Some(TrieMask::new(0b01)) },
)
.unwrap();
sparse
.reveal_node(Nibbles::from_nibbles([0x1]), TrieNode::Leaf(leaf), TrieMasks::none())
.unwrap();
// Removing a blinded leaf should result in an error
assert_matches!(
@ -2140,17 +2153,28 @@ mod tests {
TrieMask::new(0b11),
));
let mut sparse =
RevealedSparseTrie::from_root(branch.clone(), Some(TrieMask::new(0b01)), None, false)
.unwrap();
let mut sparse = RevealedSparseTrie::from_root(
branch.clone(),
TrieMasks { hash_mask: Some(TrieMask::new(0b01)), tree_mask: None },
false,
)
.unwrap();
// Reveal a branch node and one of its children
//
// Branch (Mask = 11)
// ├── 0 -> Hash (Path = 0)
// └── 1 -> Leaf (Path = 1)
sparse.reveal_node(Nibbles::default(), branch, Some(TrieMask::new(0b01)), None).unwrap();
sparse.reveal_node(Nibbles::from_nibbles([0x1]), TrieNode::Leaf(leaf), None, None).unwrap();
sparse
.reveal_node(
Nibbles::default(),
branch,
TrieMasks { hash_mask: None, tree_mask: Some(TrieMask::new(0b01)) },
)
.unwrap();
sparse
.reveal_node(Nibbles::from_nibbles([0x1]), TrieNode::Leaf(leaf), TrieMasks::none())
.unwrap();
// Removing a non-existent leaf should be a noop
let sparse_old = sparse.clone();
@ -2329,8 +2353,10 @@ mod tests {
);
let mut sparse = RevealedSparseTrie::from_root(
TrieNode::decode(&mut &hash_builder_proof_nodes.nodes_sorted()[0].1[..]).unwrap(),
branch_node_hash_masks.get(&Nibbles::default()).copied(),
branch_node_tree_masks.get(&Nibbles::default()).copied(),
TrieMasks {
hash_mask: branch_node_hash_masks.get(&Nibbles::default()).copied(),
tree_mask: branch_node_tree_masks.get(&Nibbles::default()).copied(),
},
false,
)
.unwrap();
@ -2347,7 +2373,11 @@ mod tests {
let hash_mask = branch_node_hash_masks.get(&path).copied();
let tree_mask = branch_node_tree_masks.get(&path).copied();
sparse
.reveal_node(path, TrieNode::decode(&mut &node[..]).unwrap(), tree_mask, hash_mask)
.reveal_node(
path,
TrieNode::decode(&mut &node[..]).unwrap(),
TrieMasks { hash_mask, tree_mask },
)
.unwrap();
}
@ -2378,7 +2408,11 @@ mod tests {
let hash_mask = branch_node_hash_masks.get(&path).copied();
let tree_mask = branch_node_tree_masks.get(&path).copied();
sparse
.reveal_node(path, TrieNode::decode(&mut &node[..]).unwrap(), tree_mask, hash_mask)
.reveal_node(
path,
TrieNode::decode(&mut &node[..]).unwrap(),
TrieMasks { hash_mask, tree_mask },
)
.unwrap();
}
@ -2427,8 +2461,10 @@ mod tests {
);
let mut sparse = RevealedSparseTrie::from_root(
TrieNode::decode(&mut &hash_builder_proof_nodes.nodes_sorted()[0].1[..]).unwrap(),
branch_node_hash_masks.get(&Nibbles::default()).copied(),
branch_node_tree_masks.get(&Nibbles::default()).copied(),
TrieMasks {
hash_mask: branch_node_hash_masks.get(&Nibbles::default()).copied(),
tree_mask: branch_node_tree_masks.get(&Nibbles::default()).copied(),
},
false,
)
.unwrap();
@ -2446,7 +2482,11 @@ mod tests {
let hash_mask = branch_node_hash_masks.get(&path).copied();
let tree_mask = branch_node_tree_masks.get(&path).copied();
sparse
.reveal_node(path, TrieNode::decode(&mut &node[..]).unwrap(), tree_mask, hash_mask)
.reveal_node(
path,
TrieNode::decode(&mut &node[..]).unwrap(),
TrieMasks { hash_mask, tree_mask },
)
.unwrap();
}
@ -2477,7 +2517,11 @@ mod tests {
let hash_mask = branch_node_hash_masks.get(&path).copied();
let tree_mask = branch_node_tree_masks.get(&path).copied();
sparse
.reveal_node(path, TrieNode::decode(&mut &node[..]).unwrap(), tree_mask, hash_mask)
.reveal_node(
path,
TrieNode::decode(&mut &node[..]).unwrap(),
TrieMasks { hash_mask, tree_mask },
)
.unwrap();
}
@ -2518,8 +2562,10 @@ mod tests {
);
let mut sparse = RevealedSparseTrie::from_root(
TrieNode::decode(&mut &hash_builder_proof_nodes.nodes_sorted()[0].1[..]).unwrap(),
branch_node_hash_masks.get(&Nibbles::default()).copied(),
branch_node_tree_masks.get(&Nibbles::default()).copied(),
TrieMasks {
hash_mask: branch_node_hash_masks.get(&Nibbles::default()).copied(),
tree_mask: branch_node_tree_masks.get(&Nibbles::default()).copied(),
},
false,
)
.unwrap();
@ -2551,7 +2597,11 @@ mod tests {
let hash_mask = branch_node_hash_masks.get(&path).copied();
let tree_mask = branch_node_tree_masks.get(&path).copied();
sparse
.reveal_node(path, TrieNode::decode(&mut &node[..]).unwrap(), tree_mask, hash_mask)
.reveal_node(
path,
TrieNode::decode(&mut &node[..]).unwrap(),
TrieMasks { hash_mask, tree_mask },
)
.unwrap();
}