feat(engine): introduce sync implementation of StateRootTask (#12378)

This commit is contained in:
Federico Gimenez
2024-11-18 14:58:31 +01:00
committed by GitHub
parent 26ce7fbdb2
commit 8339c716b4
4 changed files with 257 additions and 54 deletions

4
Cargo.lock generated
View File

@ -7224,9 +7224,10 @@ dependencies = [
"alloy-rlp",
"alloy-rpc-types-engine",
"assert_matches",
"criterion",
"crossbeam-channel",
"futures",
"metrics",
"pin-project",
"reth-beacon-consensus",
"reth-blockchain-tree",
"reth-blockchain-tree-api",
@ -7261,7 +7262,6 @@ dependencies = [
"revm-primitives",
"thiserror 1.0.69",
"tokio",
"tokio-stream",
"tracing",
]

View File

@ -45,9 +45,7 @@ revm-primitives.workspace = true
# common
futures.workspace = true
pin-project.workspace = true
tokio = { workspace = true, features = ["macros", "sync"] }
tokio-stream.workspace = true
thiserror.workspace = true
# metrics
@ -82,6 +80,12 @@ reth-chainspec.workspace = true
alloy-rlp.workspace = true
assert_matches.workspace = true
criterion.workspace = true
crossbeam-channel = "0.5.13"
[[bench]]
name = "channel_perf"
harness = false
[features]
test-utils = [

View File

@ -0,0 +1,132 @@
//! Benchmark comparing `std::sync::mpsc` and `crossbeam` channels for `StateRootTask`.
#![allow(missing_docs)]
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
use revm_primitives::{
Account, AccountInfo, AccountStatus, Address, EvmState, EvmStorage, EvmStorageSlot, HashMap,
B256, U256,
};
use std::thread;
/// Creates a mock state with the specified number of accounts for benchmarking
fn create_bench_state(num_accounts: usize) -> EvmState {
let mut state_changes = HashMap::default();
for i in 0..num_accounts {
let storage =
EvmStorage::from_iter([(U256::from(i), EvmStorageSlot::new(U256::from(i + 1)))]);
let account = Account {
info: AccountInfo {
balance: U256::from(100),
nonce: 10,
code_hash: B256::random(),
code: Default::default(),
},
storage,
status: AccountStatus::Loaded,
};
let address = Address::random();
state_changes.insert(address, account);
}
state_changes
}
/// Simulated `StateRootTask` with `std::sync::mpsc`
struct StdStateRootTask {
rx: std::sync::mpsc::Receiver<EvmState>,
}
impl StdStateRootTask {
const fn new(rx: std::sync::mpsc::Receiver<EvmState>) -> Self {
Self { rx }
}
fn run(self) {
while let Ok(state) = self.rx.recv() {
criterion::black_box(state);
}
}
}
/// Simulated `StateRootTask` with `crossbeam-channel`
struct CrossbeamStateRootTask {
rx: crossbeam_channel::Receiver<EvmState>,
}
impl CrossbeamStateRootTask {
const fn new(rx: crossbeam_channel::Receiver<EvmState>) -> Self {
Self { rx }
}
fn run(self) {
while let Ok(state) = self.rx.recv() {
criterion::black_box(state);
}
}
}
/// Benchmarks the performance of different channel implementations for state streaming
fn bench_state_stream(c: &mut Criterion) {
let mut group = c.benchmark_group("state_stream_channels");
group.sample_size(10);
for size in &[1, 10, 100] {
let bench_setup = || {
let states: Vec<_> = (0..100).map(|_| create_bench_state(*size)).collect();
states
};
group.bench_with_input(BenchmarkId::new("std_channel", size), size, |b, _| {
b.iter_batched(
bench_setup,
|states| {
let (tx, rx) = std::sync::mpsc::channel();
let task = StdStateRootTask::new(rx);
let processor = thread::spawn(move || {
task.run();
});
for state in states {
tx.send(state).unwrap();
}
drop(tx);
processor.join().unwrap();
},
BatchSize::LargeInput,
);
});
group.bench_with_input(BenchmarkId::new("crossbeam_channel", size), size, |b, _| {
b.iter_batched(
bench_setup,
|states| {
let (tx, rx) = crossbeam_channel::unbounded();
let task = CrossbeamStateRootTask::new(rx);
let processor = thread::spawn(move || {
task.run();
});
for state in states {
tx.send(state).unwrap();
}
drop(tx);
processor.join().unwrap();
},
BatchSize::LargeInput,
);
});
}
group.finish();
}
criterion_group!(benches, bench_state_stream);
criterion_main!(benches);

View File

@ -1,18 +1,13 @@
//! State root task related functionality.
use futures::Stream;
use pin_project::pin_project;
use reth_provider::providers::ConsistentDbView;
use reth_trie::{updates::TrieUpdates, TrieInput};
use reth_trie_parallel::root::ParallelStateRootError;
use revm_primitives::{EvmState, B256};
use std::{
future::Future,
pin::Pin,
sync::{mpsc, Arc},
task::{Context, Poll},
use std::sync::{
mpsc::{self, Receiver, RecvError},
Arc,
};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::debug;
/// Result of the state root calculation
@ -28,12 +23,43 @@ pub(crate) struct StateRootHandle {
#[allow(dead_code)]
impl StateRootHandle {
/// Creates a new handle from a receiver.
pub(crate) const fn new(rx: mpsc::Receiver<StateRootResult>) -> Self {
Self { rx }
}
/// Waits for the state root calculation to complete.
pub(crate) fn wait_for_result(self) -> StateRootResult {
self.rx.recv().expect("state root task was dropped without sending result")
}
}
/// Common configuration for state root tasks
#[derive(Debug)]
pub(crate) struct StateRootConfig<Factory> {
/// View over the state in the database.
pub consistent_view: ConsistentDbView<Factory>,
/// Latest trie input.
pub input: Arc<TrieInput>,
}
/// Wrapper for std channel receiver to maintain compatibility with `UnboundedReceiverStream`
#[allow(dead_code)]
pub(crate) struct StdReceiverStream {
rx: Receiver<EvmState>,
}
#[allow(dead_code)]
impl StdReceiverStream {
pub(crate) const fn new(rx: Receiver<EvmState>) -> Self {
Self { rx }
}
pub(crate) fn recv(&self) -> Result<EvmState, RecvError> {
self.rx.recv()
}
}
/// Standalone task that receives a transaction state stream and updates relevant
/// data structures to calculate state root.
///
@ -42,15 +68,12 @@ impl StateRootHandle {
/// fetches the proofs for relevant accounts from the database and reveal them
/// to the tree.
/// Then it updates relevant leaves according to the result of the transaction.
#[pin_project]
#[allow(dead_code)]
pub(crate) struct StateRootTask<Factory> {
/// View over the state in the database.
consistent_view: ConsistentDbView<Factory>,
/// Incoming state updates.
#[pin]
state_stream: UnboundedReceiverStream<EvmState>,
/// Latest trie input.
input: Arc<TrieInput>,
state_stream: StdReceiverStream,
/// Task configuration.
config: StateRootConfig<Factory>,
}
#[allow(dead_code)]
@ -60,65 +83,109 @@ where
{
/// Creates a new `StateRootTask`.
pub(crate) const fn new(
consistent_view: ConsistentDbView<Factory>,
input: Arc<TrieInput>,
state_stream: UnboundedReceiverStream<EvmState>,
config: StateRootConfig<Factory>,
state_stream: StdReceiverStream,
) -> Self {
Self { consistent_view, state_stream, input }
Self { config, state_stream }
}
/// Spawns the state root task and returns a handle to await its result.
pub(crate) fn spawn(self) -> StateRootHandle {
let (tx, rx) = mpsc::channel();
// Spawn the task that will process state updates and calculate the root
tokio::spawn(async move {
let (tx, rx) = mpsc::sync_channel(1);
std::thread::Builder::new()
.name("State Root Task".to_string())
.spawn(move || {
debug!(target: "engine::tree", "Starting state root task");
let result = self.await;
let result = self.run();
let _ = tx.send(result);
});
})
.expect("failed to spawn state root thread");
StateRootHandle { rx }
StateRootHandle::new(rx)
}
/// Handles state updates.
fn on_state_update(
_view: &ConsistentDbView<Factory>,
_input: &Arc<TrieInput>,
_view: &reth_provider::providers::ConsistentDbView<impl Send + 'static>,
_input: &std::sync::Arc<reth_trie::TrieInput>,
_state: EvmState,
) {
// Default implementation of state update handling
// TODO: calculate hashed state update and dispatch proof gathering for it.
}
}
impl<Factory> Future for StateRootTask<Factory>
#[allow(dead_code)]
impl<Factory> StateRootTask<Factory>
where
Factory: Send + 'static,
{
type Output = StateRootResult;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
// Process all items until the stream is closed
loop {
match this.state_stream.as_mut().poll_next(cx) {
Poll::Ready(Some(state)) => {
Self::on_state_update(this.consistent_view, this.input, state);
}
Poll::Ready(None) => {
// stream closed, return final result
return Poll::Ready(Ok((B256::default(), TrieUpdates::default())));
}
Poll::Pending => {
return Poll::Pending;
}
}
fn run(self) -> StateRootResult {
while let Ok(state) = self.state_stream.recv() {
Self::on_state_update(&self.config.consistent_view, &self.config.input, state);
}
// TODO:
// * keep track of proof calculation
// * keep track of intermediate root computation
// * return final state root result
Ok((B256::default(), TrieUpdates::default()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use reth_provider::{providers::ConsistentDbView, test_utils::MockEthProvider};
use reth_trie::TrieInput;
use revm_primitives::{
Account, AccountInfo, AccountStatus, Address, EvmState, EvmStorage, EvmStorageSlot,
HashMap, B256, U256,
};
use std::sync::Arc;
fn create_mock_config() -> StateRootConfig<MockEthProvider> {
let factory = MockEthProvider::default();
let view = ConsistentDbView::new(factory, None);
let input = Arc::new(TrieInput::default());
StateRootConfig { consistent_view: view, input }
}
fn create_mock_state() -> revm_primitives::EvmState {
let mut state_changes: EvmState = HashMap::default();
let storage = EvmStorage::from_iter([(U256::from(1), EvmStorageSlot::new(U256::from(2)))]);
let account = Account {
info: AccountInfo {
balance: U256::from(100),
nonce: 10,
code_hash: B256::random(),
code: Default::default(),
},
storage,
status: AccountStatus::Loaded,
};
let address = Address::random();
state_changes.insert(address, account);
state_changes
}
#[test]
fn test_state_root_task() {
let config = create_mock_config();
let (tx, rx) = std::sync::mpsc::channel();
let stream = StdReceiverStream::new(rx);
let task = StateRootTask::new(config, stream);
let handle = task.spawn();
for _ in 0..10 {
tx.send(create_mock_state()).expect("failed to send state");
}
drop(tx);
let result = handle.wait_for_result();
assert!(result.is_ok(), "sync block execution failed");
}
}