mirror of
https://github.com/hl-archive-node/nanoreth.git
synced 2025-12-06 10:59:55 +00:00
feat(net): add disconnect function (#223)
This commit is contained in:
@ -3,7 +3,7 @@ use std::io;
|
||||
|
||||
use reth_primitives::{Chain, ValidationError, H256};
|
||||
|
||||
use crate::capability::SharedCapabilityError;
|
||||
use crate::{capability::SharedCapabilityError, DisconnectReason};
|
||||
|
||||
/// Errors when sending/receiving messages
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
@ -69,9 +69,8 @@ pub enum P2PStreamError {
|
||||
PingBeforeHandshake,
|
||||
#[error("too many messages buffered before sending")]
|
||||
SendBufferFull,
|
||||
// TODO: remove / reconsider
|
||||
#[error("disconnected")]
|
||||
Disconnected,
|
||||
Disconnected(DisconnectReason),
|
||||
}
|
||||
|
||||
/// Errors when conducting a p2p handshake
|
||||
|
||||
@ -157,6 +157,10 @@ pub struct P2PStream<S> {
|
||||
|
||||
/// Outgoing messages buffered for sending to the underlying stream.
|
||||
outgoing_messages: VecDeque<Bytes>,
|
||||
|
||||
/// Whether this stream is currently in the process of disconnecting by sending a disconnect
|
||||
/// message.
|
||||
disconnecting: bool,
|
||||
}
|
||||
|
||||
impl<S> P2PStream<S> {
|
||||
@ -171,8 +175,39 @@ impl<S> P2PStream<S> {
|
||||
pinger: Pinger::new(PING_INTERVAL, PING_TIMEOUT),
|
||||
shared_capability: capability,
|
||||
outgoing_messages: VecDeque::new(),
|
||||
disconnecting: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` if the connection is about to disconnect.
|
||||
pub fn is_disconnecting(&self) -> bool {
|
||||
self.disconnecting
|
||||
}
|
||||
|
||||
/// Starts to gracefully disconnect the connection by sending a Disconnect message and stop
|
||||
/// reading new messages.
|
||||
///
|
||||
/// Once disconnect process has started, the [`Stream`] will terminate immediately.
|
||||
pub fn start_disconnect(&mut self, reason: DisconnectReason) {
|
||||
// clear any buffered messages and queue in
|
||||
self.outgoing_messages.clear();
|
||||
let mut buf = BytesMut::new();
|
||||
P2PMessage::Disconnect(reason).encode(&mut buf);
|
||||
self.outgoing_messages.push_back(buf.freeze());
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> P2PStream<S>
|
||||
where
|
||||
S: Sink<Bytes, Error = io::Error> + Unpin,
|
||||
{
|
||||
/// Disconnects the connection by sending a disconnect message.
|
||||
///
|
||||
/// This future resolves once the disconnect message has been sent.
|
||||
pub async fn disconnect(mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
|
||||
self.start_disconnect(reason);
|
||||
self.close().await
|
||||
}
|
||||
}
|
||||
|
||||
// S must also be `Sink` because we need to be able to respond with ping messages to follow the
|
||||
@ -186,6 +221,11 @@ where
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let mut this = self.project();
|
||||
|
||||
if *this.disconnecting {
|
||||
// if disconnecting, stop reading messages
|
||||
return Poll::Ready(None)
|
||||
}
|
||||
|
||||
// poll the pinger to determine if we should send a ping
|
||||
match this.pinger.poll_ping(cx) {
|
||||
Poll::Pending => {}
|
||||
@ -241,8 +281,7 @@ where
|
||||
// continue to the next message if there is one
|
||||
} else if id == P2PMessageID::Disconnect as u8 {
|
||||
let reason = DisconnectReason::decode(&mut &bytes[1..])?;
|
||||
// TODO: do something with the reason
|
||||
return Poll::Ready(Some(Err(P2PStreamError::Disconnected)))
|
||||
return Poll::Ready(Some(Err(P2PStreamError::Disconnected(reason))))
|
||||
} else if id == P2PMessageID::Hello as u8 {
|
||||
// we have received a hello message outside of the handshake, so we will return an
|
||||
// error
|
||||
@ -477,16 +516,6 @@ impl P2PMessage {
|
||||
}
|
||||
|
||||
impl Encodable for P2PMessage {
|
||||
fn length(&self) -> usize {
|
||||
let payload_len = match self {
|
||||
P2PMessage::Hello(msg) => msg.length(),
|
||||
P2PMessage::Disconnect(msg) => msg.length(),
|
||||
P2PMessage::Ping => 3, // len([0x01, 0x00, 0x80]) = 3
|
||||
P2PMessage::Pong => 3, // len([0x01, 0x00, 0x80]) = 3
|
||||
};
|
||||
payload_len + 1 // (1 for length of p2p message id)
|
||||
}
|
||||
|
||||
fn encode(&self, out: &mut dyn bytes::BufMut) {
|
||||
out.put_u8(self.message_id() as u8);
|
||||
match self {
|
||||
@ -504,6 +533,16 @@ impl Encodable for P2PMessage {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn length(&self) -> usize {
|
||||
let payload_len = match self {
|
||||
P2PMessage::Hello(msg) => msg.length(),
|
||||
P2PMessage::Disconnect(msg) => msg.length(),
|
||||
P2PMessage::Ping => 3, // len([0x01, 0x00, 0x80]) = 3
|
||||
P2PMessage::Pong => 3, // len([0x01, 0x00, 0x80]) = 3
|
||||
};
|
||||
payload_len + 1 // (1 for length of p2p message id)
|
||||
}
|
||||
}
|
||||
|
||||
impl Decodable for P2PMessage {
|
||||
@ -746,15 +785,61 @@ impl Decodable for DisconnectReason {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::EthVersion;
|
||||
use reth_ecies::util::pk2id;
|
||||
use reth_rlp::EMPTY_STRING_CODE;
|
||||
use secp256k1::{SecretKey, SECP256K1};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio_util::codec::Decoder;
|
||||
|
||||
use crate::EthVersion;
|
||||
/// Returns a testing `HelloMessage` and new secretkey
|
||||
fn eth_hello() -> (HelloMessage, SecretKey) {
|
||||
let server_key = SecretKey::new(&mut rand::thread_rng());
|
||||
let hello = HelloMessage {
|
||||
protocol_version: ProtocolVersion::V5,
|
||||
client_version: "bitcoind/1.0.0".to_string(),
|
||||
capabilities: vec![EthVersion::Eth67.into()],
|
||||
port: 30303,
|
||||
id: pk2id(&server_key.public_key(SECP256K1)),
|
||||
};
|
||||
(hello, server_key)
|
||||
}
|
||||
|
||||
use super::*;
|
||||
#[tokio::test]
|
||||
async fn test_can_disconnect() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let local_addr = listener.local_addr().unwrap();
|
||||
|
||||
let expected_disconnect = DisconnectReason::UselessPeer;
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
// roughly based off of the design of tokio::net::TcpListener
|
||||
let (incoming, _) = listener.accept().await.unwrap();
|
||||
let stream = crate::PassthroughCodec::default().framed(incoming);
|
||||
|
||||
let (server_hello, _) = eth_hello();
|
||||
|
||||
let (p2p_stream, _) =
|
||||
UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
|
||||
|
||||
p2p_stream.disconnect(expected_disconnect).await.unwrap();
|
||||
});
|
||||
|
||||
let outgoing = TcpStream::connect(local_addr).await.unwrap();
|
||||
let sink = crate::PassthroughCodec::default().framed(outgoing);
|
||||
|
||||
let (client_hello, _) = eth_hello();
|
||||
|
||||
let (mut p2p_stream, _) =
|
||||
UnauthedP2PStream::new(sink).handshake(client_hello).await.unwrap();
|
||||
|
||||
let err = p2p_stream.next().await.unwrap().unwrap_err();
|
||||
match err {
|
||||
P2PStreamError::Disconnected(reason) => assert_eq!(reason, expected_disconnect),
|
||||
_ => panic!("unexpected err"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_handshake_passthrough() {
|
||||
@ -768,14 +853,7 @@ mod tests {
|
||||
let (incoming, _) = listener.accept().await.unwrap();
|
||||
let stream = crate::PassthroughCodec::default().framed(incoming);
|
||||
|
||||
let server_key = SecretKey::new(&mut rand::thread_rng());
|
||||
let server_hello = HelloMessage {
|
||||
protocol_version: ProtocolVersion::V5,
|
||||
client_version: "bitcoind/1.0.0".to_string(),
|
||||
capabilities: vec![EthVersion::Eth67.into()],
|
||||
port: 30303,
|
||||
id: pk2id(&server_key.public_key(SECP256K1)),
|
||||
};
|
||||
let (server_hello, _) = eth_hello();
|
||||
|
||||
let unauthed_stream = UnauthedP2PStream::new(stream);
|
||||
let (p2p_stream, _) = unauthed_stream.handshake(server_hello).await.unwrap();
|
||||
@ -790,18 +868,10 @@ mod tests {
|
||||
);
|
||||
});
|
||||
|
||||
let client_key = SecretKey::new(&mut rand::thread_rng());
|
||||
|
||||
let outgoing = TcpStream::connect(local_addr).await.unwrap();
|
||||
let sink = crate::PassthroughCodec::default().framed(outgoing);
|
||||
|
||||
let client_hello = HelloMessage {
|
||||
protocol_version: ProtocolVersion::V5,
|
||||
client_version: "bitcoind/1.0.0".to_string(),
|
||||
capabilities: vec![EthVersion::Eth67.into()],
|
||||
port: 30303,
|
||||
id: pk2id(&client_key.public_key(SECP256K1)),
|
||||
};
|
||||
let (client_hello, _) = eth_hello();
|
||||
|
||||
let unauthed_stream = UnauthedP2PStream::new(sink);
|
||||
let (p2p_stream, _) = unauthed_stream.handshake(client_hello).await.unwrap();
|
||||
|
||||
Reference in New Issue
Block a user