feat: make downloaders and clients generic over block parts (#12469)

Co-authored-by: Matthias Seitz <matthias.seitz@outlook.de>
This commit is contained in:
Arsenii Kulikov
2024-11-12 19:13:21 +04:00
committed by GitHub
parent 3a337cd7d4
commit aece53ae88
60 changed files with 631 additions and 409 deletions

View File

@ -5,10 +5,11 @@ use crate::{
p2pstream::MAX_RESERVED_MESSAGE_ID,
protocol::{ProtoVersion, Protocol},
version::ParseVersionError,
Capability, EthMessage, EthMessageID, EthVersion,
Capability, EthMessageID, EthVersion,
};
use alloy_primitives::bytes::Bytes;
use derive_more::{Deref, DerefMut};
use reth_eth_wire_types::{EthMessage, EthNetworkPrimitives, NetworkPrimitives};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::{
@ -30,9 +31,13 @@ pub struct RawCapabilityMessage {
/// network.
#[derive(Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum CapabilityMessage {
pub enum CapabilityMessage<N: NetworkPrimitives = EthNetworkPrimitives> {
/// Eth sub-protocol message.
Eth(EthMessage),
#[cfg_attr(
feature = "serde",
serde(bound = "EthMessage<N>: Serialize + serde::de::DeserializeOwned")
)]
Eth(EthMessage<N>),
/// Any other capability message.
Other(RawCapabilityMessage),
}

View File

@ -8,6 +8,7 @@ use crate::{
use alloy_primitives::bytes::{Bytes, BytesMut};
use futures::{ready, Sink, SinkExt, StreamExt};
use pin_project::pin_project;
use reth_eth_wire_types::NetworkPrimitives;
use reth_primitives::{ForkFilter, GotExpected};
use std::{
pin::Pin,
@ -54,32 +55,32 @@ where
/// Consumes the [`UnauthedEthStream`] and returns an [`EthStream`] after the `Status`
/// handshake is completed successfully. This also returns the `Status` message sent by the
/// remote peer.
pub async fn handshake(
pub async fn handshake<N: NetworkPrimitives>(
self,
status: Status,
fork_filter: ForkFilter,
) -> Result<(EthStream<S>, Status), EthStreamError> {
) -> Result<(EthStream<S, N>, Status), EthStreamError> {
self.handshake_with_timeout(status, fork_filter, HANDSHAKE_TIMEOUT).await
}
/// Wrapper around handshake which enforces a timeout.
pub async fn handshake_with_timeout(
pub async fn handshake_with_timeout<N: NetworkPrimitives>(
self,
status: Status,
fork_filter: ForkFilter,
timeout_limit: Duration,
) -> Result<(EthStream<S>, Status), EthStreamError> {
) -> Result<(EthStream<S, N>, Status), EthStreamError> {
timeout(timeout_limit, Self::handshake_without_timeout(self, status, fork_filter))
.await
.map_err(|_| EthStreamError::StreamTimeout)?
}
/// Handshake with no timeout
pub async fn handshake_without_timeout(
pub async fn handshake_without_timeout<N: NetworkPrimitives>(
mut self,
status: Status,
fork_filter: ForkFilter,
) -> Result<(EthStream<S>, Status), EthStreamError> {
) -> Result<(EthStream<S, N>, Status), EthStreamError> {
trace!(
%status,
"sending eth status to peer"
@ -89,10 +90,8 @@ where
// The max length for a status with TTD is: <msg id = 1 byte> + <rlp(status) = 88 byte>
self.inner
.send(
alloy_rlp::encode(ProtocolMessage::from(
EthMessage::<EthNetworkPrimitives>::Status(status),
))
.into(),
alloy_rlp::encode(ProtocolMessage::<N>::from(EthMessage::<N>::Status(status)))
.into(),
)
.await?;
@ -112,15 +111,14 @@ where
}
let version = status.version;
let msg: ProtocolMessage =
match ProtocolMessage::decode_message(version, &mut their_msg.as_ref()) {
Ok(m) => m,
Err(err) => {
debug!("decode error in eth handshake: msg={their_msg:x}");
self.inner.disconnect(DisconnectReason::DisconnectRequested).await?;
return Err(EthStreamError::InvalidMessage(err))
}
};
let msg = match ProtocolMessage::<N>::decode_message(version, &mut their_msg.as_ref()) {
Ok(m) => m,
Err(err) => {
debug!("decode error in eth handshake: msg={their_msg:x}");
self.inner.disconnect(DisconnectReason::DisconnectRequested).await?;
return Err(EthStreamError::InvalidMessage(err))
}
};
// The following checks should match the checks in go-ethereum:
// https://github.com/ethereum/go-ethereum/blob/9244d5cd61f3ea5a7645fdf2a1a96d53421e412f/eth/protocols/eth/handshake.go#L87-L89
@ -194,19 +192,21 @@ where
/// compatible with eth-networking protocol messages, which get RLP encoded/decoded.
#[pin_project]
#[derive(Debug)]
pub struct EthStream<S> {
pub struct EthStream<S, N = EthNetworkPrimitives> {
/// Negotiated eth version.
version: EthVersion,
#[pin]
inner: S,
_pd: std::marker::PhantomData<N>,
}
impl<S> EthStream<S> {
impl<S, N> EthStream<S, N> {
/// Creates a new unauthed [`EthStream`] from a provided stream. You will need
/// to manually handshake a peer.
#[inline]
pub const fn new(version: EthVersion, inner: S) -> Self {
Self { version, inner }
Self { version, inner, _pd: std::marker::PhantomData }
}
/// Returns the eth version.
@ -234,15 +234,16 @@ impl<S> EthStream<S> {
}
}
impl<S, E> EthStream<S>
impl<S, E, N> EthStream<S, N>
where
S: Sink<Bytes, Error = E> + Unpin,
EthStreamError: From<E>,
N: NetworkPrimitives,
{
/// Same as [`Sink::start_send`] but accepts a [`EthBroadcastMessage`] instead.
pub fn start_send_broadcast(
&mut self,
item: EthBroadcastMessage,
item: EthBroadcastMessage<N>,
) -> Result<(), EthStreamError> {
self.inner.start_send_unpin(Bytes::from(alloy_rlp::encode(
ProtocolBroadcastMessage::from(item),
@ -252,12 +253,13 @@ where
}
}
impl<S, E> Stream for EthStream<S>
impl<S, E, N> Stream for EthStream<S, N>
where
S: Stream<Item = Result<BytesMut, E>> + Unpin,
EthStreamError: From<E>,
N: NetworkPrimitives,
{
type Item = Result<EthMessage, EthStreamError>;
type Item = Result<EthMessage<N>, EthStreamError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
@ -299,10 +301,11 @@ where
}
}
impl<S> Sink<EthMessage> for EthStream<S>
impl<S, N> Sink<EthMessage<N>> for EthStream<S, N>
where
S: CanDisconnect<Bytes> + Unpin,
EthStreamError: From<<S as Sink<Bytes>>::Error>,
N: NetworkPrimitives,
{
type Error = EthStreamError;
@ -310,7 +313,7 @@ where
self.project().inner.poll_ready(cx).map_err(Into::into)
}
fn start_send(self: Pin<&mut Self>, item: EthMessage) -> Result<(), Self::Error> {
fn start_send(self: Pin<&mut Self>, item: EthMessage<N>) -> Result<(), Self::Error> {
if matches!(item, EthMessage::Status(_)) {
// TODO: to disconnect here we would need to do something similar to P2PStream's
// start_disconnect, which would ideally be a part of the CanDisconnect trait, or at
@ -340,10 +343,11 @@ where
}
}
impl<S> CanDisconnect<EthMessage> for EthStream<S>
impl<S, N> CanDisconnect<EthMessage<N>> for EthStream<S, N>
where
S: CanDisconnect<Bytes> + Send,
EthStreamError: From<<S as Sink<Bytes>>::Error>,
N: NetworkPrimitives,
{
async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), EthStreamError> {
self.inner.disconnect(reason).await.map_err(Into::into)
@ -365,6 +369,7 @@ mod tests {
use futures::{SinkExt, StreamExt};
use reth_chainspec::NamedChain;
use reth_ecies::stream::ECIESStream;
use reth_eth_wire_types::EthNetworkPrimitives;
use reth_network_peers::pk2id;
use reth_primitives::{ForkFilter, Head};
use secp256k1::{SecretKey, SECP256K1};
@ -397,7 +402,7 @@ mod tests {
let (incoming, _) = listener.accept().await.unwrap();
let stream = PassthroughCodec::default().framed(incoming);
let (_, their_status) = UnauthedEthStream::new(stream)
.handshake(status_clone, fork_filter_clone)
.handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
.await
.unwrap();
@ -409,8 +414,10 @@ mod tests {
let sink = PassthroughCodec::default().framed(outgoing);
// try to connect
let (_, their_status) =
UnauthedEthStream::new(sink).handshake(status, fork_filter).await.unwrap();
let (_, their_status) = UnauthedEthStream::new(sink)
.handshake::<EthNetworkPrimitives>(status, fork_filter)
.await
.unwrap();
// their status is a clone of our status, these should be equal
assert_eq!(their_status, status);
@ -444,7 +451,7 @@ mod tests {
let (incoming, _) = listener.accept().await.unwrap();
let stream = PassthroughCodec::default().framed(incoming);
let (_, their_status) = UnauthedEthStream::new(stream)
.handshake(status_clone, fork_filter_clone)
.handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
.await
.unwrap();
@ -456,8 +463,10 @@ mod tests {
let sink = PassthroughCodec::default().framed(outgoing);
// try to connect
let (_, their_status) =
UnauthedEthStream::new(sink).handshake(status, fork_filter).await.unwrap();
let (_, their_status) = UnauthedEthStream::new(sink)
.handshake::<EthNetworkPrimitives>(status, fork_filter)
.await
.unwrap();
// their status is a clone of our status, these should be equal
assert_eq!(their_status, status);
@ -490,8 +499,9 @@ mod tests {
// roughly based off of the design of tokio::net::TcpListener
let (incoming, _) = listener.accept().await.unwrap();
let stream = PassthroughCodec::default().framed(incoming);
let handshake_res =
UnauthedEthStream::new(stream).handshake(status_clone, fork_filter_clone).await;
let handshake_res = UnauthedEthStream::new(stream)
.handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
.await;
// make sure the handshake fails due to td too high
assert!(matches!(
@ -506,7 +516,9 @@ mod tests {
let sink = PassthroughCodec::default().framed(outgoing);
// try to connect
let handshake_res = UnauthedEthStream::new(sink).handshake(status, fork_filter).await;
let handshake_res = UnauthedEthStream::new(sink)
.handshake::<EthNetworkPrimitives>(status, fork_filter)
.await;
// this handshake should also fail due to td too high
assert!(matches!(
@ -524,7 +536,7 @@ mod tests {
async fn can_write_and_read_cleartext() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let test_msg = EthMessage::NewBlockHashes(
let test_msg: EthMessage = EthMessage::NewBlockHashes(
vec![
BlockHashNumber { hash: B256::random(), number: 5 },
BlockHashNumber { hash: B256::random(), number: 6 },
@ -559,7 +571,7 @@ mod tests {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let server_key = SecretKey::new(&mut rand::thread_rng());
let test_msg = EthMessage::NewBlockHashes(
let test_msg: EthMessage = EthMessage::NewBlockHashes(
vec![
BlockHashNumber { hash: B256::random(), number: 5 },
BlockHashNumber { hash: B256::random(), number: 6 },
@ -601,7 +613,7 @@ mod tests {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let server_key = SecretKey::new(&mut rand::thread_rng());
let test_msg = EthMessage::NewBlockHashes(
let test_msg: EthMessage = EthMessage::NewBlockHashes(
vec![
BlockHashNumber { hash: B256::random(), number: 5 },
BlockHashNumber { hash: B256::random(), number: 6 },
@ -705,7 +717,7 @@ mod tests {
let (incoming, _) = listener.accept().await.unwrap();
let stream = PassthroughCodec::default().framed(incoming);
let (_, their_status) = UnauthedEthStream::new(stream)
.handshake(status_clone, fork_filter_clone)
.handshake::<EthNetworkPrimitives>(status_clone, fork_filter_clone)
.await
.unwrap();
@ -718,7 +730,11 @@ mod tests {
// try to connect
let handshake_result = UnauthedEthStream::new(sink)
.handshake_with_timeout(status, fork_filter, Duration::from_secs(1))
.handshake_with_timeout::<EthNetworkPrimitives>(
status,
fork_filter,
Duration::from_secs(1),
)
.await;
// Assert that a timeout error occurred

View File

@ -24,6 +24,7 @@ use crate::{
};
use bytes::{Bytes, BytesMut};
use futures::{Sink, SinkExt, Stream, StreamExt, TryStream, TryStreamExt};
use reth_eth_wire_types::NetworkPrimitives;
use reth_primitives::ForkFilter;
use tokio::sync::{mpsc, mpsc::UnboundedSender};
use tokio_stream::wrappers::UnboundedReceiverStream;
@ -204,11 +205,11 @@ impl<St> RlpxProtocolMultiplexer<St> {
/// Converts this multiplexer into a [`RlpxSatelliteStream`] with eth protocol as the given
/// primary protocol.
pub async fn into_eth_satellite_stream(
pub async fn into_eth_satellite_stream<N: NetworkPrimitives>(
self,
status: Status,
fork_filter: ForkFilter,
) -> Result<(RlpxSatelliteStream<St, EthStream<ProtocolProxy>>, Status), EthStreamError>
) -> Result<(RlpxSatelliteStream<St, EthStream<ProtocolProxy, N>>, Status), EthStreamError>
where
St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
{
@ -674,6 +675,7 @@ mod tests {
},
UnauthedP2PStream,
};
use reth_eth_wire_types::EthNetworkPrimitives;
use tokio::{net::TcpListener, sync::oneshot};
use tokio_util::codec::Decoder;
@ -693,7 +695,7 @@ mod tests {
UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
let (_eth_stream, _) = UnauthedEthStream::new(p2p_stream)
.handshake(other_status, other_fork_filter)
.handshake::<EthNetworkPrimitives>(other_status, other_fork_filter)
.await
.unwrap();
@ -708,7 +710,9 @@ mod tests {
.into_satellite_stream_with_handshake(
eth.capability().as_ref(),
move |proxy| async move {
UnauthedEthStream::new(proxy).handshake(status, fork_filter).await
UnauthedEthStream::new(proxy)
.handshake::<EthNetworkPrimitives>(status, fork_filter)
.await
},
)
.await
@ -731,7 +735,7 @@ mod tests {
let (conn, _) = UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn)
.into_eth_satellite_stream(other_status, other_fork_filter)
.into_eth_satellite_stream::<EthNetworkPrimitives>(other_status, other_fork_filter)
.await
.unwrap();
@ -762,7 +766,7 @@ mod tests {
let conn = connect_passthrough(local_addr, test_hello().0).await;
let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn)
.into_eth_satellite_stream(status, fork_filter)
.into_eth_satellite_stream::<EthNetworkPrimitives>(status, fork_filter)
.await
.unwrap();