diff --git a/crates/consensus/beacon/src/engine/forkchoice.rs b/crates/consensus/beacon/src/engine/forkchoice.rs index d8b3344f6..7ce42ce7e 100644 --- a/crates/consensus/beacon/src/engine/forkchoice.rs +++ b/crates/consensus/beacon/src/engine/forkchoice.rs @@ -75,11 +75,55 @@ impl ForkchoiceStateTracker { self.last_syncing.as_ref().map(|s| s.head_block_hash) } + /// Returns the latest received `ForkchoiceState`. + /// + /// Caution: this can be invalid. + pub const fn latest_state(&self) -> Option { + self.last_valid + } + + /// Returns the last valid `ForkchoiceState`. + pub const fn last_valid_state(&self) -> Option { + self.last_valid + } + + /// Returns the last valid finalized hash. + /// + /// This will return [`None`], if either there is no valid finalized forkchoice state, or the + /// finalized hash for the latest valid forkchoice state is zero. + #[inline] + pub fn last_valid_finalized(&self) -> Option { + self.last_valid.and_then(|state| { + // if the hash is zero then we should act like there is no finalized hash + if state.finalized_block_hash.is_zero() { + None + } else { + Some(state.finalized_block_hash) + } + }) + } + /// Returns the last received `ForkchoiceState` to which we need to sync. pub const fn sync_target_state(&self) -> Option { self.last_syncing } + /// Returns the sync target finalized hash. + /// + /// This will return [`None`], if either there is no sync target forkchoice state, or the + /// finalized hash for the sync target forkchoice state is zero. + #[inline] + pub fn sync_target_finalized(&self) -> Option { + self.last_syncing.and_then(|state| { + // if the hash is zero then we should act like there is no finalized hash + if state.finalized_block_hash.is_zero() { + None + } else { + Some(state.finalized_block_hash) + } + }) + } + /// Returns true if no forkchoice state has been received yet. pub const fn is_empty(&self) -> bool { self.latest.is_none() diff --git a/crates/engine/tree/src/tree/mod.rs b/crates/engine/tree/src/tree/mod.rs index 96af7f027..e9b9ade96 100644 --- a/crates/engine/tree/src/tree/mod.rs +++ b/crates/engine/tree/src/tree/mod.rs @@ -41,7 +41,7 @@ use reth_rpc_types::{ use reth_stages_api::ControlFlow; use reth_trie::HashedPostState; use std::{ - collections::{BTreeMap, HashMap, HashSet}, + collections::{btree_map, hash_map, BTreeMap, HashMap, HashSet, VecDeque}, ops::Bound, sync::{ mpsc::{Receiver, RecvError, RecvTimeoutError, Sender}, @@ -176,31 +176,112 @@ impl TreeState { true } - /// Remove all blocks up to __and including__ the given block number. - pub(crate) fn remove_before(&mut self, upper_bound: BlockNumber) { - let mut numbers_to_remove = Vec::new(); - for (&number, _) in - self.blocks_by_number.range((Bound::Unbounded, Bound::Included(upper_bound))) - { - numbers_to_remove.push(number); + /// Remove single executed block by its hash. + /// + /// ## Returns + /// + /// The removed block and the block hashes of its children. + fn remove_by_hash(&mut self, hash: B256) -> Option<(ExecutedBlock, HashSet)> { + let executed = self.blocks_by_hash.remove(&hash)?; + + // Remove this block from collection of children of its parent block. + let parent_entry = self.parent_to_child.entry(executed.block.parent_hash); + if let hash_map::Entry::Occupied(mut entry) = parent_entry { + entry.get_mut().remove(&hash); + + if entry.get().is_empty() { + entry.remove(); + } } - for number in numbers_to_remove { - if let Some(blocks) = self.blocks_by_number.remove(&number) { - for block in blocks { - let block_hash = block.block.hash(); - self.blocks_by_hash.remove(&block_hash); + // Remove point to children of this block. + let children = self.parent_to_child.remove(&hash).unwrap_or_default(); - if let Some(parent_children) = - self.parent_to_child.get_mut(&block.block.parent_hash) - { - parent_children.remove(&block_hash); - if parent_children.is_empty() { - self.parent_to_child.remove(&block.block.parent_hash); - } - } + // Remove this block from `blocks_by_number`. + let block_number_entry = self.blocks_by_number.entry(executed.block.number); + if let btree_map::Entry::Occupied(mut entry) = block_number_entry { + // We have to find the index of the block since it exists in a vec + if let Some(index) = entry.get().iter().position(|b| b.block.hash() == hash) { + entry.get_mut().swap_remove(index); - self.parent_to_child.remove(&block_hash); + // If there are no blocks left then remove the entry for this block + if entry.get().is_empty() { + entry.remove(); + } + } + } + + Some((executed, children)) + } + + /// Remove all blocks up to __and including__ the given block number. + /// + /// If a finalized hash is provided, the only non-canonical blocks which will be removed are + /// those which have a fork point at or below the finalized hash. + /// + /// Canonical blocks below the upper bound will still be removed. + /// + /// NOTE: This assumes that the `finalized_num` is below or equal to the `upper_bound` + pub(crate) fn remove_until( + &mut self, + upper_bound: BlockNumber, + finalized_num: Option, + ) { + debug_assert!(Some(upper_bound) >= finalized_num); + // We want to do two things: + // * remove canonical blocks that are persisted + // * remove forks whose root are below the finalized block + // We can do this in 2 steps: + // * remove all canonical blocks below the upper bound + // * fetch the number of the finalized hash, removing any sidechains that are __below__ the + // finalized block + + // TODO: move trie updates here + // First, let's walk back the canonical chain and remove canonical blocks lower than the + // upper bound + let mut current_block = self.current_canonical_head.hash; + while let Some(executed) = self.blocks_by_hash.get(¤t_block) { + current_block = executed.block.parent_hash; + if executed.block.number <= upper_bound { + self.remove_by_hash(executed.block.hash()); + } + } + + // Now, we have removed canonical blocks (assuming the upper bound is above the finalized + // block) and only have sidechains below the finalized block. + if let Some(finalized) = finalized_num { + // We remove disconnected sidechains in three steps: + // * first, remove everything with a block number __below__ the finalized block. + // * next, we populate a vec with parents __at__ the finalized block. + // * finally, we iterate through the vec, removing children until the vec is empty + // (BFS). + + // We _exclude_ the finalized block because we will be dealing with the blocks __at__ + // the finalized block later. + // TODO: remove trie updates whose root are below the finalized block + let blocks_to_remove = self + .blocks_by_number + .range((Bound::Unbounded, Bound::Excluded(finalized))) + .flat_map(|(_, blocks)| blocks.iter().map(|b| b.block.hash())) + .collect::>(); + for hash in blocks_to_remove { + self.remove_by_hash(hash); + } + + // The only blocks that exist at `finalized_num` now, are blocks in sidechains that + // should be removed. + // + // We first put their children into this vec. + // Then, we will iterate over them, removing them, adding their children, etc etc, + // until the vec is empty. + let mut blocks_to_remove = self + .blocks_by_number + .remove(&finalized) + .map(|blocks| blocks.into_iter().map(|e| e.block.hash()).collect::>()) + .unwrap_or_default(); + while let Some(block) = blocks_to_remove.pop_front() { + if let Some((_, children)) = self.remove_by_hash(block) { + blocks_to_remove.extend(children); } } } @@ -1015,7 +1096,10 @@ where // state house keeping after backfill sync // remove all executed blocks below the backfill height - self.state.tree_state.remove_before(backfill_height); + // + // We set the `finalized_num` to `Some(backfill_height)` to ensure we remove all state + // before that + self.state.tree_state.remove_until(backfill_height, Some(backfill_height)); self.metrics.executed_blocks.set(self.state.tree_state.block_count() as f64); // remove all buffered blocks below the backfill height @@ -1189,7 +1273,9 @@ where /// /// Assumes that `finish` has been called on the `persistence_state` at least once fn on_new_persisted_block(&mut self) { - self.state.tree_state.remove_before(self.persistence_state.last_persisted_block_number); + let finalized = self.state.forkchoice_state_tracker.last_valid_finalized(); + self.remove_before(self.persistence_state.last_persisted_block_number, finalized) + .expect("todo: error handling"); self.canonical_in_memory_state .remove_persisted_blocks(self.persistence_state.last_persisted_block_number); } @@ -2009,6 +2095,26 @@ where Err(_) => OnForkChoiceUpdated::invalid_payload_attributes(), } } + + /// Remove all blocks up to __and including__ the given block number. + /// + /// If a finalized hash is provided, the only non-canonical blocks which will be removed are + /// those which have a fork point at or below the finalized hash. + /// + /// Canonical blocks below the upper bound will still be removed. + pub(crate) fn remove_before( + &mut self, + upper_bound: BlockNumber, + finalized_hash: Option, + ) -> ProviderResult<()> { + // first fetch the finalized block number and then call the remove_before method on + // tree_state + let num = + if let Some(hash) = finalized_hash { self.provider.block_number(hash)? } else { None }; + + self.state.tree_state.remove_until(upper_bound, num); + Ok(()) + } } /// The state of the persistence task. @@ -2650,8 +2756,103 @@ mod tests { tree_state.insert_executed(block.clone()); } + let last = blocks.last().unwrap(); + + // set the canonical head + tree_state.set_canonical_head(last.block.num_hash()); + // inclusive bound, so we should remove anything up to and including 2 - tree_state.remove_before(2); + tree_state.remove_until(2, Some(2)); + + assert!(!tree_state.blocks_by_hash.contains_key(&blocks[0].block.hash())); + assert!(!tree_state.blocks_by_hash.contains_key(&blocks[1].block.hash())); + assert!(!tree_state.blocks_by_number.contains_key(&1)); + assert!(!tree_state.blocks_by_number.contains_key(&2)); + + assert!(tree_state.blocks_by_hash.contains_key(&blocks[2].block.hash())); + assert!(tree_state.blocks_by_hash.contains_key(&blocks[3].block.hash())); + assert!(tree_state.blocks_by_hash.contains_key(&blocks[4].block.hash())); + assert!(tree_state.blocks_by_number.contains_key(&3)); + assert!(tree_state.blocks_by_number.contains_key(&4)); + assert!(tree_state.blocks_by_number.contains_key(&5)); + + assert!(!tree_state.parent_to_child.contains_key(&blocks[0].block.hash())); + assert!(!tree_state.parent_to_child.contains_key(&blocks[1].block.hash())); + assert!(tree_state.parent_to_child.contains_key(&blocks[2].block.hash())); + assert!(tree_state.parent_to_child.contains_key(&blocks[3].block.hash())); + assert!(!tree_state.parent_to_child.contains_key(&blocks[4].block.hash())); + + assert_eq!( + tree_state.parent_to_child.get(&blocks[2].block.hash()), + Some(&HashSet::from([blocks[3].block.hash()])) + ); + assert_eq!( + tree_state.parent_to_child.get(&blocks[3].block.hash()), + Some(&HashSet::from([blocks[4].block.hash()])) + ); + } + + #[tokio::test] + async fn test_tree_state_remove_before_finalized() { + let mut tree_state = TreeState::new(BlockNumHash::default()); + let blocks: Vec<_> = TestBlockBuilder::default().get_executed_blocks(1..6).collect(); + + for block in &blocks { + tree_state.insert_executed(block.clone()); + } + + let last = blocks.last().unwrap(); + + // set the canonical head + tree_state.set_canonical_head(last.block.num_hash()); + + // we should still remove everything up to and including 2 + tree_state.remove_until(2, None); + + assert!(!tree_state.blocks_by_hash.contains_key(&blocks[0].block.hash())); + assert!(!tree_state.blocks_by_hash.contains_key(&blocks[1].block.hash())); + assert!(!tree_state.blocks_by_number.contains_key(&1)); + assert!(!tree_state.blocks_by_number.contains_key(&2)); + + assert!(tree_state.blocks_by_hash.contains_key(&blocks[2].block.hash())); + assert!(tree_state.blocks_by_hash.contains_key(&blocks[3].block.hash())); + assert!(tree_state.blocks_by_hash.contains_key(&blocks[4].block.hash())); + assert!(tree_state.blocks_by_number.contains_key(&3)); + assert!(tree_state.blocks_by_number.contains_key(&4)); + assert!(tree_state.blocks_by_number.contains_key(&5)); + + assert!(!tree_state.parent_to_child.contains_key(&blocks[0].block.hash())); + assert!(!tree_state.parent_to_child.contains_key(&blocks[1].block.hash())); + assert!(tree_state.parent_to_child.contains_key(&blocks[2].block.hash())); + assert!(tree_state.parent_to_child.contains_key(&blocks[3].block.hash())); + assert!(!tree_state.parent_to_child.contains_key(&blocks[4].block.hash())); + + assert_eq!( + tree_state.parent_to_child.get(&blocks[2].block.hash()), + Some(&HashSet::from([blocks[3].block.hash()])) + ); + assert_eq!( + tree_state.parent_to_child.get(&blocks[3].block.hash()), + Some(&HashSet::from([blocks[4].block.hash()])) + ); + } + + #[tokio::test] + async fn test_tree_state_remove_before_lower_finalized() { + let mut tree_state = TreeState::new(BlockNumHash::default()); + let blocks: Vec<_> = TestBlockBuilder::default().get_executed_blocks(1..6).collect(); + + for block in &blocks { + tree_state.insert_executed(block.clone()); + } + + let last = blocks.last().unwrap(); + + // set the canonical head + tree_state.set_canonical_head(last.block.num_hash()); + + // we have no forks so we should still remove anything up to and including 2 + tree_state.remove_until(2, Some(1)); assert!(!tree_state.blocks_by_hash.contains_key(&blocks[0].block.hash())); assert!(!tree_state.blocks_by_hash.contains_key(&blocks[1].block.hash()));