From 12cab204b50a824bcfb55f6b6d46f2e9cda34c31 Mon Sep 17 00:00:00 2001 From: Roman Krasiuk Date: Wed, 16 Oct 2024 19:21:25 +0200 Subject: [PATCH] fix(witness): branch node children decoding (#11599) --- crates/trie/db/tests/witness.rs | 52 +++++++++++++++++++++++++++++++++ crates/trie/trie/src/witness.rs | 22 ++++++++++---- 2 files changed, 68 insertions(+), 6 deletions(-) diff --git a/crates/trie/db/tests/witness.rs b/crates/trie/db/tests/witness.rs index 59656383d..20f8cfbb9 100644 --- a/crates/trie/db/tests/witness.rs +++ b/crates/trie/db/tests/witness.rs @@ -6,6 +6,8 @@ use alloy_primitives::{ Address, Bytes, B256, U256, }; use alloy_rlp::EMPTY_STRING_CODE; +use reth_db::{cursor::DbCursorRW, tables}; +use reth_db_api::transaction::DbTxMut; use reth_primitives::{constants::EMPTY_ROOT_HASH, Account, StorageEntry}; use reth_provider::{test_utils::create_test_provider_factory, HashingWriter}; use reth_trie::{proof::Proof, witness::TrieWitness, HashedPostState, HashedStorage, StateRoot}; @@ -91,3 +93,53 @@ fn includes_nodes_for_destroyed_storage_nodes() { assert_eq!(witness.get(&keccak256(node)), Some(node)); } } + +#[test] +fn correctly_decodes_branch_node_values() { + let factory = create_test_provider_factory(); + let provider = factory.provider_rw().unwrap(); + + let address = Address::random(); + let hashed_address = keccak256(address); + let hashed_slot1 = B256::with_last_byte(1); + let hashed_slot2 = B256::with_last_byte(2); + + // Insert account and slots into database + provider.insert_account_for_hashing([(address, Some(Account::default()))]).unwrap(); + let mut hashed_storage_cursor = + provider.tx_ref().cursor_dup_write::().unwrap(); + hashed_storage_cursor + .upsert(hashed_address, StorageEntry { key: hashed_slot1, value: U256::from(1) }) + .unwrap(); + hashed_storage_cursor + .upsert(hashed_address, StorageEntry { key: hashed_slot2, value: U256::from(1) }) + .unwrap(); + + let state_root = StateRoot::from_tx(provider.tx_ref()).root().unwrap(); + let multiproof = Proof::from_tx(provider.tx_ref()) + .multiproof(HashMap::from_iter([( + hashed_address, + HashSet::from_iter([hashed_slot1, hashed_slot2]), + )])) + .unwrap(); + + let witness = TrieWitness::from_tx(provider.tx_ref()) + .compute(HashedPostState { + accounts: HashMap::from([(hashed_address, Some(Account::default()))]), + storages: HashMap::from([( + hashed_address, + HashedStorage::from_iter( + false, + [hashed_slot1, hashed_slot2].map(|hashed_slot| (hashed_slot, U256::from(2))), + ), + )]), + }) + .unwrap(); + assert!(witness.contains_key(&state_root)); + for node in multiproof.account_subtree.values() { + assert_eq!(witness.get(&keccak256(node)), Some(node)); + } + for node in multiproof.storages.iter().flat_map(|(_, storage)| storage.subtree.values()) { + assert_eq!(witness.get(&keccak256(node)), Some(node)); + } +} diff --git a/crates/trie/trie/src/witness.rs b/crates/trie/trie/src/witness.rs index 971f10cfb..f3b70e85a 100644 --- a/crates/trie/trie/src/witness.rs +++ b/crates/trie/trie/src/witness.rs @@ -216,9 +216,14 @@ where TrieNode::Branch(branch) => { next_path.push(key[path.len()]); let children = branch_node_children(path.clone(), &branch); - for (child_path, node_hash) in children { + for (child_path, value) in children { if !key.starts_with(&child_path) { - trie_nodes.insert(child_path, Either::Left(node_hash)); + let value = if value.len() < B256::len_bytes() { + Either::Right(value.to_vec()) + } else { + Either::Left(B256::from_slice(&value[1..])) + }; + trie_nodes.insert(child_path, value); } } } @@ -312,8 +317,13 @@ where match TrieNode::decode(&mut &node[..])? { TrieNode::Branch(branch) => { let children = branch_node_children(path, &branch); - for (child_path, branch_hash) in children { - hash_builder.add_branch(child_path, branch_hash, false); + for (child_path, value) in children { + if value.len() < B256::len_bytes() { + hash_builder.add_leaf(child_path, value); + } else { + let hash = B256::from_slice(&value[1..]); + hash_builder.add_branch(child_path, hash, false); + } } break } @@ -343,14 +353,14 @@ where } /// Returned branch node children with keys in order. -fn branch_node_children(prefix: Nibbles, node: &BranchNode) -> Vec<(Nibbles, B256)> { +fn branch_node_children(prefix: Nibbles, node: &BranchNode) -> Vec<(Nibbles, &[u8])> { let mut children = Vec::with_capacity(node.state_mask.count_ones() as usize); let mut stack_ptr = node.as_ref().first_child_index(); for index in CHILD_INDEX_RANGE { if node.state_mask.is_bit_set(index) { let mut child_path = prefix.clone(); child_path.push(index); - children.push((child_path, B256::from_slice(&node.stack[stack_ptr][1..]))); + children.push((child_path, &node.stack[stack_ptr][..])); stack_ptr += 1; } }