feat: implement eth handshake disconnects (#1494)

This commit is contained in:
Dan Cline
2023-02-22 06:18:12 -05:00
committed by GitHub
parent 0fc9f67af8
commit c168ef4433
7 changed files with 137 additions and 41 deletions

1
Cargo.lock generated
View File

@ -4557,6 +4557,7 @@ name = "reth-eth-wire"
version = "0.1.0"
dependencies = [
"arbitrary",
"async-trait",
"bytes",
"ethers-core",
"futures",

View File

@ -15,12 +15,14 @@ serde = { version = "1", optional = true }
# reth
reth-codecs = { path = "../../storage/codecs" }
reth-primitives = { path = "../../primitives" }
reth-ecies = { path = "../ecies" }
reth-rlp = { path = "../../rlp", features = ["alloc", "derive", "std", "ethereum-types", "smol_str"] }
# used for Chain and builders
ethers-core = { git = "https://github.com/gakonst/ethers-rs", default-features = false }
tokio = { version = "1.21.2", features = ["full"] }
tokio-util = { version = "0.7.4", features = ["io", "codec"] }
futures = "0.3.24"
tokio-stream = "0.1.11"
pin-project = "1.0"
@ -28,6 +30,7 @@ tracing = "0.1.37"
snap = "1.0.5"
smol_str = "0.1"
metrics = "0.20.1"
async-trait = "0.1"
# arbitrary utils
arbitrary = { version = "1.1.7", features = ["derive"], optional = true }
@ -36,7 +39,6 @@ proptest-derive = { version = "0.3", optional = true }
[dev-dependencies]
reth-primitives = { path = "../../primitives", features = ["arbitrary"] }
reth-ecies = { path = "../ecies" }
reth-tracing = { path = "../../tracing" }
ethers-core = { git = "https://github.com/gakonst/ethers-rs", default-features = false }

View File

@ -1,10 +1,15 @@
//! Disconnect
use bytes::Bytes;
use futures::{Sink, SinkExt};
use reth_codecs::derive_arbitrary;
use reth_ecies::stream::ECIESStream;
use reth_primitives::bytes::{Buf, BufMut};
use reth_rlp::{Decodable, DecodeError, Encodable, Header};
use std::fmt::Display;
use thiserror::Error;
use tokio::io::AsyncWrite;
use tokio_util::codec::{Encoder, Framed};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
@ -143,6 +148,45 @@ impl Decodable for DisconnectReason {
}
}
/// This trait is meant to allow higher level protocols like `eth` to disconnect from a peer, using
/// lower-level disconnect functions (such as those that exist in the `p2p` protocol) if the
/// underlying stream supports it.
#[async_trait::async_trait]
pub trait CanDisconnect<T>: Sink<T> + Unpin + Sized {
/// Disconnects from the underlying stream, using a [`DisconnectReason`] as disconnect
/// information if the stream implements a protocol that can carry the additional disconnect
/// metadata.
async fn disconnect(
&mut self,
reason: DisconnectReason,
) -> Result<(), <Self as Sink<T>>::Error>;
}
// basic impls for things like Framed<TcpStream, etc>
#[async_trait::async_trait]
impl<T, I, U> CanDisconnect<I> for Framed<T, U>
where
T: AsyncWrite + Unpin + Send,
U: Encoder<I> + Send,
{
async fn disconnect(
&mut self,
_reason: DisconnectReason,
) -> Result<(), <Self as Sink<I>>::Error> {
self.close().await
}
}
#[async_trait::async_trait]
impl<S> CanDisconnect<Bytes> for ECIESStream<S>
where
S: AsyncWrite + Unpin + Send,
{
async fn disconnect(&mut self, _reason: DisconnectReason) -> Result<(), std::io::Error> {
self.close().await
}
}
#[cfg(test)]
mod tests {
use crate::{p2pstream::P2PMessage, DisconnectReason};

View File

@ -2,7 +2,7 @@ use crate::{
errors::{EthHandshakeError, EthStreamError},
message::{EthBroadcastMessage, ProtocolBroadcastMessage},
types::{EthMessage, ProtocolMessage, Status},
EthVersion,
CanDisconnect, DisconnectReason, EthVersion,
};
use futures::{ready, Sink, SinkExt, StreamExt};
use pin_project::pin_project;
@ -43,8 +43,8 @@ impl<S> UnauthedEthStream<S> {
impl<S, E> UnauthedEthStream<S>
where
S: Stream<Item = Result<BytesMut, E>> + Sink<Bytes, Error = E> + Unpin,
EthStreamError: From<E>,
S: Stream<Item = Result<BytesMut, E>> + CanDisconnect<Bytes> + Unpin,
EthStreamError: From<E> + From<<S as Sink<Bytes>>::Error>,
{
/// Consumes the [`UnauthedEthStream`] and returns an [`EthStream`] after the `Status`
/// handshake is completed successfully. This also returns the `Status` message sent by the
@ -67,13 +67,18 @@ where
self.inner.send(our_status_bytes).await?;
tracing::trace!("waiting for eth status from peer");
let their_msg = self
.inner
.next()
.await
.ok_or(EthStreamError::EthHandshakeError(EthHandshakeError::NoResponse))??;
let their_msg_res = self.inner.next().await;
let their_msg = match their_msg_res {
Some(msg) => msg,
None => {
self.inner.disconnect(DisconnectReason::DisconnectRequested).await?;
return Err(EthStreamError::EthHandshakeError(EthHandshakeError::NoResponse))
}
}?;
if their_msg.len() > MAX_MESSAGE_SIZE {
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
return Err(EthStreamError::MessageTooBig(their_msg.len()))
}
@ -82,6 +87,7 @@ where
Ok(m) => m,
Err(err) => {
tracing::debug!("decode error in eth handshake: msg={their_msg:x}");
self.inner.disconnect(DisconnectReason::DisconnectRequested).await?;
return Err(err)
}
};
@ -95,6 +101,7 @@ where
"validating incoming eth status from peer"
);
if status.genesis != resp.genesis {
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
return Err(EthHandshakeError::MismatchedGenesis {
expected: status.genesis,
got: resp.genesis,
@ -103,6 +110,7 @@ where
}
if status.version != resp.version {
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
return Err(EthHandshakeError::MismatchedProtocolVersion {
expected: status.version,
got: resp.version,
@ -111,6 +119,7 @@ where
}
if status.chain != resp.chain {
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
return Err(EthHandshakeError::MismatchedChain {
expected: status.chain,
got: resp.chain,
@ -121,6 +130,7 @@ where
// TD at mainnet block #7753254 is 76 bits. If it becomes 100 million times
// larger, it will still fit within 100 bits
if status.total_difficulty.bit_len() > 100 {
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
return Err(EthHandshakeError::TotalDifficultyBitLenTooLarge {
maximum: 100,
got: status.total_difficulty.bit_len(),
@ -128,7 +138,12 @@ where
.into())
}
fork_filter.validate(resp.forkid).map_err(EthHandshakeError::InvalidFork)?;
if let Err(err) =
fork_filter.validate(resp.forkid).map_err(EthHandshakeError::InvalidFork)
{
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
return Err(err.into())
}
// now we can create the `EthStream` because the peer has successfully completed
// the handshake
@ -136,9 +151,12 @@ where
Ok((stream, resp))
}
_ => Err(EthStreamError::EthHandshakeError(
EthHandshakeError::NonStatusMessageInHandshake,
)),
_ => {
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
Err(EthStreamError::EthHandshakeError(
EthHandshakeError::NonStatusMessageInHandshake,
))
}
}
}
}
@ -239,10 +257,10 @@ where
}
}
impl<S, E> Sink<EthMessage> for EthStream<S>
impl<S> Sink<EthMessage> for EthStream<S>
where
S: Sink<Bytes, Error = E> + Unpin,
EthStreamError: From<E>,
S: CanDisconnect<Bytes> + Unpin,
EthStreamError: From<<S as Sink<Bytes>>::Error>,
{
type Error = EthStreamError;
@ -252,6 +270,15 @@ where
fn start_send(self: Pin<&mut Self>, item: EthMessage) -> Result<(), Self::Error> {
if matches!(item, EthMessage::Status(_)) {
// TODO: to disconnect here we would need to do something similar to P2PStream's
// start_disconnect, which would ideally be a part of the CanDisconnect trait, or at
// least similar.
//
// Other parts of reth do not need traits like CanDisconnect because they work
// exclusively with EthStream<P2PStream<S>>, where the inner P2PStream is accessible,
// allowing for its start_disconnect method to be called.
//
// self.project().inner.start_disconnect(DisconnectReason::ProtocolBreach);
return Err(EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake))
}
@ -273,6 +300,17 @@ where
}
}
#[async_trait::async_trait]
impl<S> CanDisconnect<EthMessage> for EthStream<S>
where
S: CanDisconnect<Bytes> + Send,
EthStreamError: From<<S as Sink<Bytes>>::Error>,
{
async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), EthStreamError> {
self.inner.disconnect(reason).await.map_err(Into::into)
}
}
#[cfg(test)]
mod tests {
use super::UnauthedEthStream;

View File

@ -24,7 +24,7 @@ pub use tokio_util::codec::{
};
pub use crate::{
disconnect::DisconnectReason,
disconnect::{CanDisconnect, DisconnectReason},
ethstream::{EthStream, UnauthedEthStream, MAX_MESSAGE_SIZE},
hello::HelloMessage,
p2pstream::{P2PMessage, P2PMessageID, P2PStream, ProtocolVersion, UnauthedP2PStream},

View File

@ -1,6 +1,7 @@
#![allow(dead_code, unreachable_pub, missing_docs, unused_variables)]
use crate::{
capability::{Capability, SharedCapability},
disconnect::CanDisconnect,
errors::{P2PHandshakeError, P2PStreamError},
pinger::{Pinger, PingerEvent},
DisconnectReason, HelloMessage,
@ -72,25 +73,6 @@ impl<S> UnauthedP2PStream<S> {
}
}
impl<S> UnauthedP2PStream<S>
where
S: Sink<Bytes, Error = io::Error> + Unpin,
{
/// Send a disconnect message during the handshake. This is sent without snappy compression.
pub async fn send_disconnect(
&mut self,
reason: DisconnectReason,
) -> Result<(), P2PStreamError> {
let mut buf = BytesMut::new();
P2PMessage::Disconnect(reason).encode(&mut buf);
tracing::trace!(
%reason,
"Sending disconnect message during the handshake",
);
self.inner.send(buf.freeze()).await.map_err(P2PStreamError::Io)
}
}
impl<S> UnauthedP2PStream<S>
where
S: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
@ -180,6 +162,35 @@ where
}
}
impl<S> UnauthedP2PStream<S>
where
S: Sink<Bytes, Error = io::Error> + Unpin,
{
/// Send a disconnect message during the handshake. This is sent without snappy compression.
pub async fn send_disconnect(
&mut self,
reason: DisconnectReason,
) -> Result<(), P2PStreamError> {
let mut buf = BytesMut::new();
P2PMessage::Disconnect(reason).encode(&mut buf);
tracing::trace!(
%reason,
"Sending disconnect message during the handshake",
);
self.inner.send(buf.freeze()).await.map_err(P2PStreamError::Io)
}
}
#[async_trait::async_trait]
impl<S> CanDisconnect<Bytes> for P2PStream<S>
where
S: Sink<Bytes, Error = io::Error> + Unpin + Send + Sync,
{
async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
self.disconnect(reason).await
}
}
/// A P2PStream wraps over any `Stream` that yields bytes and makes it compatible with `p2p`
/// protocol messages.
#[pin_project]
@ -284,13 +295,13 @@ impl<S> P2PStream<S> {
impl<S> P2PStream<S>
where
S: Sink<Bytes, Error = io::Error> + Unpin,
S: Sink<Bytes, Error = io::Error> + Unpin + Send,
{
/// Disconnects the connection by sending a disconnect message.
///
/// This future resolves once the disconnect message has been sent and the stream has been
/// closed.
pub async fn disconnect(mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
pub async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
self.start_disconnect(reason)?;
self.close().await
}
@ -821,7 +832,7 @@ mod tests {
let (server_hello, _) = eth_hello();
let (p2p_stream, _) =
let (mut p2p_stream, _) =
UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
p2p_stream.disconnect(expected_disconnect).await.unwrap();

View File

@ -753,9 +753,9 @@ mod tests {
&self,
local_addr: SocketAddr,
f: F,
) -> Pin<Box<dyn Future<Output = ()> + Send + Sync>>
) -> Pin<Box<dyn Future<Output = ()> + Send>>
where
F: FnOnce(EthStream<P2PStream<ECIESStream<TcpStream>>>) -> O + Send + Sync + 'static,
F: FnOnce(EthStream<P2PStream<ECIESStream<TcpStream>>>) -> O + Send + 'static,
O: Future<Output = ()> + Send + Sync,
{
let status = self.status;