test(sync): stage test suite (#204)

* test(sync): stage test suite

* cleanup txindex tests

* nit

* start revamping bodies testing

* revamp body testing

* add comments to suite tests

* fmt

* cleanup dup code

* cleanup insert_headers helper fn

* fix tests

* linter

* switch mutex to atomic

* cleanup

* revert

* test: make unwind runner return value instead of channel

* test: make execute runner return value instead of channel

* Revert "test: make execute runner return value instead of channel"

This reverts commit f8608654f2e4cf97f60ce6aa95c28009f71d5331.

Co-authored-by: Georgios Konstantopoulos <me@gakonst.com>
This commit is contained in:
Roman Krasiuk
2022-11-19 03:57:29 +02:00
committed by GitHub
parent ae8f7a2dd6
commit 4936d467c9
12 changed files with 951 additions and 932 deletions

View File

@ -61,7 +61,7 @@ impl<E: EnvironmentKind> Env<E> {
inner: Environment::new()
.set_max_dbs(TABLES.len())
.set_geometry(Geometry {
size: Some(0..0x100000), // TODO: reevaluate
size: Some(0..0x1000000), // TODO: reevaluate
growth_step: Some(0x100000), // TODO: reevaluate
shrink_threshold: None,
page_size: Some(PageSize::Set(default_page_size())),

View File

@ -1,6 +1,6 @@
//! Testing support for headers related interfaces.
use crate::{
consensus::{self, Consensus, Error},
consensus::{self, Consensus},
p2p::headers::{
client::{HeadersClient, HeadersRequest, HeadersResponse, HeadersStream},
downloader::HeaderDownloader,
@ -9,20 +9,28 @@ use crate::{
};
use reth_primitives::{BlockLocked, Header, SealedHeader, H256, H512};
use reth_rpc_types::engine::ForkchoiceState;
use std::{collections::HashSet, sync::Arc, time::Duration};
use std::{
collections::HashSet,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use tokio::sync::{broadcast, mpsc, watch};
use tokio_stream::{wrappers::BroadcastStream, StreamExt};
/// A test downloader which just returns the values that have been pushed to it.
#[derive(Debug)]
pub struct TestHeaderDownloader {
result: Result<Vec<SealedHeader>, DownloadError>,
client: Arc<TestHeadersClient>,
consensus: Arc<TestConsensus>,
}
impl TestHeaderDownloader {
/// Instantiates the downloader with the mock responses
pub fn new(result: Result<Vec<SealedHeader>, DownloadError>) -> Self {
Self { result }
pub fn new(client: Arc<TestHeadersClient>, consensus: Arc<TestConsensus>) -> Self {
Self { client, consensus }
}
}
@ -36,11 +44,11 @@ impl HeaderDownloader for TestHeaderDownloader {
}
fn consensus(&self) -> &Self::Consensus {
unimplemented!()
&self.consensus
}
fn client(&self) -> &Self::Client {
unimplemented!()
&self.client
}
async fn download(
@ -48,7 +56,27 @@ impl HeaderDownloader for TestHeaderDownloader {
_: &SealedHeader,
_: &ForkchoiceState,
) -> Result<Vec<SealedHeader>, DownloadError> {
self.result.clone()
// call consensus stub first. fails if the flag is set
let empty = SealedHeader::default();
self.consensus
.validate_header(&empty, &empty)
.map_err(|error| DownloadError::HeaderValidation { hash: empty.hash(), error })?;
let stream = self.client.stream_headers().await;
let stream = stream.timeout(Duration::from_secs(1));
match Box::pin(stream).try_next().await {
Ok(Some(res)) => {
let mut headers = res.headers.iter().map(|h| h.clone().seal()).collect::<Vec<_>>();
if !headers.is_empty() {
headers.sort_unstable_by_key(|h| h.number);
headers.remove(0); // remove head from response
headers.reverse();
}
Ok(headers)
}
_ => Err(DownloadError::Timeout { request_id: 0 }),
}
}
}
@ -93,6 +121,12 @@ impl TestHeadersClient {
pub fn send_header_response(&self, id: u64, headers: Vec<Header>) {
self.res_tx.send((id, headers).into()).expect("failed to send header response");
}
/// Helper for pushing responses to the client
pub async fn send_header_response_delayed(&self, id: u64, headers: Vec<Header>, secs: u64) {
tokio::time::sleep(Duration::from_secs(secs)).await;
self.send_header_response(id, headers);
}
}
#[async_trait::async_trait]
@ -106,6 +140,9 @@ impl HeadersClient for TestHeadersClient {
}
async fn stream_headers(&self) -> HeadersStream {
if !self.res_rx.is_empty() {
println!("WARNING: broadcast receiver already contains messages.")
}
Box::pin(BroadcastStream::new(self.res_rx.resubscribe()).filter_map(|e| e.ok()))
}
}
@ -116,7 +153,7 @@ pub struct TestConsensus {
/// Watcher over the forkchoice state
channel: (watch::Sender<ForkchoiceState>, watch::Receiver<ForkchoiceState>),
/// Flag whether the header validation should purposefully fail
fail_validation: bool,
fail_validation: AtomicBool,
}
impl Default for TestConsensus {
@ -127,7 +164,7 @@ impl Default for TestConsensus {
finalized_block_hash: H256::zero(),
safe_block_hash: H256::zero(),
}),
fail_validation: false,
fail_validation: AtomicBool::new(false),
}
}
}
@ -143,9 +180,14 @@ impl TestConsensus {
self.channel.0.send(state).expect("updating fork choice state failed");
}
/// Get the failed validation flag
pub fn fail_validation(&self) -> bool {
self.fail_validation.load(Ordering::SeqCst)
}
/// Update the validation flag
pub fn set_fail_validation(&mut self, val: bool) {
self.fail_validation = val;
pub fn set_fail_validation(&self, val: bool) {
self.fail_validation.store(val, Ordering::SeqCst)
}
}
@ -160,15 +202,15 @@ impl Consensus for TestConsensus {
_header: &SealedHeader,
_parent: &SealedHeader,
) -> Result<(), consensus::Error> {
if self.fail_validation {
if self.fail_validation() {
Err(consensus::Error::BaseFeeMissing)
} else {
Ok(())
}
}
fn pre_validate_block(&self, _block: &BlockLocked) -> Result<(), Error> {
if self.fail_validation {
fn pre_validate_block(&self, _block: &BlockLocked) -> Result<(), consensus::Error> {
if self.fail_validation() {
Err(consensus::Error::BaseFeeMissing)
} else {
Ok(())

View File

@ -215,7 +215,7 @@ mod tests {
static CONSENSUS: Lazy<Arc<TestConsensus>> = Lazy::new(|| Arc::new(TestConsensus::default()));
static CONSENSUS_FAIL: Lazy<Arc<TestConsensus>> = Lazy::new(|| {
let mut consensus = TestConsensus::default();
let consensus = TestConsensus::default();
consensus.set_fail_validation(true);
Arc::new(consensus)
});

View File

@ -20,6 +20,9 @@ mod pipeline;
mod stage;
mod util;
#[cfg(test)]
mod test_utils;
/// Implementations of stages.
pub mod stages;

View File

@ -15,7 +15,7 @@ use reth_primitives::{
proofs::{EMPTY_LIST_HASH, EMPTY_ROOT},
BlockLocked, BlockNumber, SealedHeader, H256,
};
use std::fmt::Debug;
use std::{fmt::Debug, sync::Arc};
use tracing::warn;
const BODIES: StageId = StageId("Bodies");
@ -51,9 +51,9 @@ const BODIES: StageId = StageId("Bodies");
#[derive(Debug)]
pub struct BodyStage<D: BodyDownloader, C: Consensus> {
/// The body downloader.
pub downloader: D,
pub downloader: Arc<D>,
/// The consensus engine.
pub consensus: C,
pub consensus: Arc<C>,
/// The maximum amount of block bodies to process in one stage execution.
///
/// Smaller batch sizes result in less memory usage, but more disk I/O. Larger batch sizes
@ -81,6 +81,9 @@ impl<DB: Database, D: BodyDownloader, C: Consensus> Stage<DB> for BodyStage<D, C
input.previous_stage.as_ref().map(|(_, block)| *block).unwrap_or_default();
if previous_stage_progress == 0 {
warn!("The body stage seems to be running first, no work can be completed.");
return Err(StageError::DatabaseIntegrity(DatabaseIntegrityError::BlockBody {
number: 0,
}))
}
// The block we ended at last sync, and the one we are starting on now
@ -230,67 +233,36 @@ impl<D: BodyDownloader, C: Consensus> BodyStage<D, C> {
#[cfg(test)]
mod tests {
use super::*;
use crate::util::test_utils::StageTestRunner;
use assert_matches::assert_matches;
use reth_eth_wire::BlockBody;
use reth_interfaces::{
consensus,
p2p::bodies::error::DownloadError,
test_utils::generators::{random_block, random_block_range},
use crate::test_utils::{
stage_test_suite, ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner,
PREV_STAGE_ID,
};
use reth_primitives::{BlockNumber, H256};
use assert_matches::assert_matches;
use reth_interfaces::{consensus, p2p::bodies::error::DownloadError};
use std::collections::HashMap;
use test_utils::*;
/// Check that the execution is short-circuited if the database is empty.
#[tokio::test]
async fn empty_db() {
let runner = BodyTestRunner::new(TestBodyDownloader::default);
let rx = runner.execute(ExecInput::default());
assert_matches!(
rx.await.unwrap(),
Ok(ExecOutput { stage_progress: 0, reached_tip: true, done: true })
)
}
/// Check that the execution is short-circuited if the target was already reached.
#[tokio::test]
async fn already_reached_target() {
let runner = BodyTestRunner::new(TestBodyDownloader::default);
let rx = runner.execute(ExecInput {
previous_stage: Some((StageId("Headers"), 100)),
stage_progress: Some(100),
});
assert_matches!(
rx.await.unwrap(),
Ok(ExecOutput { stage_progress: 100, reached_tip: true, done: true })
)
}
stage_test_suite!(BodyTestRunner);
/// Checks that the stage downloads at most `batch_size` blocks.
#[tokio::test]
async fn partial_body_download() {
// Generate blocks
let blocks = random_block_range(1..200, GENESIS_HASH);
let bodies: HashMap<H256, Result<BlockBody, DownloadError>> =
blocks.iter().map(body_by_hash).collect();
let mut runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone()));
let (stage_progress, previous_stage) = (1, 200);
// Set up test runner
let mut runner = BodyTestRunner::default();
let input = ExecInput {
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
stage_progress: Some(stage_progress),
};
runner.seed_execution(input).expect("failed to seed execution");
// Set the batch size (max we sync per stage execution) to less than the number of blocks
// the previous stage synced (10 vs 20)
runner.set_batch_size(10);
// Insert required state
runner.insert_genesis().expect("Could not insert genesis block");
runner
.insert_headers(blocks.iter().map(|block| &block.header))
.expect("Could not insert headers");
// Run the stage
let rx = runner.execute(ExecInput {
previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)),
stage_progress: None,
});
let rx = runner.execute(input);
// Check that we only synced around `batch_size` blocks even though the number of blocks
// synced by the previous stage is higher
@ -299,34 +271,27 @@ mod tests {
output,
Ok(ExecOutput { stage_progress, reached_tip: true, done: false }) if stage_progress < 200
);
runner
.validate_db_blocks(output.unwrap().stage_progress)
.expect("Written block data invalid");
assert!(runner.validate_execution(input, output.ok()).is_ok(), "execution validation");
}
/// Same as [partial_body_download] except the `batch_size` is not hit.
#[tokio::test]
async fn full_body_download() {
// Generate blocks #1-20
let blocks = random_block_range(1..21, GENESIS_HASH);
let bodies: HashMap<H256, Result<BlockBody, DownloadError>> =
blocks.iter().map(body_by_hash).collect();
let mut runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone()));
let (stage_progress, previous_stage) = (1, 20);
// Set up test runner
let mut runner = BodyTestRunner::default();
let input = ExecInput {
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
stage_progress: Some(stage_progress),
};
runner.seed_execution(input).expect("failed to seed execution");
// Set the batch size to more than what the previous stage synced (40 vs 20)
runner.set_batch_size(40);
// Insert required state
runner.insert_genesis().expect("Could not insert genesis block");
runner
.insert_headers(blocks.iter().map(|block| &block.header))
.expect("Could not insert headers");
// Run the stage
let rx = runner.execute(ExecInput {
previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)),
stage_progress: None,
});
let rx = runner.execute(input);
// Check that we synced all blocks successfully, even though our `batch_size` allows us to
// sync more (if there were more headers)
@ -335,31 +300,26 @@ mod tests {
output,
Ok(ExecOutput { stage_progress: 20, reached_tip: true, done: true })
);
runner
.validate_db_blocks(output.unwrap().stage_progress)
.expect("Written block data invalid");
assert!(runner.validate_execution(input, output.ok()).is_ok(), "execution validation");
}
/// Same as [full_body_download] except we have made progress before
#[tokio::test]
async fn sync_from_previous_progress() {
// Generate blocks #1-20
let blocks = random_block_range(1..21, GENESIS_HASH);
let bodies: HashMap<H256, Result<BlockBody, DownloadError>> =
blocks.iter().map(body_by_hash).collect();
let runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone()));
let (stage_progress, previous_stage) = (1, 21);
// Insert required state
runner.insert_genesis().expect("Could not insert genesis block");
runner
.insert_headers(blocks.iter().map(|block| &block.header))
.expect("Could not insert headers");
// Set up test runner
let mut runner = BodyTestRunner::default();
let input = ExecInput {
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
stage_progress: Some(stage_progress),
};
runner.seed_execution(input).expect("failed to seed execution");
runner.set_batch_size(10);
// Run the stage
let rx = runner.execute(ExecInput {
previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)),
stage_progress: None,
});
let rx = runner.execute(input);
// Check that we synced at least 10 blocks
let first_run = rx.await.unwrap();
@ -370,10 +330,11 @@ mod tests {
let first_run_progress = first_run.unwrap().stage_progress;
// Execute again on top of the previous run
let rx = runner.execute(ExecInput {
previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)),
let input = ExecInput {
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
stage_progress: Some(first_run_progress),
});
};
let rx = runner.execute(input);
// Check that we synced more blocks
let output = rx.await.unwrap();
@ -381,175 +342,86 @@ mod tests {
output,
Ok(ExecOutput { stage_progress, reached_tip: true, done: true }) if stage_progress > first_run_progress
);
runner
.validate_db_blocks(output.unwrap().stage_progress)
.expect("Written block data invalid");
assert!(runner.validate_execution(input, output.ok()).is_ok(), "execution validation");
}
/// Checks that the stage asks to unwind if pre-validation of the block fails.
#[tokio::test]
async fn pre_validation_failure() {
// Generate blocks #1-19
let blocks = random_block_range(1..20, GENESIS_HASH);
let bodies: HashMap<H256, Result<BlockBody, DownloadError>> =
blocks.iter().map(body_by_hash).collect();
let mut runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone()));
let (stage_progress, previous_stage) = (1, 20);
// Set up test runner
let mut runner = BodyTestRunner::default();
let input = ExecInput {
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
stage_progress: Some(stage_progress),
};
runner.seed_execution(input).expect("failed to seed execution");
// Fail validation
runner.set_fail_validation(true);
// Insert required state
runner.insert_genesis().expect("Could not insert genesis block");
runner
.insert_headers(blocks.iter().map(|block| &block.header))
.expect("Could not insert headers");
runner.consensus.set_fail_validation(true);
// Run the stage
let rx = runner.execute(ExecInput {
previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)),
stage_progress: None,
});
let rx = runner.execute(input);
// Check that the error bubbles up
assert_matches!(
rx.await.unwrap(),
Err(StageError::Validation { block: 1, error: consensus::Error::BaseFeeMissing })
Err(StageError::Validation { error: consensus::Error::BaseFeeMissing, .. })
);
}
/// Checks that the stage unwinds correctly with no data.
#[tokio::test]
async fn unwind_empty_db() {
let unwind_to = 10;
let runner = BodyTestRunner::new(TestBodyDownloader::default);
let rx = runner.unwind(UnwindInput { bad_block: None, stage_progress: 20, unwind_to });
assert_matches!(
rx.await.unwrap(),
Ok(UnwindOutput { stage_progress }) if stage_progress == unwind_to
)
}
/// Checks that the stage unwinds correctly with data.
#[tokio::test]
async fn unwind() {
// Generate blocks #1-20
let blocks = random_block_range(1..21, GENESIS_HASH);
let bodies: HashMap<H256, Result<BlockBody, DownloadError>> =
blocks.iter().map(body_by_hash).collect();
let mut runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone()));
// Set the batch size to more than what the previous stage synced (40 vs 20)
runner.set_batch_size(40);
// Insert required state
runner.insert_genesis().expect("Could not insert genesis block");
runner
.insert_headers(blocks.iter().map(|block| &block.header))
.expect("Could not insert headers");
// Run the stage
let rx = runner.execute(ExecInput {
previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)),
stage_progress: None,
});
// Check that we synced all blocks successfully, even though our `batch_size` allows us to
// sync more (if there were more headers)
let output = rx.await.unwrap();
assert_matches!(
output,
Ok(ExecOutput { stage_progress: 20, reached_tip: true, done: true })
);
let stage_progress = output.unwrap().stage_progress;
runner.validate_db_blocks(stage_progress).expect("Written block data invalid");
// Unwind all of it
let unwind_to = 1;
let rx = runner.unwind(UnwindInput { bad_block: None, stage_progress, unwind_to });
assert_matches!(
rx.await.unwrap(),
Ok(UnwindOutput { stage_progress }) if stage_progress == 1
);
let last_body = runner.last_body().expect("Could not read last body");
let last_tx_id = last_body.base_tx_id + last_body.tx_amount;
runner
.db()
.check_no_entry_above::<tables::BlockBodies, _>(unwind_to, |key| key.number())
.expect("Did not unwind block bodies correctly.");
runner
.db()
.check_no_entry_above::<tables::Transactions, _>(last_tx_id, |key| key)
.expect("Did not unwind transactions correctly.")
assert!(runner.validate_execution(input, None).is_ok(), "execution validation");
}
/// Checks that the stage unwinds correctly, even if a transaction in a block is missing.
#[tokio::test]
async fn unwind_missing_tx() {
// Generate blocks #1-20
let blocks = random_block_range(1..21, GENESIS_HASH);
let bodies: HashMap<H256, Result<BlockBody, DownloadError>> =
blocks.iter().map(body_by_hash).collect();
let mut runner = BodyTestRunner::new(|| TestBodyDownloader::new(bodies.clone()));
let (stage_progress, previous_stage) = (1, 20);
// Set up test runner
let mut runner = BodyTestRunner::default();
let input = ExecInput {
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
stage_progress: Some(stage_progress),
};
runner.seed_execution(input).expect("failed to seed execution");
// Set the batch size to more than what the previous stage synced (40 vs 20)
runner.set_batch_size(40);
// Insert required state
runner.insert_genesis().expect("Could not insert genesis block");
runner
.insert_headers(blocks.iter().map(|block| &block.header))
.expect("Could not insert headers");
// Run the stage
let rx = runner.execute(ExecInput {
previous_stage: Some((StageId("Headers"), blocks.len() as BlockNumber)),
stage_progress: None,
});
let rx = runner.execute(input);
// Check that we synced all blocks successfully, even though our `batch_size` allows us to
// sync more (if there were more headers)
let output = rx.await.unwrap();
assert_matches!(
output,
Ok(ExecOutput { stage_progress: 20, reached_tip: true, done: true })
Ok(ExecOutput { stage_progress, reached_tip: true, done: true }) if stage_progress == previous_stage
);
let stage_progress = output.unwrap().stage_progress;
runner.validate_db_blocks(stage_progress).expect("Written block data invalid");
// Delete a transaction
{
let mut db = runner.db().container();
let mut tx_cursor = db
.get_mut()
.cursor_mut::<tables::Transactions>()
.expect("Could not get transaction cursor");
tx_cursor
.last()
.expect("Could not read database")
.expect("Could not read last transaction");
tx_cursor.delete_current().expect("Could not delete last transaction");
db.commit().expect("Could not commit database");
}
runner
.db()
.commit(|tx| {
let mut tx_cursor = tx.cursor_mut::<tables::Transactions>()?;
tx_cursor.last()?.expect("Could not read last transaction");
tx_cursor.delete_current()?;
Ok(())
})
.expect("Could not delete a transaction");
// Unwind all of it
let unwind_to = 1;
let rx = runner.unwind(UnwindInput { bad_block: None, stage_progress, unwind_to });
let input = UnwindInput { bad_block: None, stage_progress, unwind_to };
let res = runner.unwind(input).await;
assert_matches!(
rx.await.unwrap(),
res,
Ok(UnwindOutput { stage_progress }) if stage_progress == 1
);
let last_body = runner.last_body().expect("Could not read last body");
let last_tx_id = last_body.base_tx_id + last_body.tx_amount;
runner
.db()
.check_no_entry_above::<tables::BlockBodies, _>(unwind_to, |key| key.number())
.expect("Did not unwind block bodies correctly.");
runner
.db()
.check_no_entry_above::<tables::Transactions, _>(last_tx_id, |key| key)
.expect("Did not unwind transactions correctly.")
assert!(runner.validate_unwind(input).is_ok(), "unwind validation");
}
/// Checks that the stage exits if the downloader times out
@ -557,54 +429,53 @@ mod tests {
/// try again?
#[tokio::test]
async fn downloader_timeout() {
// Generate a header
let header = random_block(1, Some(GENESIS_HASH)).header;
let runner = BodyTestRunner::new(|| {
TestBodyDownloader::new(HashMap::from([(
header.hash(),
Err(DownloadError::Timeout { header_hash: header.hash() }),
)]))
});
let (stage_progress, previous_stage) = (1, 2);
// Insert required state
runner.insert_genesis().expect("Could not insert genesis block");
runner.insert_header(&header).expect("Could not insert header");
// Set up test runner
let mut runner = BodyTestRunner::default();
let input = ExecInput {
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
stage_progress: Some(stage_progress),
};
let blocks = runner.seed_execution(input).expect("failed to seed execution");
// overwrite responses
let header = blocks.last().unwrap();
runner.set_responses(HashMap::from([(
header.hash(),
Err(DownloadError::Timeout { header_hash: header.hash() }),
)]));
// Run the stage
let rx = runner.execute(ExecInput {
previous_stage: Some((StageId("Headers"), 1)),
stage_progress: None,
});
let rx = runner.execute(input);
// Check that the error bubbles up
assert_matches!(rx.await.unwrap(), Err(StageError::Internal(_)));
assert!(runner.validate_execution(input, None).is_ok(), "execution validation");
}
mod test_utils {
use crate::{
stages::bodies::BodyStage,
util::test_utils::{StageTestDB, StageTestRunner},
test_utils::{
ExecuteStageTestRunner, StageTestDB, StageTestRunner, TestRunnerError,
UnwindStageTestRunner,
},
ExecInput, ExecOutput, UnwindInput,
};
use assert_matches::assert_matches;
use async_trait::async_trait;
use reth_eth_wire::BlockBody;
use reth_interfaces::{
db,
db::{
models::{BlockNumHash, StoredBlockBody},
tables, DbCursorRO, DbTx, DbTxMut,
},
db::{models::StoredBlockBody, tables, DbCursorRO, DbTx, DbTxMut},
p2p::bodies::{
client::BodiesClient,
downloader::{BodiesStream, BodyDownloader},
error::{BodiesClientError, DownloadError},
},
test_utils::TestConsensus,
test_utils::{generators::random_block_range, TestConsensus},
};
use reth_primitives::{
BigEndianHash, BlockLocked, BlockNumber, Header, SealedHeader, H256, U256,
};
use std::{collections::HashMap, ops::Deref, time::Duration};
use reth_primitives::{BlockLocked, BlockNumber, Header, SealedHeader, H256};
use std::{collections::HashMap, sync::Arc, time::Duration};
/// The block hash of the genesis block.
pub(crate) const GENESIS_HASH: H256 = H256::zero();
@ -623,43 +494,38 @@ mod tests {
}
/// A helper struct for running the [BodyStage].
pub(crate) struct BodyTestRunner<F>
where
F: Fn() -> TestBodyDownloader,
{
downloader_builder: F,
pub(crate) struct BodyTestRunner {
pub(crate) consensus: Arc<TestConsensus>,
responses: HashMap<H256, Result<BlockBody, DownloadError>>,
db: StageTestDB,
batch_size: u64,
fail_validation: bool,
}
impl<F> BodyTestRunner<F>
where
F: Fn() -> TestBodyDownloader,
{
/// Build a new test runner.
pub(crate) fn new(downloader_builder: F) -> Self {
BodyTestRunner {
downloader_builder,
impl Default for BodyTestRunner {
fn default() -> Self {
Self {
consensus: Arc::new(TestConsensus::default()),
responses: HashMap::default(),
db: StageTestDB::default(),
batch_size: 10,
fail_validation: false,
batch_size: 1000,
}
}
}
impl BodyTestRunner {
pub(crate) fn set_batch_size(&mut self, batch_size: u64) {
self.batch_size = batch_size;
}
pub(crate) fn set_fail_validation(&mut self, fail_validation: bool) {
self.fail_validation = fail_validation;
pub(crate) fn set_responses(
&mut self,
responses: HashMap<H256, Result<BlockBody, DownloadError>>,
) {
self.responses = responses;
}
}
impl<F> StageTestRunner for BodyTestRunner<F>
where
F: Fn() -> TestBodyDownloader,
{
impl StageTestRunner for BodyTestRunner {
type S = BodyStage<TestBodyDownloader, TestConsensus>;
fn db(&self) -> &StageTestDB {
@ -667,115 +533,115 @@ mod tests {
}
fn stage(&self) -> Self::S {
let mut consensus = TestConsensus::default();
consensus.set_fail_validation(self.fail_validation);
BodyStage {
downloader: (self.downloader_builder)(),
consensus,
downloader: Arc::new(TestBodyDownloader::new(self.responses.clone())),
consensus: self.consensus.clone(),
batch_size: self.batch_size,
}
}
}
impl<F> BodyTestRunner<F>
where
F: Fn() -> TestBodyDownloader,
{
#[async_trait::async_trait]
impl ExecuteStageTestRunner for BodyTestRunner {
type Seed = Vec<BlockLocked>;
fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
let start = input.stage_progress.unwrap_or_default();
let end =
input.previous_stage.as_ref().map(|(_, num)| *num + 1).unwrap_or_default();
let blocks = random_block_range(start..end, GENESIS_HASH);
self.insert_genesis()?;
self.db.insert_headers(blocks.iter().map(|block| &block.header))?;
self.set_responses(blocks.iter().map(body_by_hash).collect());
Ok(blocks)
}
fn validate_execution(
&self,
input: ExecInput,
output: Option<ExecOutput>,
) -> Result<(), TestRunnerError> {
let highest_block = match output.as_ref() {
Some(output) => output.stage_progress,
None => input.stage_progress.unwrap_or_default(),
};
self.validate_db_blocks(highest_block)
}
}
impl UnwindStageTestRunner for BodyTestRunner {
fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
self.db.check_no_entry_above::<tables::BlockBodies, _>(input.unwind_to, |key| {
key.number()
})?;
if let Some(last_body) = self.last_body() {
let last_tx_id = last_body.base_tx_id + last_body.tx_amount;
self.db
.check_no_entry_above::<tables::Transactions, _>(last_tx_id, |key| key)?;
}
Ok(())
}
}
impl BodyTestRunner {
/// Insert the genesis block into the appropriate tables
///
/// The genesis block always has no transactions and no ommers, and it always has the
/// same hash.
pub(crate) fn insert_genesis(&self) -> Result<(), db::Error> {
self.insert_header(&SealedHeader::new(Header::default(), GENESIS_HASH))?;
let mut db = self.db.container();
let tx = db.get_mut();
tx.put::<tables::BlockBodies>(
(0, GENESIS_HASH).into(),
StoredBlockBody { base_tx_id: 0, tx_amount: 0, ommers: vec![] },
)?;
db.commit()?;
Ok(())
}
/// Insert header into tables
pub(crate) fn insert_header(&self, header: &SealedHeader) -> Result<(), db::Error> {
self.insert_headers(std::iter::once(header))
}
/// Insert headers into tables
pub(crate) fn insert_headers<'a, I>(&self, headers: I) -> Result<(), db::Error>
where
I: Iterator<Item = &'a SealedHeader>,
{
let headers = headers.collect::<Vec<_>>();
self.db
.map_put::<tables::HeaderNumbers, _, _>(&headers, |h| (h.hash(), h.number))?;
self.db.map_put::<tables::Headers, _, _>(&headers, |h| {
(BlockNumHash((h.number, h.hash())), h.deref().clone().unseal())
})?;
self.db.map_put::<tables::CanonicalHeaders, _, _>(&headers, |h| {
(h.number, h.hash())
})?;
self.db.transform_append::<tables::HeaderTD, _, _>(&headers, |prev, h| {
let prev_td = U256::from_big_endian(&prev.clone().unwrap_or_default());
(
BlockNumHash((h.number, h.hash())),
H256::from_uint(&(prev_td + h.difficulty)).as_bytes().to_vec(),
pub(crate) fn insert_genesis(&self) -> Result<(), TestRunnerError> {
let header = SealedHeader::new(Header::default(), GENESIS_HASH);
self.db.insert_headers(std::iter::once(&header))?;
self.db.commit(|tx| {
tx.put::<tables::BlockBodies>(
(0, GENESIS_HASH).into(),
StoredBlockBody { base_tx_id: 0, tx_amount: 0, ommers: vec![] },
)
})?;
Ok(())
}
/// Retrieve the last body from the database
pub(crate) fn last_body(&self) -> Option<StoredBlockBody> {
Some(
self.db()
.container()
.get()
.cursor::<tables::BlockBodies>()
.ok()?
.last()
.ok()??
.1,
)
self.db
.query(|tx| Ok(tx.cursor::<tables::BlockBodies>()?.last()?.map(|e| e.1)))
.ok()
.flatten()
}
/// Validate that the inserted block data is valid
pub(crate) fn validate_db_blocks(
&self,
highest_block: BlockNumber,
) -> Result<(), db::Error> {
let db = self.db.container();
let tx = db.get();
) -> Result<(), TestRunnerError> {
self.db.query(|tx| {
let mut block_body_cursor = tx.cursor::<tables::BlockBodies>()?;
let mut transaction_cursor = tx.cursor::<tables::Transactions>()?;
let mut block_body_cursor = tx.cursor::<tables::BlockBodies>()?;
let mut transaction_cursor = tx.cursor::<tables::Transactions>()?;
let mut entry = block_body_cursor.first()?;
let mut prev_max_tx_id = 0;
while let Some((key, body)) = entry {
assert!(
key.number() <= highest_block,
"We wrote a block body outside of our synced range. Found block with number {}, highest block according to stage is {}",
key.number(), highest_block
);
assert!(prev_max_tx_id == body.base_tx_id, "Transaction IDs are malformed.");
for num in 0..body.tx_amount {
let tx_id = body.base_tx_id + num;
assert_matches!(
transaction_cursor.seek_exact(tx_id),
Ok(Some(_)),
"A transaction is missing."
let mut entry = block_body_cursor.first()?;
let mut prev_max_tx_id = 0;
while let Some((key, body)) = entry {
assert!(
key.number() <= highest_block,
"We wrote a block body outside of our synced range. Found block with number {}, highest block according to stage is {}",
key.number(), highest_block
);
}
prev_max_tx_id = body.base_tx_id + body.tx_amount;
entry = block_body_cursor.next()?;
}
assert!(prev_max_tx_id == body.base_tx_id, "Transaction IDs are malformed.");
for num in 0..body.tx_amount {
let tx_id = body.base_tx_id + num;
assert_matches!(
transaction_cursor.seek_exact(tx_id),
Ok(Some(_)),
"A transaction is missing."
);
}
prev_max_tx_id = body.base_tx_id + body.tx_amount;
entry = block_body_cursor.next()?;
}
Ok(())
})?;
Ok(())
}
}
@ -785,7 +651,7 @@ mod tests {
#[derive(Debug)]
pub(crate) struct NoopClient;
#[async_trait]
#[async_trait::async_trait]
impl BodiesClient for NoopClient {
async fn get_block_body(&self, _: H256) -> Result<BlockBody, BodiesClientError> {
panic!("Noop client should not be called")
@ -794,7 +660,7 @@ mod tests {
// TODO(onbjerg): Move
/// A [BodyDownloader] that is backed by an internal [HashMap] for testing.
#[derive(Debug, Default)]
#[derive(Debug, Default, Clone)]
pub(crate) struct TestBodyDownloader {
responses: HashMap<H256, Result<BlockBody, DownloadError>>,
}
@ -824,14 +690,12 @@ mod tests {
{
Box::pin(futures_util::stream::iter(hashes.into_iter().map(
|(block_number, hash)| {
Ok((
*block_number,
*hash,
self.responses
.get(hash)
.expect("Stage tried downloading a block we do not have.")
.clone()?,
))
let result = self
.responses
.get(hash)
.expect("Stage tried downloading a block we do not have.")
.clone()?;
Ok((*block_number, *hash, result))
},
)))
}

View File

@ -58,8 +58,8 @@ impl<DB: Database, D: HeaderDownloader, C: Consensus, H: HeadersClient> Stage<DB
let last_block_num = input.stage_progress.unwrap_or_default();
self.update_head::<DB>(tx, last_block_num).await?;
// TODO: add batch size
// download the headers
// TODO: handle input.max_block
let last_hash = tx
.get::<tables::CanonicalHeaders>(last_block_num)?
.ok_or(DatabaseIntegrityError::CanonicalHash { number: last_block_num })?;
@ -190,214 +190,99 @@ impl<D: HeaderDownloader, C: Consensus, H: HeadersClient> HeaderStage<D, C, H> {
#[cfg(test)]
mod tests {
use super::*;
use crate::util::test_utils::StageTestRunner;
use assert_matches::assert_matches;
use reth_interfaces::{
consensus,
test_utils::{
generators::{random_header, random_header_range},
TestHeaderDownloader,
},
use crate::test_utils::{
stage_test_suite, ExecuteStageTestRunner, UnwindStageTestRunner, PREV_STAGE_ID,
};
use test_utils::HeadersTestRunner;
use assert_matches::assert_matches;
use test_runner::HeadersTestRunner;
const TEST_STAGE: StageId = StageId("Headers");
stage_test_suite!(HeadersTestRunner);
/// Check that the execution errors on empty database or
/// prev progress missing from the database.
#[tokio::test]
async fn execute_empty_db() {
let runner = HeadersTestRunner::default();
let rx = runner.execute(ExecInput::default());
assert_matches!(
rx.await.unwrap(),
Err(StageError::DatabaseIntegrity(DatabaseIntegrityError::CanonicalHeader { .. }))
);
}
/// Check that the execution exits on downloader timeout.
#[tokio::test]
// Validate that the execution does not fail on timeout
async fn execute_timeout() {
let head = random_header(0, None);
let runner = HeadersTestRunner::with_downloader(TestHeaderDownloader::new(Err(
DownloadError::Timeout { request_id: 0 },
)));
runner.insert_header(&head).expect("failed to insert header");
let rx = runner.execute(ExecInput::default());
let mut runner = HeadersTestRunner::default();
let input = ExecInput::default();
runner.seed_execution(input).expect("failed to seed execution");
let rx = runner.execute(input);
runner.consensus.update_tip(H256::from_low_u64_be(1));
assert_matches!(rx.await.unwrap(), Ok(ExecOutput { done, .. }) if !done);
let result = rx.await.unwrap();
assert_matches!(
result,
Ok(ExecOutput { done: false, reached_tip: false, stage_progress: 0 })
);
assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
}
/// Check that validation error is propagated during the execution.
#[tokio::test]
async fn execute_validation_error() {
let head = random_header(0, None);
let runner = HeadersTestRunner::with_downloader(TestHeaderDownloader::new(Err(
DownloadError::HeaderValidation {
hash: H256::zero(),
error: consensus::Error::BaseFeeMissing,
},
)));
runner.insert_header(&head).expect("failed to insert header");
let rx = runner.execute(ExecInput::default());
runner.consensus.update_tip(H256::from_low_u64_be(1));
assert_matches!(rx.await.unwrap(), Err(StageError::Validation { block, error: consensus::Error::BaseFeeMissing, }) if block == 0);
}
/// Validate that all necessary tables are updated after the
/// header download on no previous progress.
#[tokio::test]
async fn execute_no_progress() {
let (start, end) = (0, 100);
let head = random_header(start, None);
let headers = random_header_range(start + 1..end, head.hash());
let result = headers.iter().rev().cloned().collect::<Vec<_>>();
let runner = HeadersTestRunner::with_downloader(TestHeaderDownloader::new(Ok(result)));
runner.insert_header(&head).expect("failed to insert header");
let rx = runner.execute(ExecInput::default());
let tip = headers.last().unwrap();
runner.consensus.update_tip(tip.hash());
assert_matches!(
rx.await.unwrap(),
Ok(ExecOutput { done, reached_tip, stage_progress })
if done && reached_tip && stage_progress == tip.number
);
assert!(headers.iter().try_for_each(|h| runner.validate_db_header(h)).is_ok());
}
/// Validate that all necessary tables are updated after the
/// header download with some previous progress.
#[tokio::test]
async fn execute_prev_progress() {
let (start, end) = (10000, 10241);
let head = random_header(start, None);
let headers = random_header_range(start + 1..end, head.hash());
let result = headers.iter().rev().cloned().collect::<Vec<_>>();
let runner = HeadersTestRunner::with_downloader(TestHeaderDownloader::new(Ok(result)));
runner.insert_header(&head).expect("failed to insert header");
let rx = runner.execute(ExecInput {
previous_stage: Some((TEST_STAGE, head.number)),
stage_progress: Some(head.number),
});
let tip = headers.last().unwrap();
runner.consensus.update_tip(tip.hash());
assert_matches!(
rx.await.unwrap(),
Ok(ExecOutput { done, reached_tip, stage_progress })
if done && reached_tip && stage_progress == tip.number
);
assert!(headers.iter().try_for_each(|h| runner.validate_db_header(h)).is_ok());
let mut runner = HeadersTestRunner::default();
runner.consensus.set_fail_validation(true);
let input = ExecInput::default();
let seed = runner.seed_execution(input).expect("failed to seed execution");
let rx = runner.execute(input);
runner.after_execution(seed).await.expect("failed to run after execution hook");
let result = rx.await.unwrap();
assert_matches!(result, Err(StageError::Validation { .. }));
assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
}
/// Execute the stage with linear downloader
#[tokio::test]
async fn execute_with_linear_downloader() {
let (start, end) = (1000, 1200);
let head = random_header(start, None);
let headers = random_header_range(start + 1..end, head.hash());
let runner = HeadersTestRunner::with_linear_downloader();
runner.insert_header(&head).expect("failed to insert header");
let rx = runner.execute(ExecInput {
previous_stage: Some((TEST_STAGE, head.number)),
stage_progress: Some(head.number),
});
let mut runner = HeadersTestRunner::with_linear_downloader();
let (stage_progress, previous_stage) = (1000, 1200);
let input = ExecInput {
previous_stage: Some((PREV_STAGE_ID, previous_stage)),
stage_progress: Some(stage_progress),
};
let headers = runner.seed_execution(input).expect("failed to seed execution");
let rx = runner.execute(input);
// skip `after_execution` hook for linear downloader
let tip = headers.last().unwrap();
runner.consensus.update_tip(tip.hash());
let mut download_result = headers.clone();
download_result.insert(0, head);
let download_result = headers.clone();
runner
.client
.on_header_request(1, |id, _| {
runner.client.send_header_response(
id,
download_result.clone().into_iter().map(|h| h.unseal()).collect(),
)
let response = download_result.iter().map(|h| h.clone().unseal()).collect();
runner.client.send_header_response(id, response)
})
.await;
let result = rx.await.unwrap();
assert_matches!(
rx.await.unwrap(),
Ok(ExecOutput { done, reached_tip, stage_progress })
if done && reached_tip && stage_progress == tip.number
result,
Ok(ExecOutput { done: true, reached_tip: true, stage_progress }) if stage_progress == tip.number
);
assert!(headers.iter().try_for_each(|h| runner.validate_db_header(h)).is_ok());
assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed");
}
/// Check that unwind does not panic on empty database.
#[tokio::test]
async fn unwind_empty_db() {
let unwind_to = 100;
let runner = HeadersTestRunner::default();
let rx =
runner.unwind(UnwindInput { bad_block: None, stage_progress: unwind_to, unwind_to });
assert_matches!(
rx.await.unwrap(),
Ok(UnwindOutput {stage_progress} ) if stage_progress == unwind_to
);
}
/// Check that unwind can remove headers across gaps
#[tokio::test]
async fn unwind_db_gaps() {
let runner = HeadersTestRunner::default();
let head = random_header(0, None);
let first_range = random_header_range(1..20, head.hash());
let second_range = random_header_range(50..100, H256::zero());
runner.insert_header(&head).expect("failed to insert header");
runner
.insert_headers(first_range.iter().chain(second_range.iter()))
.expect("failed to insert headers");
let unwind_to = 15;
let rx =
runner.unwind(UnwindInput { bad_block: None, stage_progress: unwind_to, unwind_to });
assert_matches!(
rx.await.unwrap(),
Ok(UnwindOutput {stage_progress} ) if stage_progress == unwind_to
);
runner
.db()
.check_no_entry_above::<tables::CanonicalHeaders, _>(unwind_to, |key| key)
.expect("failed to check cannonical headers");
runner
.db()
.check_no_entry_above_by_value::<tables::HeaderNumbers, _>(unwind_to, |val| val)
.expect("failed to check header numbers");
runner
.db()
.check_no_entry_above::<tables::Headers, _>(unwind_to, |key| key.number())
.expect("failed to check headers");
runner
.db()
.check_no_entry_above::<tables::HeaderTD, _>(unwind_to, |key| key.number())
.expect("failed to check td");
}
mod test_utils {
mod test_runner {
use crate::{
stages::headers::HeaderStage,
util::test_utils::{StageTestDB, StageTestRunner},
test_utils::{
ExecuteStageTestRunner, StageTestDB, StageTestRunner, TestRunnerError,
UnwindStageTestRunner,
},
ExecInput, ExecOutput, UnwindInput,
};
use reth_headers_downloaders::linear::{LinearDownloadBuilder, LinearDownloader};
use reth_interfaces::{
db::{self, models::blocks::BlockNumHash, tables, DbTx},
db::{models::blocks::BlockNumHash, tables, DbTx},
p2p::headers::downloader::HeaderDownloader,
test_utils::{TestConsensus, TestHeaderDownloader, TestHeadersClient},
test_utils::{
generators::{random_header, random_header_range},
TestConsensus, TestHeaderDownloader, TestHeadersClient,
},
};
use reth_primitives::{rpc::BigEndianHash, SealedHeader, H256, U256};
use std::{ops::Deref, sync::Arc};
use reth_primitives::{BlockNumber, SealedHeader, H256, U256};
use std::sync::Arc;
pub(crate) struct HeadersTestRunner<D: HeaderDownloader> {
pub(crate) consensus: Arc<TestConsensus>,
@ -408,10 +293,12 @@ mod tests {
impl Default for HeadersTestRunner<TestHeaderDownloader> {
fn default() -> Self {
let client = Arc::new(TestHeadersClient::default());
let consensus = Arc::new(TestConsensus::default());
Self {
client: Arc::new(TestHeadersClient::default()),
consensus: Arc::new(TestConsensus::default()),
downloader: Arc::new(TestHeaderDownloader::new(Ok(Vec::default()))),
client: client.clone(),
consensus: consensus.clone(),
downloader: Arc::new(TestHeaderDownloader::new(client, consensus)),
db: StageTestDB::default(),
}
}
@ -433,6 +320,99 @@ mod tests {
}
}
#[async_trait::async_trait]
impl<D: HeaderDownloader + 'static> ExecuteStageTestRunner for HeadersTestRunner<D> {
type Seed = Vec<SealedHeader>;
fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
let start = input.stage_progress.unwrap_or_default();
let head = random_header(start, None);
self.db.insert_headers(std::iter::once(&head))?;
// use previous progress as seed size
let end = input.previous_stage.map(|(_, num)| num).unwrap_or_default() + 1;
if start + 1 >= end {
return Ok(Vec::default())
}
let mut headers = random_header_range(start + 1..end, head.hash());
headers.insert(0, head);
Ok(headers)
}
async fn after_execution(&self, headers: Self::Seed) -> Result<(), TestRunnerError> {
let tip = if !headers.is_empty() {
headers.last().unwrap().hash()
} else {
H256::from_low_u64_be(rand::random())
};
self.consensus.update_tip(tip);
self.client
.send_header_response_delayed(
0,
headers.into_iter().map(|h| h.unseal()).collect(),
1,
)
.await;
Ok(())
}
/// Validate stored headers
fn validate_execution(
&self,
input: ExecInput,
output: Option<ExecOutput>,
) -> Result<(), TestRunnerError> {
let initial_stage_progress = input.stage_progress.unwrap_or_default();
match output {
Some(output) if output.stage_progress > initial_stage_progress => {
self.db.query(|tx| {
for block_num in (initial_stage_progress..output.stage_progress).rev() {
// look up the header hash
let hash = tx
.get::<tables::CanonicalHeaders>(block_num)?
.expect("no header hash");
let key: BlockNumHash = (block_num, hash).into();
// validate the header number
assert_eq!(tx.get::<tables::HeaderNumbers>(hash)?, Some(block_num));
// validate the header
let header = tx.get::<tables::Headers>(key)?;
assert!(header.is_some());
let header = header.unwrap().seal();
assert_eq!(header.hash(), hash);
// validate td consistency in the database
if header.number > initial_stage_progress {
let parent_td = tx.get::<tables::HeaderTD>(
(header.number - 1, header.parent_hash).into(),
)?;
let td = tx.get::<tables::HeaderTD>(key)?.unwrap();
assert_eq!(
parent_td.map(
|td| U256::from_big_endian(&td) + header.difficulty
),
Some(U256::from_big_endian(&td))
);
}
}
Ok(())
})?;
}
_ => self.check_no_header_entry_above(initial_stage_progress)?,
};
Ok(())
}
}
impl<D: HeaderDownloader + 'static> UnwindStageTestRunner for HeadersTestRunner<D> {
fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
self.check_no_header_entry_above(input.unwind_to)
}
}
impl HeadersTestRunner<LinearDownloader<TestConsensus, TestHeadersClient>> {
pub(crate) fn with_linear_downloader() -> Self {
let client = Arc::new(TestHeadersClient::default());
@ -445,74 +425,15 @@ mod tests {
}
impl<D: HeaderDownloader> HeadersTestRunner<D> {
pub(crate) fn with_downloader(downloader: D) -> Self {
HeadersTestRunner {
client: Arc::new(TestHeadersClient::default()),
consensus: Arc::new(TestConsensus::default()),
downloader: Arc::new(downloader),
db: StageTestDB::default(),
}
}
/// Insert header into tables
pub(crate) fn insert_header(&self, header: &SealedHeader) -> Result<(), db::Error> {
self.insert_headers(std::iter::once(header))
}
/// Insert headers into tables
pub(crate) fn insert_headers<'a, I>(&self, headers: I) -> Result<(), db::Error>
where
I: Iterator<Item = &'a SealedHeader>,
{
let headers = headers.collect::<Vec<_>>();
self.db
.map_put::<tables::HeaderNumbers, _, _>(&headers, |h| (h.hash(), h.number))?;
self.db.map_put::<tables::Headers, _, _>(&headers, |h| {
(BlockNumHash((h.number, h.hash())), h.deref().clone().unseal())
})?;
self.db.map_put::<tables::CanonicalHeaders, _, _>(&headers, |h| {
(h.number, h.hash())
})?;
self.db.transform_append::<tables::HeaderTD, _, _>(&headers, |prev, h| {
let prev_td = U256::from_big_endian(&prev.clone().unwrap_or_default());
(
BlockNumHash((h.number, h.hash())),
H256::from_uint(&(prev_td + h.difficulty)).as_bytes().to_vec(),
)
})?;
Ok(())
}
/// Validate stored header against provided
pub(crate) fn validate_db_header(
pub(crate) fn check_no_header_entry_above(
&self,
header: &SealedHeader,
) -> Result<(), db::Error> {
let db = self.db.container();
let tx = db.get();
let key: BlockNumHash = (header.number, header.hash()).into();
let db_number = tx.get::<tables::HeaderNumbers>(header.hash())?;
assert_eq!(db_number, Some(header.number));
let db_header = tx.get::<tables::Headers>(key)?;
assert_eq!(db_header, Some(header.clone().unseal()));
let db_canonical_header = tx.get::<tables::CanonicalHeaders>(header.number)?;
assert_eq!(db_canonical_header, Some(header.hash()));
if header.number != 0 {
let parent_key: BlockNumHash = (header.number - 1, header.parent_hash).into();
let parent_td = tx.get::<tables::HeaderTD>(parent_key)?;
let td = U256::from_big_endian(&tx.get::<tables::HeaderTD>(key)?.unwrap());
assert_eq!(
parent_td.map(|td| U256::from_big_endian(&td) + header.difficulty),
Some(td)
);
}
block: BlockNumber,
) -> Result<(), TestRunnerError> {
self.db
.check_no_entry_above_by_value::<tables::HeaderNumbers, _>(block, |val| val)?;
self.db.check_no_entry_above::<tables::CanonicalHeaders, _>(block, |key| key)?;
self.db.check_no_entry_above::<tables::Headers, _>(block, |key| key.number())?;
self.db.check_no_entry_above::<tables::HeaderTD, _>(block, |key| key.number())?;
Ok(())
}
}

View File

@ -87,141 +87,18 @@ impl<DB: Database> Stage<DB> for TxIndex {
#[cfg(test)]
mod tests {
use super::*;
use crate::util::test_utils::{StageTestDB, StageTestRunner};
use crate::test_utils::{
stage_test_suite, ExecuteStageTestRunner, StageTestDB, StageTestRunner, TestRunnerError,
UnwindStageTestRunner,
};
use assert_matches::assert_matches;
use reth_interfaces::{db::models::BlockNumHash, test_utils::generators::random_header_range};
use reth_interfaces::{
db::models::{BlockNumHash, StoredBlockBody},
test_utils::generators::random_header_range,
};
use reth_primitives::H256;
const TEST_STAGE: StageId = StageId("PrevStage");
#[tokio::test]
async fn execute_empty_db() {
let runner = TxIndexTestRunner::default();
let rx = runner.execute(ExecInput::default());
assert_matches!(
rx.await.unwrap(),
Err(StageError::DatabaseIntegrity(DatabaseIntegrityError::CanonicalHeader { .. }))
);
}
#[tokio::test]
async fn execute_no_prev_tx_count() {
let runner = TxIndexTestRunner::default();
let headers = random_header_range(0..10, H256::zero());
runner
.db()
.map_put::<tables::CanonicalHeaders, _, _>(&headers, |h| (h.number, h.hash()))
.expect("failed to insert");
let (head, tail) = (headers.first().unwrap(), headers.last().unwrap());
let input = ExecInput {
previous_stage: Some((TEST_STAGE, tail.number)),
stage_progress: Some(head.number),
};
let rx = runner.execute(input);
assert_matches!(
rx.await.unwrap(),
Err(StageError::DatabaseIntegrity(DatabaseIntegrityError::CumulativeTxCount { .. }))
);
}
#[tokio::test]
async fn execute() {
let runner = TxIndexTestRunner::default();
let (start, pivot, end) = (0, 100, 200);
let headers = random_header_range(start..end, H256::zero());
runner
.db()
.map_put::<tables::CanonicalHeaders, _, _>(&headers, |h| (h.number, h.hash()))
.expect("failed to insert");
runner
.db()
.transform_append::<tables::CumulativeTxCount, _, _>(&headers[..=pivot], |prev, h| {
(
BlockNumHash((h.number, h.hash())),
prev.unwrap_or_default() + (rand::random::<u8>() as u64),
)
})
.expect("failed to insert");
let (pivot, tail) = (headers.get(pivot).unwrap(), headers.last().unwrap());
let input = ExecInput {
previous_stage: Some((TEST_STAGE, tail.number)),
stage_progress: Some(pivot.number),
};
let rx = runner.execute(input);
assert_matches!(
rx.await.unwrap(),
Ok(ExecOutput { stage_progress, done, reached_tip })
if done && reached_tip && stage_progress == tail.number
);
}
#[tokio::test]
async fn unwind_empty_db() {
let runner = TxIndexTestRunner::default();
let rx = runner.unwind(UnwindInput::default());
assert_matches!(
rx.await.unwrap(),
Ok(UnwindOutput { stage_progress }) if stage_progress == 0
);
}
#[tokio::test]
async fn unwind_no_input() {
let runner = TxIndexTestRunner::default();
let headers = random_header_range(0..10, H256::zero());
runner
.db()
.transform_append::<tables::CumulativeTxCount, _, _>(&headers, |prev, h| {
(
BlockNumHash((h.number, h.hash())),
prev.unwrap_or_default() + (rand::random::<u8>() as u64),
)
})
.expect("failed to insert");
let rx = runner.unwind(UnwindInput::default());
assert_matches!(
rx.await.unwrap(),
Ok(UnwindOutput { stage_progress }) if stage_progress == 0
);
runner
.db()
.check_no_entry_above::<tables::CumulativeTxCount, _>(0, |h| h.number())
.expect("failed to check tx count");
}
#[tokio::test]
async fn unwind_with_db_gaps() {
let runner = TxIndexTestRunner::default();
let first_range = random_header_range(0..20, H256::zero());
let second_range = random_header_range(50..100, H256::zero());
runner
.db()
.transform_append::<tables::CumulativeTxCount, _, _>(
&first_range.iter().chain(second_range.iter()).collect::<Vec<_>>(),
|prev, h| {
(
BlockNumHash((h.number, h.hash())),
prev.unwrap_or_default() + (rand::random::<u8>() as u64),
)
},
)
.expect("failed to insert");
let unwind_to = 10;
let input = UnwindInput { unwind_to, ..Default::default() };
let rx = runner.unwind(input);
assert_matches!(
rx.await.unwrap(),
Ok(UnwindOutput { stage_progress }) if stage_progress == unwind_to
);
runner
.db()
.check_no_entry_above::<tables::CumulativeTxCount, _>(unwind_to, |h| h.number())
.expect("failed to check tx count");
}
stage_test_suite!(TxIndexTestRunner);
#[derive(Default)]
pub(crate) struct TxIndexTestRunner {
@ -239,4 +116,82 @@ mod tests {
TxIndex {}
}
}
impl ExecuteStageTestRunner for TxIndexTestRunner {
type Seed = ();
fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
let pivot = input.stage_progress.unwrap_or_default();
let start = pivot.saturating_sub(100);
let mut end = input.previous_stage.as_ref().map(|(_, num)| *num).unwrap_or_default();
end += 2; // generate 2 additional headers to account for start header lookup
let headers = random_header_range(start..end, H256::zero());
let headers =
headers.into_iter().map(|h| (h, rand::random::<u8>())).collect::<Vec<_>>();
self.db.map_put::<tables::CanonicalHeaders, _, _>(&headers, |(h, _)| {
(h.number, h.hash())
})?;
self.db.map_put::<tables::BlockBodies, _, _>(&headers, |(h, count)| {
(
BlockNumHash((h.number, h.hash())),
StoredBlockBody { base_tx_id: 0, tx_amount: *count as u64, ommers: vec![] },
)
})?;
let slice_up_to =
std::cmp::min(pivot.saturating_sub(start) as usize, headers.len() - 1);
self.db.transform_append::<tables::CumulativeTxCount, _, _>(
&headers[..=slice_up_to],
|prev, (h, count)| {
(BlockNumHash((h.number, h.hash())), prev.unwrap_or_default() + (*count as u64))
},
)?;
Ok(())
}
fn validate_execution(
&self,
input: ExecInput,
_output: Option<ExecOutput>,
) -> Result<(), TestRunnerError> {
self.db.query(|tx| {
let (start, end) = (
input.stage_progress.unwrap_or_default(),
input.previous_stage.as_ref().map(|(_, num)| *num).unwrap_or_default(),
);
if start >= end {
return Ok(())
}
let start_hash =
tx.get::<tables::CanonicalHeaders>(start)?.expect("no canonical found");
let mut tx_count_cursor = tx.cursor::<tables::CumulativeTxCount>()?;
let mut tx_count_walker = tx_count_cursor.walk((start, start_hash).into())?;
let mut count = tx_count_walker.next().unwrap()?.1;
let mut last_num = start;
while let Some(entry) = tx_count_walker.next() {
let (key, db_count) = entry?;
count += tx.get::<tables::BlockBodies>(key)?.unwrap().tx_amount as u64;
assert_eq!(db_count, count);
last_num = key.number();
}
assert_eq!(last_num, end);
Ok(())
})?;
Ok(())
}
}
impl UnwindStageTestRunner for TxIndexTestRunner {
fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
self.db.check_no_entry_above::<tables::CumulativeTxCount, _>(input.unwind_to, |h| {
h.number()
})?;
Ok(())
}
}
}

View File

@ -0,0 +1,145 @@
macro_rules! stage_test_suite {
($runner:ident) => {
/// Check that the execution is short-circuited if the database is empty.
#[tokio::test]
async fn execute_empty_db() {
// Set up the runner
let runner = $runner::default();
// Execute the stage with empty database
let input = crate::stage::ExecInput::default();
// Run stage execution
let result = runner.execute(input).await.unwrap();
assert_matches!(
result,
Err(crate::error::StageError::DatabaseIntegrity(_))
);
// Validate the stage execution
assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation");
}
/// Check that the execution is short-circuited if the target was already reached.
#[tokio::test]
async fn execute_already_reached_target() {
let stage_progress = 1000;
// Set up the runner
let mut runner = $runner::default();
let input = crate::stage::ExecInput {
previous_stage: Some((crate::test_utils::PREV_STAGE_ID, stage_progress)),
stage_progress: Some(stage_progress),
};
let seed = runner.seed_execution(input).expect("failed to seed");
// Run stage execution
let rx = runner.execute(input);
// Run `after_execution` hook
runner.after_execution(seed).await.expect("failed to run after execution hook");
// Assert the successful result
let result = rx.await.unwrap();
assert_matches!(
result,
Ok(ExecOutput { done, reached_tip, stage_progress })
if done && reached_tip && stage_progress == stage_progress
);
// Validate the stage execution
assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation");
}
// Run the complete stage execution flow.
#[tokio::test]
async fn execute() {
let (previous_stage, stage_progress) = (500, 100);
// Set up the runner
let mut runner = $runner::default();
let input = crate::stage::ExecInput {
previous_stage: Some((crate::test_utils::PREV_STAGE_ID, previous_stage)),
stage_progress: Some(stage_progress),
};
let seed = runner.seed_execution(input).expect("failed to seed");
let rx = runner.execute(input);
// Run `after_execution` hook
runner.after_execution(seed).await.expect("failed to run after execution hook");
// Assert the successful result
let result = rx.await.unwrap();
assert_matches!(
result,
Ok(ExecOutput { done, reached_tip, stage_progress })
if done && reached_tip && stage_progress == previous_stage
);
// Validate the stage execution
assert!(runner.validate_execution(input, result.ok()).is_ok(), "execution validation");
}
// Check that unwind does not panic on empty database.
#[tokio::test]
async fn unwind_empty_db() {
// Set up the runner
let runner = $runner::default();
let input = crate::stage::UnwindInput::default();
// Run stage unwind
let rx = runner.unwind(input).await;
assert_matches!(
rx,
Ok(UnwindOutput { stage_progress }) if stage_progress == input.unwind_to
);
// Validate the stage unwind
assert!(runner.validate_unwind(input).is_ok(), "unwind validation");
}
// Run complete execute and unwind flow.
#[tokio::test]
async fn unwind() {
let (previous_stage, stage_progress) = (500, 100);
// Set up the runner
let mut runner = $runner::default();
let execute_input = crate::stage::ExecInput {
previous_stage: Some((crate::test_utils::PREV_STAGE_ID, previous_stage)),
stage_progress: Some(stage_progress),
};
let seed = runner.seed_execution(execute_input).expect("failed to seed");
// Run stage execution
let rx = runner.execute(execute_input);
runner.after_execution(seed).await.expect("failed to run after execution hook");
// Assert the successful execution result
let result = rx.await.unwrap();
assert_matches!(
result,
Ok(ExecOutput { done, reached_tip, stage_progress })
if done && reached_tip && stage_progress == previous_stage
);
assert!(runner.validate_execution(execute_input, result.ok()).is_ok(), "execution validation");
// Run stage unwind
let unwind_input = crate::stage::UnwindInput {
unwind_to: stage_progress, stage_progress, bad_block: None,
};
let rx = runner.unwind(unwind_input).await;
// Assert the successful unwind result
assert_matches!(
rx,
Ok(UnwindOutput { stage_progress }) if stage_progress == unwind_input.unwind_to
);
// Validate the stage unwind
assert!(runner.validate_unwind(unwind_input).is_ok(), "unwind validation");
}
};
}
pub(crate) use stage_test_suite;

View File

@ -0,0 +1,14 @@
use crate::StageId;
mod macros;
mod runner;
mod stage_db;
pub(crate) use macros::*;
pub(crate) use runner::{
ExecuteStageTestRunner, StageTestRunner, TestRunnerError, UnwindStageTestRunner,
};
pub(crate) use stage_db::StageTestDB;
/// The previous test stage id mock used for testing
pub(crate) const PREV_STAGE_ID: StageId = StageId("PrevStage");

View File

@ -0,0 +1,82 @@
use reth_db::{kv::Env, mdbx::WriteMap};
use reth_interfaces::db::{self, DBContainer};
use std::borrow::Borrow;
use tokio::sync::oneshot;
use super::StageTestDB;
use crate::{ExecInput, ExecOutput, Stage, StageError, UnwindInput, UnwindOutput};
#[derive(thiserror::Error, Debug)]
pub(crate) enum TestRunnerError {
#[error("Database error occured.")]
Database(#[from] db::Error),
#[error("Internal runner error occured.")]
Internal(#[from] Box<dyn std::error::Error>),
}
/// A generic test runner for stages.
#[async_trait::async_trait]
pub(crate) trait StageTestRunner {
type S: Stage<Env<WriteMap>> + 'static;
/// Return a reference to the database.
fn db(&self) -> &StageTestDB;
/// Return an instance of a Stage.
fn stage(&self) -> Self::S;
}
#[async_trait::async_trait]
pub(crate) trait ExecuteStageTestRunner: StageTestRunner {
type Seed: Send + Sync;
/// Seed database for stage execution
fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError>;
/// Validate stage execution
fn validate_execution(
&self,
input: ExecInput,
output: Option<ExecOutput>,
) -> Result<(), TestRunnerError>;
/// Run [Stage::execute] and return a receiver for the result.
fn execute(&self, input: ExecInput) -> oneshot::Receiver<Result<ExecOutput, StageError>> {
let (tx, rx) = oneshot::channel();
let (db, mut stage) = (self.db().inner(), self.stage());
tokio::spawn(async move {
let mut db = DBContainer::new(db.borrow()).expect("failed to create db container");
let result = stage.execute(&mut db, input).await;
db.commit().expect("failed to commit");
tx.send(result).expect("failed to send message")
});
rx
}
/// Run a hook after [Stage::execute]. Required for Headers & Bodies stages.
async fn after_execution(&self, _seed: Self::Seed) -> Result<(), TestRunnerError> {
Ok(())
}
}
#[async_trait::async_trait]
pub(crate) trait UnwindStageTestRunner: StageTestRunner {
/// Validate the unwind
fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError>;
/// Run [Stage::unwind] and return a receiver for the result.
async fn unwind(
&self,
input: UnwindInput,
) -> Result<UnwindOutput, Box<dyn std::error::Error + Send + Sync>> {
let (tx, rx) = oneshot::channel();
let (db, mut stage) = (self.db().inner(), self.stage());
tokio::spawn(async move {
let mut db = DBContainer::new(db.borrow()).expect("failed to create db container");
let result = stage.unwind(&mut db, input).await;
db.commit().expect("failed to commit");
tx.send(result).expect("failed to send result");
});
Box::pin(rx).await.unwrap()
}
}

View File

@ -0,0 +1,179 @@
use reth_db::{
kv::{test_utils::create_test_db, tx::Tx, Env, EnvKind},
mdbx::{WriteMap, RW},
};
use reth_interfaces::db::{
self, models::BlockNumHash, tables, DBContainer, DbCursorRO, DbCursorRW, DbTx, DbTxMut, Table,
};
use reth_primitives::{BigEndianHash, BlockNumber, SealedHeader, H256, U256};
use std::{borrow::Borrow, sync::Arc};
/// The [StageTestDB] is used as an internal
/// database for testing stage implementation.
///
/// ```rust
/// let db = StageTestDB::default();
/// stage.execute(&mut db.container(), input);
/// ```
pub(crate) struct StageTestDB {
db: Arc<Env<WriteMap>>,
}
impl Default for StageTestDB {
/// Create a new instance of [StageTestDB]
fn default() -> Self {
Self { db: create_test_db::<WriteMap>(EnvKind::RW) }
}
}
impl StageTestDB {
/// Return a database wrapped in [DBContainer].
fn container(&self) -> DBContainer<'_, Env<WriteMap>> {
DBContainer::new(self.db.borrow()).expect("failed to create db container")
}
/// Get a pointer to an internal database.
pub(crate) fn inner(&self) -> Arc<Env<WriteMap>> {
self.db.clone()
}
/// Invoke a callback with transaction committing it afterwards
pub(crate) fn commit<F>(&self, f: F) -> Result<(), db::Error>
where
F: FnOnce(&mut Tx<'_, RW, WriteMap>) -> Result<(), db::Error>,
{
let mut db = self.container();
let tx = db.get_mut();
f(tx)?;
db.commit()?;
Ok(())
}
/// Invoke a callback with a read transaction
pub(crate) fn query<F, R>(&self, f: F) -> Result<R, db::Error>
where
F: FnOnce(&Tx<'_, RW, WriteMap>) -> Result<R, db::Error>,
{
f(self.container().get())
}
/// Map a collection of values and store them in the database.
/// This function commits the transaction before exiting.
///
/// ```rust
/// let db = StageTestDB::default();
/// db.map_put::<Table, _, _>(&items, |item| item)?;
/// ```
pub(crate) fn map_put<T, S, F>(&self, values: &[S], mut map: F) -> Result<(), db::Error>
where
T: Table,
S: Clone,
F: FnMut(&S) -> (T::Key, T::Value),
{
self.commit(|tx| {
values.iter().try_for_each(|src| {
let (k, v) = map(src);
tx.put::<T>(k, v)
})
})
}
/// Transform a collection of values using a callback and store
/// them in the database. The callback additionally accepts the
/// optional last element that was stored.
/// This function commits the transaction before exiting.
///
/// ```rust
/// let db = StageTestDB::default();
/// db.transform_append::<Table, _, _>(&items, |prev, item| prev.unwrap_or_default() + item)?;
/// ```
pub(crate) fn transform_append<T, S, F>(
&self,
values: &[S],
mut transform: F,
) -> Result<(), db::Error>
where
T: Table,
<T as Table>::Value: Clone,
S: Clone,
F: FnMut(&Option<<T as Table>::Value>, &S) -> (T::Key, T::Value),
{
self.commit(|tx| {
let mut cursor = tx.cursor_mut::<T>()?;
let mut last = cursor.last()?.map(|(_, v)| v);
values.iter().try_for_each(|src| {
let (k, v) = transform(&last, src);
last = Some(v.clone());
cursor.append(k, v)
})
})
}
/// Check that there is no table entry above a given
/// block by [Table::Key]
pub(crate) fn check_no_entry_above<T, F>(
&self,
block: BlockNumber,
mut selector: F,
) -> Result<(), db::Error>
where
T: Table,
F: FnMut(T::Key) -> BlockNumber,
{
self.query(|tx| {
let mut cursor = tx.cursor::<T>()?;
if let Some((key, _)) = cursor.last()? {
assert!(selector(key) <= block);
}
Ok(())
})
}
/// Check that there is no table entry above a given
/// block by [Table::Value]
pub(crate) fn check_no_entry_above_by_value<T, F>(
&self,
block: BlockNumber,
mut selector: F,
) -> Result<(), db::Error>
where
T: Table,
F: FnMut(T::Value) -> BlockNumber,
{
self.query(|tx| {
let mut cursor = tx.cursor::<T>()?;
if let Some((_, value)) = cursor.last()? {
assert!(selector(value) <= block);
}
Ok(())
})
}
/// Insert ordered collection of [SealedHeader] into the corresponding tables
/// that are supposed to be populated by the headers stage.
pub(crate) fn insert_headers<'a, I>(&self, headers: I) -> Result<(), db::Error>
where
I: Iterator<Item = &'a SealedHeader>,
{
self.commit(|tx| {
let headers = headers.collect::<Vec<_>>();
let mut td = U256::from_big_endian(
&tx.cursor::<tables::HeaderTD>()?.last()?.map(|(_, v)| v).unwrap_or_default(),
);
for header in headers {
let key: BlockNumHash = (header.number, header.hash()).into();
tx.put::<tables::CanonicalHeaders>(header.number, header.hash())?;
tx.put::<tables::HeaderNumbers>(header.hash(), header.number)?;
tx.put::<tables::Headers>(key, header.clone().unseal())?;
td += header.difficulty;
tx.put::<tables::HeaderTD>(key, H256::from_uint(&td).as_bytes().to_vec())?;
}
Ok(())
})
}
}

View File

@ -135,189 +135,3 @@ pub(crate) mod unwind {
Ok(())
}
}
#[cfg(test)]
pub(crate) mod test_utils {
use reth_db::{
kv::{test_utils::create_test_db, Env, EnvKind},
mdbx::WriteMap,
};
use reth_interfaces::db::{DBContainer, DbCursorRO, DbCursorRW, DbTx, DbTxMut, Error, Table};
use reth_primitives::BlockNumber;
use std::{borrow::Borrow, sync::Arc};
use tokio::sync::oneshot;
use crate::{ExecInput, ExecOutput, Stage, StageError, UnwindInput, UnwindOutput};
/// The [StageTestDB] is used as an internal
/// database for testing stage implementation.
///
/// ```rust
/// let db = StageTestDB::default();
/// stage.execute(&mut db.container(), input);
/// ```
pub(crate) struct StageTestDB {
db: Arc<Env<WriteMap>>,
}
impl Default for StageTestDB {
/// Create a new instance of [StageTestDB]
fn default() -> Self {
Self { db: create_test_db::<WriteMap>(EnvKind::RW) }
}
}
impl StageTestDB {
/// Get a pointer to an internal database.
pub(crate) fn inner(&self) -> Arc<Env<WriteMap>> {
self.db.clone()
}
/// Return a database wrapped in [DBContainer].
pub(crate) fn container(&self) -> DBContainer<'_, Env<WriteMap>> {
DBContainer::new(self.db.borrow()).expect("failed to create db container")
}
/// Map a collection of values and store them in the database.
/// This function commits the transaction before exiting.
///
/// ```rust
/// let db = StageTestDB::default();
/// db.map_put::<Table, _, _>(&items, |item| item)?;
/// ```
pub(crate) fn map_put<T, S, F>(&self, values: &[S], mut map: F) -> Result<(), Error>
where
T: Table,
S: Clone,
F: FnMut(&S) -> (T::Key, T::Value),
{
let mut db = self.container();
let tx = db.get_mut();
values.iter().try_for_each(|src| {
let (k, v) = map(src);
tx.put::<T>(k, v)
})?;
db.commit()?;
Ok(())
}
/// Transform a collection of values using a callback and store
/// them in the database. The callback additionally accepts the
/// optional last element that was stored.
/// This function commits the transaction before exiting.
///
/// ```rust
/// let db = StageTestDB::default();
/// db.transform_append::<Table, _, _>(&items, |prev, item| prev.unwrap_or_default() + item)?;
/// ```
pub(crate) fn transform_append<T, S, F>(
&self,
values: &[S],
mut transform: F,
) -> Result<(), Error>
where
T: Table,
<T as Table>::Value: Clone,
S: Clone,
F: FnMut(&Option<<T as Table>::Value>, &S) -> (T::Key, T::Value),
{
let mut db = self.container();
let tx = db.get_mut();
let mut cursor = tx.cursor_mut::<T>()?;
let mut last = cursor.last()?.map(|(_, v)| v);
values.iter().try_for_each(|src| {
let (k, v) = transform(&last, src);
last = Some(v.clone());
cursor.append(k, v)
})?;
db.commit()?;
Ok(())
}
/// Check that there is no table entry above a given
/// block by [Table::Key]
pub(crate) fn check_no_entry_above<T, F>(
&self,
block: BlockNumber,
mut selector: F,
) -> Result<(), Error>
where
T: Table,
F: FnMut(T::Key) -> BlockNumber,
{
let db = self.container();
let tx = db.get();
let mut cursor = tx.cursor::<T>()?;
if let Some((key, _)) = cursor.last()? {
assert!(selector(key) <= block);
}
Ok(())
}
/// Check that there is no table entry above a given
/// block by [Table::Value]
pub(crate) fn check_no_entry_above_by_value<T, F>(
&self,
block: BlockNumber,
mut selector: F,
) -> Result<(), Error>
where
T: Table,
F: FnMut(T::Value) -> BlockNumber,
{
let db = self.container();
let tx = db.get();
let mut cursor = tx.cursor::<T>()?;
if let Some((_, value)) = cursor.last()? {
assert!(selector(value) <= block);
}
Ok(())
}
}
/// A generic test runner for stages.
#[async_trait::async_trait]
pub(crate) trait StageTestRunner {
type S: Stage<Env<WriteMap>> + 'static;
/// Return a reference to the database.
fn db(&self) -> &StageTestDB;
/// Return an instance of a Stage.
fn stage(&self) -> Self::S;
/// Run [Stage::execute] and return a receiver for the result.
fn execute(&self, input: ExecInput) -> oneshot::Receiver<Result<ExecOutput, StageError>> {
let (tx, rx) = oneshot::channel();
let (db, mut stage) = (self.db().inner(), self.stage());
tokio::spawn(async move {
let mut db = DBContainer::new(db.borrow()).expect("failed to create db container");
let result = stage.execute(&mut db, input).await;
db.commit().expect("failed to commit");
tx.send(result).expect("failed to send message")
});
rx
}
/// Run [Stage::unwind] and return a receiver for the result.
fn unwind(
&self,
input: UnwindInput,
) -> oneshot::Receiver<Result<UnwindOutput, Box<dyn std::error::Error + Send + Sync>>>
{
let (tx, rx) = oneshot::channel();
let (db, mut stage) = (self.db().inner(), self.stage());
tokio::spawn(async move {
let mut db = DBContainer::new(db.borrow()).expect("failed to create db container");
let result = stage.unwind(&mut db, input).await;
db.commit().expect("failed to commit");
tx.send(result).expect("failed to send result");
});
rx
}
}
}