From 7b781eb60271cebc047a183ab79ea4ed8a9ad246 Mon Sep 17 00:00:00 2001 From: joshieDo <93316087+joshieDo@users.noreply.github.com> Date: Tue, 14 Nov 2023 20:54:13 +0000 Subject: [PATCH] feat: add directory paths to `Snapshotter` and `SnapshotProvider` (#5283) Co-authored-by: Alexey Shekhirin Co-authored-by: Matthias Seitz --- Cargo.lock | 1 + bin/reth/src/db/snapshots/headers.rs | 34 ++++--- bin/reth/src/db/snapshots/mod.rs | 7 +- bin/reth/src/db/snapshots/receipts.rs | 34 ++++--- bin/reth/src/db/snapshots/transactions.rs | 33 ++++--- bin/reth/src/node/mod.rs | 5 +- crates/interfaces/src/error.rs | 6 ++ crates/primitives/src/fs.rs | 32 ++++++- crates/primitives/src/snapshot/compression.rs | 8 +- crates/primitives/src/snapshot/filters.rs | 9 +- crates/primitives/src/snapshot/segment.rs | 88 +++++++++++-------- crates/snapshot/Cargo.toml | 2 +- crates/snapshot/src/segments/mod.rs | 4 +- crates/snapshot/src/snapshotter.rs | 69 ++++++++++++--- .../provider/src/providers/database/mod.rs | 5 +- .../src/providers/snapshot/manager.rs | 11 ++- 16 files changed, 247 insertions(+), 101 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5ab359f77..934b8fb06 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6505,6 +6505,7 @@ dependencies = [ "reth-primitives", "reth-provider", "reth-stages", + "tempfile", "thiserror", "tokio", "tracing", diff --git a/bin/reth/src/db/snapshots/headers.rs b/bin/reth/src/db/snapshots/headers.rs index 6533dd881..b09b99ebc 100644 --- a/bin/reth/src/db/snapshots/headers.rs +++ b/bin/reth/src/db/snapshots/headers.rs @@ -12,7 +12,7 @@ use reth_primitives::{ use reth_provider::{ providers::SnapshotProvider, DatabaseProviderRO, HeaderProvider, ProviderError, ProviderFactory, }; -use reth_snapshot::segments::{Headers, Segment}; +use reth_snapshot::{segments, segments::Segment}; use std::{path::Path, sync::Arc}; impl Command { @@ -23,15 +23,22 @@ impl Command { inclusion_filter: InclusionFilter, phf: PerfectHashingFunction, ) -> eyre::Result<()> { - let segment = Headers::new( - compression, - if self.with_filters { - Filters::WithFilters(inclusion_filter, phf) - } else { - Filters::WithoutFilters - }, - ); - segment.snapshot::(provider, self.from..=(self.from + self.block_interval - 1))?; + let range = self.block_range(); + let filters = if self.with_filters { + Filters::WithFilters(inclusion_filter, phf) + } else { + Filters::WithoutFilters + }; + + let segment = segments::Headers::new(compression, filters); + + segment.snapshot::(provider, range.clone())?; + + // Default name doesn't have any configuration + reth_primitives::fs::rename( + SnapshotSegment::Headers.filename(&range), + SnapshotSegment::Headers.filename_with_configuration(filters, compression, &range), + )?; Ok(()) } @@ -51,12 +58,13 @@ impl Command { Filters::WithoutFilters }; - let range = self.from..=(self.from + self.block_interval - 1); + let range = self.block_range(); let mut row_indexes = range.clone().collect::>(); let mut rng = rand::thread_rng(); - let path = - SnapshotSegment::Headers.filename_with_configuration(filters, compression, &range); + let path = SnapshotSegment::Headers + .filename_with_configuration(filters, compression, &range) + .into(); let provider = SnapshotProvider::default(); let jar_provider = provider.get_segment_provider(SnapshotSegment::Headers, self.from, Some(path))?; diff --git a/bin/reth/src/db/snapshots/mod.rs b/bin/reth/src/db/snapshots/mod.rs index efce48783..80f0813c5 100644 --- a/bin/reth/src/db/snapshots/mod.rs +++ b/bin/reth/src/db/snapshots/mod.rs @@ -7,7 +7,7 @@ use reth_primitives::{ BlockNumber, ChainSpec, SnapshotSegment, }; use reth_provider::ProviderFactory; -use std::{path::Path, sync::Arc}; +use std::{ops::RangeInclusive, path::Path, sync::Arc}; mod bench; mod headers; @@ -130,4 +130,9 @@ impl Command { Ok(()) } + + /// Gives out the inclusive block range for the snapshot requested by the user. + fn block_range(&self) -> RangeInclusive { + self.from..=(self.from + self.block_interval - 1) + } } diff --git a/bin/reth/src/db/snapshots/receipts.rs b/bin/reth/src/db/snapshots/receipts.rs index ffe09814e..b0475eeff 100644 --- a/bin/reth/src/db/snapshots/receipts.rs +++ b/bin/reth/src/db/snapshots/receipts.rs @@ -24,15 +24,22 @@ impl Command { inclusion_filter: InclusionFilter, phf: PerfectHashingFunction, ) -> eyre::Result<()> { - let segment = segments::Receipts::new( - compression, - if self.with_filters { - Filters::WithFilters(inclusion_filter, phf) - } else { - Filters::WithoutFilters - }, - ); - segment.snapshot::(provider, self.from..=(self.from + self.block_interval - 1))?; + let range = self.block_range(); + let filters = if self.with_filters { + Filters::WithFilters(inclusion_filter, phf) + } else { + Filters::WithoutFilters + }; + + let segment = segments::Receipts::new(compression, filters); + + segment.snapshot::(provider, range.clone())?; + + // Default name doesn't have any configuration + reth_primitives::fs::rename( + SnapshotSegment::Receipts.filename(&range), + SnapshotSegment::Receipts.filename_with_configuration(filters, compression, &range), + )?; Ok(()) } @@ -62,11 +69,10 @@ impl Command { let mut row_indexes = tx_range.clone().collect::>(); - let path = SnapshotSegment::Receipts.filename_with_configuration( - filters, - compression, - &block_range, - ); + let path = SnapshotSegment::Receipts + .filename_with_configuration(filters, compression, &block_range) + .into(); + let provider = SnapshotProvider::default(); let jar_provider = provider.get_segment_provider(SnapshotSegment::Receipts, self.from, Some(path))?; diff --git a/bin/reth/src/db/snapshots/transactions.rs b/bin/reth/src/db/snapshots/transactions.rs index a52c33ddb..9d3530d40 100644 --- a/bin/reth/src/db/snapshots/transactions.rs +++ b/bin/reth/src/db/snapshots/transactions.rs @@ -24,15 +24,22 @@ impl Command { inclusion_filter: InclusionFilter, phf: PerfectHashingFunction, ) -> eyre::Result<()> { - let segment = segments::Transactions::new( - compression, - if self.with_filters { - Filters::WithFilters(inclusion_filter, phf) - } else { - Filters::WithoutFilters - }, - ); - segment.snapshot::(provider, self.from..=(self.from + self.block_interval - 1))?; + let range = self.block_range(); + let filters = if self.with_filters { + Filters::WithFilters(inclusion_filter, phf) + } else { + Filters::WithoutFilters + }; + + let segment = segments::Transactions::new(compression, filters); + + segment.snapshot::(provider, range.clone())?; + + // Default name doesn't have any configuration + reth_primitives::fs::rename( + SnapshotSegment::Transactions.filename(&range), + SnapshotSegment::Transactions.filename_with_configuration(filters, compression, &range), + )?; Ok(()) } @@ -62,11 +69,9 @@ impl Command { let mut row_indexes = tx_range.clone().collect::>(); - let path = SnapshotSegment::Transactions.filename_with_configuration( - filters, - compression, - &block_range, - ); + let path = SnapshotSegment::Transactions + .filename_with_configuration(filters, compression, &block_range) + .into(); let provider = SnapshotProvider::default(); let jar_provider = provider.get_segment_provider(SnapshotSegment::Transactions, self.from, Some(path))?; diff --git a/bin/reth/src/node/mod.rs b/bin/reth/src/node/mod.rs index 0dc2cd1ea..3e6b9cd6a 100644 --- a/bin/reth/src/node/mod.rs +++ b/bin/reth/src/node/mod.rs @@ -301,13 +301,14 @@ impl NodeCommand { // configure snapshotter let snapshotter = reth_snapshot::Snapshotter::new( db.clone(), + data_dir.snapshots_path(), self.chain.clone(), self.chain.snapshot_block_interval, - ); + )?; // setup the blockchain provider let factory = ProviderFactory::new(Arc::clone(&db), Arc::clone(&self.chain)) - .with_snapshots(snapshotter.highest_snapshot_receiver()); + .with_snapshots(data_dir.snapshots_path(), snapshotter.highest_snapshot_receiver()); let blockchain_db = BlockchainProvider::new(factory, blockchain_tree.clone())?; let blob_store = InMemoryBlobStore::default(); let validator = TransactionValidationTaskExecutor::eth_builder(Arc::clone(&self.chain)) diff --git a/crates/interfaces/src/error.rs b/crates/interfaces/src/error.rs index b972124fd..e40a1abd5 100644 --- a/crates/interfaces/src/error.rs +++ b/crates/interfaces/src/error.rs @@ -39,6 +39,12 @@ impl From for RethError { } } +impl From for RethError { + fn from(err: reth_primitives::fs::FsPathError) -> Self { + RethError::Custom(err.to_string()) + } +} + // We don't want these types to be too large because they're used in a lot of places. const _SIZE_ASSERTIONS: () = { // Main error. diff --git a/crates/primitives/src/fs.rs b/crates/primitives/src/fs.rs index f31b279c5..8e4e50acd 100644 --- a/crates/primitives/src/fs.rs +++ b/crates/primitives/src/fs.rs @@ -1,6 +1,7 @@ //! Wrapper for `std::fs` methods use std::{ - fs, io, + fs::{self, ReadDir}, + io, path::{Path, PathBuf}, }; @@ -30,6 +31,12 @@ pub enum FsPathError { /// Provides additional path context for [`std::fs::remove_dir`]. #[error("failed to remove dir {path:?}: {source}")] RemoveDir { source: io::Error, path: PathBuf }, + /// Provides additional path context for [`std::fs::read_dir`]. + #[error("failed to read dir {path:?}: {source}")] + ReadDir { source: io::Error, path: PathBuf }, + /// Provides additional context for [`std::fs::rename`]. + #[error("failed to rename {from:?} to {to:?}: {source}")] + Rename { source: io::Error, from: PathBuf, to: PathBuf }, /// Provides additional path context for [`std::fs::File::open`]. #[error("failed to open file {path:?}: {source}")] Open { source: io::Error, path: PathBuf }, @@ -77,10 +84,20 @@ impl FsPathError { FsPathError::RemoveDir { source, path: path.into() } } + /// Returns the complementary error variant for [`std::fs::read_dir`]. + pub fn read_dir(source: io::Error, path: impl Into) -> Self { + FsPathError::ReadDir { source, path: path.into() } + } + /// Returns the complementary error variant for [`std::fs::File::open`]. pub fn open(source: io::Error, path: impl Into) -> Self { FsPathError::Open { source, path: path.into() } } + + /// Returns the complementary error variant for [`std::fs::rename`]. + pub fn rename(source: io::Error, from: impl Into, to: impl Into) -> Self { + FsPathError::Rename { source, from: from.into(), to: to.into() } + } } type Result = std::result::Result; @@ -108,3 +125,16 @@ pub fn create_dir_all(path: impl AsRef) -> Result<()> { let path = path.as_ref(); fs::create_dir_all(path).map_err(|err| FsPathError::create_dir(err, path)) } + +/// Wrapper for `std::fs::read_dir` +pub fn read_dir(path: impl AsRef) -> Result { + let path = path.as_ref(); + fs::read_dir(path).map_err(|err| FsPathError::read_dir(err, path)) +} + +/// Wrapper for `std::fs::rename` +pub fn rename(from: impl AsRef, to: impl AsRef) -> Result<()> { + let from = from.as_ref(); + let to = to.as_ref(); + fs::rename(from, to).map_err(|err| FsPathError::rename(err, from, to)) +} diff --git a/crates/primitives/src/snapshot/compression.rs b/crates/primitives/src/snapshot/compression.rs index c67e3f63b..69fe4b2a4 100644 --- a/crates/primitives/src/snapshot/compression.rs +++ b/crates/primitives/src/snapshot/compression.rs @@ -1,11 +1,17 @@ -#[derive(Debug, Copy, Clone, Default)] +use strum::AsRefStr; + +#[derive(Debug, Copy, Clone, Default, AsRefStr)] #[cfg_attr(feature = "clap", derive(clap::ValueEnum))] #[allow(missing_docs)] /// Snapshot compression pub enum Compression { + #[strum(serialize = "lz4")] Lz4, + #[strum(serialize = "zstd")] Zstd, + #[strum(serialize = "zstd-dict")] ZstdWithDictionary, + #[strum(serialize = "uncompressed")] #[default] Uncompressed, } diff --git a/crates/primitives/src/snapshot/filters.rs b/crates/primitives/src/snapshot/filters.rs index e9716ac70..3443d4747 100644 --- a/crates/primitives/src/snapshot/filters.rs +++ b/crates/primitives/src/snapshot/filters.rs @@ -1,3 +1,5 @@ +use strum::AsRefStr; + #[derive(Debug, Copy, Clone)] /// Snapshot filters. pub enum Filters { @@ -14,20 +16,23 @@ impl Filters { } } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, AsRefStr)] #[cfg_attr(feature = "clap", derive(clap::ValueEnum))] /// Snapshot inclusion filter. Also see [Filters]. pub enum InclusionFilter { + #[strum(serialize = "cuckoo")] /// Cuckoo filter Cuckoo, } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, AsRefStr)] #[cfg_attr(feature = "clap", derive(clap::ValueEnum))] /// Snapshot perfect hashing function. Also see [Filters]. pub enum PerfectHashingFunction { + #[strum(serialize = "fmph")] /// Fingerprint-Based Minimal Perfect Hash Function Fmph, + #[strum(serialize = "gofmph")] /// Fingerprint-Based Minimal Perfect Hash Function with Group Optimization GoFmph, } diff --git a/crates/primitives/src/snapshot/segment.rs b/crates/primitives/src/snapshot/segment.rs index 8a86768ed..309d2c4ba 100644 --- a/crates/primitives/src/snapshot/segment.rs +++ b/crates/primitives/src/snapshot/segment.rs @@ -1,24 +1,42 @@ -use crate::{snapshot::PerfectHashingFunction, BlockNumber, TxNumber}; +use crate::{ + snapshot::{Compression, Filters, InclusionFilter}, + BlockNumber, TxNumber, +}; use serde::{Deserialize, Serialize}; -use std::{ops::RangeInclusive, path::PathBuf}; +use std::{ops::RangeInclusive, str::FromStr}; +use strum::{AsRefStr, EnumString}; -use super::{Compression, Filters, InclusionFilter}; - -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd, Deserialize, Serialize)] +#[derive( + Debug, + Copy, + Clone, + Eq, + PartialEq, + Hash, + Ord, + PartialOrd, + Deserialize, + Serialize, + EnumString, + AsRefStr, +)] #[cfg_attr(feature = "clap", derive(clap::ValueEnum))] /// Segment of the data that can be snapshotted. pub enum SnapshotSegment { + #[strum(serialize = "headers")] /// Snapshot segment responsible for the `CanonicalHeaders`, `Headers`, `HeaderTD` tables. Headers, + #[strum(serialize = "transactions")] /// Snapshot segment responsible for the `Transactions` table. Transactions, + #[strum(serialize = "receipts")] /// Snapshot segment responsible for the `Receipts` table. Receipts, } impl SnapshotSegment { /// Returns the default configuration of the segment. - const fn config(&self) -> (Filters, Compression) { + pub const fn config(&self) -> (Filters, Compression) { let default_config = ( Filters::WithFilters(InclusionFilter::Cuckoo, super::PerfectHashingFunction::Fmph), Compression::Lz4, @@ -32,49 +50,47 @@ impl SnapshotSegment { } /// Returns the default file name for the provided segment and range. - pub fn filename(&self, range: &RangeInclusive) -> PathBuf { - let (filters, compression) = self.config(); - self.filename_with_configuration(filters, compression, range) + pub fn filename(&self, range: &RangeInclusive) -> String { + // ATTENTION: if changing the name format, be sure to reflect those changes in + // [`Self::parse_filename`]. + format!("snapshot_{}_{}_{}", self.as_ref(), range.start(), range.end(),) } - /// Returns file name for the provided segment, filters, compression and range. + /// Returns file name for the provided segment and range, alongisde filters, compression. pub fn filename_with_configuration( &self, filters: Filters, compression: Compression, range: &RangeInclusive, - ) -> PathBuf { - let segment_name = match self { - SnapshotSegment::Headers => "headers", - SnapshotSegment::Transactions => "transactions", - SnapshotSegment::Receipts => "receipts", - }; + ) -> String { + let prefix = self.filename(range); + let filters_name = match filters { Filters::WithFilters(inclusion_filter, phf) => { - let inclusion_filter = match inclusion_filter { - InclusionFilter::Cuckoo => "cuckoo", - }; - let phf = match phf { - PerfectHashingFunction::Fmph => "fmph", - PerfectHashingFunction::GoFmph => "gofmph", - }; - format!("{inclusion_filter}-{phf}") + format!("{}-{}", inclusion_filter.as_ref(), phf.as_ref()) } Filters::WithoutFilters => "none".to_string(), }; - let compression_name = match compression { - Compression::Lz4 => "lz4", - Compression::Zstd => "zstd", - Compression::ZstdWithDictionary => "zstd-dict", - Compression::Uncompressed => "uncompressed", - }; - format!( - "snapshot_{segment_name}_{}_{}_{filters_name}_{compression_name}", - range.start(), - range.end(), - ) - .into() + // ATTENTION: if changing the name format, be sure to reflect those changes in + // [`Self::parse_filename`.] + format!("{prefix}_{}_{}", filters_name, compression.as_ref()) + } + + /// Takes a filename and parses the [`SnapshotSegment`] and its inclusive range. + pub fn parse_filename(name: &str) -> Option<(Self, RangeInclusive)> { + let parts: Vec<&str> = name.split('_').collect(); + if let (Ok(segment), true) = (Self::from_str(parts[1]), parts.len() >= 4) { + let start: u64 = parts[2].parse().unwrap_or(0); + let end: u64 = parts[3].parse().unwrap_or(0); + + if start <= end || parts[0] != "snapshot" { + return None + } + + return Some((segment, start..=end)) + } + None } } diff --git a/crates/snapshot/Cargo.toml b/crates/snapshot/Cargo.toml index 76b48e680..0eed67321 100644 --- a/crates/snapshot/Cargo.toml +++ b/crates/snapshot/Cargo.toml @@ -30,7 +30,7 @@ reth-db = { workspace = true, features = ["test-utils"] } reth-stages = { workspace = true, features = ["test-utils"] } # misc - +tempfile.workspace = true assert_matches.workspace = true [features] diff --git a/crates/snapshot/src/segments/mod.rs b/crates/snapshot/src/segments/mod.rs index 9a8bb4627..6293c3896 100644 --- a/crates/snapshot/src/segments/mod.rs +++ b/crates/snapshot/src/segments/mod.rs @@ -19,7 +19,7 @@ use reth_primitives::{ BlockNumber, SnapshotSegment, }; use reth_provider::{DatabaseProviderRO, TransactionsProviderExt}; -use std::ops::RangeInclusive; +use std::{ops::RangeInclusive, path::Path}; pub(crate) type Rows = [Vec>; COLUMNS]; @@ -61,7 +61,7 @@ pub(crate) fn prepare_jar( let tx_range = provider.transaction_range_by_block_range(block_range.clone())?; let mut nippy_jar = NippyJar::new( COLUMNS, - &segment.filename_with_configuration(filters, compression, &block_range), + Path::new(segment.filename(&block_range).as_str()), SegmentHeader::new(block_range, tx_range, segment), ); diff --git a/crates/snapshot/src/snapshotter.rs b/crates/snapshot/src/snapshotter.rs index 6bc722f0f..c8790d336 100644 --- a/crates/snapshot/src/snapshotter.rs +++ b/crates/snapshot/src/snapshotter.rs @@ -3,9 +3,11 @@ use crate::SnapshotterError; use reth_db::database::Database; use reth_interfaces::{RethError, RethResult}; -use reth_primitives::{snapshot::HighestSnapshots, BlockNumber, ChainSpec, TxNumber}; +use reth_primitives::{ + snapshot::HighestSnapshots, BlockNumber, ChainSpec, SnapshotSegment, TxNumber, +}; use reth_provider::{BlockReader, DatabaseProviderRO, ProviderFactory}; -use std::{collections::HashMap, ops::RangeInclusive, sync::Arc}; +use std::{collections::HashMap, ops::RangeInclusive, path::PathBuf, sync::Arc}; use tokio::sync::watch; use tracing::warn; @@ -20,6 +22,8 @@ pub type SnapshotterWithResult = (Snapshotter, SnapshotterResult); pub struct Snapshotter { /// Provider factory provider_factory: ProviderFactory, + /// Directory where snapshots are located + snapshots_path: PathBuf, /// Highest snapshotted block numbers for each segment highest_snapshots: HighestSnapshots, /// Channel sender to notify other components of the new highest snapshots @@ -79,11 +83,22 @@ impl SnapshotTargets { impl Snapshotter { /// Creates a new [Snapshotter]. - pub fn new(db: DB, chain_spec: Arc, block_interval: u64) -> Self { + pub fn new( + db: DB, + snapshots_path: PathBuf, + chain_spec: Arc, + block_interval: u64, + ) -> RethResult { + // Create directory for snapshots if it doesn't exist. + if !snapshots_path.exists() { + reth_primitives::fs::create_dir_all(&snapshots_path)?; + } + let (highest_snapshots_notifier, highest_snapshots_tracker) = watch::channel(None); - let snapshotter = Self { + let mut snapshotter = Self { provider_factory: ProviderFactory::new(db, chain_spec), + snapshots_path, // TODO(alexey): fill from on-disk snapshot data highest_snapshots: HighestSnapshots::default(), highest_snapshots_notifier, @@ -91,9 +106,9 @@ impl Snapshotter { block_interval, }; - snapshotter.update_highest_snapshots_tracker(); + snapshotter.update_highest_snapshots_tracker()?; - snapshotter + Ok(snapshotter) } #[cfg(test)] @@ -109,10 +124,37 @@ impl Snapshotter { } } - fn update_highest_snapshots_tracker(&self) { + /// Looks into the snapshot directory to find the highest snapshotted block of each segment, and + /// notifies every tracker. + fn update_highest_snapshots_tracker(&mut self) -> RethResult<()> { + // It walks over the directory and parses the snapshot filenames extracting + // `SnapshotSegment` and their inclusive range. It then takes the maximum block + // number for each specific segment. + for (segment, range) in reth_primitives::fs::read_dir(&self.snapshots_path)? + .filter_map(Result::ok) + .filter_map(|entry| { + if let Ok(true) = entry.metadata().map(|metadata| metadata.is_file()) { + return SnapshotSegment::parse_filename(&entry.file_name().to_string_lossy()) + } + None + }) + { + let max_segment_block = match segment { + SnapshotSegment::Headers => &mut self.highest_snapshots.headers, + SnapshotSegment::Transactions => &mut self.highest_snapshots.transactions, + SnapshotSegment::Receipts => &mut self.highest_snapshots.receipts, + }; + + if max_segment_block.map_or(true, |block| block < *range.end()) { + *max_segment_block = Some(*range.end()); + } + } + let _ = self.highest_snapshots_notifier.send(Some(self.highest_snapshots)).map_err(|_| { warn!(target: "snapshot", "Highest snapshots channel closed"); }); + + Ok(()) } /// Returns a new [`HighestSnapshotsTracker`]. @@ -127,7 +169,7 @@ impl Snapshotter { // TODO(alexey): snapshot logic - self.update_highest_snapshots_tracker(); + self.update_highest_snapshots_tracker()?; Ok(targets) } @@ -246,8 +288,10 @@ mod tests { #[test] fn new() { let tx = TestTransaction::default(); - - let snapshotter = Snapshotter::new(tx.inner_raw(), MAINNET.clone(), 2); + let snapshots_dir = tempfile::TempDir::new().unwrap(); + let snapshotter = + Snapshotter::new(tx.inner_raw(), snapshots_dir.into_path(), MAINNET.clone(), 2) + .unwrap(); assert_eq!( *snapshotter.highest_snapshot_receiver().borrow(), @@ -258,12 +302,15 @@ mod tests { #[test] fn get_snapshot_targets() { let tx = TestTransaction::default(); + let snapshots_dir = tempfile::TempDir::new().unwrap(); let mut rng = generators::rng(); let blocks = random_block_range(&mut rng, 0..=3, B256::ZERO, 2..3); tx.insert_blocks(blocks.iter(), None).expect("insert blocks"); - let mut snapshotter = Snapshotter::new(tx.inner_raw(), MAINNET.clone(), 2); + let mut snapshotter = + Snapshotter::new(tx.inner_raw(), snapshots_dir.into_path(), MAINNET.clone(), 2) + .unwrap(); // Snapshot targets has data per part up to the passed finalized block number, // respecting the block interval diff --git a/crates/storage/provider/src/providers/database/mod.rs b/crates/storage/provider/src/providers/database/mod.rs index 5336de0bf..4e4f14de8 100644 --- a/crates/storage/provider/src/providers/database/mod.rs +++ b/crates/storage/provider/src/providers/database/mod.rs @@ -21,6 +21,7 @@ use reth_primitives::{ use revm::primitives::{BlockEnv, CfgEnv}; use std::{ ops::{RangeBounds, RangeInclusive}, + path::PathBuf, sync::Arc, }; use tokio::sync::watch; @@ -82,10 +83,12 @@ impl ProviderFactory { /// database provider comes with a shared snapshot provider pub fn with_snapshots( mut self, + snapshots_path: PathBuf, highest_snapshot_tracker: watch::Receiver>, ) -> Self { self.snapshot_provider = Some(Arc::new( - SnapshotProvider::default().with_highest_tracker(Some(highest_snapshot_tracker)), + SnapshotProvider::new(snapshots_path) + .with_highest_tracker(Some(highest_snapshot_tracker)), )); self } diff --git a/crates/storage/provider/src/providers/snapshot/manager.rs b/crates/storage/provider/src/providers/snapshot/manager.rs index 1b26f1db6..f75990059 100644 --- a/crates/storage/provider/src/providers/snapshot/manager.rs +++ b/crates/storage/provider/src/providers/snapshot/manager.rs @@ -20,9 +20,16 @@ pub struct SnapshotProvider { map: DashMap<(BlockNumber, SnapshotSegment), LoadedJar>, /// Tracks the highest snapshot of every segment. highest_tracker: Option>>, + /// Directory where snapshots are located + path: PathBuf, } impl SnapshotProvider { + /// Creates a new [`SnapshotProvider`]. + pub fn new(path: PathBuf) -> Self { + Self { map: Default::default(), highest_tracker: None, path } + } + /// Adds a highest snapshot tracker to the provider pub fn with_highest_tracker( mut self, @@ -50,9 +57,9 @@ impl SnapshotProvider { if let Some(path) = &path { self.map.insert(key, LoadedJar::new(NippyJar::load(path)?)?); } else { - path = Some(segment.filename( + path = Some(self.path.join(segment.filename( &((snapshot * BLOCKS_PER_SNAPSHOT)..=((snapshot + 1) * BLOCKS_PER_SNAPSHOT - 1)), - )); + ))); } self.get_segment_provider(segment, block, path)