feat: share SnapshotProvider through ProviderFactory (#5249)

Co-authored-by: Alexey Shekhirin <a.shekhirin@gmail.com>
This commit is contained in:
joshieDo
2023-11-14 17:50:12 +00:00
committed by GitHub
parent 8ecd90b884
commit d21e346c04
8 changed files with 146 additions and 56 deletions

View File

@ -15,6 +15,5 @@ mod snapshotter;
pub use error::SnapshotterError;
pub use snapshotter::{
HighestSnapshots, HighestSnapshotsTracker, SnapshotTargets, Snapshotter, SnapshotterResult,
SnapshotterWithResult,
HighestSnapshotsTracker, SnapshotTargets, Snapshotter, SnapshotterResult, SnapshotterWithResult,
};

View File

@ -3,7 +3,7 @@
use crate::SnapshotterError;
use reth_db::database::Database;
use reth_interfaces::{RethError, RethResult};
use reth_primitives::{BlockNumber, ChainSpec, TxNumber};
use reth_primitives::{snapshot::HighestSnapshots, BlockNumber, ChainSpec, TxNumber};
use reth_provider::{BlockReader, DatabaseProviderRO, ProviderFactory};
use std::{collections::HashMap, ops::RangeInclusive, sync::Arc};
use tokio::sync::watch;
@ -18,9 +18,14 @@ pub type SnapshotterWithResult<DB> = (Snapshotter<DB>, SnapshotterResult);
/// Snapshotting routine. Main snapshotting logic happens in [Snapshotter::run].
#[derive(Debug)]
pub struct Snapshotter<DB> {
/// Provider factory
provider_factory: ProviderFactory<DB>,
/// Highest snapshotted block numbers for each segment
highest_snapshots: HighestSnapshots,
highest_snapshots_tracker: watch::Sender<Option<HighestSnapshots>>,
/// Channel sender to notify other components of the new highest snapshots
highest_snapshots_notifier: watch::Sender<Option<HighestSnapshots>>,
/// Channel receiver to be cloned and shared that already comes with the newest value
highest_snapshots_tracker: HighestSnapshotsTracker,
/// Block interval after which the snapshot is taken.
block_interval: u64,
}
@ -28,20 +33,6 @@ pub struct Snapshotter<DB> {
/// Tracker for the latest [`HighestSnapshots`] value.
pub type HighestSnapshotsTracker = watch::Receiver<Option<HighestSnapshots>>;
/// Highest snapshotted block numbers, per data part.
#[derive(Debug, Clone, Copy, Default, Eq, PartialEq)]
pub struct HighestSnapshots {
/// Highest snapshotted block of headers, inclusive.
/// If [`None`], no snapshot is available.
pub headers: Option<BlockNumber>,
/// Highest snapshotted block of receipts, inclusive.
/// If [`None`], no snapshot is available.
pub receipts: Option<BlockNumber>,
/// Highest snapshotted block of transactions, inclusive.
/// If [`None`], no snapshot is available.
pub transactions: Option<BlockNumber>,
}
/// Snapshot targets, per data part, measured in [`BlockNumber`] and [`TxNumber`], if applicable.
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct SnapshotTargets {
@ -88,16 +79,14 @@ impl SnapshotTargets {
impl<DB: Database> Snapshotter<DB> {
/// Creates a new [Snapshotter].
pub fn new(
db: DB,
chain_spec: Arc<ChainSpec>,
block_interval: u64,
highest_snapshots_tracker: watch::Sender<Option<HighestSnapshots>>,
) -> Self {
pub fn new(db: DB, chain_spec: Arc<ChainSpec>, block_interval: u64) -> Self {
let (highest_snapshots_notifier, highest_snapshots_tracker) = watch::channel(None);
let snapshotter = Self {
provider_factory: ProviderFactory::new(db, chain_spec),
// TODO(alexey): fill from on-disk snapshot data
highest_snapshots: HighestSnapshots::default(),
highest_snapshots_notifier,
highest_snapshots_tracker,
block_interval,
};
@ -121,11 +110,16 @@ impl<DB: Database> Snapshotter<DB> {
}
fn update_highest_snapshots_tracker(&self) {
let _ = self.highest_snapshots_tracker.send(Some(self.highest_snapshots)).map_err(|_| {
let _ = self.highest_snapshots_notifier.send(Some(self.highest_snapshots)).map_err(|_| {
warn!(target: "snapshot", "Highest snapshots channel closed");
});
}
/// Returns a new [`HighestSnapshotsTracker`].
pub fn highest_snapshot_receiver(&self) -> HighestSnapshotsTracker {
self.highest_snapshots_tracker.clone()
}
/// Run the snapshotter
pub fn run(&mut self, targets: SnapshotTargets) -> SnapshotterResult {
debug_assert!(targets.is_multiple_of_block_interval(self.block_interval));
@ -240,25 +234,25 @@ impl<DB: Database> Snapshotter<DB> {
#[cfg(test)]
mod tests {
use crate::{snapshotter::SnapshotTargets, HighestSnapshots, Snapshotter};
use crate::{snapshotter::SnapshotTargets, Snapshotter};
use assert_matches::assert_matches;
use reth_interfaces::{
test_utils::{generators, generators::random_block_range},
RethError,
};
use reth_primitives::{B256, MAINNET};
use reth_primitives::{snapshot::HighestSnapshots, B256, MAINNET};
use reth_stages::test_utils::TestTransaction;
use tokio::sync::watch;
#[test]
fn new() {
let tx = TestTransaction::default();
let (highest_snapshots_tx, highest_snapshots_rx) = watch::channel(None);
assert_eq!(*highest_snapshots_rx.borrow(), None);
let snapshotter = Snapshotter::new(tx.inner_raw(), MAINNET.clone(), 2);
Snapshotter::new(tx.inner_raw(), MAINNET.clone(), 2, highest_snapshots_tx);
assert_eq!(*highest_snapshots_rx.borrow(), Some(HighestSnapshots::default()));
assert_eq!(
*snapshotter.highest_snapshot_receiver().borrow(),
Some(HighestSnapshots::default())
);
}
#[test]
@ -269,8 +263,7 @@ mod tests {
let blocks = random_block_range(&mut rng, 0..=3, B256::ZERO, 2..3);
tx.insert_blocks(blocks.iter(), None).expect("insert blocks");
let mut snapshotter =
Snapshotter::new(tx.inner_raw(), MAINNET.clone(), 2, watch::channel(None).0);
let mut snapshotter = Snapshotter::new(tx.inner_raw(), MAINNET.clone(), 2);
// Snapshot targets has data per part up to the passed finalized block number,
// respecting the block interval