diff --git a/crates/chain-state/src/in_memory.rs b/crates/chain-state/src/in_memory.rs index 00b4f1b7b..a16fdca43 100644 --- a/crates/chain-state/src/in_memory.rs +++ b/crates/chain-state/src/in_memory.rs @@ -563,7 +563,7 @@ mod tests { use reth_primitives::Receipt; fn create_mock_state(block_number: u64) -> BlockState { - BlockState::new(get_executed_block_with_number(block_number)) + BlockState::new(get_executed_block_with_number(block_number, B256::random())) } #[test] @@ -643,7 +643,7 @@ mod tests { #[test] fn test_state_new() { let number = rand::thread_rng().gen::(); - let block = get_executed_block_with_number(number); + let block = get_executed_block_with_number(number, B256::random()); let state = BlockState::new(block.clone()); @@ -653,7 +653,7 @@ mod tests { #[test] fn test_state_block() { let number = rand::thread_rng().gen::(); - let block = get_executed_block_with_number(number); + let block = get_executed_block_with_number(number, B256::random()); let state = BlockState::new(block.clone()); @@ -663,17 +663,17 @@ mod tests { #[test] fn test_state_hash() { let number = rand::thread_rng().gen::(); - let block = get_executed_block_with_number(number); + let block = get_executed_block_with_number(number, B256::random()); let state = BlockState::new(block.clone()); - assert_eq!(state.hash(), block.block().hash()); + assert_eq!(state.hash(), block.block.hash()); } #[test] fn test_state_number() { let number = rand::thread_rng().gen::(); - let block = get_executed_block_with_number(number); + let block = get_executed_block_with_number(number, B256::random()); let state = BlockState::new(block); @@ -683,7 +683,7 @@ mod tests { #[test] fn test_state_state_root() { let number = rand::thread_rng().gen::(); - let block = get_executed_block_with_number(number); + let block = get_executed_block_with_number(number, B256::random()); let state = BlockState::new(block.clone()); @@ -694,7 +694,7 @@ mod tests { fn test_state_receipts() { let receipts = Receipts { receipt_vec: vec![vec![Some(Receipt::default())]] }; - let block = get_executed_block_with_receipts(receipts.clone()); + let block = get_executed_block_with_receipts(receipts.clone(), B256::random()); let state = BlockState::new(block); @@ -704,8 +704,8 @@ mod tests { #[test] fn test_in_memory_state_chain_update() { let state = CanonicalInMemoryState::new(HashMap::new(), HashMap::new(), None); - let block1 = get_executed_block_with_number(0); - let block2 = get_executed_block_with_number(0); + let block1 = get_executed_block_with_number(0, B256::random()); + let block2 = get_executed_block_with_number(0, B256::random()); let chain = NewCanonicalChain::Commit { new: vec![block1.clone()] }; state.update_chain(chain); assert_eq!(state.head_state().unwrap().block().block().hash(), block1.block().hash()); diff --git a/crates/chain-state/src/test_utils.rs b/crates/chain-state/src/test_utils.rs index 23c2bf71f..4cb2d270a 100644 --- a/crates/chain-state/src/test_utils.rs +++ b/crates/chain-state/src/test_utils.rs @@ -6,6 +6,7 @@ use rand::Rng; use reth_execution_types::{Chain, ExecutionOutcome}; use reth_primitives::{ Address, Block, BlockNumber, Receipts, Requests, SealedBlockWithSenders, TransactionSigned, + B256, }; use reth_trie::{updates::TrieUpdates, HashedPostState}; use revm::db::BundleState; @@ -15,18 +16,22 @@ use std::{ }; use tokio::sync::broadcast::{self, Sender}; -fn get_executed_block(block_number: BlockNumber, receipts: Receipts) -> ExecutedBlock { +fn get_executed_block( + block_number: BlockNumber, + receipts: Receipts, + parent_hash: B256, +) -> ExecutedBlock { let mut block = Block::default(); let mut header = block.header.clone(); header.number = block_number; + header.parent_hash = parent_hash; + header.ommers_hash = B256::random(); block.header = header; - - let sender = Address::random(); let tx = TransactionSigned::default(); block.body.push(tx); let sealed = block.seal_slow(); + let sender = Address::random(); let sealed_with_senders = SealedBlockWithSenders::new(sealed.clone(), vec![sender]).unwrap(); - ExecutedBlock::new( Arc::new(sealed), Arc::new(sealed_with_senders.senders), @@ -42,20 +47,27 @@ fn get_executed_block(block_number: BlockNumber, receipts: Receipts) -> Executed } /// Generates an `ExecutedBlock` that includes the given `Receipts`. -pub fn get_executed_block_with_receipts(receipts: Receipts) -> ExecutedBlock { +pub fn get_executed_block_with_receipts(receipts: Receipts, parent_hash: B256) -> ExecutedBlock { let number = rand::thread_rng().gen::(); - - get_executed_block(number, receipts) + get_executed_block(number, receipts, parent_hash) } /// Generates an `ExecutedBlock` with the given `BlockNumber`. -pub fn get_executed_block_with_number(block_number: BlockNumber) -> ExecutedBlock { - get_executed_block(block_number, Receipts { receipt_vec: vec![vec![]] }) +pub fn get_executed_block_with_number( + block_number: BlockNumber, + parent_hash: B256, +) -> ExecutedBlock { + get_executed_block(block_number, Receipts { receipt_vec: vec![vec![]] }, parent_hash) } /// Generates a range of executed blocks with ascending block numbers. pub fn get_executed_blocks(range: Range) -> impl Iterator { - range.map(get_executed_block_with_number) + let mut parent_hash = B256::default(); + range.map(move |number| { + let block = get_executed_block_with_number(number, parent_hash); + parent_hash = block.block.hash(); + block + }) } /// A test `ChainEventSubscriptions` diff --git a/crates/engine/tree/src/persistence.rs b/crates/engine/tree/src/persistence.rs index 55d6a75b2..b85881491 100644 --- a/crates/engine/tree/src/persistence.rs +++ b/crates/engine/tree/src/persistence.rs @@ -442,7 +442,7 @@ mod tests { reth_tracing::init_test_tracing(); let persistence_handle = default_persistence_handle(); let block_number = 0; - let executed = get_executed_block_with_number(block_number); + let executed = get_executed_block_with_number(block_number, B256::random()); let block_hash = executed.block().hash(); let blocks = vec![executed]; diff --git a/crates/engine/tree/src/tree/mod.rs b/crates/engine/tree/src/tree/mod.rs index 68ece2d3c..b158654c2 100644 --- a/crates/engine/tree/src/tree/mod.rs +++ b/crates/engine/tree/src/tree/mod.rs @@ -40,7 +40,7 @@ use reth_rpc_types::{ use reth_stages_api::ControlFlow; use reth_trie::HashedPostState; use std::{ - collections::{BTreeMap, HashMap}, + collections::{BTreeMap, HashMap, HashSet}, sync::{mpsc::Receiver, Arc}, }; use tokio::sync::{ @@ -75,6 +75,8 @@ pub struct TreeState { blocks_by_number: BTreeMap>, /// Currently tracked canonical head of the chain. current_canonical_head: BlockNumHash, + /// Map of any parent block hash to its children. + parent_to_child: HashMap>, } impl TreeState { @@ -84,6 +86,7 @@ impl TreeState { blocks_by_hash: HashMap::new(), blocks_by_number: BTreeMap::new(), current_canonical_head, + parent_to_child: HashMap::new(), } } @@ -101,27 +104,55 @@ impl TreeState { /// Insert executed block into the state. fn insert_executed(&mut self, executed: ExecutedBlock) { - self.blocks_by_number.entry(executed.block.number).or_default().push(executed.clone()); - let existing = self.blocks_by_hash.insert(executed.block.hash(), executed); - debug_assert!(existing.is_none(), "inserted duplicate block"); + let hash = executed.block.hash(); + let parent_hash = executed.block.parent_hash; + let block_number = executed.block.number; + + if self.blocks_by_hash.contains_key(&hash) { + return; + } + + self.blocks_by_hash.insert(hash, executed.clone()); + + self.blocks_by_number.entry(block_number).or_default().push(executed); + + self.parent_to_child.entry(parent_hash).or_default().insert(hash); + + if let Some(existing_blocks) = self.blocks_by_number.get(&block_number) { + if existing_blocks.len() > 1 { + self.parent_to_child.entry(parent_hash).or_default().insert(hash); + } + } + + for children in self.parent_to_child.values_mut() { + children.retain(|child| self.blocks_by_hash.contains_key(child)); + } } /// Remove blocks before specified block number. pub(crate) fn remove_before(&mut self, block_number: BlockNumber) { - while self - .blocks_by_number - .first_key_value() - .map(|entry| entry.0 < &block_number) - .unwrap_or_default() - { - let (_, to_remove) = self.blocks_by_number.pop_first().unwrap(); - for block in to_remove { - let block_hash = block.block.hash(); - let removed = self.blocks_by_hash.remove(&block_hash); - debug_assert!( - removed.is_some(), - "attempted to remove non-existing block {block_hash}" - ); + let mut numbers_to_remove = Vec::new(); + for (&number, _) in self.blocks_by_number.range(..block_number) { + numbers_to_remove.push(number); + } + + for number in numbers_to_remove { + if let Some(blocks) = self.blocks_by_number.remove(&number) { + for block in blocks { + let block_hash = block.block.hash(); + self.blocks_by_hash.remove(&block_hash); + + if let Some(parent_children) = + self.parent_to_child.get_mut(&block.block.parent_hash) + { + parent_children.remove(&block_hash); + if parent_children.is_empty() { + self.parent_to_child.remove(&block.block.parent_hash); + } + } + + self.parent_to_child.remove(&block_hash); + } } } } @@ -155,33 +186,57 @@ impl TreeState { /// /// This also handles reorgs. fn on_new_head(&self, new_head: B256) -> Option { - let new_head_block = self.blocks_by_hash.get(&new_head).cloned()?; - let mut parent = new_head_block.block.num_hash(); - let mut new_chain = vec![new_head_block]; - let mut reorged = vec![]; + let mut new_chain = Vec::new(); + let mut current_hash = new_head; + let mut fork_point = None; // walk back the chain until we reach the canonical block - while parent.hash != self.canonical_block_hash() { - if parent.number == self.canonical_head().number { - // we have a reorg - todo!("handle reorg") + while current_hash != self.canonical_block_hash() { + let current_block = self.blocks_by_hash.get(¤t_hash)?; + new_chain.push(current_block.clone()); + + // check if this block's parent has multiple children + if let Some(children) = self.parent_to_child.get(¤t_block.block.parent_hash) { + if children.len() > 1 || + self.canonical_block_hash() == current_block.block.parent_hash + { + // we've found a fork point + fork_point = Some(current_block.block.parent_hash); + break; + } } - let parent_block = self.blocks_by_hash.get(&parent.hash).cloned()?; - parent = parent_block.block.num_hash(); - new_chain.push(parent_block); + + current_hash = current_block.block.parent_hash; } - // reverse the chains new_chain.reverse(); - reorged.reverse(); - let chain = if reorged.is_empty() { - NewCanonicalChain::Commit { new: new_chain } + // if we found a fork point, collect the reorged blocks + let reorged = if let Some(fork_hash) = fork_point { + let mut reorged = Vec::new(); + let mut current_hash = self.current_canonical_head.hash; + // walk back the chain up to the fork hash + while current_hash != fork_hash { + if let Some(block) = self.blocks_by_hash.get(¤t_hash) { + reorged.push(block.clone()); + current_hash = block.block.parent_hash; + } else { + // current hash not found in memory + warn!(target: "consensus::engine", invalid_hash=?current_hash, "Block not found in TreeState while walking back fork"); + return None; + } + } + reorged.reverse(); + reorged } else { - NewCanonicalChain::Reorg { new: new_chain, old: reorged } + Vec::new() }; - Some(chain) + if reorged.is_empty() { + Some(NewCanonicalChain::Commit { new: new_chain }) + } else { + Some(NewCanonicalChain::Reorg { new: new_chain, old: reorged }) + } } } @@ -1495,7 +1550,10 @@ mod tests { use crate::persistence::PersistenceAction; use alloy_rlp::Decodable; use reth_beacon_consensus::EthBeaconConsensus; - use reth_chain_state::{test_utils::get_executed_blocks, BlockState}; + use reth_chain_state::{ + test_utils::{get_executed_block_with_number, get_executed_blocks}, + BlockState, + }; use reth_chainspec::{ChainSpecBuilder, HOLESKY, MAINNET}; use reth_ethereum_engine_primitives::EthEngineTypes; use reth_evm::test_utils::MockExecutorProvider; @@ -1719,4 +1777,148 @@ mod tests { let resp = rx.await.unwrap().unwrap(); assert!(resp.is_syncing()); } + + #[tokio::test] + async fn test_tree_state_insert_executed() { + let mut tree_state = TreeState::new(BlockNumHash::default()); + let blocks: Vec<_> = get_executed_blocks(1..4).collect(); + + tree_state.insert_executed(blocks[0].clone()); + tree_state.insert_executed(blocks[1].clone()); + + assert_eq!( + tree_state.parent_to_child.get(&blocks[0].block.hash()), + Some(&HashSet::from([blocks[1].block.hash()])) + ); + + assert!(!tree_state.parent_to_child.contains_key(&blocks[1].block.hash())); + + tree_state.insert_executed(blocks[2].clone()); + + assert_eq!( + tree_state.parent_to_child.get(&blocks[1].block.hash()), + Some(&HashSet::from([blocks[2].block.hash()])) + ); + assert!(tree_state.parent_to_child.contains_key(&blocks[1].block.hash())); + + assert!(!tree_state.parent_to_child.contains_key(&blocks[2].block.hash())); + } + + #[tokio::test] + async fn test_tree_state_insert_executed_with_reorg() { + let mut tree_state = TreeState::new(BlockNumHash::default()); + let blocks: Vec<_> = get_executed_blocks(1..6).collect(); + + for block in &blocks { + tree_state.insert_executed(block.clone()); + } + assert_eq!(tree_state.blocks_by_hash.len(), 5); + + let fork_block_3 = get_executed_block_with_number(3, blocks[1].block.hash()); + let fork_block_4 = get_executed_block_with_number(4, fork_block_3.block.hash()); + let fork_block_5 = get_executed_block_with_number(5, fork_block_4.block.hash()); + + tree_state.insert_executed(fork_block_3.clone()); + tree_state.insert_executed(fork_block_4.clone()); + tree_state.insert_executed(fork_block_5.clone()); + + assert_eq!(tree_state.blocks_by_hash.len(), 8); + assert_eq!(tree_state.blocks_by_number[&3].len(), 2); // two blocks at height 3 (original and fork) + assert_eq!(tree_state.parent_to_child[&blocks[1].block.hash()].len(), 2); // block 2 should have two children + + // verify that we can insert the same block again without issues + tree_state.insert_executed(fork_block_4.clone()); + assert_eq!(tree_state.blocks_by_hash.len(), 8); + + assert!(tree_state.parent_to_child[&fork_block_3.block.hash()] + .contains(&fork_block_4.block.hash())); + assert!(tree_state.parent_to_child[&fork_block_4.block.hash()] + .contains(&fork_block_5.block.hash())); + + assert_eq!(tree_state.blocks_by_number[&4].len(), 2); + assert_eq!(tree_state.blocks_by_number[&5].len(), 2); + } + + #[tokio::test] + async fn test_tree_state_remove_before() { + let mut tree_state = TreeState::new(BlockNumHash::default()); + let blocks: Vec<_> = get_executed_blocks(1..6).collect(); + + for block in &blocks { + tree_state.insert_executed(block.clone()); + } + + tree_state.remove_before(3); + + assert!(!tree_state.blocks_by_hash.contains_key(&blocks[0].block.hash())); + assert!(!tree_state.blocks_by_hash.contains_key(&blocks[1].block.hash())); + assert!(!tree_state.blocks_by_number.contains_key(&1)); + assert!(!tree_state.blocks_by_number.contains_key(&2)); + + assert!(tree_state.blocks_by_hash.contains_key(&blocks[2].block.hash())); + assert!(tree_state.blocks_by_hash.contains_key(&blocks[3].block.hash())); + assert!(tree_state.blocks_by_hash.contains_key(&blocks[4].block.hash())); + assert!(tree_state.blocks_by_number.contains_key(&3)); + assert!(tree_state.blocks_by_number.contains_key(&4)); + assert!(tree_state.blocks_by_number.contains_key(&5)); + + assert!(!tree_state.parent_to_child.contains_key(&blocks[0].block.hash())); + assert!(!tree_state.parent_to_child.contains_key(&blocks[1].block.hash())); + assert!(tree_state.parent_to_child.contains_key(&blocks[2].block.hash())); + assert!(tree_state.parent_to_child.contains_key(&blocks[3].block.hash())); + assert!(!tree_state.parent_to_child.contains_key(&blocks[4].block.hash())); + + assert_eq!( + tree_state.parent_to_child.get(&blocks[2].block.hash()), + Some(&HashSet::from([blocks[3].block.hash()])) + ); + assert_eq!( + tree_state.parent_to_child.get(&blocks[3].block.hash()), + Some(&HashSet::from([blocks[4].block.hash()])) + ); + } + + #[tokio::test] + async fn test_tree_state_on_new_head() { + let mut tree_state = TreeState::new(BlockNumHash::default()); + let blocks: Vec<_> = get_executed_blocks(1..6).collect(); + + for block in &blocks { + tree_state.insert_executed(block.clone()); + } + + // set block 3 as the current canonical head + tree_state.set_canonical_head(blocks[2].block.num_hash()); + + // create a fork from block 2 + let fork_block_3 = get_executed_block_with_number(3, blocks[1].block.hash()); + let fork_block_4 = get_executed_block_with_number(4, fork_block_3.block.hash()); + let fork_block_5 = get_executed_block_with_number(5, fork_block_4.block.hash()); + + tree_state.insert_executed(fork_block_3.clone()); + tree_state.insert_executed(fork_block_4.clone()); + tree_state.insert_executed(fork_block_5.clone()); + + // normal (non-reorg) case + let result = tree_state.on_new_head(blocks[4].block.hash()); + assert!(matches!(result, Some(NewCanonicalChain::Commit { .. }))); + if let Some(NewCanonicalChain::Commit { new }) = result { + assert_eq!(new.len(), 2); + assert_eq!(new[0].block.hash(), blocks[3].block.hash()); + assert_eq!(new[1].block.hash(), blocks[4].block.hash()); + } + + // reorg case + let result = tree_state.on_new_head(fork_block_5.block.hash()); + assert!(matches!(result, Some(NewCanonicalChain::Reorg { .. }))); + if let Some(NewCanonicalChain::Reorg { new, old }) = result { + assert_eq!(new.len(), 3); + assert_eq!(new[0].block.hash(), fork_block_3.block.hash()); + assert_eq!(new[1].block.hash(), fork_block_4.block.hash()); + assert_eq!(new[2].block.hash(), fork_block_5.block.hash()); + + assert_eq!(old.len(), 1); + assert_eq!(old[0].block.hash(), blocks[2].block.hash()); + } + } }