From ebd27b60253385965a02ed71530ca757428252f2 Mon Sep 17 00:00:00 2001 From: Matthias Seitz Date: Fri, 18 Nov 2022 16:26:49 +0100 Subject: [PATCH] feat(net): add disconnect function (#223) --- crates/net/eth-wire/src/error.rs | 5 +- crates/net/eth-wire/src/p2pstream.rs | 132 ++++++++++++++++++++------- 2 files changed, 103 insertions(+), 34 deletions(-) diff --git a/crates/net/eth-wire/src/error.rs b/crates/net/eth-wire/src/error.rs index 7a8037fe4..967afee4b 100644 --- a/crates/net/eth-wire/src/error.rs +++ b/crates/net/eth-wire/src/error.rs @@ -3,7 +3,7 @@ use std::io; use reth_primitives::{Chain, ValidationError, H256}; -use crate::capability::SharedCapabilityError; +use crate::{capability::SharedCapabilityError, DisconnectReason}; /// Errors when sending/receiving messages #[derive(thiserror::Error, Debug)] @@ -69,9 +69,8 @@ pub enum P2PStreamError { PingBeforeHandshake, #[error("too many messages buffered before sending")] SendBufferFull, - // TODO: remove / reconsider #[error("disconnected")] - Disconnected, + Disconnected(DisconnectReason), } /// Errors when conducting a p2p handshake diff --git a/crates/net/eth-wire/src/p2pstream.rs b/crates/net/eth-wire/src/p2pstream.rs index 5c7614830..e0ac1d31d 100644 --- a/crates/net/eth-wire/src/p2pstream.rs +++ b/crates/net/eth-wire/src/p2pstream.rs @@ -157,6 +157,10 @@ pub struct P2PStream { /// Outgoing messages buffered for sending to the underlying stream. outgoing_messages: VecDeque, + + /// Whether this stream is currently in the process of disconnecting by sending a disconnect + /// message. + disconnecting: bool, } impl P2PStream { @@ -171,8 +175,39 @@ impl P2PStream { pinger: Pinger::new(PING_INTERVAL, PING_TIMEOUT), shared_capability: capability, outgoing_messages: VecDeque::new(), + disconnecting: false, } } + + /// Returns `true` if the connection is about to disconnect. + pub fn is_disconnecting(&self) -> bool { + self.disconnecting + } + + /// Starts to gracefully disconnect the connection by sending a Disconnect message and stop + /// reading new messages. + /// + /// Once disconnect process has started, the [`Stream`] will terminate immediately. + pub fn start_disconnect(&mut self, reason: DisconnectReason) { + // clear any buffered messages and queue in + self.outgoing_messages.clear(); + let mut buf = BytesMut::new(); + P2PMessage::Disconnect(reason).encode(&mut buf); + self.outgoing_messages.push_back(buf.freeze()); + } +} + +impl P2PStream +where + S: Sink + Unpin, +{ + /// Disconnects the connection by sending a disconnect message. + /// + /// This future resolves once the disconnect message has been sent. + pub async fn disconnect(mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> { + self.start_disconnect(reason); + self.close().await + } } // S must also be `Sink` because we need to be able to respond with ping messages to follow the @@ -186,6 +221,11 @@ where fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.project(); + if *this.disconnecting { + // if disconnecting, stop reading messages + return Poll::Ready(None) + } + // poll the pinger to determine if we should send a ping match this.pinger.poll_ping(cx) { Poll::Pending => {} @@ -241,8 +281,7 @@ where // continue to the next message if there is one } else if id == P2PMessageID::Disconnect as u8 { let reason = DisconnectReason::decode(&mut &bytes[1..])?; - // TODO: do something with the reason - return Poll::Ready(Some(Err(P2PStreamError::Disconnected))) + return Poll::Ready(Some(Err(P2PStreamError::Disconnected(reason)))) } else if id == P2PMessageID::Hello as u8 { // we have received a hello message outside of the handshake, so we will return an // error @@ -477,16 +516,6 @@ impl P2PMessage { } impl Encodable for P2PMessage { - fn length(&self) -> usize { - let payload_len = match self { - P2PMessage::Hello(msg) => msg.length(), - P2PMessage::Disconnect(msg) => msg.length(), - P2PMessage::Ping => 3, // len([0x01, 0x00, 0x80]) = 3 - P2PMessage::Pong => 3, // len([0x01, 0x00, 0x80]) = 3 - }; - payload_len + 1 // (1 for length of p2p message id) - } - fn encode(&self, out: &mut dyn bytes::BufMut) { out.put_u8(self.message_id() as u8); match self { @@ -504,6 +533,16 @@ impl Encodable for P2PMessage { } } } + + fn length(&self) -> usize { + let payload_len = match self { + P2PMessage::Hello(msg) => msg.length(), + P2PMessage::Disconnect(msg) => msg.length(), + P2PMessage::Ping => 3, // len([0x01, 0x00, 0x80]) = 3 + P2PMessage::Pong => 3, // len([0x01, 0x00, 0x80]) = 3 + }; + payload_len + 1 // (1 for length of p2p message id) + } } impl Decodable for P2PMessage { @@ -746,15 +785,61 @@ impl Decodable for DisconnectReason { #[cfg(test)] mod tests { + use super::*; + use crate::EthVersion; use reth_ecies::util::pk2id; use reth_rlp::EMPTY_STRING_CODE; use secp256k1::{SecretKey, SECP256K1}; use tokio::net::{TcpListener, TcpStream}; use tokio_util::codec::Decoder; - use crate::EthVersion; + /// Returns a testing `HelloMessage` and new secretkey + fn eth_hello() -> (HelloMessage, SecretKey) { + let server_key = SecretKey::new(&mut rand::thread_rng()); + let hello = HelloMessage { + protocol_version: ProtocolVersion::V5, + client_version: "bitcoind/1.0.0".to_string(), + capabilities: vec![EthVersion::Eth67.into()], + port: 30303, + id: pk2id(&server_key.public_key(SECP256K1)), + }; + (hello, server_key) + } - use super::*; + #[tokio::test] + async fn test_can_disconnect() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let local_addr = listener.local_addr().unwrap(); + + let expected_disconnect = DisconnectReason::UselessPeer; + + let handle = tokio::spawn(async move { + // roughly based off of the design of tokio::net::TcpListener + let (incoming, _) = listener.accept().await.unwrap(); + let stream = crate::PassthroughCodec::default().framed(incoming); + + let (server_hello, _) = eth_hello(); + + let (p2p_stream, _) = + UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap(); + + p2p_stream.disconnect(expected_disconnect).await.unwrap(); + }); + + let outgoing = TcpStream::connect(local_addr).await.unwrap(); + let sink = crate::PassthroughCodec::default().framed(outgoing); + + let (client_hello, _) = eth_hello(); + + let (mut p2p_stream, _) = + UnauthedP2PStream::new(sink).handshake(client_hello).await.unwrap(); + + let err = p2p_stream.next().await.unwrap().unwrap_err(); + match err { + P2PStreamError::Disconnected(reason) => assert_eq!(reason, expected_disconnect), + _ => panic!("unexpected err"), + } + } #[tokio::test] async fn test_handshake_passthrough() { @@ -768,14 +853,7 @@ mod tests { let (incoming, _) = listener.accept().await.unwrap(); let stream = crate::PassthroughCodec::default().framed(incoming); - let server_key = SecretKey::new(&mut rand::thread_rng()); - let server_hello = HelloMessage { - protocol_version: ProtocolVersion::V5, - client_version: "bitcoind/1.0.0".to_string(), - capabilities: vec![EthVersion::Eth67.into()], - port: 30303, - id: pk2id(&server_key.public_key(SECP256K1)), - }; + let (server_hello, _) = eth_hello(); let unauthed_stream = UnauthedP2PStream::new(stream); let (p2p_stream, _) = unauthed_stream.handshake(server_hello).await.unwrap(); @@ -790,18 +868,10 @@ mod tests { ); }); - let client_key = SecretKey::new(&mut rand::thread_rng()); - let outgoing = TcpStream::connect(local_addr).await.unwrap(); let sink = crate::PassthroughCodec::default().framed(outgoing); - let client_hello = HelloMessage { - protocol_version: ProtocolVersion::V5, - client_version: "bitcoind/1.0.0".to_string(), - capabilities: vec![EthVersion::Eth67.into()], - port: 30303, - id: pk2id(&client_key.public_key(SECP256K1)), - }; + let (client_hello, _) = eth_hello(); let unauthed_stream = UnauthedP2PStream::new(sink); let (p2p_stream, _) = unauthed_stream.handshake(client_hello).await.unwrap();