mirror of
https://github.com/hl-archive-node/nanoreth.git
synced 2025-12-06 10:59:55 +00:00
feat: implement eth handshake disconnects (#1494)
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -4557,6 +4557,7 @@ name = "reth-eth-wire"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"arbitrary",
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"ethers-core",
|
||||
"futures",
|
||||
|
||||
@ -15,12 +15,14 @@ serde = { version = "1", optional = true }
|
||||
# reth
|
||||
reth-codecs = { path = "../../storage/codecs" }
|
||||
reth-primitives = { path = "../../primitives" }
|
||||
reth-ecies = { path = "../ecies" }
|
||||
reth-rlp = { path = "../../rlp", features = ["alloc", "derive", "std", "ethereum-types", "smol_str"] }
|
||||
|
||||
# used for Chain and builders
|
||||
ethers-core = { git = "https://github.com/gakonst/ethers-rs", default-features = false }
|
||||
|
||||
tokio = { version = "1.21.2", features = ["full"] }
|
||||
tokio-util = { version = "0.7.4", features = ["io", "codec"] }
|
||||
futures = "0.3.24"
|
||||
tokio-stream = "0.1.11"
|
||||
pin-project = "1.0"
|
||||
@ -28,6 +30,7 @@ tracing = "0.1.37"
|
||||
snap = "1.0.5"
|
||||
smol_str = "0.1"
|
||||
metrics = "0.20.1"
|
||||
async-trait = "0.1"
|
||||
|
||||
# arbitrary utils
|
||||
arbitrary = { version = "1.1.7", features = ["derive"], optional = true }
|
||||
@ -36,7 +39,6 @@ proptest-derive = { version = "0.3", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
reth-primitives = { path = "../../primitives", features = ["arbitrary"] }
|
||||
reth-ecies = { path = "../ecies" }
|
||||
reth-tracing = { path = "../../tracing" }
|
||||
ethers-core = { git = "https://github.com/gakonst/ethers-rs", default-features = false }
|
||||
|
||||
|
||||
@ -1,10 +1,15 @@
|
||||
//! Disconnect
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures::{Sink, SinkExt};
|
||||
use reth_codecs::derive_arbitrary;
|
||||
use reth_ecies::stream::ECIESStream;
|
||||
use reth_primitives::bytes::{Buf, BufMut};
|
||||
use reth_rlp::{Decodable, DecodeError, Encodable, Header};
|
||||
use std::fmt::Display;
|
||||
use thiserror::Error;
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio_util::codec::{Encoder, Framed};
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
@ -143,6 +148,45 @@ impl Decodable for DisconnectReason {
|
||||
}
|
||||
}
|
||||
|
||||
/// This trait is meant to allow higher level protocols like `eth` to disconnect from a peer, using
|
||||
/// lower-level disconnect functions (such as those that exist in the `p2p` protocol) if the
|
||||
/// underlying stream supports it.
|
||||
#[async_trait::async_trait]
|
||||
pub trait CanDisconnect<T>: Sink<T> + Unpin + Sized {
|
||||
/// Disconnects from the underlying stream, using a [`DisconnectReason`] as disconnect
|
||||
/// information if the stream implements a protocol that can carry the additional disconnect
|
||||
/// metadata.
|
||||
async fn disconnect(
|
||||
&mut self,
|
||||
reason: DisconnectReason,
|
||||
) -> Result<(), <Self as Sink<T>>::Error>;
|
||||
}
|
||||
|
||||
// basic impls for things like Framed<TcpStream, etc>
|
||||
#[async_trait::async_trait]
|
||||
impl<T, I, U> CanDisconnect<I> for Framed<T, U>
|
||||
where
|
||||
T: AsyncWrite + Unpin + Send,
|
||||
U: Encoder<I> + Send,
|
||||
{
|
||||
async fn disconnect(
|
||||
&mut self,
|
||||
_reason: DisconnectReason,
|
||||
) -> Result<(), <Self as Sink<I>>::Error> {
|
||||
self.close().await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl<S> CanDisconnect<Bytes> for ECIESStream<S>
|
||||
where
|
||||
S: AsyncWrite + Unpin + Send,
|
||||
{
|
||||
async fn disconnect(&mut self, _reason: DisconnectReason) -> Result<(), std::io::Error> {
|
||||
self.close().await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{p2pstream::P2PMessage, DisconnectReason};
|
||||
|
||||
@ -2,7 +2,7 @@ use crate::{
|
||||
errors::{EthHandshakeError, EthStreamError},
|
||||
message::{EthBroadcastMessage, ProtocolBroadcastMessage},
|
||||
types::{EthMessage, ProtocolMessage, Status},
|
||||
EthVersion,
|
||||
CanDisconnect, DisconnectReason, EthVersion,
|
||||
};
|
||||
use futures::{ready, Sink, SinkExt, StreamExt};
|
||||
use pin_project::pin_project;
|
||||
@ -43,8 +43,8 @@ impl<S> UnauthedEthStream<S> {
|
||||
|
||||
impl<S, E> UnauthedEthStream<S>
|
||||
where
|
||||
S: Stream<Item = Result<BytesMut, E>> + Sink<Bytes, Error = E> + Unpin,
|
||||
EthStreamError: From<E>,
|
||||
S: Stream<Item = Result<BytesMut, E>> + CanDisconnect<Bytes> + Unpin,
|
||||
EthStreamError: From<E> + From<<S as Sink<Bytes>>::Error>,
|
||||
{
|
||||
/// Consumes the [`UnauthedEthStream`] and returns an [`EthStream`] after the `Status`
|
||||
/// handshake is completed successfully. This also returns the `Status` message sent by the
|
||||
@ -67,13 +67,18 @@ where
|
||||
self.inner.send(our_status_bytes).await?;
|
||||
|
||||
tracing::trace!("waiting for eth status from peer");
|
||||
let their_msg = self
|
||||
.inner
|
||||
.next()
|
||||
.await
|
||||
.ok_or(EthStreamError::EthHandshakeError(EthHandshakeError::NoResponse))??;
|
||||
let their_msg_res = self.inner.next().await;
|
||||
|
||||
let their_msg = match their_msg_res {
|
||||
Some(msg) => msg,
|
||||
None => {
|
||||
self.inner.disconnect(DisconnectReason::DisconnectRequested).await?;
|
||||
return Err(EthStreamError::EthHandshakeError(EthHandshakeError::NoResponse))
|
||||
}
|
||||
}?;
|
||||
|
||||
if their_msg.len() > MAX_MESSAGE_SIZE {
|
||||
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
|
||||
return Err(EthStreamError::MessageTooBig(their_msg.len()))
|
||||
}
|
||||
|
||||
@ -82,6 +87,7 @@ where
|
||||
Ok(m) => m,
|
||||
Err(err) => {
|
||||
tracing::debug!("decode error in eth handshake: msg={their_msg:x}");
|
||||
self.inner.disconnect(DisconnectReason::DisconnectRequested).await?;
|
||||
return Err(err)
|
||||
}
|
||||
};
|
||||
@ -95,6 +101,7 @@ where
|
||||
"validating incoming eth status from peer"
|
||||
);
|
||||
if status.genesis != resp.genesis {
|
||||
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
|
||||
return Err(EthHandshakeError::MismatchedGenesis {
|
||||
expected: status.genesis,
|
||||
got: resp.genesis,
|
||||
@ -103,6 +110,7 @@ where
|
||||
}
|
||||
|
||||
if status.version != resp.version {
|
||||
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
|
||||
return Err(EthHandshakeError::MismatchedProtocolVersion {
|
||||
expected: status.version,
|
||||
got: resp.version,
|
||||
@ -111,6 +119,7 @@ where
|
||||
}
|
||||
|
||||
if status.chain != resp.chain {
|
||||
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
|
||||
return Err(EthHandshakeError::MismatchedChain {
|
||||
expected: status.chain,
|
||||
got: resp.chain,
|
||||
@ -121,6 +130,7 @@ where
|
||||
// TD at mainnet block #7753254 is 76 bits. If it becomes 100 million times
|
||||
// larger, it will still fit within 100 bits
|
||||
if status.total_difficulty.bit_len() > 100 {
|
||||
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
|
||||
return Err(EthHandshakeError::TotalDifficultyBitLenTooLarge {
|
||||
maximum: 100,
|
||||
got: status.total_difficulty.bit_len(),
|
||||
@ -128,7 +138,12 @@ where
|
||||
.into())
|
||||
}
|
||||
|
||||
fork_filter.validate(resp.forkid).map_err(EthHandshakeError::InvalidFork)?;
|
||||
if let Err(err) =
|
||||
fork_filter.validate(resp.forkid).map_err(EthHandshakeError::InvalidFork)
|
||||
{
|
||||
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
|
||||
return Err(err.into())
|
||||
}
|
||||
|
||||
// now we can create the `EthStream` because the peer has successfully completed
|
||||
// the handshake
|
||||
@ -136,9 +151,12 @@ where
|
||||
|
||||
Ok((stream, resp))
|
||||
}
|
||||
_ => Err(EthStreamError::EthHandshakeError(
|
||||
EthHandshakeError::NonStatusMessageInHandshake,
|
||||
)),
|
||||
_ => {
|
||||
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
|
||||
Err(EthStreamError::EthHandshakeError(
|
||||
EthHandshakeError::NonStatusMessageInHandshake,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -239,10 +257,10 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, E> Sink<EthMessage> for EthStream<S>
|
||||
impl<S> Sink<EthMessage> for EthStream<S>
|
||||
where
|
||||
S: Sink<Bytes, Error = E> + Unpin,
|
||||
EthStreamError: From<E>,
|
||||
S: CanDisconnect<Bytes> + Unpin,
|
||||
EthStreamError: From<<S as Sink<Bytes>>::Error>,
|
||||
{
|
||||
type Error = EthStreamError;
|
||||
|
||||
@ -252,6 +270,15 @@ where
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: EthMessage) -> Result<(), Self::Error> {
|
||||
if matches!(item, EthMessage::Status(_)) {
|
||||
// TODO: to disconnect here we would need to do something similar to P2PStream's
|
||||
// start_disconnect, which would ideally be a part of the CanDisconnect trait, or at
|
||||
// least similar.
|
||||
//
|
||||
// Other parts of reth do not need traits like CanDisconnect because they work
|
||||
// exclusively with EthStream<P2PStream<S>>, where the inner P2PStream is accessible,
|
||||
// allowing for its start_disconnect method to be called.
|
||||
//
|
||||
// self.project().inner.start_disconnect(DisconnectReason::ProtocolBreach);
|
||||
return Err(EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake))
|
||||
}
|
||||
|
||||
@ -273,6 +300,17 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl<S> CanDisconnect<EthMessage> for EthStream<S>
|
||||
where
|
||||
S: CanDisconnect<Bytes> + Send,
|
||||
EthStreamError: From<<S as Sink<Bytes>>::Error>,
|
||||
{
|
||||
async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), EthStreamError> {
|
||||
self.inner.disconnect(reason).await.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::UnauthedEthStream;
|
||||
|
||||
@ -24,7 +24,7 @@ pub use tokio_util::codec::{
|
||||
};
|
||||
|
||||
pub use crate::{
|
||||
disconnect::DisconnectReason,
|
||||
disconnect::{CanDisconnect, DisconnectReason},
|
||||
ethstream::{EthStream, UnauthedEthStream, MAX_MESSAGE_SIZE},
|
||||
hello::HelloMessage,
|
||||
p2pstream::{P2PMessage, P2PMessageID, P2PStream, ProtocolVersion, UnauthedP2PStream},
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#![allow(dead_code, unreachable_pub, missing_docs, unused_variables)]
|
||||
use crate::{
|
||||
capability::{Capability, SharedCapability},
|
||||
disconnect::CanDisconnect,
|
||||
errors::{P2PHandshakeError, P2PStreamError},
|
||||
pinger::{Pinger, PingerEvent},
|
||||
DisconnectReason, HelloMessage,
|
||||
@ -72,25 +73,6 @@ impl<S> UnauthedP2PStream<S> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> UnauthedP2PStream<S>
|
||||
where
|
||||
S: Sink<Bytes, Error = io::Error> + Unpin,
|
||||
{
|
||||
/// Send a disconnect message during the handshake. This is sent without snappy compression.
|
||||
pub async fn send_disconnect(
|
||||
&mut self,
|
||||
reason: DisconnectReason,
|
||||
) -> Result<(), P2PStreamError> {
|
||||
let mut buf = BytesMut::new();
|
||||
P2PMessage::Disconnect(reason).encode(&mut buf);
|
||||
tracing::trace!(
|
||||
%reason,
|
||||
"Sending disconnect message during the handshake",
|
||||
);
|
||||
self.inner.send(buf.freeze()).await.map_err(P2PStreamError::Io)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> UnauthedP2PStream<S>
|
||||
where
|
||||
S: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
|
||||
@ -180,6 +162,35 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> UnauthedP2PStream<S>
|
||||
where
|
||||
S: Sink<Bytes, Error = io::Error> + Unpin,
|
||||
{
|
||||
/// Send a disconnect message during the handshake. This is sent without snappy compression.
|
||||
pub async fn send_disconnect(
|
||||
&mut self,
|
||||
reason: DisconnectReason,
|
||||
) -> Result<(), P2PStreamError> {
|
||||
let mut buf = BytesMut::new();
|
||||
P2PMessage::Disconnect(reason).encode(&mut buf);
|
||||
tracing::trace!(
|
||||
%reason,
|
||||
"Sending disconnect message during the handshake",
|
||||
);
|
||||
self.inner.send(buf.freeze()).await.map_err(P2PStreamError::Io)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl<S> CanDisconnect<Bytes> for P2PStream<S>
|
||||
where
|
||||
S: Sink<Bytes, Error = io::Error> + Unpin + Send + Sync,
|
||||
{
|
||||
async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
|
||||
self.disconnect(reason).await
|
||||
}
|
||||
}
|
||||
|
||||
/// A P2PStream wraps over any `Stream` that yields bytes and makes it compatible with `p2p`
|
||||
/// protocol messages.
|
||||
#[pin_project]
|
||||
@ -284,13 +295,13 @@ impl<S> P2PStream<S> {
|
||||
|
||||
impl<S> P2PStream<S>
|
||||
where
|
||||
S: Sink<Bytes, Error = io::Error> + Unpin,
|
||||
S: Sink<Bytes, Error = io::Error> + Unpin + Send,
|
||||
{
|
||||
/// Disconnects the connection by sending a disconnect message.
|
||||
///
|
||||
/// This future resolves once the disconnect message has been sent and the stream has been
|
||||
/// closed.
|
||||
pub async fn disconnect(mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
|
||||
pub async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
|
||||
self.start_disconnect(reason)?;
|
||||
self.close().await
|
||||
}
|
||||
@ -821,7 +832,7 @@ mod tests {
|
||||
|
||||
let (server_hello, _) = eth_hello();
|
||||
|
||||
let (p2p_stream, _) =
|
||||
let (mut p2p_stream, _) =
|
||||
UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
|
||||
|
||||
p2p_stream.disconnect(expected_disconnect).await.unwrap();
|
||||
|
||||
@ -753,9 +753,9 @@ mod tests {
|
||||
&self,
|
||||
local_addr: SocketAddr,
|
||||
f: F,
|
||||
) -> Pin<Box<dyn Future<Output = ()> + Send + Sync>>
|
||||
) -> Pin<Box<dyn Future<Output = ()> + Send>>
|
||||
where
|
||||
F: FnOnce(EthStream<P2PStream<ECIESStream<TcpStream>>>) -> O + Send + Sync + 'static,
|
||||
F: FnOnce(EthStream<P2PStream<ECIESStream<TcpStream>>>) -> O + Send + 'static,
|
||||
O: Future<Output = ()> + Send + Sync,
|
||||
{
|
||||
let status = self.status;
|
||||
|
||||
Reference in New Issue
Block a user