From 9b74d7d39db21a18e644013c813d65291c8e853f Mon Sep 17 00:00:00 2001 From: Dan Cline <6798349+Rjected@users.noreply.github.com> Date: Fri, 4 Nov 2022 03:36:40 -0400 Subject: [PATCH] feat(eth-wire): use UnauthedEthStream to create EthStream (#162) * Create UnauthedEthStream * remove authed flag from EthStream * encode and decode in status handshake * update test to assert the proper status is communicated * cargo fmt Co-authored-by: Matthias Seitz --- crates/net/eth-wire/src/ethstream.rs | 114 ++++++++++++++++++--------- crates/net/eth-wire/src/p2pstream.rs | 4 +- 2 files changed, 77 insertions(+), 41 deletions(-) diff --git a/crates/net/eth-wire/src/ethstream.rs b/crates/net/eth-wire/src/ethstream.rs index 325db58ff..743f2ecf3 100644 --- a/crates/net/eth-wire/src/ethstream.rs +++ b/crates/net/eth-wire/src/ethstream.rs @@ -16,62 +16,62 @@ use tokio_stream::Stream; // https://github.com/ethereum/go-ethereum/blob/30602163d5d8321fbc68afdcbbaf2362b2641bde/eth/protocols/eth/protocol.go#L50 const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024; -/// An `EthStream` wraps over any `Stream` that yields bytes and makes it -/// compatible with eth-networking protocol messages, which get RLP encoded/decoded. +/// An un-authenticated [`EthStream`]. This is consumed and returns a [`EthStream`] after the +/// `Status` handshake is completed. #[pin_project] -pub struct EthStream { +pub struct UnauthedEthStream { #[pin] inner: S, - /// Whether the `Status` handshake has been completed - authed: bool, } -impl EthStream { - /// Creates a new unauthed [`EthStream`] from a provided stream. You will need - /// to manually handshake a peer. +impl UnauthedEthStream { + /// Create a new `UnauthedEthStream` from a type `S` which implements `Stream` and `Sink`. pub fn new(inner: S) -> Self { - Self { inner, authed: false } + Self { inner } } } -impl EthStream +impl UnauthedEthStream where S: Stream> + Sink + Unpin, EthStreamError: From, { - /// Given an instantiated transport layer, it proceeds to return an [`EthStream`] - /// after performing a [`Status`] message handshake as specified in - pub async fn connect( - inner: S, - status: Status, - fork_filter: ForkFilter, - ) -> Result { - let mut this = Self::new(inner); - this.handshake(status, fork_filter).await?; - Ok(this) - } - - /// Performs a handshake with the connected peer over the transport stream. + /// Consumes the [`UnauthedEthStream`] and returns an [`EthStream`] after the `Status` + /// handshake is completed successfully. This also returns the `Status` message sent by the + /// remote peer. pub async fn handshake( - &mut self, + mut self, status: Status, fork_filter: ForkFilter, - ) -> Result<(), EthStreamError> { + ) -> Result<(EthStream, Status), EthStreamError> { tracing::trace!("sending eth status ..."); - self.send(EthMessage::Status(status)).await?; + + // we need to encode and decode here on our own because we don't have an `EthStream` yet + let mut our_status_bytes = BytesMut::new(); + ProtocolMessage::from(EthMessage::Status(status)).encode(&mut our_status_bytes); + let our_status_bytes = our_status_bytes.freeze(); + self.inner.send(our_status_bytes).await?; tracing::trace!("waiting for eth status from peer ..."); - let msg = self + let their_msg = self + .inner .next() .await .ok_or(EthStreamError::HandshakeError(HandshakeError::NoResponse))??; + if their_msg.len() > MAX_MESSAGE_SIZE { + return Err(EthStreamError::MessageTooBig(their_msg.len())) + } + + let msg = match ProtocolMessage::decode(&mut their_msg.as_ref()) { + Ok(m) => m, + Err(err) => return Err(err.into()), + }; + // TODO: Add any missing checks // https://github.com/ethereum/go-ethereum/blob/9244d5cd61f3ea5a7645fdf2a1a96d53421e412f/eth/protocols/eth/handshake.go#L87-L89 - match msg { + match msg.message { EthMessage::Status(resp) => { - self.authed = true; - if status.genesis != resp.genesis { return Err(HandshakeError::MismatchedGenesis { expected: status.genesis, @@ -96,13 +96,35 @@ where .into()) } - Ok(fork_filter.validate(resp.forkid).map_err(HandshakeError::InvalidFork)?) + fork_filter.validate(resp.forkid).map_err(HandshakeError::InvalidFork)?; + + // now we can create the `EthStream` because the peer has successfully completed + // the handshake + let stream = EthStream::new(self.inner); + + Ok((stream, resp)) } _ => Err(EthStreamError::HandshakeError(HandshakeError::NonStatusMessageInHandshake)), } } } +/// An `EthStream` wraps over any `Stream` that yields bytes and makes it +/// compatible with eth-networking protocol messages, which get RLP encoded/decoded. +#[pin_project] +pub struct EthStream { + #[pin] + inner: S, +} + +impl EthStream { + /// Creates a new unauthed [`EthStream`] from a provided stream. You will need + /// to manually handshake a peer. + pub fn new(inner: S) -> Self { + Self { inner } + } +} + impl Stream for EthStream where S: Stream> + Unpin, @@ -128,7 +150,7 @@ where Err(err) => return Poll::Ready(Some(Err(err.into()))), }; - if *this.authed && matches!(msg.message, EthMessage::Status(_)) { + if matches!(msg.message, EthMessage::Status(_)) { return Poll::Ready(Some(Err(EthStreamError::HandshakeError( HandshakeError::StatusNotInHandshake, )))) @@ -150,7 +172,7 @@ where } fn start_send(self: Pin<&mut Self>, item: EthMessage) -> Result<(), Self::Error> { - if self.authed && matches!(item, EthMessage::Status(_)) { + if matches!(item, EthMessage::Status(_)) { return Err(EthStreamError::HandshakeError(HandshakeError::StatusNotInHandshake)) } @@ -189,6 +211,8 @@ mod tests { use ethers_core::types::Chain; use reth_primitives::{H256, U256}; + use super::UnauthedEthStream; + #[tokio::test] async fn can_handshake() { let genesis = H256::random(); @@ -213,14 +237,24 @@ mod tests { // roughly based off of the design of tokio::net::TcpListener let (incoming, _) = listener.accept().await.unwrap(); let stream = crate::PassthroughCodec::default().framed(incoming); - let _ = EthStream::connect(stream, status_clone, fork_filter_clone).await.unwrap(); + let (_, their_status) = UnauthedEthStream::new(stream) + .handshake(status_clone, fork_filter_clone) + .await + .unwrap(); + + // just make sure it equals our status (our status is a clone of their status) + assert_eq!(their_status, status_clone); }); let outgoing = TcpStream::connect(local_addr).await.unwrap(); let sink = crate::PassthroughCodec::default().framed(outgoing); // try to connect - let _ = EthStream::connect(sink, status, fork_filter).await.unwrap(); + let (_, their_status) = + UnauthedEthStream::new(sink).handshake(status, fork_filter).await.unwrap(); + + // their status is a clone of our status, these should be equal + assert_eq!(their_status, status); // wait for it to finish handle.await.unwrap(); @@ -349,8 +383,10 @@ mod tests { let unauthed_stream = UnauthedP2PStream::new(stream); let (p2p_stream, _) = unauthed_stream.handshake(server_hello).await.unwrap(); - let mut eth_stream = EthStream::new(p2p_stream); - eth_stream.handshake(status_copy, fork_filter_clone).await.unwrap(); + let (mut eth_stream, _) = UnauthedEthStream::new(p2p_stream) + .handshake(status_copy, fork_filter_clone) + .await + .unwrap(); // use the stream to get the next message let message = eth_stream.next().await.unwrap().unwrap(); @@ -378,8 +414,8 @@ mod tests { let unauthed_stream = UnauthedP2PStream::new(sink); let (p2p_stream, _) = unauthed_stream.handshake(client_hello).await.unwrap(); - let mut client_stream = EthStream::new(p2p_stream); - client_stream.handshake(status, fork_filter).await.unwrap(); + let (mut client_stream, _) = + UnauthedEthStream::new(p2p_stream).handshake(status, fork_filter).await.unwrap(); client_stream.send(test_msg).await.unwrap(); diff --git a/crates/net/eth-wire/src/p2pstream.rs b/crates/net/eth-wire/src/p2pstream.rs index 614b426c2..a8901a136 100644 --- a/crates/net/eth-wire/src/p2pstream.rs +++ b/crates/net/eth-wire/src/p2pstream.rs @@ -60,7 +60,7 @@ pub struct UnauthedP2PStream { } impl UnauthedP2PStream { - /// Create a new `UnauthedP2PStream` from a `Stream` of bytes. + /// Create a new `UnauthedP2PStream` from a type `S` which implements `Stream` and `Sink`. pub fn new(inner: S) -> Self { Self { inner } } @@ -635,7 +635,7 @@ impl Display for DisconnectReason { DisconnectReason::SubprotocolSpecific => "Some other reason specific to a subprotocol", }; - write!(f, "{}", message) + write!(f, "{message}") } }