feat(trie): witness (#9803)

This commit is contained in:
Roman Krasiuk
2024-07-30 13:18:20 -07:00
committed by GitHub
parent d90f2396e5
commit 2c2a782bb8
8 changed files with 324 additions and 11 deletions

1
Cargo.lock generated
View File

@ -7297,6 +7297,7 @@ dependencies = [
"alloy-eips", "alloy-eips",
"alloy-primitives", "alloy-primitives",
"alloy-rlp", "alloy-rlp",
"nybbles",
"reth-consensus", "reth-consensus",
"reth-prune-types", "reth-prune-types",
"reth-storage-errors", "reth-storage-errors",

View File

@ -20,6 +20,7 @@ alloy-primitives.workspace = true
alloy-rlp.workspace = true alloy-rlp.workspace = true
alloy-eips.workspace = true alloy-eips.workspace = true
revm-primitives.workspace = true revm-primitives.workspace = true
nybbles.workspace = true
thiserror-no-std = { workspace = true, default-features = false } thiserror-no-std = { workspace = true, default-features = false }

View File

@ -1,5 +1,7 @@
//! Errors when computing the state root. //! Errors when computing the state root.
use alloy_primitives::B256;
use nybbles::Nibbles;
use reth_storage_errors::{db::DatabaseError, provider::ProviderError}; use reth_storage_errors::{db::DatabaseError, provider::ProviderError};
use thiserror_no_std::Error; use thiserror_no_std::Error;
@ -23,6 +25,26 @@ impl From<StateProofError> for ProviderError {
} }
} }
/// Trie witness errors.
#[derive(Error, Debug, PartialEq, Eq, Clone)]
pub enum TrieWitnessError {
/// Error gather proofs.
#[error(transparent)]
Proof(#[from] StateProofError),
/// RLP decoding error.
#[error(transparent)]
Rlp(#[from] alloy_rlp::Error),
/// Missing storage multiproof.
#[error("missing storage multiproof for {0}")]
MissingStorageMultiProof(B256),
/// Missing account.
#[error("missing account {0}")]
MissingAccount(B256),
/// Missing target node.
#[error("target node missing from proof {0:?}")]
MissingTargetNode(Nibbles),
}
/// State root errors. /// State root errors.
#[derive(Error, Debug, PartialEq, Eq, Clone)] #[derive(Error, Debug, PartialEq, Eq, Clone)]
pub enum StateRootError { pub enum StateRootError {

View File

@ -17,7 +17,7 @@ pub struct MultiProof {
/// State trie multiproof for requested accounts. /// State trie multiproof for requested accounts.
pub account_subtree: BTreeMap<Nibbles, Bytes>, pub account_subtree: BTreeMap<Nibbles, Bytes>,
/// Storage trie multiproofs. /// Storage trie multiproofs.
pub storage_multiproofs: HashMap<B256, StorageMultiProof>, pub storages: HashMap<B256, StorageMultiProof>,
} }
impl MultiProof { impl MultiProof {
@ -58,7 +58,7 @@ impl MultiProof {
}; };
// Retrieve proofs for requested storage slots. // Retrieve proofs for requested storage slots.
let storage_multiproof = self.storage_multiproofs.get(&hashed_address); let storage_multiproof = self.storages.get(&hashed_address);
let storage_root = storage_multiproof.map(|m| m.root).unwrap_or(EMPTY_ROOT_HASH); let storage_root = storage_multiproof.map(|m| m.root).unwrap_or(EMPTY_ROOT_HASH);
let mut storage_proofs = Vec::with_capacity(slots.len()); let mut storage_proofs = Vec::with_capacity(slots.len());
for slot in slots { for slot in slots {

View File

@ -39,6 +39,9 @@ pub use state::*;
/// Merkle proof generation. /// Merkle proof generation.
pub mod proof; pub mod proof;
/// Trie witness generation.
pub mod witness;
/// The implementation of the Merkle Patricia Trie. /// The implementation of the Merkle Patricia Trie.
mod trie; mod trie;
pub use trie::{StateRoot, StorageRoot}; pub use trie::{StateRoot, StorageRoot};

View File

@ -9,7 +9,7 @@ mod loader;
pub use loader::PrefixSetLoader; pub use loader::PrefixSetLoader;
/// Collection of mutable prefix sets. /// Collection of mutable prefix sets.
#[derive(Default, Debug)] #[derive(Clone, Default, Debug)]
pub struct TriePrefixSetsMut { pub struct TriePrefixSetsMut {
/// A set of account prefixes that have changed. /// A set of account prefixes that have changed.
pub account_prefix_set: PrefixSetMut, pub account_prefix_set: PrefixSetMut,
@ -75,7 +75,7 @@ pub struct TriePrefixSets {
/// assert!(prefix_set.contains(&[0xa, 0xb])); /// assert!(prefix_set.contains(&[0xa, 0xb]));
/// assert!(prefix_set.contains(&[0xa, 0xb, 0xc])); /// assert!(prefix_set.contains(&[0xa, 0xb, 0xc]));
/// ``` /// ```
#[derive(Debug, Default, Clone)] #[derive(Clone, Default, Debug)]
pub struct PrefixSetMut { pub struct PrefixSetMut {
keys: Vec<Nibbles>, keys: Vec<Nibbles>,
sorted: bool, sorted: bool,

View File

@ -21,10 +21,10 @@ use std::collections::HashMap;
/// See `StateRoot::root` for more info. /// See `StateRoot::root` for more info.
#[derive(Debug)] #[derive(Debug)]
pub struct Proof<T, H> { pub struct Proof<T, H> {
/// The factory for traversing trie nodes.
trie_cursor_factory: T,
/// The factory for hashed cursors. /// The factory for hashed cursors.
hashed_cursor_factory: H, hashed_cursor_factory: H,
/// Creates cursor for traversing trie entities.
trie_cursor_factory: T,
/// A set of prefix sets that have changes. /// A set of prefix sets that have changes.
prefix_sets: TriePrefixSetsMut, prefix_sets: TriePrefixSetsMut,
/// Proof targets. /// Proof targets.
@ -81,12 +81,12 @@ where
keccak256(address), keccak256(address),
slots.iter().map(keccak256).collect(), slots.iter().map(keccak256).collect(),
)])) )]))
.multi_proof()? .multiproof()?
.account_proof(address, slots)?) .account_proof(address, slots)?)
} }
/// Generate a state multiproof according to specified targets. /// Generate a state multiproof according to specified targets.
pub fn multi_proof(&self) -> Result<MultiProof, StateProofError> { pub fn multiproof(&self) -> Result<MultiProof, StateProofError> {
let hashed_account_cursor = self.hashed_cursor_factory.hashed_account_cursor()?; let hashed_account_cursor = self.hashed_cursor_factory.hashed_account_cursor()?;
let trie_cursor = self.trie_cursor_factory.account_trie_cursor()?; let trie_cursor = self.trie_cursor_factory.account_trie_cursor()?;
@ -99,7 +99,7 @@ where
let retainer = ProofRetainer::from_iter(self.targets.keys().map(Nibbles::unpack)); let retainer = ProofRetainer::from_iter(self.targets.keys().map(Nibbles::unpack));
let mut hash_builder = HashBuilder::default().with_proof_retainer(retainer); let mut hash_builder = HashBuilder::default().with_proof_retainer(retainer);
let mut storage_multiproofs = HashMap::default(); let mut storages = HashMap::default();
let mut account_rlp = Vec::with_capacity(128); let mut account_rlp = Vec::with_capacity(128);
let mut account_node_iter = TrieNodeIter::new(walker, hashed_account_cursor); let mut account_node_iter = TrieNodeIter::new(walker, hashed_account_cursor);
while let Some(account_node) = account_node_iter.try_next()? { while let Some(account_node) = account_node_iter.try_next()? {
@ -116,12 +116,12 @@ where
account.encode(&mut account_rlp as &mut dyn BufMut); account.encode(&mut account_rlp as &mut dyn BufMut);
hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp); hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
storage_multiproofs.insert(hashed_address, storage_multiproof); storages.insert(hashed_address, storage_multiproof);
} }
} }
} }
let _ = hash_builder.root(); let _ = hash_builder.root();
Ok(MultiProof { account_subtree: hash_builder.take_proofs(), storage_multiproofs }) Ok(MultiProof { account_subtree: hash_builder.take_proofs(), storages })
} }
/// Generate a storage multiproof according to specified targets. /// Generate a storage multiproof according to specified targets.

