feat(sync): sender recovery stage (#181)

* feat(sync): sender recovery stage

* execute tests

* more tests & cleanup

* senders cont

* clean up & comments

* clippy

* cleanup

* comments

* put back arbitrary dep
This commit is contained in:
Roman Krasiuk
2022-11-23 10:22:17 +02:00
committed by GitHub
parent 027fc2bbf2
commit 82b37b9bfb
12 changed files with 403 additions and 45 deletions

13
Cargo.lock generated
View File

@ -2984,11 +2984,10 @@ dependencies = [
[[package]]
name = "rayon"
version = "1.5.3"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd99e5772ead8baa5215278c9b15bf92087709e9c1b2d1f97cdb5a183c933a7d"
checksum = "1e060280438193c554f654141c9ea9417886713b7acd75974c85b18a69a88e0b"
dependencies = [
"autocfg",
"crossbeam-deque",
"either",
"rayon-core",
@ -2996,9 +2995,9 @@ dependencies = [
[[package]]
name = "rayon-core"
version = "1.9.3"
version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "258bcdb5ac6dad48491bb2992db6b7cf74878b0384908af124823d118c99683f"
checksum = "cac410af5d00ab6884528b4ab69d1e8e146e8d471201800fa1b4524126de6ad3"
dependencies = [
"crossbeam-channel",
"crossbeam-deque",
@ -3280,6 +3279,7 @@ dependencies = [
"bytes",
"futures",
"heapless",
"hex-literal",
"modular-bitfield",
"parity-scale-codec",
"postcard",
@ -3289,6 +3289,7 @@ dependencies = [
"reth-eth-wire",
"reth-primitives",
"reth-rpc-types",
"secp256k1",
"serde",
"test-fuzz",
"thiserror",
@ -3486,8 +3487,10 @@ dependencies = [
"assert_matches",
"async-trait",
"futures-util",
"itertools 0.10.5",
"metrics",
"rand 0.8.5",
"rayon",
"reth-bodies-downloaders",
"reth-db",
"reth-eth-wire",

View File

@ -27,7 +27,8 @@ parity-scale-codec = { version = "3.2.1", features = ["bytes"] }
futures = "0.3.25"
tokio-stream = "0.1.11"
rand = "0.8.5"
arbitrary = { version = "1.1.7", features = ["derive"], optional = true}
arbitrary = { version = "1.1.7", features = ["derive"], optional = true }
secp256k1 = { version = "0.24.0", default-features = false, features = ["alloc", "recovery", "rand"], optional = true }
modular-bitfield = "0.11.2"
[dev-dependencies]
@ -36,7 +37,9 @@ test-fuzz = "3.0.4"
tokio = { version = "1.21.2", features = ["full"] }
tokio-stream = { version = "0.1.11", features = ["sync"] }
arbitrary = { version = "1.1.7", features = ["derive"]}
hex-literal = "0.3"
secp256k1 = { version = "0.24.0", default-features = false, features = ["alloc", "recovery", "rand"] }
[features]
bench = []
test-utils = ["tokio-stream/sync"]
test-utils = ["tokio-stream/sync", "secp256k1"]

View File

@ -3,6 +3,7 @@ use reth_primitives::{
proofs, Address, BlockLocked, Bytes, Header, SealedHeader, Signature, Transaction,
TransactionKind, TransactionSigned, H256, U256,
};
use secp256k1::{KeyPair, Message as SecpMessage, Secp256k1, SecretKey};
// TODO(onbjerg): Maybe we should split this off to its own crate, or move the helpers to the
// relevant crates?
@ -62,18 +63,27 @@ pub fn random_tx() -> Transaction {
///
/// - There is no guarantee that the nonce is not used twice for the same account
pub fn random_signed_tx() -> TransactionSigned {
let secp = Secp256k1::new();
let key_pair = KeyPair::new(&secp, &mut rand::thread_rng());
let tx = random_tx();
let hash = tx.signature_hash();
TransactionSigned {
transaction: tx,
hash,
signature: Signature {
// TODO
r: Default::default(),
s: Default::default(),
odd_y_parity: false,
},
}
let signature =
sign_message(H256::from_slice(&key_pair.secret_bytes()[..]), tx.signature_hash()).unwrap();
TransactionSigned::from_transaction_and_signature(tx, signature)
}
/// Signs message with the given secret key.
/// Returns the corresponding signature.
pub fn sign_message(secret: H256, message: H256) -> Result<Signature, secp256k1::Error> {
let secp = Secp256k1::new();
let sec = SecretKey::from_slice(secret.as_ref())?;
let s = secp.sign_ecdsa_recoverable(&SecpMessage::from_slice(&message[..])?, &sec);
let (rec_id, data) = s.serialize_compact();
Ok(Signature {
r: U256::from_big_endian(&data[..32]),
s: U256::from_big_endian(&data[32..64]),
odd_y_parity: rec_id.to_i32() != 0,
})
}
/// Generate a random block filled with a random number of signed transactions (generated using
@ -139,3 +149,45 @@ pub fn random_block_range(rng: std::ops::Range<u64>, head: H256) -> Vec<BlockLoc
}
blocks
}
#[cfg(test)]
mod test {
use super::*;
use hex_literal::hex;
use reth_primitives::{keccak256, AccessList, Address, TransactionKind};
use secp256k1::KeyPair;
#[test]
fn test_sign_message() {
let secp = Secp256k1::new();
let tx = Transaction::Eip1559 {
chain_id: 1,
nonce: 0x42,
gas_limit: 44386,
to: TransactionKind::Call(hex!("6069a6c32cf691f5982febae4faf8a6f3ab2f0f6").into()),
value: 0_u128,
input: hex!("a22cb4650000000000000000000000005eee75727d804a2b13038928d36f8b188945a57a0000000000000000000000000000000000000000000000000000000000000000").into(),
max_fee_per_gas: 0x4a817c800,
max_priority_fee_per_gas: 0x3b9aca00,
access_list: AccessList::default(),
};
let signature_hash = tx.signature_hash();
for _ in 0..100 {
let key_pair = KeyPair::new(&secp, &mut rand::thread_rng());
let signature =
sign_message(H256::from_slice(&key_pair.secret_bytes()[..]), signature_hash)
.unwrap();
let signed = TransactionSigned::from_transaction_and_signature(tx.clone(), signature);
let recovered = signed.recover_signer().unwrap();
let public_key_hash = keccak256(&key_pair.public_key().serialize_uncompressed()[1..]);
let expected = Address::from_slice(&public_key_hash[12..]);
assert_eq!(recovered, expected);
}
}
}

View File

@ -19,6 +19,8 @@ tokio = { version = "1.21.2", features = ["sync"] }
aquamarine = "0.1.12"
metrics = "0.20.1"
futures-util = "0.3.25"
itertools = "0.10.5"
rayon = "1.6.0"
[dev-dependencies]
reth-db = { path = "../db", features = ["test-utils"] }

View File

@ -12,6 +12,13 @@ pub struct ExecInput {
pub stage_progress: Option<BlockNumber>,
}
impl ExecInput {
/// Return the progress of the previous stage or default.
pub fn previous_stage_progress(&self) -> BlockNumber {
self.previous_stage.as_ref().map(|(_, num)| *num).unwrap_or_default()
}
}
/// Stage unwind input, see [Stage::unwind].
#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
pub struct UnwindInput {

View File

@ -77,8 +77,7 @@ impl<DB: Database, D: BodyDownloader, C: Consensus> Stage<DB> for BodyStage<D, C
) -> Result<ExecOutput, StageError> {
let tx = db.get_mut();
let previous_stage_progress =
input.previous_stage.as_ref().map(|(_, block)| *block).unwrap_or_default();
let previous_stage_progress = input.previous_stage_progress();
if previous_stage_progress == 0 {
warn!("The body stage seems to be running first, no work can be completed.");
return Err(StageError::DatabaseIntegrity(DatabaseIntegrityError::BlockBody {
@ -547,8 +546,7 @@ mod tests {
fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
let start = input.stage_progress.unwrap_or_default();
let end =
input.previous_stage.as_ref().map(|(_, num)| *num + 1).unwrap_or_default();
let end = input.previous_stage_progress() + 1;
let blocks = random_block_range(start..end, GENESIS_HASH);
self.insert_genesis()?;
self.db.insert_headers(blocks.iter().map(|block| &block.header))?;

View File

@ -2,5 +2,7 @@
pub mod bodies;
/// The headers stage.
pub mod headers;
/// The sender recovery stage.
pub mod senders;
/// The cumulative transaction index stage.
pub mod tx_index;

View File

@ -0,0 +1,285 @@
use crate::{
util::unwind::unwind_table_by_num, DatabaseIntegrityError, ExecInput, ExecOutput, Stage,
StageError, StageId, UnwindInput, UnwindOutput,
};
use itertools::Itertools;
use rayon::prelude::*;
use reth_interfaces::db::{
self, tables, DBContainer, Database, DbCursorRO, DbCursorRW, DbTx, DbTxMut,
};
use reth_primitives::TxNumber;
use std::fmt::Debug;
use thiserror::Error;
const SENDERS: StageId = StageId("Senders");
/// The senders stage iterates over existing transactions,
/// recovers the transaction signer and stores them
/// in [`TxSenders`][reth_interfaces::db::tables::TxSenders] table.
#[derive(Debug)]
pub struct SendersStage {
/// The size of the chunk for parallel sender recovery
pub batch_size: usize,
}
#[derive(Error, Debug)]
enum SendersStageError {
#[error("Sender recovery failed for transaction {tx}.")]
SenderRecovery { tx: TxNumber },
}
impl From<SendersStageError> for StageError {
fn from(error: SendersStageError) -> Self {
StageError::Internal(Box::new(error))
}
}
#[async_trait::async_trait]
impl<DB: Database> Stage<DB> for SendersStage {
/// Return the id of the stage
fn id(&self) -> StageId {
SENDERS
}
/// Retrieve the range of transactions to iterate over by querying
/// [`CumulativeTxCount`][reth_interfaces::db::tables::CumulativeTxCount],
/// collect transactions within that range,
/// recover signer for each transaction and store entries in
/// the [`TxSenders`][reth_interfaces::db::tables::TxSenders] table.
async fn execute(
&mut self,
db: &mut DBContainer<'_, DB>,
input: ExecInput,
) -> Result<ExecOutput, StageError> {
let tx = db.get_mut();
// Look up the start index for transaction range
let last_block_num = input.stage_progress.unwrap_or_default();
let last_block_hash = tx
.get::<tables::CanonicalHeaders>(last_block_num)?
.ok_or(DatabaseIntegrityError::CanonicalHash { number: last_block_num })?;
let start_tx_index = tx
.get::<tables::CumulativeTxCount>((last_block_num, last_block_hash).into())?
.ok_or(DatabaseIntegrityError::CumulativeTxCount {
number: last_block_num,
hash: last_block_hash,
})?;
// Look up the end index for transaction range (exclusive)
let max_block_num = input.previous_stage_progress();
let max_block_hash = tx
.get::<tables::CanonicalHeaders>(max_block_num)?
.ok_or(DatabaseIntegrityError::CanonicalHash { number: max_block_num })?;
let end_tx_index = tx
.get::<tables::CumulativeTxCount>((max_block_num, max_block_hash).into())?
.ok_or(DatabaseIntegrityError::CumulativeTxCount {
number: last_block_num,
hash: last_block_hash,
})?;
// Acquire the cursor for inserting elements
let mut senders_cursor = tx.cursor_mut::<tables::TxSenders>()?;
// Acquire the cursor over the transactions
let mut tx_cursor = tx.cursor::<tables::Transactions>()?;
// Walk the transactions from start to end index (exclusive)
let entries = tx_cursor
.walk(start_tx_index)?
.take_while(|res| res.as_ref().map(|(k, _)| *k < end_tx_index).unwrap_or_default());
// Iterate over transactions in chunks
for chunk in &entries.chunks(self.batch_size) {
let transactions = chunk.collect::<Result<Vec<_>, db::Error>>()?;
// Recover signers for the chunk in parallel
let recovered = transactions
.into_par_iter()
.map(|(id, transaction)| {
let signer =
transaction.recover_signer().ok_or_else::<StageError, _>(|| {
SendersStageError::SenderRecovery { tx: id }.into()
})?;
Ok((id, signer))
})
.collect::<Result<Vec<_>, StageError>>()?;
// Append the signers to the table
recovered.into_iter().try_for_each(|(id, sender)| senders_cursor.append(id, sender))?;
}
Ok(ExecOutput { stage_progress: max_block_num, done: true, reached_tip: true })
}
/// Unwind the stage.
async fn unwind(
&mut self,
db: &mut DBContainer<'_, DB>,
input: UnwindInput,
) -> Result<UnwindOutput, Box<dyn std::error::Error + Send + Sync>> {
let tx = db.get_mut();
// Look up the hash of the unwind block
if let Some(unwind_hash) = tx.get::<tables::CanonicalHeaders>(input.unwind_to)? {
// Look up the cumulative tx count at unwind block
let latest_tx = tx
.get::<tables::CumulativeTxCount>((input.unwind_to, unwind_hash).into())?
.ok_or(DatabaseIntegrityError::CumulativeTxCount {
number: input.unwind_to,
hash: unwind_hash,
})?;
unwind_table_by_num::<DB, tables::TxSenders>(tx, latest_tx - 1)?;
}
Ok(UnwindOutput { stage_progress: input.unwind_to })
}
}
#[cfg(test)]
mod tests {
use reth_interfaces::{
db::models::StoredBlockBody, test_utils::generators::random_block_range,
};
use reth_primitives::{BlockLocked, BlockNumber, H256};
use super::*;
use crate::test_utils::{
stage_test_suite, ExecuteStageTestRunner, StageTestDB, StageTestRunner, TestRunnerError,
UnwindStageTestRunner,
};
stage_test_suite!(SendersTestRunner);
#[derive(Default)]
struct SendersTestRunner {
db: StageTestDB,
}
impl StageTestRunner for SendersTestRunner {
type S = SendersStage;
fn db(&self) -> &StageTestDB {
&self.db
}
fn stage(&self) -> Self::S {
SendersStage { batch_size: 100 }
}
}
impl ExecuteStageTestRunner for SendersTestRunner {
type Seed = Vec<BlockLocked>;
fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
let stage_progress = input.stage_progress.unwrap_or_default();
let end = input.previous_stage_progress() + 1;
let blocks = random_block_range(stage_progress..end, H256::zero());
self.db.commit(|tx| {
let mut base_tx_id = 0;
blocks.iter().try_for_each(|b| {
let ommers = b.ommers.iter().map(|o| o.clone().unseal()).collect::<Vec<_>>();
let txs = b.body.clone();
let tx_amount = txs.len() as u64;
let num_hash = (b.number, b.hash()).into();
tx.put::<tables::CanonicalHeaders>(b.number, b.hash())?;
tx.put::<tables::BlockBodies>(
num_hash,
StoredBlockBody { base_tx_id, tx_amount, ommers },
)?;
tx.put::<tables::CumulativeTxCount>(num_hash, base_tx_id + tx_amount)?;
for body_tx in txs {
tx.put::<tables::Transactions>(base_tx_id, body_tx)?;
base_tx_id += 1;
}
Ok(())
})?;
Ok(())
})?;
Ok(blocks)
}
fn validate_execution(
&self,
input: ExecInput,
output: Option<ExecOutput>,
) -> Result<(), TestRunnerError> {
if let Some(output) = output {
self.db.query(|tx| {
let start_block = input.stage_progress.unwrap_or_default() + 1;
let end_block = output.stage_progress;
if start_block > end_block {
return Ok(())
}
let start_hash = tx.get::<tables::CanonicalHeaders>(start_block)?.unwrap();
let mut body_cursor = tx.cursor::<tables::BlockBodies>()?;
let mut body_walker = body_cursor.walk((start_block, start_hash).into())?;
while let Some(entry) = body_walker.next() {
let (_, body) = entry?;
for tx_id in body.base_tx_id..body.base_tx_id + body.tx_amount {
let transaction = tx
.get::<tables::Transactions>(tx_id)?
.expect("no transaction entry");
let signer =
transaction.recover_signer().expect("failed to recover signer");
assert_eq!(Some(signer), tx.get::<tables::TxSenders>(tx_id)?);
}
}
Ok(())
})?;
} else {
self.check_no_transaction_by_block(input.stage_progress.unwrap_or_default())?;
}
Ok(())
}
}
impl UnwindStageTestRunner for SendersTestRunner {
fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
self.check_no_transaction_by_block(input.unwind_to)
}
}
impl SendersTestRunner {
fn check_no_transaction_by_block(&self, block: BlockNumber) -> Result<(), TestRunnerError> {
match self.get_block_body_entry(block)? {
Some(body) => {
let last_index = body.base_tx_id + body.tx_amount;
self.db.check_no_entry_above::<tables::TxSenders, _>(last_index, |key| key)?;
}
None => {
assert!(self.db.table_is_empty::<tables::TxSenders>()?);
}
};
Ok(())
}
/// Get the block body entry at block number. If it doesn't exist,
/// fallback to the previous entry.
fn get_block_body_entry(
&self,
block: BlockNumber,
) -> Result<Option<StoredBlockBody>, TestRunnerError> {
let entry = self.db.query(|tx| match tx.get::<tables::CanonicalHeaders>(block)? {
Some(hash) => {
let mut body_cursor = tx.cursor::<tables::BlockBodies>()?;
let entry = match body_cursor.seek_exact((block, hash).into())? {
Some((_, block)) => Some(block),
_ => body_cursor.prev()?.map(|(_, block)| block),
};
Ok(entry)
}
None => Ok(None),
})?;
Ok(entry)
}
}
}

View File

@ -46,7 +46,7 @@ impl<DB: Database> Stage<DB> for TxIndex {
.ok_or(DatabaseIntegrityError::CanonicalHeader { number: start_block })?;
// The maximum block that this stage should insert to
let max_block = input.previous_stage.as_ref().map(|(_, block)| *block).unwrap_or_default();
let max_block = input.previous_stage_progress();
// Get the cursor over the table
let mut cursor = tx.cursor_mut::<tables::CumulativeTxCount>()?;
@ -91,7 +91,6 @@ mod tests {
stage_test_suite, ExecuteStageTestRunner, StageTestDB, StageTestRunner, TestRunnerError,
UnwindStageTestRunner,
};
use assert_matches::assert_matches;
use reth_interfaces::{
db::models::{BlockNumHash, StoredBlockBody},
test_utils::generators::random_header_range,
@ -123,7 +122,7 @@ mod tests {
fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
let pivot = input.stage_progress.unwrap_or_default();
let start = pivot.saturating_sub(100);
let mut end = input.previous_stage.as_ref().map(|(_, num)| *num).unwrap_or_default();
let mut end = input.previous_stage_progress();
end += 2; // generate 2 additional headers to account for start header lookup
let headers = random_header_range(start..end, H256::zero());
@ -157,11 +156,10 @@ mod tests {
input: ExecInput,
_output: Option<ExecOutput>,
) -> Result<(), TestRunnerError> {
// TODO: validate that base_tx_index of next block equals the cum count at current
self.db.query(|tx| {
let (start, end) = (
input.stage_progress.unwrap_or_default(),
input.previous_stage.as_ref().map(|(_, num)| *num).unwrap_or_default(),
);
let (start, end) =
(input.stage_progress.unwrap_or_default(), input.previous_stage_progress());
if start >= end {
return Ok(())
}

View File

@ -11,7 +11,7 @@ macro_rules! stage_test_suite {
// Run stage execution
let result = runner.execute(input).await.unwrap();
assert_matches!(
assert_matches::assert_matches!(
result,
Err(crate::error::StageError::DatabaseIntegrity(_))
);
@ -41,7 +41,7 @@ macro_rules! stage_test_suite {
// Assert the successful result
let result = rx.await.unwrap();
assert_matches!(
assert_matches::assert_matches!(
result,
Ok(ExecOutput { done, reached_tip, stage_progress })
if done && reached_tip && stage_progress == stage_progress
@ -70,7 +70,7 @@ macro_rules! stage_test_suite {
// Assert the successful result
let result = rx.await.unwrap();
assert_matches!(
assert_matches::assert_matches!(
result,
Ok(ExecOutput { done, reached_tip, stage_progress })
if done && reached_tip && stage_progress == previous_stage
@ -89,7 +89,7 @@ macro_rules! stage_test_suite {
// Run stage unwind
let rx = runner.unwind(input).await;
assert_matches!(
assert_matches::assert_matches!(
rx,
Ok(UnwindOutput { stage_progress }) if stage_progress == input.unwind_to
);
@ -117,7 +117,7 @@ macro_rules! stage_test_suite {
// Assert the successful execution result
let result = rx.await.unwrap();
assert_matches!(
assert_matches::assert_matches!(
result,
Ok(ExecOutput { done, reached_tip, stage_progress })
if done && reached_tip && stage_progress == previous_stage
@ -131,7 +131,7 @@ macro_rules! stage_test_suite {
let rx = runner.unwind(unwind_input).await;
// Assert the successful unwind result
assert_matches!(
assert_matches::assert_matches!(
rx,
Ok(UnwindOutput { stage_progress }) if stage_progress == unwind_input.unwind_to
);

View File

@ -57,6 +57,14 @@ impl StageTestDB {
f(self.container().get())
}
/// Check if the table is empty
pub(crate) fn table_is_empty<T: Table>(&self) -> Result<bool, db::Error> {
self.query(|tx| {
let last = tx.cursor::<T>()?.last()?;
Ok(last.is_none())
})
}
/// Map a collection of values and store them in the database.
/// This function commits the transaction before exiting.
///
@ -110,10 +118,10 @@ impl StageTestDB {
}
/// Check that there is no table entry above a given
/// block by [Table::Key]
/// number by [Table::Key]
pub(crate) fn check_no_entry_above<T, F>(
&self,
block: BlockNumber,
num: u64,
mut selector: F,
) -> Result<(), db::Error>
where
@ -123,17 +131,17 @@ impl StageTestDB {
self.query(|tx| {
let mut cursor = tx.cursor::<T>()?;
if let Some((key, _)) = cursor.last()? {
assert!(selector(key) <= block);
assert!(selector(key) <= num);
}
Ok(())
})
}
/// Check that there is no table entry above a given
/// block by [Table::Value]
/// number by [Table::Value]
pub(crate) fn check_no_entry_above_by_value<T, F>(
&self,
block: BlockNumber,
num: u64,
mut selector: F,
) -> Result<(), db::Error>
where
@ -143,7 +151,7 @@ impl StageTestDB {
self.query(|tx| {
let mut cursor = tx.cursor::<T>()?;
if let Some((_, value)) = cursor.last()? {
assert!(selector(value) <= block);
assert!(selector(value) <= num);
}
Ok(())
})

View File

@ -68,17 +68,17 @@ pub(crate) mod unwind {
};
use reth_primitives::BlockNumber;
/// Unwind table by block number key
/// Unwind table by some number key
#[inline]
pub(crate) fn unwind_table_by_num<DB, T>(
tx: &mut <DB as DatabaseGAT<'_>>::TXMut,
block: BlockNumber,
num: u64,
) -> Result<(), Error>
where
DB: Database,
T: Table<Key = BlockNumber>,
T: Table<Key = u64>,
{
unwind_table::<DB, T, _>(tx, block, |key| key)
unwind_table::<DB, T, _>(tx, num, |key| key)
}
/// Unwind table by composite block number hash key