mirror of
https://github.com/hl-archive-node/nanoreth.git
synced 2025-12-06 10:59:55 +00:00
test(sync): stage test suite (#204)
* test(sync): stage test suite * cleanup txindex tests * nit * start revamping bodies testing * revamp body testing * add comments to suite tests * fmt * cleanup dup code * cleanup insert_headers helper fn * fix tests * linter * switch mutex to atomic * cleanup * revert * test: make unwind runner return value instead of channel * test: make execute runner return value instead of channel * Revert "test: make execute runner return value instead of channel" This reverts commit f8608654f2e4cf97f60ce6aa95c28009f71d5331. Co-authored-by: Georgios Konstantopoulos <me@gakonst.com>
This commit is contained in:
@ -61,7 +61,7 @@ impl<E: EnvironmentKind> Env<E> {
|
||||
inner: Environment::new()
|
||||
.set_max_dbs(TABLES.len())
|
||||
.set_geometry(Geometry {
|
||||
size: Some(0..0x100000), // TODO: reevaluate
|
||||
size: Some(0..0x1000000), // TODO: reevaluate
|
||||
growth_step: Some(0x100000), // TODO: reevaluate
|
||||
shrink_threshold: None,
|
||||
page_size: Some(PageSize::Set(default_page_size())),
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
//! Testing support for headers related interfaces.
|
||||
use crate::{
|
||||
consensus::{self, Consensus, Error},
|
||||
consensus::{self, Consensus},
|
||||
p2p::headers::{
|
||||
client::{HeadersClient, HeadersRequest, HeadersResponse, HeadersStream},
|
||||
downloader::HeaderDownloader,
|
||||
@ -9,20 +9,28 @@ use crate::{
|
||||
};
|
||||
use reth_primitives::{BlockLocked, Header, SealedHeader, H256, H512};
|
||||
use reth_rpc_types::engine::ForkchoiceState;
|
||||
use std::{collections::HashSet, sync::Arc, time::Duration};
|
||||
use std::{
|
||||
collections::HashSet,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::sync::{broadcast, mpsc, watch};
|
||||
use tokio_stream::{wrappers::BroadcastStream, StreamExt};
|
||||
|
||||
/// A test downloader which just returns the values that have been pushed to it.
|
||||
#[derive(Debug)]
|
||||
pub struct TestHeaderDownloader {
|
||||
result: Result<Vec<SealedHeader>, DownloadError>,
|
||||
client: Arc<TestHeadersClient>,
|
||||
consensus: Arc<TestConsensus>,
|
||||
}
|
||||
|
||||
impl TestHeaderDownloader {
|
||||
/// Instantiates the downloader with the mock responses
|
||||
pub fn new(result: Result<Vec<SealedHeader>, DownloadError>) -> Self {
|
||||
Self { result }
|
||||
pub fn new(client: Arc<TestHeadersClient>, consensus: Arc<TestConsensus>) -> Self {
|
||||
Self { client, consensus }
|
||||
}
|
||||
}
|
||||
|
||||
@ -36,11 +44,11 @@ impl HeaderDownloader for TestHeaderDownloader {
|
||||
}
|
||||
|
||||
fn consensus(&self) -> &Self::Consensus {
|
||||
unimplemented!()
|
||||
&self.consensus
|
||||
}
|
||||
|
||||
fn client(&self) -> &Self::Client {
|
||||
unimplemented!()
|
||||
&self.client
|
||||
}
|
||||
|
||||
async fn download(
|
||||
@ -48,7 +56,27 @@ impl HeaderDownloader for TestHeaderDownloader {
|
||||
_: &SealedHeader,
|
||||
_: &ForkchoiceState,
|
||||
) -> Result<Vec<SealedHeader>, DownloadError> {
|
||||
self.result.clone()
|
||||
// call consensus stub first. fails if the flag is set
|
||||
let empty = SealedHeader::default();
|
||||
self.consensus
|
||||
.validate_header(&empty, &empty)
|
||||
.map_err(|error| DownloadError::HeaderValidation { hash: empty.hash(), error })?;
|
||||
|
||||
let stream = self.client.stream_headers().await;
|
||||
let stream = stream.timeout(Duration::from_secs(1));
|
||||
|
||||
match Box::pin(stream).try_next().await {
|
||||
Ok(Some(res)) => {
|
||||
let mut headers = res.headers.iter().map(|h| h.clone().seal()).collect::<Vec<_>>();
|
||||
if !headers.is_empty() {
|
||||
headers.sort_unstable_by_key(|h| h.number);
|
||||
headers.remove(0); // remove head from response
|
||||
headers.reverse();
|
||||
}
|
||||
Ok(headers)
|
||||
}
|
||||
_ => Err(DownloadError::Timeout { request_id: 0 }),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -93,6 +121,12 @@ impl TestHeadersClient {
|
||||
pub fn send_header_response(&self, id: u64, headers: Vec<Header>) {
|
||||
self.res_tx.send((id, headers).into()).expect("failed to send header response");
|
||||
}
|
||||
|
||||
/// Helper for pushing responses to the client
|
||||
pub async fn send_header_response_delayed(&self, id: u64, headers: Vec<Header>, secs: u64) {
|
||||
tokio::time::sleep(Duration::from_secs(secs)).await;
|
||||
self.send_header_response(id, headers);
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@ -106,6 +140,9 @@ impl HeadersClient for TestHeadersClient {
|
||||
}
|
||||
|
||||
async fn stream_headers(&self) -> HeadersStream {
|
||||
if !self.res_rx.is_empty() {
|
||||
println!("WARNING: broadcast receiver already contains messages.")
|
||||
}
|
||||
Box::pin(BroadcastStream::new(self.res_rx.resubscribe()).filter_map(|e| e.ok()))
|
||||
}
|
||||
}
|
||||
@ -116,7 +153,7 @@ pub struct TestConsensus {
|
||||
/// Watcher over the forkchoice state
|
||||
channel: (watch::Sender<ForkchoiceState>, watch::Receiver<ForkchoiceState>),
|
||||
/// Flag whether the header validation should purposefully fail
|
||||
fail_validation: bool,
|
||||
fail_validation: AtomicBool,
|
||||
}
|
||||
|
||||
impl Default for TestConsensus {
|
||||
@ -127,7 +164,7 @@ impl Default for TestConsensus {
|
||||
finalized_block_hash: H256::zero(),
|
||||
safe_block_hash: H256::zero(),
|
||||
}),
|
||||
fail_validation: false,
|
||||
fail_validation: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -143,9 +180,14 @@ impl TestConsensus {
|
||||
self.channel.0.send(state).expect("updating fork choice state failed");
|
||||
}
|
||||
|
||||
/// Get the failed validation flag
|
||||
pub fn fail_validation(&self) -> bool {
|
||||
self.fail_validation.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Update the validation flag
|
||||
pub fn set_fail_validation(&mut self, val: bool) {
|
||||
self.fail_validation = val;
|
||||
pub fn set_fail_validation(&self, val: bool) {
|
||||
self.fail_validation.store(val, Ordering::SeqCst)
|
||||
}
|
||||
}
|
||||
|
||||
@ -160,15 +202,15 @@ impl Consensus for TestConsensus {
|
||||
_header: &SealedHeader,
|
||||
_parent: &SealedHeader,
|
||||
) -> Result<(), consensus::Error> {
|
||||
if self.fail_validation {
|
||||
if self.fail_validation() {
|
||||
Err(consensus::Error::BaseFeeMissing)
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn pre_validate_block(&self, _block: &BlockLocked) -> Result<(), Error> {
|
||||
if self.fail_validation {
|
||||
fn pre_validate_block(&self, _block: &BlockLocked) -> Result<(), consensus::Error> {
|
||||
if self.fail_validation() {
|
||||
Err(consensus::Error::BaseFeeMissing)
|
||||
} else {
|
||||
Ok(())
|
||||
|
||||
@ -215,7 +215,7 @@ mod tests {
|
||||
|
||||
static CONSENSUS: Lazy<Arc<TestConsensus>> = Lazy::new(|| Arc::new(TestConsensus::default()));
|
||||
static CONSENSUS_FAIL: Lazy<Arc<TestConsensus>> = Lazy::new(|| {
|
||||
let mut consensus = TestConsensus::default();
|
||||
let consensus = TestConsensus::default();
|
||||
consensus.set_fail_validation(true);
|
||||
Arc::new(consensus)
|
||||
});
|
||||
|
||||
@ -20,6 +20,9 @@ mod pipeline;
|
||||
mod stage;
|
||||
mod util;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_utils;
|
||||
|
||||
/// Implementations of stages.
|
||||
pub mod stages;
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ use reth_primitives::{
|
||||
proofs::{EMPTY_LIST_HASH, EMPTY_ROOT},
|
||||
BlockLocked, BlockNumber, SealedHeader, H256,
|
||||
};
|
||||
use std::fmt::Debug;
|
||||
use std::{fmt::Debug, sync::Arc};
|
||||
use tracing::warn;
|
||||
|
||||
const BODIES: StageId = StageId("Bodies");
|
||||
@ -51,9 +51,9 @@ const BODIES: StageId = StageId("Bodies");
|
||||
#[derive(Debug)]
|
||||
pub struct BodyStage<D: BodyDownloader, C: Consensus> {
|
||||
/// The body downloader.
|
||||
pub downloader: D,
|
||||
pub downloader: Arc<D>,
|
||||
/// The consensus engine.
|
||||
pub consensus: C,
|
||||
pub consensus: Arc<C>,
|
||||
/// The maximum amount of block bodies to process in one stage execution.
|
||||
///
|
||||
/// Smaller batch sizes result in less memory usage, but more disk I/O. Larger batch sizes
|
||||
@ -81,6 +81,9 @@ impl<DB: Database, D: BodyDownloader, C: Consensus> Stage<DB> for BodyStage<D, C
|
||||
input.previous_stage.as_ref().map(|(_, block)| *block).unwrap_or_default();
|
||||
if previous_stage_progress == 0 {
|
||||
warn!("The body stage seems to be running first, no work can be completed.");
|
||||
return Err(StageError::DatabaseIntegrity(DatabaseIntegrityError::BlockBody {
|
||||
number: 0,
|
||||
}))
|
||||
}
|
||||
|
||||
// The block we ended at last sync, and the one we are starting on now
|
||||
@ -230,67 +233,36 @@ impl<D: BodyDownloader, C: Consensus> BodyStage<D, C> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::util::test_utils::StageTestRunner;
|
||||
use assert_matches::assert_matches;
|
||||
use reth_eth_wire::BlockBody;
|
||||
use reth_interfaces::{
|
||||
consensus,
|
||||
p2p::bodies::error::DownloadError,
|
||||
test_utils::generators::{random_block, random_block_range},
|
||||
use crate::test_utils::{
|
||||
stage_test_suite, ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner,
|
||||
PREV_STAGE_ID,
|
||||
};
|
||||
use reth_primitives::{BlockNumber, H256};
|
||||
use assert_matches::assert_matches;
|
||||
use reth_interfaces::{consensus, p2p::bodies::error::DownloadError};
|
||||
use std::collections::HashMap;
|
||||
use test_utils::*;
|
||||
|
||||
/// Check that the execution is short-circuited if the database is empty.
|
||||
#[tokio::test]
|
||||
async fn empty_db() {
|
||||
let runner = BodyTestRunner::new(TestBodyDownloader::default);
|
||||
let rx = runner.execute(ExecInput::default());
|
||||
assert_matches!(
|
||||
rx.await.unwrap(),
|
||||
Ok(ExecOutput { stage_progress: 0, reached_tip: true, done: true })
|
||||
)
|
||||
}
|
||||
|
||||
/// Check that the execution is short-circuited if the target was already reached.
|
||||
#[tokio::test]
|
||||
async fn already_reached_target() {
|
||||
let runner = BodyTestRunner::new(TestBodyDownloader::default);
|
||||
let rx = runner.execute(ExecInput {
|
||||
previous_stage: Some((StageId("Headers"), 100)),
|
||||
stage_progress: Some(100),
|
||||
});
|
||||
assert_matches!(
|
||||
rx.await.unwrap(),
|
||||
Ok(ExecOutput { stage_progress: 100, reached_tip: true, done: true })
|
||||
)
|
||||
}
|
||||
stage_test_suite!(BodyTestRunner);
|
||||
|
||||
/// Checks that the stage downloads at most `batch_size` blocks.
|
||||
#[tokio::test]
|
||||
async fn partial_body_download() {
|
||||
// Generate blocks
|
||||
let blocks = random_block_range(1..200, GENESIS_HASH);
|
||||
let bodies: HashMap<H256, Result<BlockBody, DownloadError>> =
|
||||
blocks.iter().map(body_by_hash).collect();
|
||||
let mut runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone()));
|
||||
let (stage_progress, previous_stage) = (1, 200);
|
||||
|
||||
// Set up test runner
|
||||
let mut runner = BodyTestRunner::default();
|
||||
let input = ExecInput {
|
||||
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
|
||||
stage_progress: Some(stage_progress),
|
||||
};
|
||||
runner.seed_execution(input).expect("failed to seed execution");
|
||||
|
||||
// Set the batch size (max we sync per stage execution) to less than the number of blocks
|
||||
// the previous stage synced (10 vs 20)
|
||||
runner.set_batch_size(10);
|
||||
|
||||
// Insert required state
|
||||
runner.insert_genesis().expect("Could not insert genesis block");
|
||||
runner
|
||||
.insert_headers(blocks.iter().map(|block| &block.header))
|
||||
.expect("Could not insert headers");
|
||||
|
||||
// Run the stage
|
||||
let rx = runner.execute(ExecInput {
|
||||
previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)),
|
||||
stage_progress: None,
|
||||
});
|
||||
let rx = runner.execute(input);
|
||||
|
||||
// Check that we only synced around `batch_size` blocks even though the number of blocks
|
||||
// synced by the previous stage is higher
|
||||
@ -299,34 +271,27 @@ mod tests {
|
||||
output,
|
||||
Ok(ExecOutput { stage_progress, reached_tip: true, done: false }) if stage_progress < 200
|
||||
);
|
||||
runner
|
||||
.validate_db_blocks(output.unwrap().stage_progress)
|
||||
.expect("Written block data invalid");
|
||||
assert!(runner.validate_execution(input, output.ok()).is_ok(), "execution validation");
|
||||
}
|
||||
|
||||
/// Same as [partial_body_download] except the `batch_size` is not hit.
|
||||
#[tokio::test]
|
||||
async fn full_body_download() {
|
||||
// Generate blocks #1-20
|
||||
let blocks = random_block_range(1..21, GENESIS_HASH);
|
||||
let bodies: HashMap<H256, Result<BlockBody, DownloadError>> =
|
||||
blocks.iter().map(body_by_hash).collect();
|
||||
let mut runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone()));
|
||||
let (stage_progress, previous_stage) = (1, 20);
|
||||
|
||||
// Set up test runner
|
||||
let mut runner = BodyTestRunner::default();
|
||||
let input = ExecInput {
|
||||
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
|
||||
stage_progress: Some(stage_progress),
|
||||
};
|
||||
runner.seed_execution(input).expect("failed to seed execution");
|
||||
|
||||
// Set the batch size to more than what the previous stage synced (40 vs 20)
|
||||
runner.set_batch_size(40);
|
||||
|
||||
// Insert required state
|
||||
runner.insert_genesis().expect("Could not insert genesis block");
|
||||
runner
|
||||
.insert_headers(blocks.iter().map(|block| &block.header))
|
||||
.expect("Could not insert headers");
|
||||
|
||||
// Run the stage
|
||||
let rx = runner.execute(ExecInput {
|
||||
previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)),
|
||||
stage_progress: None,
|
||||
});
|
||||
let rx = runner.execute(input);
|
||||
|
||||
// Check that we synced all blocks successfully, even though our `batch_size` allows us to
|
||||
// sync more (if there were more headers)
|
||||
@ -335,31 +300,26 @@ mod tests {
|
||||
output,
|
||||
Ok(ExecOutput { stage_progress: 20, reached_tip: true, done: true })
|
||||
);
|
||||
runner
|
||||
.validate_db_blocks(output.unwrap().stage_progress)
|
||||
.expect("Written block data invalid");
|
||||
assert!(runner.validate_execution(input, output.ok()).is_ok(), "execution validation");
|
||||
}
|
||||
|
||||
/// Same as [full_body_download] except we have made progress before
|
||||
#[tokio::test]
|
||||
async fn sync_from_previous_progress() {
|
||||
// Generate blocks #1-20
|
||||
let blocks = random_block_range(1..21, GENESIS_HASH);
|
||||
let bodies: HashMap<H256, Result<BlockBody, DownloadError>> =
|
||||
blocks.iter().map(body_by_hash).collect();
|
||||
let runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone()));
|
||||
let (stage_progress, previous_stage) = (1, 21);
|
||||
|
||||
// Insert required state
|
||||
runner.insert_genesis().expect("Could not insert genesis block");
|
||||
runner
|
||||
.insert_headers(blocks.iter().map(|block| &block.header))
|
||||
.expect("Could not insert headers");
|
||||
// Set up test runner
|
||||
let mut runner = BodyTestRunner::default();
|
||||
let input = ExecInput {
|
||||
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
|
||||
stage_progress: Some(stage_progress),
|
||||
};
|
||||
runner.seed_execution(input).expect("failed to seed execution");
|
||||
|
||||
runner.set_batch_size(10);
|
||||
|
||||
// Run the stage
|
||||
let rx = runner.execute(ExecInput {
|
||||
previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)),
|
||||
stage_progress: None,
|
||||
});
|
||||
let rx = runner.execute(input);
|
||||
|
||||
// Check that we synced at least 10 blocks
|
||||
let first_run = rx.await.unwrap();
|
||||
@ -370,10 +330,11 @@ mod tests {
|
||||
let first_run_progress = first_run.unwrap().stage_progress;
|
||||
|
||||
// Execute again on top of the previous run
|
||||
let rx = runner.execute(ExecInput {
|
||||
previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)),
|
||||
let input = ExecInput {
|
||||
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
|
||||
stage_progress: Some(first_run_progress),
|
||||
});
|
||||
};
|
||||
let rx = runner.execute(input);
|
||||
|
||||
// Check that we synced more blocks
|
||||
let output = rx.await.unwrap();
|
||||
@ -381,175 +342,86 @@ mod tests {
|
||||
output,
|
||||
Ok(ExecOutput { stage_progress, reached_tip: true, done: true }) if stage_progress > first_run_progress
|
||||
);
|
||||
runner
|
||||
.validate_db_blocks(output.unwrap().stage_progress)
|
||||
.expect("Written block data invalid");
|
||||
assert!(runner.validate_execution(input, output.ok()).is_ok(), "execution validation");
|
||||
}
|
||||
|
||||
/// Checks that the stage asks to unwind if pre-validation of the block fails.
|
||||
#[tokio::test]
|
||||
async fn pre_validation_failure() {
|
||||
// Generate blocks #1-19
|
||||
let blocks = random_block_range(1..20, GENESIS_HASH);
|
||||
let bodies: HashMap<H256, Result<BlockBody, DownloadError>> =
|
||||
blocks.iter().map(body_by_hash).collect();
|
||||
let mut runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone()));
|
||||
let (stage_progress, previous_stage) = (1, 20);
|
||||
|
||||
// Set up test runner
|
||||
let mut runner = BodyTestRunner::default();
|
||||
let input = ExecInput {
|
||||
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
|
||||
stage_progress: Some(stage_progress),
|
||||
};
|
||||
runner.seed_execution(input).expect("failed to seed execution");
|
||||
|
||||
// Fail validation
|
||||
runner.set_fail_validation(true);
|
||||
|
||||
// Insert required state
|
||||
runner.insert_genesis().expect("Could not insert genesis block");
|
||||
runner
|
||||
.insert_headers(blocks.iter().map(|block| &block.header))
|
||||
.expect("Could not insert headers");
|
||||
runner.consensus.set_fail_validation(true);
|
||||
|
||||
// Run the stage
|
||||
let rx = runner.execute(ExecInput {
|
||||
previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)),
|
||||
stage_progress: None,
|
||||
});
|
||||
let rx = runner.execute(input);
|
||||
|
||||
// Check that the error bubbles up
|
||||
assert_matches!(
|
||||
rx.await.unwrap(),
|
||||
Err(StageError::Validation { block: 1, error: consensus::Error::BaseFeeMissing })
|
||||
Err(StageError::Validation { error: consensus::Error::BaseFeeMissing, .. })
|
||||
);
|
||||
}
|
||||
|
||||
/// Checks that the stage unwinds correctly with no data.
|
||||
#[tokio::test]
|
||||
async fn unwind_empty_db() {
|
||||
let unwind_to = 10;
|
||||
let runner = BodyTestRunner::new(TestBodyDownloader::default);
|
||||
let rx = runner.unwind(UnwindInput { bad_block: None, stage_progress: 20, unwind_to });
|
||||
assert_matches!(
|
||||
rx.await.unwrap(),
|
||||
Ok(UnwindOutput { stage_progress }) if stage_progress == unwind_to
|
||||
)
|
||||
}
|
||||
|
||||
/// Checks that the stage unwinds correctly with data.
|
||||
#[tokio::test]
|
||||
async fn unwind() {
|
||||
// Generate blocks #1-20
|
||||
let blocks = random_block_range(1..21, GENESIS_HASH);
|
||||
let bodies: HashMap<H256, Result<BlockBody, DownloadError>> =
|
||||
blocks.iter().map(body_by_hash).collect();
|
||||
let mut runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone()));
|
||||
|
||||
// Set the batch size to more than what the previous stage synced (40 vs 20)
|
||||
runner.set_batch_size(40);
|
||||
|
||||
// Insert required state
|
||||
runner.insert_genesis().expect("Could not insert genesis block");
|
||||
runner
|
||||
.insert_headers(blocks.iter().map(|block| &block.header))
|
||||
.expect("Could not insert headers");
|
||||
|
||||
// Run the stage
|
||||
let rx = runner.execute(ExecInput {
|
||||
previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)),
|
||||
stage_progress: None,
|
||||
});
|
||||
|
||||
// Check that we synced all blocks successfully, even though our `batch_size` allows us to
|
||||
// sync more (if there were more headers)
|
||||
let output = rx.await.unwrap();
|
||||
assert_matches!(
|
||||
output,
|
||||
Ok(ExecOutput { stage_progress: 20, reached_tip: true, done: true })
|
||||
);
|
||||
let stage_progress = output.unwrap().stage_progress;
|
||||
runner.validate_db_blocks(stage_progress).expect("Written block data invalid");
|
||||
|
||||
// Unwind all of it
|
||||
let unwind_to = 1;
|
||||
let rx = runner.unwind(UnwindInput { bad_block: None, stage_progress, unwind_to });
|
||||
assert_matches!(
|
||||
rx.await.unwrap(),
|
||||
Ok(UnwindOutput { stage_progress }) if stage_progress == 1
|
||||
);
|
||||
|
||||
let last_body = runner.last_body().expect("Could not read last body");
|
||||
let last_tx_id = last_body.base_tx_id + last_body.tx_amount;
|
||||
runner
|
||||
.db()
|
||||
.check_no_entry_above::<tables::BlockBodies, _>(unwind_to, |key| key.number())
|
||||
.expect("Did not unwind block bodies correctly.");
|
||||
runner
|
||||
.db()
|
||||
.check_no_entry_above::<tables::Transactions, _>(last_tx_id, |key| key)
|
||||
.expect("Did not unwind transactions correctly.")
|
||||
assert!(runner.validate_execution(input, None).is_ok(), "execution validation");
|
||||
}
|
||||
|
||||
/// Checks that the stage unwinds correctly, even if a transaction in a block is missing.
|
||||
#[tokio::test]
|
||||
async fn unwind_missing_tx() {
|
||||
// Generate blocks #1-20
|
||||
let blocks = random_block_range(1..21, GENESIS_HASH);
|
||||
let bodies: HashMap<H256, Result<BlockBody, DownloadError>> =
|
||||
blocks.iter().map(body_by_hash).collect();
|
||||
let mut runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone()));
|
||||
let (stage_progress, previous_stage) = (1, 20);
|
||||
|
||||
// Set up test runner
|
||||
let mut runner = BodyTestRunner::default();
|
||||
let input = ExecInput {
|
||||
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
|
||||
stage_progress: Some(stage_progress),
|
||||
};
|
||||
runner.seed_execution(input).expect("failed to seed execution");
|
||||
|
||||
// Set the batch size to more than what the previous stage synced (40 vs 20)
|
||||
runner.set_batch_size(40);
|
||||
|
||||
// Insert required state
|
||||
runner.insert_genesis().expect("Could not insert genesis block");
|
||||
runner
|
||||
.insert_headers(blocks.iter().map(|block| &block.header))
|
||||
.expect("Could not insert headers");
|
||||
|
||||
// Run the stage
|
||||
let rx = runner.execute(ExecInput {
|
||||
previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)),
|
||||
stage_progress: None,
|
||||
});
|
||||
let rx = runner.execute(input);
|
||||
|
||||
// Check that we synced all blocks successfully, even though our `batch_size` allows us to
|
||||
// sync more (if there were more headers)
|
||||
let output = rx.await.unwrap();
|
||||
assert_matches!(
|
||||
output,
|
||||
Ok(ExecOutput { stage_progress: 20, reached_tip: true, done: true })
|
||||
Ok(ExecOutput { stage_progress, reached_tip: true, done: true }) if stage_progress == previous_stage
|
||||
);
|
||||
let stage_progress = output.unwrap().stage_progress;
|
||||
runner.validate_db_blocks(stage_progress).expect("Written block data invalid");
|
||||
|
||||
// Delete a transaction
|
||||
{
|
||||
let mut db = runner.db().container();
|
||||
let mut tx_cursor = db
|
||||
.get_mut()
|
||||
.cursor_mut::<tables::Transactions>()
|
||||
.expect("Could not get transaction cursor");
|
||||
tx_cursor
|
||||
.last()
|
||||
.expect("Could not read database")
|
||||
.expect("Could not read last transaction");
|
||||
tx_cursor.delete_current().expect("Could not delete last transaction");
|
||||
db.commit().expect("Could not commit database");
|
||||
}
|
||||
runner
|
||||
.db()
|
||||
.commit(|tx| {
|
||||
let mut tx_cursor = tx.cursor_mut::<tables::Transactions>()?;
|
||||
tx_cursor.last()?.expect("Could not read last transaction");
|
||||
tx_cursor.delete_current()?;
|
||||
Ok(())
|
||||
})
|
||||
.expect("Could not delete a transaction");
|
||||
|
||||
// Unwind all of it
|
||||
let unwind_to = 1;
|
||||
let rx = runner.unwind(UnwindInput { bad_block: None, stage_progress, unwind_to });
|
||||
let input = UnwindInput { bad_block: None, stage_progress, unwind_to };
|
||||
let res = runner.unwind(input).await;
|
||||
assert_matches!(
|
||||
rx.await.unwrap(),
|
||||
res,
|
||||
Ok(UnwindOutput { stage_progress }) if stage_progress == 1
|
||||
);
|
||||
|
||||
let last_body = runner.last_body().expect("Could not read last body");
|
||||
let last_tx_id = last_body.base_tx_id + last_body.tx_amount;
|
||||
runner
|
||||
.db()
|
||||
.check_no_entry_above::<tables::BlockBodies, _>(unwind_to, |key| key.number())
|
||||
.expect("Did not unwind block bodies correctly.");
|
||||
runner
|
||||
.db()
|
||||
.check_no_entry_above::<tables::Transactions, _>(last_tx_id, |key| key)
|
||||
.expect("Did not unwind transactions correctly.")
|
||||
assert!(runner.validate_unwind(input).is_ok(), "unwind validation");
|
||||
}
|
||||
|
||||
/// Checks that the stage exits if the downloader times out
|
||||
@ -557,54 +429,53 @@ mod tests {
|
||||
/// try again?
|
||||
#[tokio::test]
|
||||
async fn downloader_timeout() {
|
||||
// Generate a header
|
||||
let header = random_block(1, Some(GENESIS_HASH)).header;
|
||||
let runner = BodyTestRunner::new(|| {
|
||||
TestBodyDownloader::new(HashMap::from([(
|
||||
header.hash(),
|
||||
Err(DownloadError::Timeout { header_hash: header.hash() }),
|
||||
)]))
|
||||
});
|
||||
let (stage_progress, previous_stage) = (1, 2);
|
||||
|
||||
// Insert required state
|
||||
runner.insert_genesis().expect("Could not insert genesis block");
|
||||
runner.insert_header(&header).expect("Could not insert header");
|
||||
// Set up test runner
|
||||
let mut runner = BodyTestRunner::default();
|
||||
let input = ExecInput {
|
||||
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
|
||||
stage_progress: Some(stage_progress),
|
||||
};
|
||||
let blocks = runner.seed_execution(input).expect("failed to seed execution");
|
||||
|
||||
// overwrite responses
|
||||
let header = blocks.last().unwrap();
|
||||
runner.set_responses(HashMap::from([(
|
||||
header.hash(),
|
||||
Err(DownloadError::Timeout { header_hash: header.hash() }),
|
||||
)]));
|
||||
|
||||
// Run the stage
|
||||
let rx = runner.execute(ExecInput {
|
||||
previous_stage: Some((StageId("Headers"), 1)),
|
||||
stage_progress: None,
|
||||
});
|
||||
let rx = runner.execute(input);
|
||||
|
||||
// Check that the error bubbles up
|
||||
assert_matches!(rx.await.unwrap(), Err(StageError::Internal(_)));
|
||||
assert!(runner.validate_execution(input, None).is_ok(), "execution validation");
|
||||
}
|
||||
|
||||
mod test_utils {
|
||||
use crate::{
|
||||
stages::bodies::BodyStage,
|
||||
util::test_utils::{StageTestDB, StageTestRunner},
|
||||
test_utils::{
|
||||
ExecuteStageTestRunner, StageTestDB, StageTestRunner, TestRunnerError,
|
||||
UnwindStageTestRunner,
|
||||
},
|
||||
ExecInput, ExecOutput, UnwindInput,
|
||||
};
|
||||
use assert_matches::assert_matches;
|
||||
use async_trait::async_trait;
|
||||
use reth_eth_wire::BlockBody;
|
||||
use reth_interfaces::{
|
||||
db,
|
||||
db::{
|
||||
models::{BlockNumHash, StoredBlockBody},
|
||||
tables, DbCursorRO, DbTx, DbTxMut,
|
||||
},
|
||||
db::{models::StoredBlockBody, tables, DbCursorRO, DbTx, DbTxMut},
|
||||
p2p::bodies::{
|
||||
client::BodiesClient,
|
||||
downloader::{BodiesStream, BodyDownloader},
|
||||
error::{BodiesClientError, DownloadError},
|
||||
},
|
||||
test_utils::TestConsensus,
|
||||
test_utils::{generators::random_block_range, TestConsensus},
|
||||
};
|
||||
use reth_primitives::{
|
||||
BigEndianHash, BlockLocked, BlockNumber, Header, SealedHeader, H256, U256,
|
||||
};
|
||||
use std::{collections::HashMap, ops::Deref, time::Duration};
|
||||
use reth_primitives::{BlockLocked, BlockNumber, Header, SealedHeader, H256};
|
||||
use std::{collections::HashMap, sync::Arc, time::Duration};
|
||||
|
||||
/// The block hash of the genesis block.
|
||||
pub(crate) const GENESIS_HASH: H256 = H256::zero();
|
||||
@ -623,43 +494,38 @@ mod tests {
|
||||
}
|
||||
|
||||
/// A helper struct for running the [BodyStage].
|
||||
pub(crate) struct BodyTestRunner<F>
|
||||
where
|
||||
F: Fn() -> TestBodyDownloader,
|
||||
{
|
||||
downloader_builder: F,
|
||||
pub(crate) struct BodyTestRunner {
|
||||
pub(crate) consensus: Arc<TestConsensus>,
|
||||
responses: HashMap<H256, Result<BlockBody, DownloadError>>,
|
||||
db: StageTestDB,
|
||||
batch_size: u64,
|
||||
fail_validation: bool,
|
||||
}
|
||||
|
||||
impl<F> BodyTestRunner<F>
|
||||
where
|
||||
F: Fn() -> TestBodyDownloader,
|
||||
{
|
||||
/// Build a new test runner.
|
||||
pub(crate) fn new(downloader_builder: F) -> Self {
|
||||
BodyTestRunner {
|
||||
downloader_builder,
|
||||
impl Default for BodyTestRunner {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
consensus: Arc::new(TestConsensus::default()),
|
||||
responses: HashMap::default(),
|
||||
db: StageTestDB::default(),
|
||||
batch_size: 10,
|
||||
fail_validation: false,
|
||||
batch_size: 1000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BodyTestRunner {
|
||||
pub(crate) fn set_batch_size(&mut self, batch_size: u64) {
|
||||
self.batch_size = batch_size;
|
||||
}
|
||||
|
||||
pub(crate) fn set_fail_validation(&mut self, fail_validation: bool) {
|
||||
self.fail_validation = fail_validation;
|
||||
pub(crate) fn set_responses(
|
||||
&mut self,
|
||||
responses: HashMap<H256, Result<BlockBody, DownloadError>>,
|
||||
) {
|
||||
self.responses = responses;
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> StageTestRunner for BodyTestRunner<F>
|
||||
where
|
||||
F: Fn() -> TestBodyDownloader,
|
||||
{
|
||||
impl StageTestRunner for BodyTestRunner {
|
||||
type S = BodyStage<TestBodyDownloader, TestConsensus>;
|
||||
|
||||
fn db(&self) -> &StageTestDB {
|
||||
@ -667,115 +533,115 @@ mod tests {
|
||||
}
|
||||
|
||||
fn stage(&self) -> Self::S {
|
||||
let mut consensus = TestConsensus::default();
|
||||
consensus.set_fail_validation(self.fail_validation);
|
||||
|
||||
BodyStage {
|
||||
downloader: (self.downloader_builder)(),
|
||||
consensus,
|
||||
downloader: Arc::new(TestBodyDownloader::new(self.responses.clone())),
|
||||
consensus: self.consensus.clone(),
|
||||
batch_size: self.batch_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> BodyTestRunner<F>
|
||||
where
|
||||
F: Fn() -> TestBodyDownloader,
|
||||
{
|
||||
#[async_trait::async_trait]
|
||||
impl ExecuteStageTestRunner for BodyTestRunner {
|
||||
type Seed = Vec<BlockLocked>;
|
||||
|
||||
fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
|
||||
let start = input.stage_progress.unwrap_or_default();
|
||||
let end =
|
||||
input.previous_stage.as_ref().map(|(_, num)| *num + 1).unwrap_or_default();
|
||||
let blocks = random_block_range(start..end, GENESIS_HASH);
|
||||
self.insert_genesis()?;
|
||||
self.db.insert_headers(blocks.iter().map(|block| &block.header))?;
|
||||
self.set_responses(blocks.iter().map(body_by_hash).collect());
|
||||
Ok(blocks)
|
||||
}
|
||||
|
||||
fn validate_execution(
|
||||
&self,
|
||||
input: ExecInput,
|
||||
output: Option<ExecOutput>,
|
||||
) -> Result<(), TestRunnerError> {
|
||||
let highest_block = match output.as_ref() {
|
||||
Some(output) => output.stage_progress,
|
||||
None => input.stage_progress.unwrap_or_default(),
|
||||
};
|
||||
self.validate_db_blocks(highest_block)
|
||||
}
|
||||
}
|
||||
|
||||
impl UnwindStageTestRunner for BodyTestRunner {
|
||||
fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
|
||||
self.db.check_no_entry_above::<tables::BlockBodies, _>(input.unwind_to, |key| {
|
||||
key.number()
|
||||
})?;
|
||||
if let Some(last_body) = self.last_body() {
|
||||
let last_tx_id = last_body.base_tx_id + last_body.tx_amount;
|
||||
self.db
|
||||
.check_no_entry_above::<tables::Transactions, _>(last_tx_id, |key| key)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl BodyTestRunner {
|
||||
/// Insert the genesis block into the appropriate tables
|
||||
///
|
||||
/// The genesis block always has no transactions and no ommers, and it always has the
|
||||
/// same hash.
|
||||
pub(crate) fn insert_genesis(&self) -> Result<(), db::Error> {
|
||||
self.insert_header(&SealedHeader::new(Header::default(), GENESIS_HASH))?;
|
||||
let mut db = self.db.container();
|
||||
let tx = db.get_mut();
|
||||
tx.put::<tables::BlockBodies>(
|
||||
(0, GENESIS_HASH).into(),
|
||||
StoredBlockBody { base_tx_id: 0, tx_amount: 0, ommers: vec![] },
|
||||
)?;
|
||||
db.commit()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Insert header into tables
|
||||
pub(crate) fn insert_header(&self, header: &SealedHeader) -> Result<(), db::Error> {
|
||||
self.insert_headers(std::iter::once(header))
|
||||
}
|
||||
|
||||
/// Insert headers into tables
|
||||
pub(crate) fn insert_headers<'a, I>(&self, headers: I) -> Result<(), db::Error>
|
||||
where
|
||||
I: Iterator<Item = &'a SealedHeader>,
|
||||
{
|
||||
let headers = headers.collect::<Vec<_>>();
|
||||
self.db
|
||||
.map_put::<tables::HeaderNumbers, _, _>(&headers, |h| (h.hash(), h.number))?;
|
||||
self.db.map_put::<tables::Headers, _, _>(&headers, |h| {
|
||||
(BlockNumHash((h.number, h.hash())), h.deref().clone().unseal())
|
||||
})?;
|
||||
self.db.map_put::<tables::CanonicalHeaders, _, _>(&headers, |h| {
|
||||
(h.number, h.hash())
|
||||
})?;
|
||||
|
||||
self.db.transform_append::<tables::HeaderTD, _, _>(&headers, |prev, h| {
|
||||
let prev_td = U256::from_big_endian(&prev.clone().unwrap_or_default());
|
||||
(
|
||||
BlockNumHash((h.number, h.hash())),
|
||||
H256::from_uint(&(prev_td + h.difficulty)).as_bytes().to_vec(),
|
||||
pub(crate) fn insert_genesis(&self) -> Result<(), TestRunnerError> {
|
||||
let header = SealedHeader::new(Header::default(), GENESIS_HASH);
|
||||
self.db.insert_headers(std::iter::once(&header))?;
|
||||
self.db.commit(|tx| {
|
||||
tx.put::<tables::BlockBodies>(
|
||||
(0, GENESIS_HASH).into(),
|
||||
StoredBlockBody { base_tx_id: 0, tx_amount: 0, ommers: vec![] },
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Retrieve the last body from the database
|
||||
pub(crate) fn last_body(&self) -> Option<StoredBlockBody> {
|
||||
Some(
|
||||
self.db()
|
||||
.container()
|
||||
.get()
|
||||
.cursor::<tables::BlockBodies>()
|
||||
.ok()?
|
||||
.last()
|
||||
.ok()??
|
||||
.1,
|
||||
)
|
||||
self.db
|
||||
.query(|tx| Ok(tx.cursor::<tables::BlockBodies>()?.last()?.map(|e| e.1)))
|
||||
.ok()
|
||||
.flatten()
|
||||
}
|
||||
|
||||
/// Validate that the inserted block data is valid
|
||||
pub(crate) fn validate_db_blocks(
|
||||
&self,
|
||||
highest_block: BlockNumber,
|
||||
) -> Result<(), db::Error> {
|
||||
let db = self.db.container();
|
||||
let tx = db.get();
|
||||
) -> Result<(), TestRunnerError> {
|
||||
self.db.query(|tx| {
|
||||
let mut block_body_cursor = tx.cursor::<tables::BlockBodies>()?;
|
||||
let mut transaction_cursor = tx.cursor::<tables::Transactions>()?;
|
||||
|
||||
let mut block_body_cursor = tx.cursor::<tables::BlockBodies>()?;
|
||||
let mut transaction_cursor = tx.cursor::<tables::Transactions>()?;
|
||||
|
||||
let mut entry = block_body_cursor.first()?;
|
||||
let mut prev_max_tx_id = 0;
|
||||
while let Some((key, body)) = entry {
|
||||
assert!(
|
||||
key.number() <= highest_block,
|
||||
"We wrote a block body outside of our synced range. Found block with number {}, highest block according to stage is {}",
|
||||
key.number(), highest_block
|
||||
);
|
||||
|
||||
assert!(prev_max_tx_id == body.base_tx_id, "Transaction IDs are malformed.");
|
||||
for num in 0..body.tx_amount {
|
||||
let tx_id = body.base_tx_id + num;
|
||||
assert_matches!(
|
||||
transaction_cursor.seek_exact(tx_id),
|
||||
Ok(Some(_)),
|
||||
"A transaction is missing."
|
||||
let mut entry = block_body_cursor.first()?;
|
||||
let mut prev_max_tx_id = 0;
|
||||
while let Some((key, body)) = entry {
|
||||
assert!(
|
||||
key.number() <= highest_block,
|
||||
"We wrote a block body outside of our synced range. Found block with number {}, highest block according to stage is {}",
|
||||
key.number(), highest_block
|
||||
);
|
||||
}
|
||||
prev_max_tx_id = body.base_tx_id + body.tx_amount;
|
||||
entry = block_body_cursor.next()?;
|
||||
}
|
||||
|
||||
assert!(prev_max_tx_id == body.base_tx_id, "Transaction IDs are malformed.");
|
||||
for num in 0..body.tx_amount {
|
||||
let tx_id = body.base_tx_id + num;
|
||||
assert_matches!(
|
||||
transaction_cursor.seek_exact(tx_id),
|
||||
Ok(Some(_)),
|
||||
"A transaction is missing."
|
||||
);
|
||||
}
|
||||
prev_max_tx_id = body.base_tx_id + body.tx_amount;
|
||||
entry = block_body_cursor.next()?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -785,7 +651,7 @@ mod tests {
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct NoopClient;
|
||||
|
||||
#[async_trait]
|
||||
#[async_trait::async_trait]
|
||||
impl BodiesClient for NoopClient {
|
||||
async fn get_block_body(&self, _: H256) -> Result<BlockBody, BodiesClientError> {
|
||||
panic!("Noop client should not be called")
|
||||
@ -794,7 +660,7 @@ mod tests {
|
||||
|
||||
// TODO(onbjerg): Move
|
||||
/// A [BodyDownloader] that is backed by an internal [HashMap] for testing.
|
||||
#[derive(Debug, Default)]
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub(crate) struct TestBodyDownloader {
|
||||
responses: HashMap<H256, Result<BlockBody, DownloadError>>,
|
||||
}
|
||||
@ -824,14 +690,12 @@ mod tests {
|
||||
{
|
||||
Box::pin(futures_util::stream::iter(hashes.into_iter().map(
|
||||
|(block_number, hash)| {
|
||||
Ok((
|
||||
*block_number,
|
||||
*hash,
|
||||
self.responses
|
||||
.get(hash)
|
||||
.expect("Stage tried downloading a block we do not have.")
|
||||
.clone()?,
|
||||
))
|
||||
let result = self
|
||||
.responses
|
||||
.get(hash)
|
||||
.expect("Stage tried downloading a block we do not have.")
|
||||
.clone()?;
|
||||
Ok((*block_number, *hash, result))
|
||||
},
|
||||
)))
|
||||
}
|
||||
|
||||
@ -58,8 +58,8 @@ impl<DB: Database, D: HeaderDownloader, C: Consensus, H: HeadersClient> Stage<DB
|
||||
let last_block_num = input.stage_progress.unwrap_or_default();
|
||||
self.update_head::<DB>(tx, last_block_num).await?;
|
||||
|
||||
// TODO: add batch size
|
||||
// download the headers
|
||||
// TODO: handle input.max_block
|
||||
let last_hash = tx
|
||||
.get::<tables::CanonicalHeaders>(last_block_num)?
|
||||
.ok_or(DatabaseIntegrityError::CanonicalHash { number: last_block_num })?;
|
||||
@ -190,214 +190,99 @@ impl<D: HeaderDownloader, C: Consensus, H: HeadersClient> HeaderStage<D, C, H> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::util::test_utils::StageTestRunner;
|
||||
use assert_matches::assert_matches;
|
||||
use reth_interfaces::{
|
||||
consensus,
|
||||
test_utils::{
|
||||
generators::{random_header, random_header_range},
|
||||
TestHeaderDownloader,
|
||||
},
|
||||
use crate::test_utils::{
|
||||
stage_test_suite, ExecuteStageTestRunner, UnwindStageTestRunner, PREV_STAGE_ID,
|
||||
};
|
||||
use test_utils::HeadersTestRunner;
|
||||
use assert_matches::assert_matches;
|
||||
use test_runner::HeadersTestRunner;
|
||||
|
||||
const TEST_STAGE: StageId = StageId("Headers");
|
||||
stage_test_suite!(HeadersTestRunner);
|
||||
|
||||
/// Check that the execution errors on empty database or
|
||||
/// prev progress missing from the database.
|
||||
#[tokio::test]
|
||||
async fn execute_empty_db() {
|
||||
let runner = HeadersTestRunner::default();
|
||||
let rx = runner.execute(ExecInput::default());
|
||||
assert_matches!(
|
||||
rx.await.unwrap(),
|
||||
Err(StageError::DatabaseIntegrity(DatabaseIntegrityError::CanonicalHeader { .. }))
|
||||
);
|
||||
}
|
||||
|
||||
/// Check that the execution exits on downloader timeout.
|
||||
#[tokio::test]
|
||||
// Validate that the execution does not fail on timeout
|
||||
async fn execute_timeout() {
|
||||
let head = random_header(0, None);
|
||||
let runner = HeadersTestRunner::with_downloader(TestHeaderDownloader::new(Err(
|
||||
DownloadError::Timeout { request_id: 0 },
|
||||
)));
|
||||
runner.insert_header(&head).expect("failed to insert header");
|
||||
|
||||
let rx = runner.execute(ExecInput::default());
|
||||
let mut runner = HeadersTestRunner::default();
|
||||
let input = ExecInput::default();
|
||||
runner.seed_execution(input).expect("failed to seed execution");
|
||||
let rx = runner.execute(input);
|
||||
runner.consensus.update_tip(H256::from_low_u64_be(1));
|
||||
assert_matches!(rx.await.unwrap(), Ok(ExecOutput { done, .. }) if !done);
|
||||
let result = rx.await.unwrap();
|
||||
assert_matches!(
|
||||
result,
|
||||
Ok(ExecOutput { done: false, reached_tip: false, stage_progress: 0 })
|
||||
);
|
||||
assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
|
||||
}
|
||||
|
||||
/// Check that validation error is propagated during the execution.
|
||||
#[tokio::test]
|
||||
async fn execute_validation_error() {
|
||||
let head = random_header(0, None);
|
||||
let runner = HeadersTestRunner::with_downloader(TestHeaderDownloader::new(Err(
|
||||
DownloadError::HeaderValidation {
|
||||
hash: H256::zero(),
|
||||
error: consensus::Error::BaseFeeMissing,
|
||||
},
|
||||
)));
|
||||
runner.insert_header(&head).expect("failed to insert header");
|
||||
|
||||
let rx = runner.execute(ExecInput::default());
|
||||
runner.consensus.update_tip(H256::from_low_u64_be(1));
|
||||
assert_matches!(rx.await.unwrap(), Err(StageError::Validation { block, error: consensus::Error::BaseFeeMissing, }) if block == 0);
|
||||
}
|
||||
|
||||
/// Validate that all necessary tables are updated after the
|
||||
/// header download on no previous progress.
|
||||
#[tokio::test]
|
||||
async fn execute_no_progress() {
|
||||
let (start, end) = (0, 100);
|
||||
let head = random_header(start, None);
|
||||
let headers = random_header_range(start + 1..end, head.hash());
|
||||
|
||||
let result = headers.iter().rev().cloned().collect::<Vec<_>>();
|
||||
let runner = HeadersTestRunner::with_downloader(TestHeaderDownloader::new(Ok(result)));
|
||||
runner.insert_header(&head).expect("failed to insert header");
|
||||
|
||||
let rx = runner.execute(ExecInput::default());
|
||||
let tip = headers.last().unwrap();
|
||||
runner.consensus.update_tip(tip.hash());
|
||||
|
||||
assert_matches!(
|
||||
rx.await.unwrap(),
|
||||
Ok(ExecOutput { done, reached_tip, stage_progress })
|
||||
if done && reached_tip && stage_progress == tip.number
|
||||
);
|
||||
assert!(headers.iter().try_for_each(|h| runner.validate_db_header(h)).is_ok());
|
||||
}
|
||||
|
||||
/// Validate that all necessary tables are updated after the
|
||||
/// header download with some previous progress.
|
||||
#[tokio::test]
|
||||
async fn execute_prev_progress() {
|
||||
let (start, end) = (10000, 10241);
|
||||
let head = random_header(start, None);
|
||||
let headers = random_header_range(start + 1..end, head.hash());
|
||||
|
||||
let result = headers.iter().rev().cloned().collect::<Vec<_>>();
|
||||
let runner = HeadersTestRunner::with_downloader(TestHeaderDownloader::new(Ok(result)));
|
||||
runner.insert_header(&head).expect("failed to insert header");
|
||||
|
||||
let rx = runner.execute(ExecInput {
|
||||
previous_stage: Some((TEST_STAGE, head.number)),
|
||||
stage_progress: Some(head.number),
|
||||
});
|
||||
let tip = headers.last().unwrap();
|
||||
runner.consensus.update_tip(tip.hash());
|
||||
|
||||
assert_matches!(
|
||||
rx.await.unwrap(),
|
||||
Ok(ExecOutput { done, reached_tip, stage_progress })
|
||||
if done && reached_tip && stage_progress == tip.number
|
||||
);
|
||||
assert!(headers.iter().try_for_each(|h| runner.validate_db_header(h)).is_ok());
|
||||
let mut runner = HeadersTestRunner::default();
|
||||
runner.consensus.set_fail_validation(true);
|
||||
let input = ExecInput::default();
|
||||
let seed = runner.seed_execution(input).expect("failed to seed execution");
|
||||
let rx = runner.execute(input);
|
||||
runner.after_execution(seed).await.expect("failed to run after execution hook");
|
||||
let result = rx.await.unwrap();
|
||||
assert_matches!(result, Err(StageError::Validation { .. }));
|
||||
assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
|
||||
}
|
||||
|
||||
/// Execute the stage with linear downloader
|
||||
#[tokio::test]
|
||||
async fn execute_with_linear_downloader() {
|
||||
let (start, end) = (1000, 1200);
|
||||
let head = random_header(start, None);
|
||||
let headers = random_header_range(start + 1..end, head.hash());
|
||||
|
||||
let runner = HeadersTestRunner::with_linear_downloader();
|
||||
runner.insert_header(&head).expect("failed to insert header");
|
||||
let rx = runner.execute(ExecInput {
|
||||
previous_stage: Some((TEST_STAGE, head.number)),
|
||||
stage_progress: Some(head.number),
|
||||
});
|
||||
let mut runner = HeadersTestRunner::with_linear_downloader();
|
||||
let (stage_progress, previous_stage) = (1000, 1200);
|
||||
let input = ExecInput {
|
||||
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
|
||||
stage_progress: Some(stage_progress),
|
||||
};
|
||||
let headers = runner.seed_execution(input).expect("failed to seed execution");
|
||||
let rx = runner.execute(input);
|
||||
|
||||
// skip `after_execution` hook for linear downloader
|
||||
let tip = headers.last().unwrap();
|
||||
runner.consensus.update_tip(tip.hash());
|
||||
|
||||
let mut download_result = headers.clone();
|
||||
download_result.insert(0, head);
|
||||
let download_result = headers.clone();
|
||||
runner
|
||||
.client
|
||||
.on_header_request(1, |id, _| {
|
||||
runner.client.send_header_response(
|
||||
id,
|
||||
download_result.clone().into_iter().map(|h| h.unseal()).collect(),
|
||||
)
|
||||
let response = download_result.iter().map(|h| h.clone().unseal()).collect();
|
||||
runner.client.send_header_response(id, response)
|
||||
})
|
||||
.await;
|
||||
|
||||
let result = rx.await.unwrap();
|
||||
assert_matches!(
|
||||
rx.await.unwrap(),
|
||||
Ok(ExecOutput { done, reached_tip, stage_progress })
|
||||
if done && reached_tip && stage_progress == tip.number
|
||||
result,
|
||||
Ok(ExecOutput { done: true, reached_tip: true, stage_progress }) if stage_progress == tip.number
|
||||
);
|
||||
assert!(headers.iter().try_for_each(|h| runner.validate_db_header(h)).is_ok());
|
||||
assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
|
||||
}
|
||||
|
||||
/// Check that unwind does not panic on empty database.
|
||||
#[tokio::test]
|
||||
async fn unwind_empty_db() {
|
||||
let unwind_to = 100;
|
||||
let runner = HeadersTestRunner::default();
|
||||
let rx =
|
||||
runner.unwind(UnwindInput { bad_block: None, stage_progress: unwind_to, unwind_to });
|
||||
assert_matches!(
|
||||
rx.await.unwrap(),
|
||||
Ok(UnwindOutput {stage_progress} ) if stage_progress == unwind_to
|
||||
);
|
||||
}
|
||||
|
||||
/// Check that unwind can remove headers across gaps
|
||||
#[tokio::test]
|
||||
async fn unwind_db_gaps() {
|
||||
let runner = HeadersTestRunner::default();
|
||||
let head = random_header(0, None);
|
||||
let first_range = random_header_range(1..20, head.hash());
|
||||
let second_range = random_header_range(50..100, H256::zero());
|
||||
runner.insert_header(&head).expect("failed to insert header");
|
||||
runner
|
||||
.insert_headers(first_range.iter().chain(second_range.iter()))
|
||||
.expect("failed to insert headers");
|
||||
|
||||
let unwind_to = 15;
|
||||
let rx =
|
||||
runner.unwind(UnwindInput { bad_block: None, stage_progress: unwind_to, unwind_to });
|
||||
assert_matches!(
|
||||
rx.await.unwrap(),
|
||||
Ok(UnwindOutput {stage_progress} ) if stage_progress == unwind_to
|
||||
);
|
||||
|
||||
runner
|
||||
.db()
|
||||
.check_no_entry_above::<tables::CanonicalHeaders, _>(unwind_to, |key| key)
|
||||
.expect("failed to check cannonical headers");
|
||||
runner
|
||||
.db()
|
||||
.check_no_entry_above_by_value::<tables::HeaderNumbers, _>(unwind_to, |val| val)
|
||||
.expect("failed to check header numbers");
|
||||
runner
|
||||
.db()
|
||||
.check_no_entry_above::<tables::Headers, _>(unwind_to, |key| key.number())
|
||||
.expect("failed to check headers");
|
||||
runner
|
||||
.db()
|
||||
.check_no_entry_above::<tables::HeaderTD, _>(unwind_to, |key| key.number())
|
||||
.expect("failed to check td");
|
||||
}
|
||||
|
||||
mod test_utils {
|
||||
mod test_runner {
|
||||
use crate::{
|
||||
stages::headers::HeaderStage,
|
||||
util::test_utils::{StageTestDB, StageTestRunner},
|
||||
test_utils::{
|
||||
ExecuteStageTestRunner, StageTestDB, StageTestRunner, TestRunnerError,
|
||||
UnwindStageTestRunner,
|
||||
},
|
||||
ExecInput, ExecOutput, UnwindInput,
|
||||
};
|
||||
use reth_headers_downloaders::linear::{LinearDownloadBuilder, LinearDownloader};
|
||||
use reth_interfaces::{
|
||||
db::{self, models::blocks::BlockNumHash, tables, DbTx},
|
||||
db::{models::blocks::BlockNumHash, tables, DbTx},
|
||||
p2p::headers::downloader::HeaderDownloader,
|
||||
test_utils::{TestConsensus, TestHeaderDownloader, TestHeadersClient},
|
||||
test_utils::{
|
||||
generators::{random_header, random_header_range},
|
||||
TestConsensus, TestHeaderDownloader, TestHeadersClient,
|
||||
},
|
||||
};
|
||||
use reth_primitives::{rpc::BigEndianHash, SealedHeader, H256, U256};
|
||||
use std::{ops::Deref, sync::Arc};
|
||||
use reth_primitives::{BlockNumber, SealedHeader, H256, U256};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub(crate) struct HeadersTestRunner<D: HeaderDownloader> {
|
||||
pub(crate) consensus: Arc<TestConsensus>,
|
||||
@ -408,10 +293,12 @@ mod tests {
|
||||
|
||||
impl Default for HeadersTestRunner<TestHeaderDownloader> {
|
||||
fn default() -> Self {
|
||||
let client = Arc::new(TestHeadersClient::default());
|
||||
let consensus = Arc::new(TestConsensus::default());
|
||||
Self {
|
||||
client: Arc::new(TestHeadersClient::default()),
|
||||
consensus: Arc::new(TestConsensus::default()),
|
||||
downloader: Arc::new(TestHeaderDownloader::new(Ok(Vec::default()))),
|
||||
client: client.clone(),
|
||||
consensus: consensus.clone(),
|
||||
downloader: Arc::new(TestHeaderDownloader::new(client, consensus)),
|
||||
db: StageTestDB::default(),
|
||||
}
|
||||
}
|
||||
@ -433,6 +320,99 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl<D: HeaderDownloader + 'static> ExecuteStageTestRunner for HeadersTestRunner<D> {
|
||||
type Seed = Vec<SealedHeader>;
|
||||
|
||||
fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
|
||||
let start = input.stage_progress.unwrap_or_default();
|
||||
let head = random_header(start, None);
|
||||
self.db.insert_headers(std::iter::once(&head))?;
|
||||
|
||||
// use previous progress as seed size
|
||||
let end = input.previous_stage.map(|(_, num)| num).unwrap_or_default() + 1;
|
||||
|
||||
if start + 1 >= end {
|
||||
return Ok(Vec::default())
|
||||
}
|
||||
|
||||
let mut headers = random_header_range(start + 1..end, head.hash());
|
||||
headers.insert(0, head);
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
async fn after_execution(&self, headers: Self::Seed) -> Result<(), TestRunnerError> {
|
||||
let tip = if !headers.is_empty() {
|
||||
headers.last().unwrap().hash()
|
||||
} else {
|
||||
H256::from_low_u64_be(rand::random())
|
||||
};
|
||||
self.consensus.update_tip(tip);
|
||||
self.client
|
||||
.send_header_response_delayed(
|
||||
0,
|
||||
headers.into_iter().map(|h| h.unseal()).collect(),
|
||||
1,
|
||||
)
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate stored headers
|
||||
fn validate_execution(
|
||||
&self,
|
||||
input: ExecInput,
|
||||
output: Option<ExecOutput>,
|
||||
) -> Result<(), TestRunnerError> {
|
||||
let initial_stage_progress = input.stage_progress.unwrap_or_default();
|
||||
match output {
|
||||
Some(output) if output.stage_progress > initial_stage_progress => {
|
||||
self.db.query(|tx| {
|
||||
for block_num in (initial_stage_progress..output.stage_progress).rev() {
|
||||
// look up the header hash
|
||||
let hash = tx
|
||||
.get::<tables::CanonicalHeaders>(block_num)?
|
||||
.expect("no header hash");
|
||||
let key: BlockNumHash = (block_num, hash).into();
|
||||
|
||||
// validate the header number
|
||||
assert_eq!(tx.get::<tables::HeaderNumbers>(hash)?, Some(block_num));
|
||||
|
||||
// validate the header
|
||||
let header = tx.get::<tables::Headers>(key)?;
|
||||
assert!(header.is_some());
|
||||
let header = header.unwrap().seal();
|
||||
assert_eq!(header.hash(), hash);
|
||||
|
||||
// validate td consistency in the database
|
||||
if header.number > initial_stage_progress {
|
||||
let parent_td = tx.get::<tables::HeaderTD>(
|
||||
(header.number - 1, header.parent_hash).into(),
|
||||
)?;
|
||||
let td = tx.get::<tables::HeaderTD>(key)?.unwrap();
|
||||
assert_eq!(
|
||||
parent_td.map(
|
||||
|td| U256::from_big_endian(&td) + header.difficulty
|
||||
),
|
||||
Some(U256::from_big_endian(&td))
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
})?;
|
||||
}
|
||||
_ => self.check_no_header_entry_above(initial_stage_progress)?,
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: HeaderDownloader + 'static> UnwindStageTestRunner for HeadersTestRunner<D> {
|
||||
fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
|
||||
self.check_no_header_entry_above(input.unwind_to)
|
||||
}
|
||||
}
|
||||
|
||||
impl HeadersTestRunner<LinearDownloader<TestConsensus, TestHeadersClient>> {
|
||||
pub(crate) fn with_linear_downloader() -> Self {
|
||||
let client = Arc::new(TestHeadersClient::default());
|
||||
@ -445,74 +425,15 @@ mod tests {
|
||||
}
|
||||
|
||||
impl<D: HeaderDownloader> HeadersTestRunner<D> {
|
||||
pub(crate) fn with_downloader(downloader: D) -> Self {
|
||||
HeadersTestRunner {
|
||||
client: Arc::new(TestHeadersClient::default()),
|
||||
consensus: Arc::new(TestConsensus::default()),
|
||||
downloader: Arc::new(downloader),
|
||||
db: StageTestDB::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert header into tables
|
||||
pub(crate) fn insert_header(&self, header: &SealedHeader) -> Result<(), db::Error> {
|
||||
self.insert_headers(std::iter::once(header))
|
||||
}
|
||||
|
||||
/// Insert headers into tables
|
||||
pub(crate) fn insert_headers<'a, I>(&self, headers: I) -> Result<(), db::Error>
|
||||
where
|
||||
I: Iterator<Item = &'a SealedHeader>,
|
||||
{
|
||||
let headers = headers.collect::<Vec<_>>();
|
||||
self.db
|
||||
.map_put::<tables::HeaderNumbers, _, _>(&headers, |h| (h.hash(), h.number))?;
|
||||
self.db.map_put::<tables::Headers, _, _>(&headers, |h| {
|
||||
(BlockNumHash((h.number, h.hash())), h.deref().clone().unseal())
|
||||
})?;
|
||||
self.db.map_put::<tables::CanonicalHeaders, _, _>(&headers, |h| {
|
||||
(h.number, h.hash())
|
||||
})?;
|
||||
|
||||
self.db.transform_append::<tables::HeaderTD, _, _>(&headers, |prev, h| {
|
||||
let prev_td = U256::from_big_endian(&prev.clone().unwrap_or_default());
|
||||
(
|
||||
BlockNumHash((h.number, h.hash())),
|
||||
H256::from_uint(&(prev_td + h.difficulty)).as_bytes().to_vec(),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate stored header against provided
|
||||
pub(crate) fn validate_db_header(
|
||||
pub(crate) fn check_no_header_entry_above(
|
||||
&self,
|
||||
header: &SealedHeader,
|
||||
) -> Result<(), db::Error> {
|
||||
let db = self.db.container();
|
||||
let tx = db.get();
|
||||
let key: BlockNumHash = (header.number, header.hash()).into();
|
||||
|
||||
let db_number = tx.get::<tables::HeaderNumbers>(header.hash())?;
|
||||
assert_eq!(db_number, Some(header.number));
|
||||
|
||||
let db_header = tx.get::<tables::Headers>(key)?;
|
||||
assert_eq!(db_header, Some(header.clone().unseal()));
|
||||
|
||||
let db_canonical_header = tx.get::<tables::CanonicalHeaders>(header.number)?;
|
||||
assert_eq!(db_canonical_header, Some(header.hash()));
|
||||
|
||||
if header.number != 0 {
|
||||
let parent_key: BlockNumHash = (header.number - 1, header.parent_hash).into();
|
||||
let parent_td = tx.get::<tables::HeaderTD>(parent_key)?;
|
||||
let td = U256::from_big_endian(&tx.get::<tables::HeaderTD>(key)?.unwrap());
|
||||
assert_eq!(
|
||||
parent_td.map(|td| U256::from_big_endian(&td) + header.difficulty),
|
||||
Some(td)
|
||||
);
|
||||
}
|
||||
|
||||
block: BlockNumber,
|
||||
) -> Result<(), TestRunnerError> {
|
||||
self.db
|
||||
.check_no_entry_above_by_value::<tables::HeaderNumbers, _>(block, |val| val)?;
|
||||
self.db.check_no_entry_above::<tables::CanonicalHeaders, _>(block, |key| key)?;
|
||||
self.db.check_no_entry_above::<tables::Headers, _>(block, |key| key.number())?;
|
||||
self.db.check_no_entry_above::<tables::HeaderTD, _>(block, |key| key.number())?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@ -87,141 +87,18 @@ impl<DB: Database> Stage<DB> for TxIndex {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::util::test_utils::{StageTestDB, StageTestRunner};
|
||||
use crate::test_utils::{
|
||||
stage_test_suite, ExecuteStageTestRunner, StageTestDB, StageTestRunner, TestRunnerError,
|
||||
UnwindStageTestRunner,
|
||||
};
|
||||
use assert_matches::assert_matches;
|
||||
use reth_interfaces::{db::models::BlockNumHash, test_utils::generators::random_header_range};
|
||||
use reth_interfaces::{
|
||||
db::models::{BlockNumHash, StoredBlockBody},
|
||||
test_utils::generators::random_header_range,
|
||||
};
|
||||
use reth_primitives::H256;
|
||||
|
||||
const TEST_STAGE: StageId = StageId("PrevStage");
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_empty_db() {
|
||||
let runner = TxIndexTestRunner::default();
|
||||
let rx = runner.execute(ExecInput::default());
|
||||
assert_matches!(
|
||||
rx.await.unwrap(),
|
||||
Err(StageError::DatabaseIntegrity(DatabaseIntegrityError::CanonicalHeader { .. }))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_no_prev_tx_count() {
|
||||
let runner = TxIndexTestRunner::default();
|
||||
let headers = random_header_range(0..10, H256::zero());
|
||||
runner
|
||||
.db()
|
||||
.map_put::<tables::CanonicalHeaders, _, _>(&headers, |h| (h.number, h.hash()))
|
||||
.expect("failed to insert");
|
||||
|
||||
let (head, tail) = (headers.first().unwrap(), headers.last().unwrap());
|
||||
let input = ExecInput {
|
||||
previous_stage: Some((TEST_STAGE, tail.number)),
|
||||
stage_progress: Some(head.number),
|
||||
};
|
||||
let rx = runner.execute(input);
|
||||
assert_matches!(
|
||||
rx.await.unwrap(),
|
||||
Err(StageError::DatabaseIntegrity(DatabaseIntegrityError::CumulativeTxCount { .. }))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute() {
|
||||
let runner = TxIndexTestRunner::default();
|
||||
let (start, pivot, end) = (0, 100, 200);
|
||||
let headers = random_header_range(start..end, H256::zero());
|
||||
runner
|
||||
.db()
|
||||
.map_put::<tables::CanonicalHeaders, _, _>(&headers, |h| (h.number, h.hash()))
|
||||
.expect("failed to insert");
|
||||
runner
|
||||
.db()
|
||||
.transform_append::<tables::CumulativeTxCount, _, _>(&headers[..=pivot], |prev, h| {
|
||||
(
|
||||
BlockNumHash((h.number, h.hash())),
|
||||
prev.unwrap_or_default() + (rand::random::<u8>() as u64),
|
||||
)
|
||||
})
|
||||
.expect("failed to insert");
|
||||
|
||||
let (pivot, tail) = (headers.get(pivot).unwrap(), headers.last().unwrap());
|
||||
let input = ExecInput {
|
||||
previous_stage: Some((TEST_STAGE, tail.number)),
|
||||
stage_progress: Some(pivot.number),
|
||||
};
|
||||
let rx = runner.execute(input);
|
||||
assert_matches!(
|
||||
rx.await.unwrap(),
|
||||
Ok(ExecOutput { stage_progress, done, reached_tip })
|
||||
if done && reached_tip && stage_progress == tail.number
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unwind_empty_db() {
|
||||
let runner = TxIndexTestRunner::default();
|
||||
let rx = runner.unwind(UnwindInput::default());
|
||||
assert_matches!(
|
||||
rx.await.unwrap(),
|
||||
Ok(UnwindOutput { stage_progress }) if stage_progress == 0
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unwind_no_input() {
|
||||
let runner = TxIndexTestRunner::default();
|
||||
let headers = random_header_range(0..10, H256::zero());
|
||||
runner
|
||||
.db()
|
||||
.transform_append::<tables::CumulativeTxCount, _, _>(&headers, |prev, h| {
|
||||
(
|
||||
BlockNumHash((h.number, h.hash())),
|
||||
prev.unwrap_or_default() + (rand::random::<u8>() as u64),
|
||||
)
|
||||
})
|
||||
.expect("failed to insert");
|
||||
|
||||
let rx = runner.unwind(UnwindInput::default());
|
||||
assert_matches!(
|
||||
rx.await.unwrap(),
|
||||
Ok(UnwindOutput { stage_progress }) if stage_progress == 0
|
||||
);
|
||||
runner
|
||||
.db()
|
||||
.check_no_entry_above::<tables::CumulativeTxCount, _>(0, |h| h.number())
|
||||
.expect("failed to check tx count");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unwind_with_db_gaps() {
|
||||
let runner = TxIndexTestRunner::default();
|
||||
let first_range = random_header_range(0..20, H256::zero());
|
||||
let second_range = random_header_range(50..100, H256::zero());
|
||||
runner
|
||||
.db()
|
||||
.transform_append::<tables::CumulativeTxCount, _, _>(
|
||||
&first_range.iter().chain(second_range.iter()).collect::<Vec<_>>(),
|
||||
|prev, h| {
|
||||
(
|
||||
BlockNumHash((h.number, h.hash())),
|
||||
prev.unwrap_or_default() + (rand::random::<u8>() as u64),
|
||||
)
|
||||
},
|
||||
)
|
||||
.expect("failed to insert");
|
||||
|
||||
let unwind_to = 10;
|
||||
let input = UnwindInput { unwind_to, ..Default::default() };
|
||||
let rx = runner.unwind(input);
|
||||
assert_matches!(
|
||||
rx.await.unwrap(),
|
||||
Ok(UnwindOutput { stage_progress }) if stage_progress == unwind_to
|
||||
);
|
||||
runner
|
||||
.db()
|
||||
.check_no_entry_above::<tables::CumulativeTxCount, _>(unwind_to, |h| h.number())
|
||||
.expect("failed to check tx count");
|
||||
}
|
||||
stage_test_suite!(TxIndexTestRunner);
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct TxIndexTestRunner {
|
||||
@ -239,4 +116,82 @@ mod tests {
|
||||
TxIndex {}
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecuteStageTestRunner for TxIndexTestRunner {
|
||||
type Seed = ();
|
||||
|
||||
fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
|
||||
let pivot = input.stage_progress.unwrap_or_default();
|
||||
let start = pivot.saturating_sub(100);
|
||||
let mut end = input.previous_stage.as_ref().map(|(_, num)| *num).unwrap_or_default();
|
||||
end += 2; // generate 2 additional headers to account for start header lookup
|
||||
let headers = random_header_range(start..end, H256::zero());
|
||||
|
||||
let headers =
|
||||
headers.into_iter().map(|h| (h, rand::random::<u8>())).collect::<Vec<_>>();
|
||||
|
||||
self.db.map_put::<tables::CanonicalHeaders, _, _>(&headers, |(h, _)| {
|
||||
(h.number, h.hash())
|
||||
})?;
|
||||
self.db.map_put::<tables::BlockBodies, _, _>(&headers, |(h, count)| {
|
||||
(
|
||||
BlockNumHash((h.number, h.hash())),
|
||||
StoredBlockBody { base_tx_id: 0, tx_amount: *count as u64, ommers: vec![] },
|
||||
)
|
||||
})?;
|
||||
|
||||
let slice_up_to =
|
||||
std::cmp::min(pivot.saturating_sub(start) as usize, headers.len() - 1);
|
||||
self.db.transform_append::<tables::CumulativeTxCount, _, _>(
|
||||
&headers[..=slice_up_to],
|
||||
|prev, (h, count)| {
|
||||
(BlockNumHash((h.number, h.hash())), prev.unwrap_or_default() + (*count as u64))
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_execution(
|
||||
&self,
|
||||
input: ExecInput,
|
||||
_output: Option<ExecOutput>,
|
||||
) -> Result<(), TestRunnerError> {
|
||||
self.db.query(|tx| {
|
||||
let (start, end) = (
|
||||
input.stage_progress.unwrap_or_default(),
|
||||
input.previous_stage.as_ref().map(|(_, num)| *num).unwrap_or_default(),
|
||||
);
|
||||
if start >= end {
|
||||
return Ok(())
|
||||
}
|
||||
|
||||
let start_hash =
|
||||
tx.get::<tables::CanonicalHeaders>(start)?.expect("no canonical found");
|
||||
let mut tx_count_cursor = tx.cursor::<tables::CumulativeTxCount>()?;
|
||||
let mut tx_count_walker = tx_count_cursor.walk((start, start_hash).into())?;
|
||||
let mut count = tx_count_walker.next().unwrap()?.1;
|
||||
let mut last_num = start;
|
||||
while let Some(entry) = tx_count_walker.next() {
|
||||
let (key, db_count) = entry?;
|
||||
count += tx.get::<tables::BlockBodies>(key)?.unwrap().tx_amount as u64;
|
||||
assert_eq!(db_count, count);
|
||||
last_num = key.number();
|
||||
}
|
||||
assert_eq!(last_num, end);
|
||||
|
||||
Ok(())
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl UnwindStageTestRunner for TxIndexTestRunner {
|
||||
fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
|
||||
self.db.check_no_entry_above::<tables::CumulativeTxCount, _>(input.unwind_to, |h| {
|
||||
h.number()
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
145
crates/stages/src/test_utils/macros.rs
Normal file
145
crates/stages/src/test_utils/macros.rs
Normal file
@ -0,0 +1,145 @@
|
||||
macro_rules! stage_test_suite {
|
||||
($runner:ident) => {
|
||||
/// Check that the execution is short-circuited if the database is empty.
|
||||
#[tokio::test]
|
||||
async fn execute_empty_db() {
|
||||
// Set up the runner
|
||||
let runner = $runner::default();
|
||||
|
||||
// Execute the stage with empty database
|
||||
let input = crate::stage::ExecInput::default();
|
||||
|
||||
// Run stage execution
|
||||
let result = runner.execute(input).await.unwrap();
|
||||
assert_matches!(
|
||||
result,
|
||||
Err(crate::error::StageError::DatabaseIntegrity(_))
|
||||
);
|
||||
|
||||
// Validate the stage execution
|
||||
assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation");
|
||||
}
|
||||
|
||||
/// Check that the execution is short-circuited if the target was already reached.
|
||||
#[tokio::test]
|
||||
async fn execute_already_reached_target() {
|
||||
let stage_progress = 1000;
|
||||
|
||||
// Set up the runner
|
||||
let mut runner = $runner::default();
|
||||
let input = crate::stage::ExecInput {
|
||||
previous_stage: Some((crate::test_utils::PREV_STAGE_ID, stage_progress)),
|
||||
stage_progress: Some(stage_progress),
|
||||
};
|
||||
let seed = runner.seed_execution(input).expect("failed to seed");
|
||||
|
||||
// Run stage execution
|
||||
let rx = runner.execute(input);
|
||||
|
||||
// Run `after_execution` hook
|
||||
runner.after_execution(seed).await.expect("failed to run after execution hook");
|
||||
|
||||
// Assert the successful result
|
||||
let result = rx.await.unwrap();
|
||||
assert_matches!(
|
||||
result,
|
||||
Ok(ExecOutput { done, reached_tip, stage_progress })
|
||||
if done && reached_tip && stage_progress == stage_progress
|
||||
);
|
||||
|
||||
// Validate the stage execution
|
||||
assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation");
|
||||
}
|
||||
|
||||
// Run the complete stage execution flow.
|
||||
#[tokio::test]
|
||||
async fn execute() {
|
||||
let (previous_stage, stage_progress) = (500, 100);
|
||||
|
||||
// Set up the runner
|
||||
let mut runner = $runner::default();
|
||||
let input = crate::stage::ExecInput {
|
||||
previous_stage: Some((crate::test_utils::PREV_STAGE_ID, previous_stage)),
|
||||
stage_progress: Some(stage_progress),
|
||||
};
|
||||
let seed = runner.seed_execution(input).expect("failed to seed");
|
||||
let rx = runner.execute(input);
|
||||
|
||||
// Run `after_execution` hook
|
||||
runner.after_execution(seed).await.expect("failed to run after execution hook");
|
||||
|
||||
// Assert the successful result
|
||||
let result = rx.await.unwrap();
|
||||
assert_matches!(
|
||||
result,
|
||||
Ok(ExecOutput { done, reached_tip, stage_progress })
|
||||
if done && reached_tip && stage_progress == previous_stage
|
||||
);
|
||||
|
||||
// Validate the stage execution
|
||||
assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation");
|
||||
}
|
||||
|
||||
// Check that unwind does not panic on empty database.
|
||||
#[tokio::test]
|
||||
async fn unwind_empty_db() {
|
||||
// Set up the runner
|
||||
let runner = $runner::default();
|
||||
let input = crate::stage::UnwindInput::default();
|
||||
|
||||
// Run stage unwind
|
||||
let rx = runner.unwind(input).await;
|
||||
assert_matches!(
|
||||
rx,
|
||||
Ok(UnwindOutput { stage_progress }) if stage_progress == input.unwind_to
|
||||
);
|
||||
|
||||
// Validate the stage unwind
|
||||
assert!(runner.validate_unwind(input).is_ok(), "unwind validation");
|
||||
}
|
||||
|
||||
// Run complete execute and unwind flow.
|
||||
#[tokio::test]
|
||||
async fn unwind() {
|
||||
let (previous_stage, stage_progress) = (500, 100);
|
||||
|
||||
// Set up the runner
|
||||
let mut runner = $runner::default();
|
||||
let execute_input = crate::stage::ExecInput {
|
||||
previous_stage: Some((crate::test_utils::PREV_STAGE_ID, previous_stage)),
|
||||
stage_progress: Some(stage_progress),
|
||||
};
|
||||
let seed = runner.seed_execution(execute_input).expect("failed to seed");
|
||||
|
||||
// Run stage execution
|
||||
let rx = runner.execute(execute_input);
|
||||
runner.after_execution(seed).await.expect("failed to run after execution hook");
|
||||
|
||||
// Assert the successful execution result
|
||||
let result = rx.await.unwrap();
|
||||
assert_matches!(
|
||||
result,
|
||||
Ok(ExecOutput { done, reached_tip, stage_progress })
|
||||
if done && reached_tip && stage_progress == previous_stage
|
||||
);
|
||||
assert!(runner.validate_execution(execute_input, result.ok()).is_ok(), "execution validation");
|
||||
|
||||
// Run stage unwind
|
||||
let unwind_input = crate::stage::UnwindInput {
|
||||
unwind_to: stage_progress, stage_progress, bad_block: None,
|
||||
};
|
||||
let rx = runner.unwind(unwind_input).await;
|
||||
|
||||
// Assert the successful unwind result
|
||||
assert_matches!(
|
||||
rx,
|
||||
Ok(UnwindOutput { stage_progress }) if stage_progress == unwind_input.unwind_to
|
||||
);
|
||||
|
||||
// Validate the stage unwind
|
||||
assert!(runner.validate_unwind(unwind_input).is_ok(), "unwind validation");
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub(crate) use stage_test_suite;
|
||||
14
crates/stages/src/test_utils/mod.rs
Normal file
14
crates/stages/src/test_utils/mod.rs
Normal file
@ -0,0 +1,14 @@
|
||||
use crate::StageId;
|
||||
|
||||
mod macros;
|
||||
mod runner;
|
||||
mod stage_db;
|
||||
|
||||
pub(crate) use macros::*;
|
||||
pub(crate) use runner::{
|
||||
ExecuteStageTestRunner, StageTestRunner, TestRunnerError, UnwindStageTestRunner,
|
||||
};
|
||||
pub(crate) use stage_db::StageTestDB;
|
||||
|
||||
/// The previous test stage id mock used for testing
|
||||
pub(crate) const PREV_STAGE_ID: StageId = StageId("PrevStage");
|
||||
82
crates/stages/src/test_utils/runner.rs
Normal file
82
crates/stages/src/test_utils/runner.rs
Normal file
@ -0,0 +1,82 @@
|
||||
use reth_db::{kv::Env, mdbx::WriteMap};
|
||||
use reth_interfaces::db::{self, DBContainer};
|
||||
use std::borrow::Borrow;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
use super::StageTestDB;
|
||||
use crate::{ExecInput, ExecOutput, Stage, StageError, UnwindInput, UnwindOutput};
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub(crate) enum TestRunnerError {
|
||||
#[error("Database error occured.")]
|
||||
Database(#[from] db::Error),
|
||||
#[error("Internal runner error occured.")]
|
||||
Internal(#[from] Box<dyn std::error::Error>),
|
||||
}
|
||||
|
||||
/// A generic test runner for stages.
|
||||
#[async_trait::async_trait]
|
||||
pub(crate) trait StageTestRunner {
|
||||
type S: Stage<Env<WriteMap>> + 'static;
|
||||
|
||||
/// Return a reference to the database.
|
||||
fn db(&self) -> &StageTestDB;
|
||||
|
||||
/// Return an instance of a Stage.
|
||||
fn stage(&self) -> Self::S;
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub(crate) trait ExecuteStageTestRunner: StageTestRunner {
|
||||
type Seed: Send + Sync;
|
||||
|
||||
/// Seed database for stage execution
|
||||
fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError>;
|
||||
|
||||
/// Validate stage execution
|
||||
fn validate_execution(
|
||||
&self,
|
||||
input: ExecInput,
|
||||
output: Option<ExecOutput>,
|
||||
) -> Result<(), TestRunnerError>;
|
||||
|
||||
/// Run [Stage::execute] and return a receiver for the result.
|
||||
fn execute(&self, input: ExecInput) -> oneshot::Receiver<Result<ExecOutput, StageError>> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let (db, mut stage) = (self.db().inner(), self.stage());
|
||||
tokio::spawn(async move {
|
||||
let mut db = DBContainer::new(db.borrow()).expect("failed to create db container");
|
||||
let result = stage.execute(&mut db, input).await;
|
||||
db.commit().expect("failed to commit");
|
||||
tx.send(result).expect("failed to send message")
|
||||
});
|
||||
rx
|
||||
}
|
||||
|
||||
/// Run a hook after [Stage::execute]. Required for Headers & Bodies stages.
|
||||
async fn after_execution(&self, _seed: Self::Seed) -> Result<(), TestRunnerError> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub(crate) trait UnwindStageTestRunner: StageTestRunner {
|
||||
/// Validate the unwind
|
||||
fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError>;
|
||||
|
||||
/// Run [Stage::unwind] and return a receiver for the result.
|
||||
async fn unwind(
|
||||
&self,
|
||||
input: UnwindInput,
|
||||
) -> Result<UnwindOutput, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let (db, mut stage) = (self.db().inner(), self.stage());
|
||||
tokio::spawn(async move {
|
||||
let mut db = DBContainer::new(db.borrow()).expect("failed to create db container");
|
||||
let result = stage.unwind(&mut db, input).await;
|
||||
db.commit().expect("failed to commit");
|
||||
tx.send(result).expect("failed to send result");
|
||||
});
|
||||
Box::pin(rx).await.unwrap()
|
||||
}
|
||||
}
|
||||
179
crates/stages/src/test_utils/stage_db.rs
Normal file
179
crates/stages/src/test_utils/stage_db.rs
Normal file
@ -0,0 +1,179 @@
|
||||
use reth_db::{
|
||||
kv::{test_utils::create_test_db, tx::Tx, Env, EnvKind},
|
||||
mdbx::{WriteMap, RW},
|
||||
};
|
||||
use reth_interfaces::db::{
|
||||
self, models::BlockNumHash, tables, DBContainer, DbCursorRO, DbCursorRW, DbTx, DbTxMut, Table,
|
||||
};
|
||||
use reth_primitives::{BigEndianHash, BlockNumber, SealedHeader, H256, U256};
|
||||
use std::{borrow::Borrow, sync::Arc};
|
||||
|
||||
/// The [StageTestDB] is used as an internal
|
||||
/// database for testing stage implementation.
|
||||
///
|
||||
/// ```rust
|
||||
/// let db = StageTestDB::default();
|
||||
/// stage.execute(&mut db.container(), input);
|
||||
/// ```
|
||||
pub(crate) struct StageTestDB {
|
||||
db: Arc<Env<WriteMap>>,
|
||||
}
|
||||
|
||||
impl Default for StageTestDB {
|
||||
/// Create a new instance of [StageTestDB]
|
||||
fn default() -> Self {
|
||||
Self { db: create_test_db::<WriteMap>(EnvKind::RW) }
|
||||
}
|
||||
}
|
||||
|
||||
impl StageTestDB {
|
||||
/// Return a database wrapped in [DBContainer].
|
||||
fn container(&self) -> DBContainer<'_, Env<WriteMap>> {
|
||||
DBContainer::new(self.db.borrow()).expect("failed to create db container")
|
||||
}
|
||||
|
||||
/// Get a pointer to an internal database.
|
||||
pub(crate) fn inner(&self) -> Arc<Env<WriteMap>> {
|
||||
self.db.clone()
|
||||
}
|
||||
|
||||
/// Invoke a callback with transaction committing it afterwards
|
||||
pub(crate) fn commit<F>(&self, f: F) -> Result<(), db::Error>
|
||||
where
|
||||
F: FnOnce(&mut Tx<'_, RW, WriteMap>) -> Result<(), db::Error>,
|
||||
{
|
||||
let mut db = self.container();
|
||||
let tx = db.get_mut();
|
||||
f(tx)?;
|
||||
db.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Invoke a callback with a read transaction
|
||||
pub(crate) fn query<F, R>(&self, f: F) -> Result<R, db::Error>
|
||||
where
|
||||
F: FnOnce(&Tx<'_, RW, WriteMap>) -> Result<R, db::Error>,
|
||||
{
|
||||
f(self.container().get())
|
||||
}
|
||||
|
||||
/// Map a collection of values and store them in the database.
|
||||
/// This function commits the transaction before exiting.
|
||||
///
|
||||
/// ```rust
|
||||
/// let db = StageTestDB::default();
|
||||
/// db.map_put::<Table, _, _>(&items, |item| item)?;
|
||||
/// ```
|
||||
pub(crate) fn map_put<T, S, F>(&self, values: &[S], mut map: F) -> Result<(), db::Error>
|
||||
where
|
||||
T: Table,
|
||||
S: Clone,
|
||||
F: FnMut(&S) -> (T::Key, T::Value),
|
||||
{
|
||||
self.commit(|tx| {
|
||||
values.iter().try_for_each(|src| {
|
||||
let (k, v) = map(src);
|
||||
tx.put::<T>(k, v)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/// Transform a collection of values using a callback and store
|
||||
/// them in the database. The callback additionally accepts the
|
||||
/// optional last element that was stored.
|
||||
/// This function commits the transaction before exiting.
|
||||
///
|
||||
/// ```rust
|
||||
/// let db = StageTestDB::default();
|
||||
/// db.transform_append::<Table, _, _>(&items, |prev, item| prev.unwrap_or_default() + item)?;
|
||||
/// ```
|
||||
pub(crate) fn transform_append<T, S, F>(
|
||||
&self,
|
||||
values: &[S],
|
||||
mut transform: F,
|
||||
) -> Result<(), db::Error>
|
||||
where
|
||||
T: Table,
|
||||
<T as Table>::Value: Clone,
|
||||
S: Clone,
|
||||
F: FnMut(&Option<<T as Table>::Value>, &S) -> (T::Key, T::Value),
|
||||
{
|
||||
self.commit(|tx| {
|
||||
let mut cursor = tx.cursor_mut::<T>()?;
|
||||
let mut last = cursor.last()?.map(|(_, v)| v);
|
||||
values.iter().try_for_each(|src| {
|
||||
let (k, v) = transform(&last, src);
|
||||
last = Some(v.clone());
|
||||
cursor.append(k, v)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/// Check that there is no table entry above a given
|
||||
/// block by [Table::Key]
|
||||
pub(crate) fn check_no_entry_above<T, F>(
|
||||
&self,
|
||||
block: BlockNumber,
|
||||
mut selector: F,
|
||||
) -> Result<(), db::Error>
|
||||
where
|
||||
T: Table,
|
||||
F: FnMut(T::Key) -> BlockNumber,
|
||||
{
|
||||
self.query(|tx| {
|
||||
let mut cursor = tx.cursor::<T>()?;
|
||||
if let Some((key, _)) = cursor.last()? {
|
||||
assert!(selector(key) <= block);
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Check that there is no table entry above a given
|
||||
/// block by [Table::Value]
|
||||
pub(crate) fn check_no_entry_above_by_value<T, F>(
|
||||
&self,
|
||||
block: BlockNumber,
|
||||
mut selector: F,
|
||||
) -> Result<(), db::Error>
|
||||
where
|
||||
T: Table,
|
||||
F: FnMut(T::Value) -> BlockNumber,
|
||||
{
|
||||
self.query(|tx| {
|
||||
let mut cursor = tx.cursor::<T>()?;
|
||||
if let Some((_, value)) = cursor.last()? {
|
||||
assert!(selector(value) <= block);
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Insert ordered collection of [SealedHeader] into the corresponding tables
|
||||
/// that are supposed to be populated by the headers stage.
|
||||
pub(crate) fn insert_headers<'a, I>(&self, headers: I) -> Result<(), db::Error>
|
||||
where
|
||||
I: Iterator<Item = &'a SealedHeader>,
|
||||
{
|
||||
self.commit(|tx| {
|
||||
let headers = headers.collect::<Vec<_>>();
|
||||
|
||||
let mut td = U256::from_big_endian(
|
||||
&tx.cursor::<tables::HeaderTD>()?.last()?.map(|(_, v)| v).unwrap_or_default(),
|
||||
);
|
||||
|
||||
for header in headers {
|
||||
let key: BlockNumHash = (header.number, header.hash()).into();
|
||||
|
||||
tx.put::<tables::CanonicalHeaders>(header.number, header.hash())?;
|
||||
tx.put::<tables::HeaderNumbers>(header.hash(), header.number)?;
|
||||
tx.put::<tables::Headers>(key, header.clone().unseal())?;
|
||||
|
||||
td += header.difficulty;
|
||||
tx.put::<tables::HeaderTD>(key, H256::from_uint(&td).as_bytes().to_vec())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -135,189 +135,3 @@ pub(crate) mod unwind {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod test_utils {
|
||||
use reth_db::{
|
||||
kv::{test_utils::create_test_db, Env, EnvKind},
|
||||
mdbx::WriteMap,
|
||||
};
|
||||
use reth_interfaces::db::{DBContainer, DbCursorRO, DbCursorRW, DbTx, DbTxMut, Error, Table};
|
||||
use reth_primitives::BlockNumber;
|
||||
use std::{borrow::Borrow, sync::Arc};
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
use crate::{ExecInput, ExecOutput, Stage, StageError, UnwindInput, UnwindOutput};
|
||||
|
||||
/// The [StageTestDB] is used as an internal
|
||||
/// database for testing stage implementation.
|
||||
///
|
||||
/// ```rust
|
||||
/// let db = StageTestDB::default();
|
||||
/// stage.execute(&mut db.container(), input);
|
||||
/// ```
|
||||
pub(crate) struct StageTestDB {
|
||||
db: Arc<Env<WriteMap>>,
|
||||
}
|
||||
|
||||
impl Default for StageTestDB {
|
||||
/// Create a new instance of [StageTestDB]
|
||||
fn default() -> Self {
|
||||
Self { db: create_test_db::<WriteMap>(EnvKind::RW) }
|
||||
}
|
||||
}
|
||||
|
||||
impl StageTestDB {
|
||||
/// Get a pointer to an internal database.
|
||||
pub(crate) fn inner(&self) -> Arc<Env<WriteMap>> {
|
||||
self.db.clone()
|
||||
}
|
||||
|
||||
/// Return a database wrapped in [DBContainer].
|
||||
pub(crate) fn container(&self) -> DBContainer<'_, Env<WriteMap>> {
|
||||
DBContainer::new(self.db.borrow()).expect("failed to create db container")
|
||||
}
|
||||
|
||||
/// Map a collection of values and store them in the database.
|
||||
/// This function commits the transaction before exiting.
|
||||
///
|
||||
/// ```rust
|
||||
/// let db = StageTestDB::default();
|
||||
/// db.map_put::<Table, _, _>(&items, |item| item)?;
|
||||
/// ```
|
||||
pub(crate) fn map_put<T, S, F>(&self, values: &[S], mut map: F) -> Result<(), Error>
|
||||
where
|
||||
T: Table,
|
||||
S: Clone,
|
||||
F: FnMut(&S) -> (T::Key, T::Value),
|
||||
{
|
||||
let mut db = self.container();
|
||||
let tx = db.get_mut();
|
||||
values.iter().try_for_each(|src| {
|
||||
let (k, v) = map(src);
|
||||
tx.put::<T>(k, v)
|
||||
})?;
|
||||
db.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Transform a collection of values using a callback and store
|
||||
/// them in the database. The callback additionally accepts the
|
||||
/// optional last element that was stored.
|
||||
/// This function commits the transaction before exiting.
|
||||
///
|
||||
/// ```rust
|
||||
/// let db = StageTestDB::default();
|
||||
/// db.transform_append::<Table, _, _>(&items, |prev, item| prev.unwrap_or_default() + item)?;
|
||||
/// ```
|
||||
pub(crate) fn transform_append<T, S, F>(
|
||||
&self,
|
||||
values: &[S],
|
||||
mut transform: F,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
T: Table,
|
||||
<T as Table>::Value: Clone,
|
||||
S: Clone,
|
||||
F: FnMut(&Option<<T as Table>::Value>, &S) -> (T::Key, T::Value),
|
||||
{
|
||||
let mut db = self.container();
|
||||
let tx = db.get_mut();
|
||||
let mut cursor = tx.cursor_mut::<T>()?;
|
||||
let mut last = cursor.last()?.map(|(_, v)| v);
|
||||
values.iter().try_for_each(|src| {
|
||||
let (k, v) = transform(&last, src);
|
||||
last = Some(v.clone());
|
||||
cursor.append(k, v)
|
||||
})?;
|
||||
db.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check that there is no table entry above a given
|
||||
/// block by [Table::Key]
|
||||
pub(crate) fn check_no_entry_above<T, F>(
|
||||
&self,
|
||||
block: BlockNumber,
|
||||
mut selector: F,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
T: Table,
|
||||
F: FnMut(T::Key) -> BlockNumber,
|
||||
{
|
||||
let db = self.container();
|
||||
let tx = db.get();
|
||||
|
||||
let mut cursor = tx.cursor::<T>()?;
|
||||
if let Some((key, _)) = cursor.last()? {
|
||||
assert!(selector(key) <= block);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check that there is no table entry above a given
|
||||
/// block by [Table::Value]
|
||||
pub(crate) fn check_no_entry_above_by_value<T, F>(
|
||||
&self,
|
||||
block: BlockNumber,
|
||||
mut selector: F,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
T: Table,
|
||||
F: FnMut(T::Value) -> BlockNumber,
|
||||
{
|
||||
let db = self.container();
|
||||
let tx = db.get();
|
||||
|
||||
let mut cursor = tx.cursor::<T>()?;
|
||||
if let Some((_, value)) = cursor.last()? {
|
||||
assert!(selector(value) <= block);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A generic test runner for stages.
|
||||
#[async_trait::async_trait]
|
||||
pub(crate) trait StageTestRunner {
|
||||
type S: Stage<Env<WriteMap>> + 'static;
|
||||
|
||||
/// Return a reference to the database.
|
||||
fn db(&self) -> &StageTestDB;
|
||||
|
||||
/// Return an instance of a Stage.
|
||||
fn stage(&self) -> Self::S;
|
||||
|
||||
/// Run [Stage::execute] and return a receiver for the result.
|
||||
fn execute(&self, input: ExecInput) -> oneshot::Receiver<Result<ExecOutput, StageError>> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let (db, mut stage) = (self.db().inner(), self.stage());
|
||||
tokio::spawn(async move {
|
||||
let mut db = DBContainer::new(db.borrow()).expect("failed to create db container");
|
||||
let result = stage.execute(&mut db, input).await;
|
||||
db.commit().expect("failed to commit");
|
||||
tx.send(result).expect("failed to send message")
|
||||
});
|
||||
rx
|
||||
}
|
||||
|
||||
/// Run [Stage::unwind] and return a receiver for the result.
|
||||
fn unwind(
|
||||
&self,
|
||||
input: UnwindInput,
|
||||
) -> oneshot::Receiver<Result<UnwindOutput, Box<dyn std::error::Error + Send + Sync>>>
|
||||
{
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let (db, mut stage) = (self.db().inner(), self.stage());
|
||||
tokio::spawn(async move {
|
||||
let mut db = DBContainer::new(db.borrow()).expect("failed to create db container");
|
||||
let result = stage.unwind(&mut db, input).await;
|
||||
db.commit().expect("failed to commit");
|
||||
tx.send(result).expect("failed to send result");
|
||||
});
|
||||
rx
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user