From b40546b9999f97d38e1a8a53c633df2ae32585cc Mon Sep 17 00:00:00 2001 From: Matthias Seitz Date: Mon, 5 Dec 2022 20:49:22 +0100 Subject: [PATCH] refactor(net): misc P2Pstream refactor (#336) * refactor(net): misc P2Pstream refactor * update note --- crates/net/eth-wire/src/ethstream.rs | 17 ++-- crates/net/eth-wire/src/p2pstream.rs | 119 ++++++++++++----------- crates/net/eth-wire/src/types/message.rs | 2 +- 3 files changed, 71 insertions(+), 67 deletions(-) diff --git a/crates/net/eth-wire/src/ethstream.rs b/crates/net/eth-wire/src/ethstream.rs index 604cb42c8..d9b8dce3f 100644 --- a/crates/net/eth-wire/src/ethstream.rs +++ b/crates/net/eth-wire/src/ethstream.rs @@ -40,7 +40,7 @@ impl UnauthedEthStream { impl UnauthedEthStream where - S: Stream> + Sink + Unpin, + S: Stream> + Sink + Unpin, EthStreamError: From, { /// Consumes the [`UnauthedEthStream`] and returns an [`EthStream`] after the `Status` @@ -54,7 +54,8 @@ where tracing::trace!("sending eth status ..."); // we need to encode and decode here on our own because we don't have an `EthStream` yet - let mut our_status_bytes = BytesMut::new(); + // The max length for a status with TTD is: + + let mut our_status_bytes = BytesMut::with_capacity(1 + 88); ProtocolMessage::from(EthMessage::Status(status)).encode(&mut our_status_bytes); let our_status_bytes = our_status_bytes.freeze(); self.inner.send(our_status_bytes).await?; @@ -277,7 +278,7 @@ mod tests { let handle = tokio::spawn(async move { // roughly based off of the design of tokio::net::TcpListener let (incoming, _) = listener.accept().await.unwrap(); - let stream = crate::PassthroughCodec::default().framed(incoming); + let stream = PassthroughCodec::default().framed(incoming); let (_, their_status) = UnauthedEthStream::new(stream) .handshake(status_clone, fork_filter_clone) .await @@ -288,7 +289,7 @@ mod tests { }); let outgoing = TcpStream::connect(local_addr).await.unwrap(); - let sink = crate::PassthroughCodec::default().framed(outgoing); + let sink = PassthroughCodec::default().framed(outgoing); // try to connect let (_, their_status) = @@ -307,8 +308,8 @@ mod tests { let local_addr = listener.local_addr().unwrap(); let test_msg = EthMessage::NewBlockHashes( vec![ - BlockHashNumber { hash: reth_primitives::H256::random(), number: 5 }, - BlockHashNumber { hash: reth_primitives::H256::random(), number: 6 }, + BlockHashNumber { hash: H256::random(), number: 5 }, + BlockHashNumber { hash: H256::random(), number: 6 }, ] .into(), ); @@ -342,8 +343,8 @@ mod tests { let server_key = SecretKey::new(&mut rand::thread_rng()); let test_msg = EthMessage::NewBlockHashes( vec![ - BlockHashNumber { hash: reth_primitives::H256::random(), number: 5 }, - BlockHashNumber { hash: reth_primitives::H256::random(), number: 6 }, + BlockHashNumber { hash: H256::random(), number: 5 }, + BlockHashNumber { hash: H256::random(), number: 6 }, ] .into(), ); diff --git a/crates/net/eth-wire/src/p2pstream.rs b/crates/net/eth-wire/src/p2pstream.rs index 64d6d2ebc..7502a1d71 100644 --- a/crates/net/eth-wire/src/p2pstream.rs +++ b/crates/net/eth-wire/src/p2pstream.rs @@ -67,7 +67,7 @@ impl UnauthedP2PStream { impl UnauthedP2PStream where - S: Stream> + Sink + Unpin, + S: Stream> + Sink + Unpin, { /// Consumes the `UnauthedP2PStream` and returns a `P2PStream` after the `Hello` handshake is /// completed successfully. This also returns the `Hello` message sent by the remote peer. @@ -218,8 +218,9 @@ impl P2PStream { pub fn start_disconnect(&mut self, reason: DisconnectReason) { // clear any buffered messages and queue in self.outgoing_messages.clear(); - let mut buf = BytesMut::new(); - P2PMessage::Disconnect(reason).encode(&mut buf); + let disconnect = P2PMessage::Disconnect(reason); + let mut buf = BytesMut::with_capacity(disconnect.length()); + disconnect.encode(&mut buf); self.outgoing_messages.push_back(buf.freeze()); } } @@ -241,14 +242,14 @@ where // protocol impl Stream for P2PStream where - S: Stream> + Sink + Unpin, + S: Stream> + Sink + Unpin, { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); + let this = self.get_mut(); - if *this.disconnecting { + if this.disconnecting { // if disconnecting, stop reading messages return Poll::Ready(None) } @@ -271,12 +272,7 @@ where } _ => { // encode the disconnect message - let mut disconnect_bytes = BytesMut::new(); - P2PMessage::Disconnect(DisconnectReason::PingTimeout).encode(&mut disconnect_bytes); - - // clear any buffered messages so that the next message will be disconnect - this.outgoing_messages.clear(); - this.outgoing_messages.push_back(disconnect_bytes.freeze()); + this.start_disconnect(DisconnectReason::PingTimeout); // End the stream after ping related error return Poll::Ready(None) @@ -285,7 +281,7 @@ where // we should loop here to ensure we don't return Poll::Pending if we have a message to // return behind any pings we need to respond to - while let Poll::Ready(res) = this.inner.as_mut().poll_next(cx) { + while let Poll::Ready(res) = this.inner.poll_next_unpin(cx) { let bytes = match res { Some(Ok(bytes)) => bytes, Some(Err(err)) => return Poll::Ready(Some(Err(err.into()))), @@ -293,57 +289,64 @@ where }; let id = *bytes.first().ok_or(P2PStreamError::EmptyProtocolMessage)?; - if id == P2PMessageID::Ping as u8 { - // TODO: do we need to decode the ping? - // we have received a ping, so we will send a pong - let mut pong_bytes = BytesMut::new(); - P2PMessage::Pong.encode(&mut pong_bytes); - // check if the buffer is full - if this.outgoing_messages.len() >= MAX_P2P_CAPACITY { - return Poll::Ready(Some(Err(P2PStreamError::SendBufferFull))) + match id { + _ if id == P2PMessageID::Ping as u8 => { + // we have received a ping, so we will send a pong + let pong = P2PMessage::Pong; + let mut pong_bytes = BytesMut::with_capacity(pong.length()); + pong.encode(&mut pong_bytes); + + // check if the buffer is full + if this.outgoing_messages.len() >= MAX_P2P_CAPACITY { + return Poll::Ready(Some(Err(P2PStreamError::SendBufferFull))) + } + // continue to the next message if there is one + this.outgoing_messages.push_back(pong_bytes.into()); } - this.outgoing_messages.push_back(pong_bytes.into()); - - // continue to the next message if there is one - } else if id == P2PMessageID::Disconnect as u8 { - let reason = DisconnectReason::decode(&mut &bytes[1..])?; - return Poll::Ready(Some(Err(P2PStreamError::Disconnected(reason)))) - } else if id == P2PMessageID::Hello as u8 { - // we have received a hello message outside of the handshake, so we will return an - // error - return Poll::Ready(Some(Err(P2PStreamError::HandshakeError( - P2PHandshakeError::HelloNotInHandshake, - )))) - } else if id == P2PMessageID::Pong as u8 { - // TODO: do we need to decode the pong? - // if we were waiting for a pong, this will reset the pinger state - this.pinger.on_pong()? - } else if id > MAX_P2P_MESSAGE_ID && id <= MAX_RESERVED_MESSAGE_ID { - // we have received an unknown reserved message - return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(id)))) - } else { - // first check that the compressed message length does not exceed the max message - // size - let decompressed_len = snap::raw::decompress_len(&bytes[1..])?; - if decompressed_len > MAX_PAYLOAD_SIZE { - return Poll::Ready(Some(Err(P2PStreamError::MessageTooBig { - message_size: decompressed_len, - max_size: MAX_PAYLOAD_SIZE, - }))) + _ if id == P2PMessageID::Disconnect as u8 => { + let reason = DisconnectReason::decode(&mut &bytes[1..])?; + return Poll::Ready(Some(Err(P2PStreamError::Disconnected(reason)))) } + _ if id == P2PMessageID::Hello as u8 => { + // we have received a hello message outside of the handshake, so we will return + // an error + return Poll::Ready(Some(Err(P2PStreamError::HandshakeError( + P2PHandshakeError::HelloNotInHandshake, + )))) + } + _ if id == P2PMessageID::Pong as u8 => { + // TODO: do we need to decode the pong? + // if we were waiting for a pong, this will reset the pinger state + this.pinger.on_pong()? + } + _ if id > MAX_P2P_MESSAGE_ID && id <= MAX_RESERVED_MESSAGE_ID => { + // we have received an unknown reserved message + return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(id)))) + } + _ => { + // first check that the compressed message length does not exceed the max + // message size + let decompressed_len = snap::raw::decompress_len(&bytes[1..])?; + if decompressed_len > MAX_PAYLOAD_SIZE { + return Poll::Ready(Some(Err(P2PStreamError::MessageTooBig { + message_size: decompressed_len, + max_size: MAX_PAYLOAD_SIZE, + }))) + } - // then decompress the message - let mut decompress_buf = BytesMut::zeroed(decompressed_len + 1); + // then decompress the message + let mut decompress_buf = BytesMut::zeroed(decompressed_len + 1); - // we have a subprotocol message that needs to be sent in the stream. - // first, switch the message id based on offset so the next layer can decode it - // without being aware of the p2p stream's state (shared capabilities / the message - // id offset) - decompress_buf[0] = bytes[0] - this.shared_capability.offset(); - this.decoder.decompress(&bytes[1..], &mut decompress_buf[1..])?; + // we have a subprotocol message that needs to be sent in the stream. + // first, switch the message id based on offset so the next layer can decode it + // without being aware of the p2p stream's state (shared capabilities / the + // message id offset) + decompress_buf[0] = bytes[0] - this.shared_capability.offset(); + this.decoder.decompress(&bytes[1..], &mut decompress_buf[1..])?; - return Poll::Ready(Some(Ok(decompress_buf))) + return Poll::Ready(Some(Ok(decompress_buf))) + } } } diff --git a/crates/net/eth-wire/src/types/message.rs b/crates/net/eth-wire/src/types/message.rs index a88b0fd56..55a42c8eb 100644 --- a/crates/net/eth-wire/src/types/message.rs +++ b/crates/net/eth-wire/src/types/message.rs @@ -289,7 +289,7 @@ pub enum EthMessageID { } impl Encodable for EthMessageID { - fn encode(&self, out: &mut dyn bytes::BufMut) { + fn encode(&self, out: &mut dyn BufMut) { out.put_u8(*self as u8); } fn length(&self) -> usize {