feat(net): add disconnect function (#223)

This commit is contained in:
Matthias Seitz
2022-11-18 16:26:49 +01:00
committed by GitHub
parent b4098b9558
commit ebd27b6025
2 changed files with 103 additions and 34 deletions

View File

@ -3,7 +3,7 @@ use std::io;
use reth_primitives::{Chain, ValidationError, H256};
use crate::capability::SharedCapabilityError;
use crate::{capability::SharedCapabilityError, DisconnectReason};
/// Errors when sending/receiving messages
#[derive(thiserror::Error, Debug)]
@ -69,9 +69,8 @@ pub enum P2PStreamError {
PingBeforeHandshake,
#[error("too many messages buffered before sending")]
SendBufferFull,
// TODO: remove / reconsider
#[error("disconnected")]
Disconnected,
Disconnected(DisconnectReason),
}
/// Errors when conducting a p2p handshake

View File

@ -157,6 +157,10 @@ pub struct P2PStream<S> {
/// Outgoing messages buffered for sending to the underlying stream.
outgoing_messages: VecDeque<Bytes>,
/// Whether this stream is currently in the process of disconnecting by sending a disconnect
/// message.
disconnecting: bool,
}
impl<S> P2PStream<S> {
@ -171,8 +175,39 @@ impl<S> P2PStream<S> {
pinger: Pinger::new(PING_INTERVAL, PING_TIMEOUT),
shared_capability: capability,
outgoing_messages: VecDeque::new(),
disconnecting: false,
}
}
/// Returns `true` if the connection is about to disconnect.
pub fn is_disconnecting(&self) -> bool {
self.disconnecting
}
/// Starts to gracefully disconnect the connection by sending a Disconnect message and stop
/// reading new messages.
///
/// Once disconnect process has started, the [`Stream`] will terminate immediately.
pub fn start_disconnect(&mut self, reason: DisconnectReason) {
// clear any buffered messages and queue in
self.outgoing_messages.clear();
let mut buf = BytesMut::new();
P2PMessage::Disconnect(reason).encode(&mut buf);
self.outgoing_messages.push_back(buf.freeze());
}
}
impl<S> P2PStream<S>
where
S: Sink<Bytes, Error = io::Error> + Unpin,
{
/// Disconnects the connection by sending a disconnect message.
///
/// This future resolves once the disconnect message has been sent.
pub async fn disconnect(mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
self.start_disconnect(reason);
self.close().await
}
}
// S must also be `Sink` because we need to be able to respond with ping messages to follow the
@ -186,6 +221,11 @@ where
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
if *this.disconnecting {
// if disconnecting, stop reading messages
return Poll::Ready(None)
}
// poll the pinger to determine if we should send a ping
match this.pinger.poll_ping(cx) {
Poll::Pending => {}
@ -241,8 +281,7 @@ where
// continue to the next message if there is one
} else if id == P2PMessageID::Disconnect as u8 {
let reason = DisconnectReason::decode(&mut &bytes[1..])?;
// TODO: do something with the reason
return Poll::Ready(Some(Err(P2PStreamError::Disconnected)))
return Poll::Ready(Some(Err(P2PStreamError::Disconnected(reason))))
} else if id == P2PMessageID::Hello as u8 {
// we have received a hello message outside of the handshake, so we will return an
// error
@ -477,16 +516,6 @@ impl P2PMessage {
}
impl Encodable for P2PMessage {
fn length(&self) -> usize {
let payload_len = match self {
P2PMessage::Hello(msg) => msg.length(),
P2PMessage::Disconnect(msg) => msg.length(),
P2PMessage::Ping => 3, // len([0x01, 0x00, 0x80]) = 3
P2PMessage::Pong => 3, // len([0x01, 0x00, 0x80]) = 3
};
payload_len + 1 // (1 for length of p2p message id)
}
fn encode(&self, out: &mut dyn bytes::BufMut) {
out.put_u8(self.message_id() as u8);
match self {
@ -504,6 +533,16 @@ impl Encodable for P2PMessage {
}
}
}
fn length(&self) -> usize {
let payload_len = match self {
P2PMessage::Hello(msg) => msg.length(),
P2PMessage::Disconnect(msg) => msg.length(),
P2PMessage::Ping => 3, // len([0x01, 0x00, 0x80]) = 3
P2PMessage::Pong => 3, // len([0x01, 0x00, 0x80]) = 3
};
payload_len + 1 // (1 for length of p2p message id)
}
}
impl Decodable for P2PMessage {
@ -746,15 +785,61 @@ impl Decodable for DisconnectReason {
#[cfg(test)]
mod tests {
use super::*;
use crate::EthVersion;
use reth_ecies::util::pk2id;
use reth_rlp::EMPTY_STRING_CODE;
use secp256k1::{SecretKey, SECP256K1};
use tokio::net::{TcpListener, TcpStream};
use tokio_util::codec::Decoder;
use crate::EthVersion;
/// Returns a testing `HelloMessage` and new secretkey
fn eth_hello() -> (HelloMessage, SecretKey) {
let server_key = SecretKey::new(&mut rand::thread_rng());
let hello = HelloMessage {
protocol_version: ProtocolVersion::V5,
client_version: "bitcoind/1.0.0".to_string(),
capabilities: vec![EthVersion::Eth67.into()],
port: 30303,
id: pk2id(&server_key.public_key(SECP256K1)),
};
(hello, server_key)
}
use super::*;
#[tokio::test]
async fn test_can_disconnect() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let expected_disconnect = DisconnectReason::UselessPeer;
let handle = tokio::spawn(async move {
// roughly based off of the design of tokio::net::TcpListener
let (incoming, _) = listener.accept().await.unwrap();
let stream = crate::PassthroughCodec::default().framed(incoming);
let (server_hello, _) = eth_hello();
let (p2p_stream, _) =
UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
p2p_stream.disconnect(expected_disconnect).await.unwrap();
});
let outgoing = TcpStream::connect(local_addr).await.unwrap();
let sink = crate::PassthroughCodec::default().framed(outgoing);
let (client_hello, _) = eth_hello();
let (mut p2p_stream, _) =
UnauthedP2PStream::new(sink).handshake(client_hello).await.unwrap();
let err = p2p_stream.next().await.unwrap().unwrap_err();
match err {
P2PStreamError::Disconnected(reason) => assert_eq!(reason, expected_disconnect),
_ => panic!("unexpected err"),
}
}
#[tokio::test]
async fn test_handshake_passthrough() {
@ -768,14 +853,7 @@ mod tests {
let (incoming, _) = listener.accept().await.unwrap();
let stream = crate::PassthroughCodec::default().framed(incoming);
let server_key = SecretKey::new(&mut rand::thread_rng());
let server_hello = HelloMessage {
protocol_version: ProtocolVersion::V5,
client_version: "bitcoind/1.0.0".to_string(),
capabilities: vec![EthVersion::Eth67.into()],
port: 30303,
id: pk2id(&server_key.public_key(SECP256K1)),
};
let (server_hello, _) = eth_hello();
let unauthed_stream = UnauthedP2PStream::new(stream);
let (p2p_stream, _) = unauthed_stream.handshake(server_hello).await.unwrap();
@ -790,18 +868,10 @@ mod tests {
);
});
let client_key = SecretKey::new(&mut rand::thread_rng());
let outgoing = TcpStream::connect(local_addr).await.unwrap();
let sink = crate::PassthroughCodec::default().framed(outgoing);
let client_hello = HelloMessage {
protocol_version: ProtocolVersion::V5,
client_version: "bitcoind/1.0.0".to_string(),
capabilities: vec![EthVersion::Eth67.into()],
port: 30303,
id: pk2id(&client_key.public_key(SECP256K1)),
};
let (client_hello, _) = eth_hello();
let unauthed_stream = UnauthedP2PStream::new(sink);
let (p2p_stream, _) = unauthed_stream.handshake(client_hello).await.unwrap();