View File

@ -0,0 +1,286 @@
use crate::{
hashed_cursor::HashedCursorFactory, prefix_set::TriePrefixSetsMut, proof::Proof,
trie_cursor::TrieCursorFactory, HashedPostState,
};
use alloy_rlp::{BufMut, Decodable, Encodable};
use itertools::Either;
use reth_execution_errors::{StateProofError, TrieWitnessError};
use reth_primitives::{constants::EMPTY_ROOT_HASH, keccak256, Bytes, B256};
use reth_trie_common::{
BranchNode, HashBuilder, Nibbles, TrieAccount, TrieNode, CHILD_INDEX_RANGE,
};
use std::collections::{BTreeMap, HashMap, HashSet};
/// State transition witness for the trie.
#[derive(Debug)]
pub struct TrieWitness<T, H> {
/// The cursor factory for traversing trie nodes.
trie_cursor_factory: T,
/// The factory for hashed cursors.
hashed_cursor_factory: H,
/// A set of prefix sets that have changes.
prefix_sets: TriePrefixSetsMut,
/// Recorded witness.
witness: HashMap<B256, Bytes>,
}
impl<T, H> TrieWitness<T, H> {
/// Creates a new witness generator.
pub fn new(trie_cursor_factory: T, hashed_cursor_factory: H) -> Self {
Self {
trie_cursor_factory,
hashed_cursor_factory,
prefix_sets: TriePrefixSetsMut::default(),
witness: HashMap::default(),
}
}
/// Set the hashed cursor factory.
pub fn with_hashed_cursor_factory<HF>(self, hashed_cursor_factory: HF) -> TrieWitness<T, HF> {
TrieWitness {
trie_cursor_factory: self.trie_cursor_factory,
hashed_cursor_factory,
prefix_sets: self.prefix_sets,
witness: self.witness,
}
}
/// Set the prefix sets. They have to be mutable in order to allow extension with proof target.
pub fn with_prefix_sets_mut(mut self, prefix_sets: TriePrefixSetsMut) -> Self {
self.prefix_sets = prefix_sets;
self
}
}
impl<T, H> TrieWitness<T, H>
where
T: TrieCursorFactory + Clone,
H: HashedCursorFactory + Clone,
{
/// Compute the state transition witness for the trie. Gather all required nodes
/// to apply `state` on top of the current trie state.
///
/// # Arguments
///
/// `state` - state transition containing both modified and touched accounts and storage slots.
pub fn compute(
mut self,
state: HashedPostState,
) -> Result<HashMap<B256, Bytes>, TrieWitnessError> {
let proof_targets = HashMap::from_iter(
state.accounts.keys().map(|hashed_address| (*hashed_address, Vec::new())).chain(
state.storages.iter().map(|(hashed_address, storage)| {
(*hashed_address, storage.storage.keys().copied().collect())
}),
),
);
let account_multiproof =
Proof::new(self.trie_cursor_factory.clone(), self.hashed_cursor_factory.clone())
.with_prefix_sets_mut(self.prefix_sets.clone())
.with_targets(proof_targets.clone())
.multiproof()?;
// Attempt to compute state root from proofs and gather additional
// information for the witness.
let mut account_rlp = Vec::with_capacity(128);
let mut account_trie_nodes = BTreeMap::default();
for (hashed_address, hashed_slots) in proof_targets {
let key = Nibbles::unpack(hashed_address);
let storage_multiproof = account_multiproof
.storages
.get(&hashed_address)
.ok_or(TrieWitnessError::MissingStorageMultiProof(hashed_address))?;
// Gather and record account trie nodes.
let account = state
.accounts
.get(&hashed_address)
.ok_or(TrieWitnessError::MissingAccount(hashed_address))?;
let value = if account.is_some() || storage_multiproof.root != EMPTY_ROOT_HASH {
account_rlp.clear();
TrieAccount::from((account.unwrap_or_default(), storage_multiproof.root))
.encode(&mut account_rlp as &mut dyn BufMut);
Some(account_rlp.clone())
} else {
None
};
let proof = account_multiproof.account_subtree.iter().filter(|e| key.starts_with(e.0));
account_trie_nodes.extend(self.target_nodes(key.clone(), value, proof)?);
// Gather and record storage trie nodes for this account.
let mut storage_trie_nodes = BTreeMap::default();
let storage = state.storages.get(&hashed_address);
for hashed_slot in hashed_slots {
let slot_key = Nibbles::unpack(hashed_slot);
let slot_value = storage
.and_then(|s| s.storage.get(&hashed_slot))
.filter(|v| !v.is_zero())
.map(|v| alloy_rlp::encode_fixed_size(v).to_vec());
let proof = storage_multiproof.subtree.iter().filter(|e| slot_key.starts_with(e.0));
storage_trie_nodes.extend(self.target_nodes(
slot_key.clone(),
slot_value,
proof,
)?);
}
let root = Self::next_root_from_proofs(storage_trie_nodes, |key: Nibbles| {
// Right pad the target with 0s.
let mut padded_key = key.pack();
padded_key.resize(32, 0);
let mut proof = Proof::new(
self.trie_cursor_factory.clone(),
self.hashed_cursor_factory.clone(),
)
.with_prefix_sets_mut(self.prefix_sets.clone())
.with_targets(HashMap::from([(B256::from_slice(&padded_key), Vec::new())]))
.storage_multiproof(hashed_address)?;
// The subtree only contains the proof for a single target.
let node =
proof.subtree.remove(&key).ok_or(TrieWitnessError::MissingTargetNode(key))?;
self.witness.insert(keccak256(node.as_ref()), node.clone()); // record in witness
Ok(node)
})?;
debug_assert_eq!(storage_multiproof.root, root);
}
Self::next_root_from_proofs(account_trie_nodes, |key: Nibbles| {
// Right pad the target with 0s.
let mut padded_key = key.pack();
padded_key.resize(32, 0);
let mut proof =
Proof::new(self.trie_cursor_factory.clone(), self.hashed_cursor_factory.clone())
.with_prefix_sets_mut(self.prefix_sets.clone())
.with_targets(HashMap::from([(B256::from_slice(&padded_key), Vec::new())]))
.multiproof()?;
// The subtree only contains the proof for a single target.
let node = proof
.account_subtree
.remove(&key)
.ok_or(TrieWitnessError::MissingTargetNode(key))?;
self.witness.insert(keccak256(node.as_ref()), node.clone()); // record in witness
Ok(node)
})?;
Ok(self.witness)
}
/// Decodes and unrolls all nodes from the proof. Returns only sibling nodes
/// in the path of the target and the final leaf node with updated value.
fn target_nodes<'b>(
&mut self,
key: Nibbles,
value: Option<Vec<u8>>,
proof: impl IntoIterator<Item = (&'b Nibbles, &'b Bytes)>,
) -> Result<BTreeMap<Nibbles, Either<B256, Vec<u8>>>, StateProofError> {
let mut trie_nodes = BTreeMap::default();
for (path, encoded) in proof {
// Record the node in witness.
self.witness.insert(keccak256(encoded.as_ref()), encoded.clone());
let mut next_path = path.clone();
match TrieNode::decode(&mut &encoded[..])? {
TrieNode::Branch(branch) => {
next_path.push(key[path.len()]);
let children = branch_node_children(path.clone(), &branch);
for (child_path, node_hash) in children {
if !key.starts_with(&child_path) {
trie_nodes.insert(child_path, Either::Left(node_hash));
}
}
}
TrieNode::Extension(extension) => {
next_path.extend_from_slice(&extension.key);
}
TrieNode::Leaf(leaf) => {
next_path.extend_from_slice(&leaf.key);
if next_path != key {
trie_nodes.insert(next_path.clone(), Either::Right(leaf.value.clone()));
}
}
};
}
if let Some(value) = value {
trie_nodes.insert(key, Either::Right(value));
}
Ok(trie_nodes)
}
fn next_root_from_proofs(
trie_nodes: BTreeMap<Nibbles, Either<B256, Vec<u8>>>,
mut trie_node_provider: impl FnMut(Nibbles) -> Result<Bytes, TrieWitnessError>,
) -> Result<B256, TrieWitnessError> {
// Ignore branch child hashes in the path of leaves or lower child hashes.
let mut keys = trie_nodes.keys().peekable();
let mut ignored = HashSet::<Nibbles>::default();
while let Some(key) = keys.next() {
if keys.peek().map_or(false, |next| next.starts_with(key)) {
ignored.insert(key.clone());
}
}
let mut hash_builder = HashBuilder::default();
let mut trie_nodes = trie_nodes.into_iter().filter(|e| !ignored.contains(&e.0)).peekable();
while let Some((path, value)) = trie_nodes.next() {
match value {
Either::Left(branch_hash) => {
let parent_branch_path = path.slice(..path.len() - 1);
if hash_builder.key.starts_with(&parent_branch_path) ||
trie_nodes
.peek()
.map_or(false, |next| next.0.starts_with(&parent_branch_path))
{
hash_builder.add_branch(path, branch_hash, false);
} else {
// Parent is a branch node that needs to be turned into an extension node.
let mut path = path.clone();
loop {
let node = trie_node_provider(path.clone())?;
match TrieNode::decode(&mut &node[..])? {
TrieNode::Branch(branch) => {
let children = branch_node_children(path, &branch);
for (child_path, branch_hash) in children {
hash_builder.add_branch(child_path, branch_hash, false);
}
break
}
TrieNode::Leaf(leaf) => {
let mut child_path = path;
child_path.extend_from_slice(&leaf.key);
hash_builder.add_leaf(child_path, &leaf.value);
break
}
TrieNode::Extension(ext) => {
path.extend_from_slice(&ext.key);
}
}
}
}
}
Either::Right(leaf_value) => {
hash_builder.add_leaf(path, &leaf_value);
}
}
}
Ok(hash_builder.root())
}
}
/// Returned branch node children with keys in order.
fn branch_node_children(prefix: Nibbles, node: &BranchNode) -> Vec<(Nibbles, B256)> {
let mut children = Vec::with_capacity(node.state_mask.count_ones() as usize);
let mut stack_ptr = node.as_ref().first_child_index();
for index in CHILD_INDEX_RANGE {
if node.state_mask.is_bit_set(index) {
let mut child_path = prefix.clone();
child_path.push(index);
children.push((child_path, B256::from_slice(&node.stack[stack_ptr][1..])));
stack_ptr += 1;
}
}
children
}