mirror of
https://github.com/hl-archive-node/nanoreth.git
synced 2025-12-06 10:59:55 +00:00
feat(eth-wire): Implement p2p stream (#114)
* wip: p2pstream * add comment for handshake timeout * temp allow some lint violations * ignore unused_variables * start of ping task * TODO: make it compile * TODO: test ping/pong/disconnect state machine * TODO: send subprotocol messages to stream * TODO: encode non-hello p2p messages as snappy encoding without using an encoder * TODO: create test comparing encoder to hand-written snappy encoding for ping, pong, disconnect messages * implement message handling in stream poll method * restricts S to be Stream+Sink for P2PStream to implement Stream * start of a poll-based refactor * impl Stream and Sink for P2PStream * add tests * TODO: make stream/sink types compatible * TODO: handshake message ids * TODO: inner poll fn * TODO: pinger interval * TODO: ethstream test * TODO: passthrough test * create pingers and test * impl working timeout interval pinger * it should be much easier to poll for pings and detect timeouts now * use pinger in p2p stream * change item produced by stream so it's compatible with EthStream * add note on pros/cons * shorten message sends in stream * improve errors and remove redundant methods * fix handshake * debugging printlns * fix encoding and decoding * switch to snappy formatting for non-hello p2p messages * cargo fmt * perform handshake in ethstream over p2pstream test * remove check for `Hello` messages outside of the handshake because `P2PStream`s should assume messages sent in the sink are subprotocol messages, not `p2p` messages. * impl From<EthVersion> for CapabilityMessage * remove printlns * add total_message method to EthVersion * decode Hello in handshake * disallow protocol versions other than v5 * Integrate snappy and implement message size limits * document constants, move stream definition * fix missing hello message id * implement shared capabilities * todo: test shared capabilities * todo: determine how / when / why to support multiple capabilities * removes obsolete authed and offset fields * add sink api TODOs * remove les * should add protocols when necessary rather than name unsupported protocols * fix snappy compression length * add test for p2pstream over a passthrough codec which tests that peers agree on a single shared capability * fix some clippy lints
This commit is contained in:
@ -31,6 +31,8 @@ pub fn pk2id(pk: &PublicKey) -> PeerId {
|
||||
/// Converts a [PeerId] to a [secp256k1::PublicKey] by prepending the [PeerId] bytes with the
|
||||
/// SECP256K1_TAG_PUBKEY_UNCOMPRESSED tag.
|
||||
pub(crate) fn id2pk(id: PeerId) -> Result<PublicKey, secp256k1::Error> {
|
||||
// NOTE: H512 is used as a PeerId not because it represents a hash, but because 512 bits is
|
||||
// enough to represent an uncompressed public key.
|
||||
let mut s = [0_u8; 65];
|
||||
// SECP256K1_TAG_PUBKEY_UNCOMPRESSED = 0x04
|
||||
// see: https://github.com/bitcoin-core/secp256k1/blob/master/include/secp256k1.h#L211
|
||||
|
||||
@ -26,7 +26,9 @@ tokio-stream = "0.1.11"
|
||||
secp256k1 = { version = "0.24.0", features = ["global-context", "rand-std", "recovery"] }
|
||||
tokio-util = { version = "0.7.4", features = ["io"] }
|
||||
pin-project = "1.0"
|
||||
pin-utils = "0.1.0"
|
||||
tracing = "0.1.37"
|
||||
snap = "1.0.5"
|
||||
|
||||
[dev-dependencies]
|
||||
hex-literal = "0.3"
|
||||
|
||||
87
crates/net/eth-wire/src/capability.rs
Normal file
87
crates/net/eth-wire/src/capability.rs
Normal file
@ -0,0 +1,87 @@
|
||||
use crate::{version::ParseVersionError, EthVersion};
|
||||
|
||||
/// This represents a shared capability, its version, and its offset.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum SharedCapability {
|
||||
/// The `eth` capability.
|
||||
Eth { version: EthVersion, offset: u8 },
|
||||
|
||||
/// An unknown capability.
|
||||
UnknownCapability { name: String, version: u8, offset: u8 },
|
||||
}
|
||||
|
||||
impl SharedCapability {
|
||||
/// Creates a new [`SharedCapability`] based on the given name, offset, and version.
|
||||
pub(crate) fn new(name: &str, version: u8, offset: u8) -> Result<Self, SharedCapabilityError> {
|
||||
match name {
|
||||
"eth" => Ok(Self::Eth { version: EthVersion::try_from(version)?, offset }),
|
||||
_ => Ok(Self::UnknownCapability { name: name.to_string(), version, offset }),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the name of the capability.
|
||||
pub(crate) fn name(&self) -> &str {
|
||||
match self {
|
||||
SharedCapability::Eth { .. } => "eth",
|
||||
SharedCapability::UnknownCapability { name, .. } => name,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the version of the capability.
|
||||
pub(crate) fn version(&self) -> u8 {
|
||||
match self {
|
||||
SharedCapability::Eth { version, .. } => *version as u8,
|
||||
SharedCapability::UnknownCapability { version, .. } => *version,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the message ID offset of the current capability.
|
||||
pub(crate) fn offset(&self) -> u8 {
|
||||
match self {
|
||||
SharedCapability::Eth { offset, .. } => *offset,
|
||||
SharedCapability::UnknownCapability { offset, .. } => *offset,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of protocol messages supported by this capability.
|
||||
pub(crate) fn num_messages(&self) -> Result<u8, SharedCapabilityError> {
|
||||
match self {
|
||||
SharedCapability::Eth { version, .. } => Ok(version.total_messages()),
|
||||
_ => Err(SharedCapabilityError::UnknownCapability),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An error that may occur while creating a [`SharedCapability`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum SharedCapabilityError {
|
||||
/// Unsupported `eth` version.
|
||||
#[error(transparent)]
|
||||
UnsupportedVersion(#[from] ParseVersionError),
|
||||
/// Cannot determine the number of messages for unknown capabilities.
|
||||
#[error("cannot determine the number of messages for unknown capabilities")]
|
||||
UnknownCapability,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn from_eth_67() {
|
||||
let capability = SharedCapability::new("eth", 67, 0).unwrap();
|
||||
|
||||
assert_eq!(capability.name(), "eth");
|
||||
assert_eq!(capability.version(), 67);
|
||||
assert_eq!(capability, SharedCapability::Eth { version: EthVersion::Eth67, offset: 0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_eth_66() {
|
||||
let capability = SharedCapability::new("eth", 66, 0).unwrap();
|
||||
|
||||
assert_eq!(capability.name(), "eth");
|
||||
assert_eq!(capability.version(), 66);
|
||||
assert_eq!(capability, SharedCapability::Eth { version: EthVersion::Eth66, offset: 0 });
|
||||
}
|
||||
}
|
||||
@ -3,7 +3,7 @@ use std::io;
|
||||
|
||||
use reth_primitives::{Chain, H256};
|
||||
|
||||
use crate::types::forkid::ValidationError;
|
||||
use crate::{capability::SharedCapabilityError, types::forkid::ValidationError};
|
||||
|
||||
/// Errors when sending/receiving messages
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
@ -14,6 +14,8 @@ pub enum EthStreamError {
|
||||
#[error(transparent)]
|
||||
Rlp(#[from] reth_rlp::DecodeError),
|
||||
#[error(transparent)]
|
||||
P2PStreamError(#[from] P2PStreamError),
|
||||
#[error(transparent)]
|
||||
HandshakeError(#[from] HandshakeError),
|
||||
#[error("message size ({0}) exceeds max length (10MB)")]
|
||||
MessageTooBig(usize),
|
||||
@ -37,3 +39,66 @@ pub enum HandshakeError {
|
||||
#[error("mismatched chain in Status message. expected: {expected:?}, got: {got:?}")]
|
||||
MismatchedChain { expected: Chain, got: Chain },
|
||||
}
|
||||
|
||||
/// Errors when sending/receiving p2p messages. These should result in kicking the peer.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum P2PStreamError {
|
||||
#[error(transparent)]
|
||||
Io(#[from] io::Error),
|
||||
#[error(transparent)]
|
||||
Rlp(#[from] reth_rlp::DecodeError),
|
||||
#[error(transparent)]
|
||||
Snap(#[from] snap::Error),
|
||||
#[error(transparent)]
|
||||
HandshakeError(#[from] P2PHandshakeError),
|
||||
#[error("message size ({message_size}) exceeds max length ({max_size})")]
|
||||
MessageTooBig { message_size: usize, max_size: usize },
|
||||
#[error("unknown reserved p2p message id: {0}")]
|
||||
UnknownReservedMessageId(u8),
|
||||
#[error("empty protocol message received")]
|
||||
EmptyProtocolMessage,
|
||||
#[error(transparent)]
|
||||
PingerError(#[from] PingerError),
|
||||
#[error("ping timed out with {0} retries")]
|
||||
PingTimeout(u8),
|
||||
#[error(transparent)]
|
||||
ParseVersionError(#[from] SharedCapabilityError),
|
||||
#[error("mismatched protocol version in Hello message. expected: {expected:?}, got: {got:?}")]
|
||||
MismatchedProtocolVersion { expected: u8, got: u8 },
|
||||
#[error("started ping task before the handshake completed")]
|
||||
PingBeforeHandshake,
|
||||
// TODO: remove / reconsider
|
||||
#[error("disconnected")]
|
||||
Disconnected,
|
||||
}
|
||||
|
||||
/// Errors when conducting a p2p handshake
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum P2PHandshakeError {
|
||||
#[error("hello message can only be recv/sent in handshake")]
|
||||
HelloNotInHandshake,
|
||||
#[error("received non-hello message when trying to handshake")]
|
||||
NonHelloMessageInHandshake,
|
||||
#[error("no capabilities shared with peer")]
|
||||
NoSharedCapabilities,
|
||||
#[error("no response received when sending out handshake")]
|
||||
NoResponse,
|
||||
#[error("handshake timed out")]
|
||||
Timeout,
|
||||
}
|
||||
|
||||
/// An error that can occur when interacting with a [`Pinger`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum PingerError {
|
||||
/// A ping was sent while the pinger was in the `TimedOut` state.
|
||||
#[error("ping sent while timed out")]
|
||||
PingWhileTimedOut,
|
||||
|
||||
/// A pong was received while the pinger was in the `Ready` state.
|
||||
#[error("pong received while ready")]
|
||||
PongWhileReady,
|
||||
|
||||
/// A pong was received while the pinger was in the `TimedOut` state.
|
||||
#[error("pong received while timed out")]
|
||||
PongWhileTimedOut,
|
||||
}
|
||||
|
||||
@ -2,12 +2,11 @@ use crate::{
|
||||
error::{EthStreamError, HandshakeError},
|
||||
types::{forkid::ForkFilter, EthMessage, ProtocolMessage, Status},
|
||||
};
|
||||
use bytes::BytesMut;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use futures::{ready, Sink, SinkExt, StreamExt};
|
||||
use pin_project::pin_project;
|
||||
use reth_rlp::{Decodable, Encodable};
|
||||
use std::{
|
||||
io,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
@ -35,11 +34,10 @@ impl<S> EthStream<S> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> EthStream<S>
|
||||
impl<S, E> EthStream<S>
|
||||
where
|
||||
S: Stream<Item = Result<bytes::BytesMut, io::Error>>
|
||||
+ Sink<bytes::Bytes, Error = io::Error>
|
||||
+ Unpin,
|
||||
S: Stream<Item = Result<bytes::BytesMut, E>> + Sink<bytes::Bytes, Error = E> + Unpin,
|
||||
EthStreamError: From<E>,
|
||||
{
|
||||
/// Given an instantiated transport layer, it proceeds to return an [`EthStream`]
|
||||
/// after performing a [`Status`] message handshake as specified in
|
||||
@ -105,9 +103,10 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Stream for EthStream<S>
|
||||
impl<S, E> Stream for EthStream<S>
|
||||
where
|
||||
S: Stream<Item = Result<bytes::BytesMut, io::Error>> + Unpin,
|
||||
S: Stream<Item = Result<bytes::BytesMut, E>> + Unpin,
|
||||
EthStreamError: From<E>,
|
||||
{
|
||||
type Item = Result<EthMessage, EthStreamError>;
|
||||
|
||||
@ -139,9 +138,10 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Sink<EthMessage> for EthStream<S>
|
||||
impl<S, E> Sink<EthMessage> for EthStream<S>
|
||||
where
|
||||
S: Sink<bytes::Bytes, Error = io::Error> + Unpin,
|
||||
S: Sink<Bytes, Error = E> + Unpin,
|
||||
EthStreamError: From<E>,
|
||||
{
|
||||
type Error = EthStreamError;
|
||||
|
||||
@ -175,6 +175,7 @@ where
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
p2pstream::{CapabilityMessage, HelloMessage, ProtocolVersion, UnauthedP2PStream},
|
||||
types::{broadcast::BlockHashNumber, forkid::ForkFilter, EthMessage, Status},
|
||||
EthStream, PassthroughCodec,
|
||||
};
|
||||
@ -298,4 +299,91 @@ mod tests {
|
||||
// make sure the server receives the message and asserts before ending the test
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ethstream_over_p2p() {
|
||||
// create a p2p stream and server, then confirm that the two are authed
|
||||
// create tcpstream
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let local_addr = listener.local_addr().unwrap();
|
||||
let server_key = SecretKey::new(&mut rand::thread_rng());
|
||||
let test_msg = EthMessage::NewBlockHashes(
|
||||
vec![
|
||||
BlockHashNumber { hash: reth_primitives::H256::random(), number: 5 },
|
||||
BlockHashNumber { hash: reth_primitives::H256::random(), number: 6 },
|
||||
]
|
||||
.into(),
|
||||
);
|
||||
|
||||
let genesis = H256::random();
|
||||
let fork_filter = ForkFilter::new(0, genesis, vec![]);
|
||||
|
||||
let status = Status {
|
||||
version: EthVersion::Eth67 as u8,
|
||||
chain: Chain::Mainnet.into(),
|
||||
total_difficulty: U256::from(0),
|
||||
blockhash: H256::random(),
|
||||
genesis,
|
||||
// Pass the current fork id.
|
||||
forkid: fork_filter.current(),
|
||||
};
|
||||
|
||||
let status_copy = status;
|
||||
let fork_filter_clone = fork_filter.clone();
|
||||
let test_msg_clone = test_msg.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
// roughly based off of the design of tokio::net::TcpListener
|
||||
let (incoming, _) = listener.accept().await.unwrap();
|
||||
let stream = ECIESStream::incoming(incoming, server_key).await.unwrap();
|
||||
|
||||
let server_hello = HelloMessage {
|
||||
protocol_version: ProtocolVersion::V5,
|
||||
client_version: "bitcoind/1.0.0".to_string(),
|
||||
capabilities: vec![CapabilityMessage::new(
|
||||
"eth".to_string(),
|
||||
EthVersion::Eth67 as usize,
|
||||
)],
|
||||
port: 30303,
|
||||
id: pk2id(&server_key.public_key(SECP256K1)),
|
||||
};
|
||||
|
||||
let unauthed_stream = UnauthedP2PStream::new(stream);
|
||||
let p2p_stream = unauthed_stream.handshake(server_hello).await.unwrap();
|
||||
let mut eth_stream = EthStream::new(p2p_stream);
|
||||
eth_stream.handshake(status_copy, fork_filter_clone).await.unwrap();
|
||||
|
||||
// use the stream to get the next message
|
||||
let message = eth_stream.next().await.unwrap().unwrap();
|
||||
assert_eq!(message, test_msg_clone);
|
||||
});
|
||||
|
||||
// create the server pubkey
|
||||
let server_id = pk2id(&server_key.public_key(SECP256K1));
|
||||
|
||||
let client_key = SecretKey::new(&mut rand::thread_rng());
|
||||
|
||||
let outgoing = TcpStream::connect(local_addr).await.unwrap();
|
||||
let sink = ECIESStream::connect(outgoing, client_key, server_id).await.unwrap();
|
||||
|
||||
let client_hello = HelloMessage {
|
||||
protocol_version: ProtocolVersion::V5,
|
||||
client_version: "bitcoind/1.0.0".to_string(),
|
||||
capabilities: vec![CapabilityMessage::new(
|
||||
"eth".to_string(),
|
||||
EthVersion::Eth67 as usize,
|
||||
)],
|
||||
port: 30303,
|
||||
id: pk2id(&client_key.public_key(SECP256K1)),
|
||||
};
|
||||
|
||||
let unauthed_stream = UnauthedP2PStream::new(sink);
|
||||
let p2p_stream = unauthed_stream.handshake(client_hello).await.unwrap();
|
||||
let mut client_stream = EthStream::new(p2p_stream);
|
||||
client_stream.handshake(status, fork_filter).await.unwrap();
|
||||
|
||||
client_stream.send(test_msg).await.unwrap();
|
||||
|
||||
// make sure the server receives the message and asserts before ending the test
|
||||
handle.await.unwrap();
|
||||
}
|
||||
}
|
||||
@ -9,9 +9,12 @@
|
||||
pub use tokio_util::codec::{
|
||||
LengthDelimitedCodec as PassthroughCodec, LengthDelimitedCodecError as PassthroughCodecError,
|
||||
};
|
||||
mod capability;
|
||||
pub mod error;
|
||||
mod stream;
|
||||
mod ethstream;
|
||||
mod p2pstream;
|
||||
mod pinger;
|
||||
pub mod types;
|
||||
pub use types::*;
|
||||
|
||||
pub use stream::EthStream;
|
||||
pub use ethstream::EthStream;
|
||||
|
||||
977
crates/net/eth-wire/src/p2pstream.rs
Normal file
977
crates/net/eth-wire/src/p2pstream.rs
Normal file
@ -0,0 +1,977 @@
|
||||
#![allow(dead_code, unreachable_pub, missing_docs, unused_variables)]
|
||||
use bytes::{Buf, Bytes, BytesMut};
|
||||
use futures::{ready, FutureExt, Sink, SinkExt, StreamExt};
|
||||
use pin_project::pin_project;
|
||||
use reth_primitives::H512 as PeerId;
|
||||
use reth_rlp::{Decodable, DecodeError, Encodable, RlpDecodable, RlpEncodable};
|
||||
use std::{
|
||||
collections::{BTreeSet, HashMap},
|
||||
fmt::Display,
|
||||
io,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
use tokio_stream::Stream;
|
||||
|
||||
use crate::{
|
||||
capability::SharedCapability,
|
||||
error::{P2PHandshakeError, P2PStreamError},
|
||||
pinger::{IntervalTimeoutPinger, PingerEvent},
|
||||
};
|
||||
|
||||
/// [`MAX_PAYLOAD_SIZE`] is the maximum size of an uncompressed message payload.
|
||||
/// This is defined in [EIP-706](https://eips.ethereum.org/EIPS/eip-706).
|
||||
const MAX_PAYLOAD_SIZE: usize = 16 * 1024 * 1024;
|
||||
|
||||
/// [`MAX_RESERVED_MESSAGE_ID`] is the maximum message ID reserved for the `p2p` subprotocol. If
|
||||
/// there are any incoming messages with an ID greater than this, they are subprotocol messages.
|
||||
const MAX_RESERVED_MESSAGE_ID: u8 = 0x0f;
|
||||
|
||||
/// [`MAX_P2P_MESSAGE_ID`] is the maximum message ID in use for the `p2p` subprotocol.
|
||||
const MAX_P2P_MESSAGE_ID: u8 = P2PMessageID::Pong as u8;
|
||||
|
||||
/// [`HANDSHAKE_TIMEOUT`] determines the amount of time to wait before determining that a `p2p`
|
||||
/// handshake has timed out.
|
||||
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
|
||||
/// [`PING_TIMEOUT`] determines the amount of time to wait before determining that a `p2p` ping has
|
||||
/// timed out.
|
||||
const PING_TIMEOUT: Duration = Duration::from_secs(15);
|
||||
|
||||
/// [`PING_INTERVAL`] determines the amount of time to wait between sending `p2p` ping messages
|
||||
/// when the peer is responsive.
|
||||
const PING_INTERVAL: Duration = Duration::from_secs(60);
|
||||
|
||||
/// [`GRACE_PERIOD`] determines the amount of time to wait for a peer to disconnect after sending a
|
||||
/// [`P2PMessage::Disconnect`] message.
|
||||
const GRACE_PERIOD: Duration = Duration::from_secs(2);
|
||||
|
||||
/// [`MAX_FAILED_PINGS`] determines the maximum number of failed ping attempts before disconnecting
|
||||
/// from a peer.
|
||||
const MAX_FAILED_PINGS: u8 = 3;
|
||||
|
||||
/// An un-authenticated `P2PStream`. This is consumed and returns a [`P2PStream`] after the `Hello`
|
||||
/// handshake is completed.
|
||||
#[pin_project]
|
||||
pub struct UnauthedP2PStream<S> {
|
||||
#[pin]
|
||||
inner: S,
|
||||
}
|
||||
|
||||
impl<S> UnauthedP2PStream<S> {
|
||||
/// Create a new `UnauthedP2PStream` from a `Stream` of bytes.
|
||||
pub fn new(inner: S) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> UnauthedP2PStream<S>
|
||||
where
|
||||
S: Stream<Item = Result<BytesMut, io::Error>> + Sink<Bytes, Error = io::Error> + Unpin,
|
||||
{
|
||||
/// Consumes the `UnauthedP2PStream` and returns a `P2PStream` after the `Hello` handshake is
|
||||
/// completed.
|
||||
pub async fn handshake(mut self, hello: HelloMessage) -> Result<P2PStream<S>, P2PStreamError> {
|
||||
tracing::trace!("sending p2p hello ...");
|
||||
|
||||
// send our hello message with the Sink
|
||||
let mut raw_hello_bytes = BytesMut::new();
|
||||
P2PMessage::Hello(hello.clone()).encode(&mut raw_hello_bytes);
|
||||
self.inner.send(raw_hello_bytes.into()).await?;
|
||||
|
||||
tracing::trace!("waiting for p2p hello from peer ...");
|
||||
|
||||
let hello_bytes = tokio::time::timeout(HANDSHAKE_TIMEOUT, self.inner.next())
|
||||
.await
|
||||
.or(Err(P2PStreamError::HandshakeError(P2PHandshakeError::Timeout)))?
|
||||
.ok_or(P2PStreamError::HandshakeError(P2PHandshakeError::NoResponse))??;
|
||||
|
||||
// let's check the compressed length first, we will need to check again once confirming
|
||||
// that it contains snappy-compressed data (this will be the case for all non-p2p messages).
|
||||
if hello_bytes.len() > MAX_PAYLOAD_SIZE {
|
||||
return Err(P2PStreamError::MessageTooBig {
|
||||
message_size: hello_bytes.len(),
|
||||
max_size: MAX_PAYLOAD_SIZE,
|
||||
})
|
||||
}
|
||||
|
||||
// get the message id
|
||||
let id = *hello_bytes.first().ok_or_else(|| P2PStreamError::EmptyProtocolMessage)?;
|
||||
|
||||
// the first message sent MUST be the hello message
|
||||
if id != P2PMessageID::Hello as u8 {
|
||||
return Err(P2PStreamError::HandshakeError(
|
||||
P2PHandshakeError::NonHelloMessageInHandshake,
|
||||
))
|
||||
}
|
||||
|
||||
let their_hello = match P2PMessage::decode(&mut &hello_bytes[..])? {
|
||||
P2PMessage::Hello(hello) => Ok(hello),
|
||||
_ => {
|
||||
// TODO: this should never occur due to the id check
|
||||
Err(P2PStreamError::HandshakeError(P2PHandshakeError::NonHelloMessageInHandshake))
|
||||
}
|
||||
}?;
|
||||
|
||||
// TODO: explicitly document that we only support v5.
|
||||
if their_hello.protocol_version != ProtocolVersion::V5 {
|
||||
// TODO: do we want to send a `Disconnect` message here?
|
||||
return Err(P2PStreamError::MismatchedProtocolVersion {
|
||||
expected: ProtocolVersion::V5 as u8,
|
||||
got: their_hello.protocol_version as u8,
|
||||
})
|
||||
}
|
||||
|
||||
// determine shared capabilities (currently returns only one capability)
|
||||
let capability = set_capability_offsets(hello.capabilities, their_hello.capabilities)?;
|
||||
|
||||
let stream = P2PStream::new(self.inner, capability);
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
}
|
||||
|
||||
/// A P2PStream wraps over any `Stream` that yields bytes and makes it compatible with `p2p`
|
||||
/// protocol messages.
|
||||
#[pin_project]
|
||||
pub struct P2PStream<S> {
|
||||
#[pin]
|
||||
inner: S,
|
||||
|
||||
/// The snappy encoder used for compressing outgoing messages
|
||||
encoder: snap::raw::Encoder,
|
||||
|
||||
/// The snappy decoder used for decompressing incoming messages
|
||||
decoder: snap::raw::Decoder,
|
||||
|
||||
/// The state machine used for keeping track of the peer's ping status.
|
||||
pinger: IntervalTimeoutPinger,
|
||||
|
||||
/// The supported capability for this stream.
|
||||
shared_capability: SharedCapability,
|
||||
}
|
||||
|
||||
impl<S> P2PStream<S> {
|
||||
/// Create a new unauthed [`P2PStream`] from the provided stream. You will need to manually
|
||||
/// handshake with a peer.
|
||||
pub fn new(inner: S, capability: SharedCapability) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
encoder: snap::raw::Encoder::new(),
|
||||
decoder: snap::raw::Decoder::new(),
|
||||
pinger: IntervalTimeoutPinger::new(MAX_FAILED_PINGS, PING_INTERVAL, PING_TIMEOUT),
|
||||
shared_capability: capability,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// S must also be `Sink` because we need to be able to respond with ping messages to follow the
|
||||
// protocol
|
||||
impl<S> Stream for P2PStream<S>
|
||||
where
|
||||
S: Stream<Item = Result<BytesMut, io::Error>> + Sink<Bytes, Error = io::Error> + Unpin,
|
||||
{
|
||||
type Item = Result<BytesMut, P2PStreamError>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let mut this = self.project();
|
||||
|
||||
// poll the pinger to determine if we should send a ping
|
||||
let pinger_res = ready!(Pin::new(&mut this.pinger).poll_next(cx));
|
||||
match pinger_res {
|
||||
Some(Ok(PingerEvent::Ping)) => {
|
||||
// encode the ping message
|
||||
let mut ping_bytes = BytesMut::new();
|
||||
P2PMessage::Ping.encode(&mut ping_bytes);
|
||||
|
||||
// TODO: fix use of Sink API
|
||||
let send_res = Pin::new(&mut this.inner).send(ping_bytes.into()).poll_unpin(cx)?;
|
||||
ready!(send_res)
|
||||
}
|
||||
// either None (stream ended) or Some(PingEvent::Timeout) or Err(err)
|
||||
_ => {
|
||||
// encode the disconnect message
|
||||
let mut disconnect_bytes = BytesMut::new();
|
||||
P2PMessage::Disconnect(DisconnectReason::PingTimeout).encode(&mut disconnect_bytes);
|
||||
|
||||
// TODO: fix use of Sink API
|
||||
let send_res =
|
||||
Pin::new(&mut this.inner).send(disconnect_bytes.into()).poll_unpin(cx)?;
|
||||
ready!(send_res);
|
||||
|
||||
// since the ping stream has timed out, let's send a None
|
||||
return Poll::Ready(None)
|
||||
}
|
||||
};
|
||||
|
||||
// we should loop here to ensure we don't return Poll::Pending if we have a message to
|
||||
// return behind any pings we need to respond to
|
||||
while let Poll::Ready(res) = this.inner.as_mut().poll_next(cx) {
|
||||
let bytes = match res {
|
||||
Some(Ok(bytes)) => bytes,
|
||||
Some(Err(err)) => return Poll::Ready(Some(Err(err.into()))),
|
||||
None => return Poll::Ready(None),
|
||||
};
|
||||
|
||||
let id = *bytes.first().ok_or(P2PStreamError::EmptyProtocolMessage)?;
|
||||
if id == P2PMessageID::Ping as u8 {
|
||||
// TODO: do we need to decode the ping?
|
||||
// we have received a ping, so we will send a pong
|
||||
let mut pong_bytes = BytesMut::new();
|
||||
P2PMessage::Pong.encode(&mut pong_bytes);
|
||||
|
||||
// TODO: fix use of Sink API
|
||||
let send_res = Pin::new(&mut this.inner).send(pong_bytes.into()).poll_unpin(cx)?;
|
||||
ready!(send_res)
|
||||
|
||||
// continue to the next message if there is one
|
||||
} else if id == P2PMessageID::Disconnect as u8 {
|
||||
let reason = DisconnectReason::decode(&mut &bytes[1..])?;
|
||||
// TODO: do something with the reason
|
||||
return Poll::Ready(Some(Err(P2PStreamError::Disconnected)))
|
||||
} else if id == P2PMessageID::Hello as u8 {
|
||||
// we have received a hello message outside of the handshake, so we will return an
|
||||
// error
|
||||
return Poll::Ready(Some(Err(P2PStreamError::HandshakeError(
|
||||
P2PHandshakeError::HelloNotInHandshake,
|
||||
))))
|
||||
} else if id == P2PMessageID::Pong as u8 {
|
||||
// TODO: do we need to decode the pong?
|
||||
// if we were waiting for a pong, this will reset the pinger state
|
||||
this.pinger.pong_received()?
|
||||
} else if id > MAX_P2P_MESSAGE_ID && id <= MAX_RESERVED_MESSAGE_ID {
|
||||
// we have received an unknown reserved message
|
||||
return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(id))))
|
||||
} else {
|
||||
// first check that the compressed message length does not exceed the max message
|
||||
// size
|
||||
let decompressed_len = snap::raw::decompress_len(&bytes[1..])?;
|
||||
if decompressed_len > MAX_PAYLOAD_SIZE {
|
||||
return Poll::Ready(Some(Err(P2PStreamError::MessageTooBig {
|
||||
message_size: decompressed_len,
|
||||
max_size: MAX_PAYLOAD_SIZE,
|
||||
})))
|
||||
}
|
||||
|
||||
// then decompress the message
|
||||
let mut decompress_buf = BytesMut::zeroed(decompressed_len + 1);
|
||||
|
||||
// we have a subprotocol message that needs to be sent in the stream.
|
||||
// first, switch the message id based on offset so the next layer can decode it
|
||||
// without being aware of the p2p stream's state (shared capabilities / the message
|
||||
// id offset)
|
||||
decompress_buf[0] = bytes[0] - this.shared_capability.offset();
|
||||
this.decoder.decompress(&bytes[1..], &mut decompress_buf[1..])?;
|
||||
|
||||
return Poll::Ready(Some(Ok(decompress_buf)))
|
||||
}
|
||||
}
|
||||
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Sink<Bytes> for P2PStream<S>
|
||||
where
|
||||
S: Sink<Bytes, Error = io::Error> + Unpin,
|
||||
{
|
||||
type Error = P2PStreamError;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.project().inner.poll_ready(cx).map_err(Into::into)
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
|
||||
let this = self.project();
|
||||
|
||||
let mut compressed = BytesMut::zeroed(1 + snap::raw::max_compress_len(item.len() - 1));
|
||||
|
||||
// all messages sent in this stream are subprotocol messages, so we need to switch the
|
||||
// message id based on the offset
|
||||
compressed[0] = item[0] + this.shared_capability.offset();
|
||||
let compressed_size = this.encoder.compress(&item[1..], &mut compressed[1..])?;
|
||||
|
||||
// truncate the compressed buffer to the actual compressed size (plus one for the message
|
||||
// id)
|
||||
compressed.truncate(compressed_size + 1);
|
||||
|
||||
this.inner.start_send(compressed.freeze())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.project().inner.poll_flush(cx).map_err(Into::into)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.project().inner.poll_close(cx).map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
/// Determines the offsets for each shared capability between the input list of peer
|
||||
/// capabilities and the input list of locally supported capabilities.
|
||||
///
|
||||
/// Currently only `eth` versions 66 and 67 are supported.
|
||||
pub fn set_capability_offsets(
|
||||
local_capabilities: Vec<CapabilityMessage>,
|
||||
peer_capabilities: Vec<CapabilityMessage>,
|
||||
) -> Result<SharedCapability, P2PStreamError> {
|
||||
// find intersection of capabilities
|
||||
let our_capabilities_map =
|
||||
local_capabilities.into_iter().map(|c| (c.name, c.version)).collect::<HashMap<_, _>>();
|
||||
|
||||
// map of capability name to version
|
||||
let mut shared_capabilities = HashMap::new();
|
||||
|
||||
// sorted list of capability names
|
||||
// TODO: the Ord implementation for strings says the following:
|
||||
// > Strings are ordered lexicographically by their byte values. This orders Unicode code
|
||||
// points based on their positions in the code charts. This is not necessarily the same as
|
||||
// “alphabetical” order.
|
||||
// We need to implement a case-sensitive alphabetical sort
|
||||
let mut shared_capability_names = BTreeSet::new();
|
||||
|
||||
// find highest shared version of each shared capability
|
||||
for capability in peer_capabilities {
|
||||
// if this is Some, we share this capability
|
||||
if let Some(version) = our_capabilities_map.get(&capability.name) {
|
||||
// If multiple versions are shared of the same (equal name) capability, the numerically
|
||||
// highest wins, others are ignored
|
||||
if capability.version <= *version {
|
||||
shared_capabilities.insert(capability.name.clone(), capability.version);
|
||||
shared_capability_names.insert(capability.name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// disconnect if we don't share any capabilities
|
||||
if shared_capabilities.is_empty() {
|
||||
// TODO: send a disconnect message? if we want to do this, this will need to be a member
|
||||
// method of `UnauthedP2PStream` so it can access the inner stream
|
||||
return Err(P2PStreamError::HandshakeError(P2PHandshakeError::NoSharedCapabilities))
|
||||
}
|
||||
|
||||
// order versions based on capability name (alphabetical) and select offsets based on
|
||||
// BASE_OFFSET + prev_total_message
|
||||
let mut shared_with_offsets = Vec::new();
|
||||
|
||||
// Message IDs are assumed to be compact from ID 0x10 onwards (0x00-0x0f is reserved for the
|
||||
// "p2p" capability) and given to each shared (equal-version, equal-name) capability in
|
||||
// alphabetic order.
|
||||
let mut offset = MAX_RESERVED_MESSAGE_ID + 1;
|
||||
for name in shared_capability_names {
|
||||
let version = shared_capabilities.get(&name).unwrap();
|
||||
|
||||
let shared_capability = SharedCapability::new(&name, *version as u8, offset)?;
|
||||
|
||||
match shared_capability {
|
||||
SharedCapability::UnknownCapability { .. } => {
|
||||
// Capabilities which are not shared are ignored
|
||||
tracing::warn!("unknown capability: name={:?}, version={}", name, version,);
|
||||
}
|
||||
SharedCapability::Eth { .. } => {
|
||||
shared_with_offsets.push(shared_capability.clone());
|
||||
|
||||
// increment the offset if the capability is known
|
||||
offset += shared_capability.num_messages()?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: support multiple capabilities - we would need a new Stream type to go on top of
|
||||
// `P2PStream` containing its capability. `P2PStream` would still send pings and handle
|
||||
// pongs, but instead contain a map of capabilities to their respective stream / channel.
|
||||
// Each channel would be responsible for containing the offset for that stream and would
|
||||
// only increment / decrement message IDs.
|
||||
// NOTE: since the `P2PStream` currently only supports one capability, we set the
|
||||
// capability with the lowest offset.
|
||||
Ok(shared_with_offsets
|
||||
.first()
|
||||
.ok_or_else(|| P2PStreamError::HandshakeError(P2PHandshakeError::NoSharedCapabilities))?
|
||||
.clone())
|
||||
}
|
||||
|
||||
/// This represents only the reserved `p2p` subprotocol messages.
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum P2PMessage {
|
||||
/// The first packet sent over the connection, and sent once by both sides.
|
||||
Hello(HelloMessage),
|
||||
|
||||
/// Inform the peer that a disconnection is imminent; if received, a peer should disconnect
|
||||
/// immediately.
|
||||
Disconnect(DisconnectReason),
|
||||
|
||||
/// Requests an immediate reply of [`Pong`] from the peer.
|
||||
Ping,
|
||||
|
||||
/// Reply to the peer's [`Ping`] packet.
|
||||
Pong,
|
||||
}
|
||||
|
||||
impl P2PMessage {
|
||||
/// Gets the [`P2PMessageID`] for the given message.
|
||||
pub fn message_id(&self) -> P2PMessageID {
|
||||
match self {
|
||||
P2PMessage::Hello(_) => P2PMessageID::Hello,
|
||||
P2PMessage::Disconnect(_) => P2PMessageID::Disconnect,
|
||||
P2PMessage::Ping => P2PMessageID::Ping,
|
||||
P2PMessage::Pong => P2PMessageID::Pong,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Encodable for P2PMessage {
|
||||
fn length(&self) -> usize {
|
||||
let payload_len = match self {
|
||||
P2PMessage::Hello(msg) => msg.length(),
|
||||
P2PMessage::Disconnect(msg) => msg.length(),
|
||||
P2PMessage::Ping => 3, // len([0x01, 0x00, 0x80]) = 3
|
||||
P2PMessage::Pong => 3, // len([0x01, 0x00, 0x80]) = 3
|
||||
};
|
||||
payload_len + 1 // (1 for length of p2p message id)
|
||||
}
|
||||
|
||||
fn encode(&self, out: &mut dyn bytes::BufMut) {
|
||||
out.put_u8(self.message_id() as u8);
|
||||
match self {
|
||||
P2PMessage::Hello(msg) => msg.encode(out),
|
||||
P2PMessage::Disconnect(msg) => msg.encode(out),
|
||||
P2PMessage::Ping => {
|
||||
out.put_u8(0x01);
|
||||
out.put_u8(0x00);
|
||||
out.put_u8(0x80);
|
||||
}
|
||||
P2PMessage::Pong => {
|
||||
out.put_u8(0x01);
|
||||
out.put_u8(0x00);
|
||||
out.put_u8(0x80);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Decodable for P2PMessage {
|
||||
fn decode(buf: &mut &[u8]) -> Result<Self, DecodeError> {
|
||||
let first = buf.first().expect("cannot decode empty p2p message");
|
||||
let id = P2PMessageID::try_from(*first)
|
||||
.or(Err(DecodeError::Custom("unknown p2p message id")))?;
|
||||
buf.advance(1);
|
||||
match id {
|
||||
P2PMessageID::Hello => Ok(P2PMessage::Hello(HelloMessage::decode(buf)?)),
|
||||
P2PMessageID::Disconnect => Ok(P2PMessage::Disconnect(DisconnectReason::decode(buf)?)),
|
||||
P2PMessageID::Ping => {
|
||||
// len([0x01, 0x00, 0x80]) = 3
|
||||
buf.advance(3);
|
||||
Ok(P2PMessage::Ping)
|
||||
}
|
||||
P2PMessageID::Pong => {
|
||||
// len([0x01, 0x00, 0x80]) = 3
|
||||
buf.advance(3);
|
||||
Ok(P2PMessage::Pong)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Message IDs for `p2p` subprotocol messages.
|
||||
pub enum P2PMessageID {
|
||||
/// Message ID for the [`P2PMessage::Hello`] message.
|
||||
Hello = 0x00,
|
||||
|
||||
/// Message ID for the [`P2PMessage::Disconnect`] message.
|
||||
Disconnect = 0x01,
|
||||
|
||||
/// Message ID for the [`P2PMessage::Ping`] message.
|
||||
Ping = 0x02,
|
||||
|
||||
/// Message ID for the [`P2PMessage::Pong`] message.
|
||||
Pong = 0x03,
|
||||
}
|
||||
|
||||
impl From<P2PMessage> for P2PMessageID {
|
||||
fn from(msg: P2PMessage) -> Self {
|
||||
match msg {
|
||||
P2PMessage::Hello(_) => P2PMessageID::Hello,
|
||||
P2PMessage::Disconnect(_) => P2PMessageID::Disconnect,
|
||||
P2PMessage::Ping => P2PMessageID::Ping,
|
||||
P2PMessage::Pong => P2PMessageID::Pong,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for P2PMessageID {
|
||||
type Error = P2PStreamError;
|
||||
|
||||
fn try_from(id: u8) -> Result<Self, Self::Error> {
|
||||
match id {
|
||||
0x00 => Ok(P2PMessageID::Hello),
|
||||
0x01 => Ok(P2PMessageID::Disconnect),
|
||||
0x02 => Ok(P2PMessageID::Ping),
|
||||
0x03 => Ok(P2PMessageID::Pong),
|
||||
_ => Err(P2PStreamError::UnknownReservedMessageId(id)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A message indicating a supported capability and capability version.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, RlpEncodable, RlpDecodable)]
|
||||
pub struct CapabilityMessage {
|
||||
/// The name of the subprotocol
|
||||
pub name: String,
|
||||
/// The version of the subprotocol
|
||||
pub version: usize,
|
||||
}
|
||||
|
||||
impl CapabilityMessage {
|
||||
/// Create a new `CapabilityMessage` with the given name and version.
|
||||
pub fn new(name: String, version: usize) -> Self {
|
||||
Self { name, version }
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: determine if we should allow for the extra fields at the end like EIP-706 suggests
|
||||
/// Message used in the `p2p` handshake, containing information about the supported RLPx protocol
|
||||
/// version and capabilities.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, RlpEncodable, RlpDecodable)]
|
||||
pub struct HelloMessage {
|
||||
/// The version of the `p2p` protocol.
|
||||
pub protocol_version: ProtocolVersion,
|
||||
/// Specifies the client software identity, as a human-readable string (e.g.
|
||||
/// "Ethereum(++)/1.0.0").
|
||||
pub client_version: String,
|
||||
/// The list of supported capabilities and their versions.
|
||||
pub capabilities: Vec<CapabilityMessage>,
|
||||
/// The port that the client is listening on, zero indicates the client is not listening.
|
||||
pub port: u16,
|
||||
/// The secp256k1 public key corresponding to the node's private key.
|
||||
pub id: PeerId,
|
||||
}
|
||||
|
||||
/// RLPx `p2p` protocol version
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||
pub enum ProtocolVersion {
|
||||
/// `p2p` version 4
|
||||
V4 = 4,
|
||||
/// `p2p` version 5
|
||||
V5 = 5,
|
||||
}
|
||||
|
||||
impl Encodable for ProtocolVersion {
|
||||
fn length(&self) -> usize {
|
||||
// the version should be a single byte
|
||||
(*self as u8).length()
|
||||
}
|
||||
fn encode(&self, out: &mut dyn bytes::BufMut) {
|
||||
(*self as u8).encode(out)
|
||||
}
|
||||
}
|
||||
|
||||
impl Decodable for ProtocolVersion {
|
||||
fn decode(buf: &mut &[u8]) -> Result<Self, DecodeError> {
|
||||
let version = u8::decode(buf)?;
|
||||
match version {
|
||||
4 => Ok(ProtocolVersion::V4),
|
||||
5 => Ok(ProtocolVersion::V5),
|
||||
_ => Err(DecodeError::Custom("unknown p2p protocol version")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// RLPx disconnect reason.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum DisconnectReason {
|
||||
/// Disconnect requested by the local node or remote peer.
|
||||
DisconnectRequested = 0x00,
|
||||
/// TCP related error
|
||||
TcpSubsystemError = 0x01,
|
||||
/// Breach of protocol at the transport or p2p level
|
||||
ProtocolBreach = 0x02,
|
||||
/// Node has no matching protocols.
|
||||
UselessPeer = 0x03,
|
||||
/// Either the remote or local node has too many peers.
|
||||
TooManyPeers = 0x04,
|
||||
/// Already connected to the peer.
|
||||
AlreadyConnected = 0x05,
|
||||
/// `p2p` protocol version is incompatible
|
||||
IncompatibleP2PProtocolVersion = 0x06,
|
||||
NullNodeIdentity = 0x07,
|
||||
ClientQuitting = 0x08,
|
||||
UnexpectedHandshakeIdentity = 0x09,
|
||||
/// The node is connected to itself
|
||||
ConnectedToSelf = 0x0a,
|
||||
/// Peer or local node did not respond to a ping in time.
|
||||
PingTimeout = 0x0b,
|
||||
/// Peer or local node violated a subprotocol-specific rule.
|
||||
SubprotocolSpecific = 0x10,
|
||||
}
|
||||
|
||||
impl Display for DisconnectReason {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let message = match self {
|
||||
DisconnectReason::DisconnectRequested => "Disconnect requested",
|
||||
DisconnectReason::TcpSubsystemError => "TCP sub-system error",
|
||||
DisconnectReason::ProtocolBreach => {
|
||||
"Breach of protocol, e.g. a malformed message, bad RLP, ..."
|
||||
}
|
||||
DisconnectReason::UselessPeer => "Useless peer",
|
||||
DisconnectReason::TooManyPeers => "Too many peers",
|
||||
DisconnectReason::AlreadyConnected => "Already connected",
|
||||
DisconnectReason::IncompatibleP2PProtocolVersion => "Incompatible P2P protocol version",
|
||||
DisconnectReason::NullNodeIdentity => {
|
||||
"Null node identity received - this is automatically invalid"
|
||||
}
|
||||
DisconnectReason::ClientQuitting => "Client quitting",
|
||||
DisconnectReason::UnexpectedHandshakeIdentity => "Unexpected identity in handshake",
|
||||
DisconnectReason::ConnectedToSelf => {
|
||||
"Identity is the same as this node (i.e. connected to itself)"
|
||||
}
|
||||
DisconnectReason::PingTimeout => "Ping timeout",
|
||||
DisconnectReason::SubprotocolSpecific => "Some other reason specific to a subprotocol",
|
||||
};
|
||||
|
||||
write!(f, "{}", message)
|
||||
}
|
||||
}
|
||||
|
||||
/// This represents an unknown disconnect reason with the given code.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UnknownDisconnectReason(u8);
|
||||
|
||||
impl TryFrom<u8> for DisconnectReason {
|
||||
// This error type should not be used to crash the node, but rather to log the error and
|
||||
// disconnect the peer.
|
||||
type Error = UnknownDisconnectReason;
|
||||
|
||||
fn try_from(value: u8) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
0x00 => Ok(DisconnectReason::DisconnectRequested),
|
||||
0x01 => Ok(DisconnectReason::TcpSubsystemError),
|
||||
0x02 => Ok(DisconnectReason::ProtocolBreach),
|
||||
0x03 => Ok(DisconnectReason::UselessPeer),
|
||||
0x04 => Ok(DisconnectReason::TooManyPeers),
|
||||
0x05 => Ok(DisconnectReason::AlreadyConnected),
|
||||
0x06 => Ok(DisconnectReason::IncompatibleP2PProtocolVersion),
|
||||
0x07 => Ok(DisconnectReason::NullNodeIdentity),
|
||||
0x08 => Ok(DisconnectReason::ClientQuitting),
|
||||
0x09 => Ok(DisconnectReason::UnexpectedHandshakeIdentity),
|
||||
0x0a => Ok(DisconnectReason::ConnectedToSelf),
|
||||
0x0b => Ok(DisconnectReason::PingTimeout),
|
||||
0x10 => Ok(DisconnectReason::SubprotocolSpecific),
|
||||
_ => Err(UnknownDisconnectReason(value)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Encodable for DisconnectReason {
|
||||
fn length(&self) -> usize {
|
||||
// disconnect reasons are snappy encoded as follows:
|
||||
// [0x01, 0x00, reason as u8]
|
||||
// this is 3 bytes
|
||||
3
|
||||
}
|
||||
fn encode(&self, out: &mut dyn bytes::BufMut) {
|
||||
// disconnect reasons are snappy encoded as follows:
|
||||
// [0x01, 0x00, reason as u8]
|
||||
// this is 3 bytes
|
||||
out.put_u8(0x01);
|
||||
out.put_u8(0x00);
|
||||
out.put_u8(*self as u8);
|
||||
}
|
||||
}
|
||||
|
||||
impl Decodable for DisconnectReason {
|
||||
fn decode(buf: &mut &[u8]) -> Result<Self, DecodeError> {
|
||||
let first = *buf.first().expect("disconnect reason should have at least 1 byte");
|
||||
buf.advance(1);
|
||||
if first != 0x01 {
|
||||
return Err(DecodeError::Custom("invalid disconnect reason - invalid snappy header"))
|
||||
}
|
||||
|
||||
let second = *buf.first().expect("disconnect reason should have at least 2 bytes");
|
||||
buf.advance(1);
|
||||
if second != 0x00 {
|
||||
// TODO: make sure this error message is correct
|
||||
return Err(DecodeError::Custom("invalid disconnect reason - invalid snappy header"))
|
||||
}
|
||||
|
||||
let reason = *buf.first().expect("disconnect reason should have 3 bytes");
|
||||
buf.advance(1);
|
||||
DisconnectReason::try_from(reason)
|
||||
.map_err(|_| DecodeError::Custom("unknown disconnect reason"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use reth_ecies::util::pk2id;
|
||||
use reth_rlp::EMPTY_STRING_CODE;
|
||||
use secp256k1::{SecretKey, SECP256K1};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio_util::codec::Decoder;
|
||||
|
||||
use crate::EthVersion;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_handshake_passthrough() {
|
||||
// create a p2p stream and server, then confirm that the two are authed
|
||||
// create tcpstream
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let local_addr = listener.local_addr().unwrap();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
// roughly based off of the design of tokio::net::TcpListener
|
||||
let (incoming, _) = listener.accept().await.unwrap();
|
||||
let stream = crate::PassthroughCodec::default().framed(incoming);
|
||||
|
||||
let server_key = SecretKey::new(&mut rand::thread_rng());
|
||||
let server_hello = HelloMessage {
|
||||
protocol_version: ProtocolVersion::V5,
|
||||
client_version: "bitcoind/1.0.0".to_string(),
|
||||
capabilities: vec![EthVersion::Eth67.into()],
|
||||
port: 30303,
|
||||
id: pk2id(&server_key.public_key(SECP256K1)),
|
||||
};
|
||||
|
||||
let unauthed_stream = UnauthedP2PStream::new(stream);
|
||||
let p2p_stream = unauthed_stream.handshake(server_hello).await.unwrap();
|
||||
|
||||
// ensure that the two share a single capability, eth67
|
||||
assert_eq!(
|
||||
p2p_stream.shared_capability,
|
||||
SharedCapability::Eth {
|
||||
version: EthVersion::Eth67,
|
||||
offset: MAX_RESERVED_MESSAGE_ID + 1
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
let client_key = SecretKey::new(&mut rand::thread_rng());
|
||||
|
||||
let outgoing = TcpStream::connect(local_addr).await.unwrap();
|
||||
let sink = crate::PassthroughCodec::default().framed(outgoing);
|
||||
|
||||
let client_hello = HelloMessage {
|
||||
protocol_version: ProtocolVersion::V5,
|
||||
client_version: "bitcoind/1.0.0".to_string(),
|
||||
capabilities: vec![EthVersion::Eth67.into()],
|
||||
port: 30303,
|
||||
id: pk2id(&client_key.public_key(SECP256K1)),
|
||||
};
|
||||
|
||||
let unauthed_stream = UnauthedP2PStream::new(sink);
|
||||
let p2p_stream = unauthed_stream.handshake(client_hello).await.unwrap();
|
||||
|
||||
// ensure that the two share a single capability, eth67
|
||||
assert_eq!(
|
||||
p2p_stream.shared_capability,
|
||||
SharedCapability::Eth {
|
||||
version: EthVersion::Eth67,
|
||||
offset: MAX_RESERVED_MESSAGE_ID + 1
|
||||
}
|
||||
);
|
||||
|
||||
// make sure the server receives the message and asserts before ending the test
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ping_snappy_encoding_parity() {
|
||||
// encode ping using our `Encodable` implementation
|
||||
let ping = P2PMessage::Ping;
|
||||
let mut ping_encoded = Vec::new();
|
||||
ping.encode(&mut ping_encoded);
|
||||
|
||||
// the definition of ping is 0x80 (an empty rlp string)
|
||||
let ping_raw = vec![EMPTY_STRING_CODE];
|
||||
let mut snappy_encoder = snap::raw::Encoder::new();
|
||||
let ping_compressed = snappy_encoder.compress_vec(&ping_raw).unwrap();
|
||||
let mut ping_expected = vec![P2PMessageID::Ping as u8];
|
||||
ping_expected.extend(&ping_compressed);
|
||||
|
||||
// ensure that the two encodings are equal
|
||||
assert_eq!(
|
||||
ping_expected, ping_encoded,
|
||||
"left: {:#x?}, right: {:#x?}",
|
||||
ping_expected, ping_encoded
|
||||
);
|
||||
|
||||
// also ensure that the length is correct
|
||||
assert_eq!(ping_expected.len(), P2PMessage::Ping.length());
|
||||
|
||||
// try to decode using Decodable
|
||||
let p2p_message = P2PMessage::decode(&mut &ping_expected[..]).unwrap();
|
||||
assert_eq!(p2p_message, P2PMessage::Ping);
|
||||
|
||||
// finally decode the encoded message with snappy
|
||||
let mut snappy_decoder = snap::raw::Decoder::new();
|
||||
|
||||
// the message id is not compressed, only compress the latest bits
|
||||
let decompressed = snappy_decoder.decompress_vec(&ping_encoded[1..]).unwrap();
|
||||
|
||||
assert_eq!(decompressed, ping_raw);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pong_snappy_encoding_parity() {
|
||||
// encode pong using our `Encodable` implementation
|
||||
let pong = P2PMessage::Pong;
|
||||
let mut pong_encoded = Vec::new();
|
||||
pong.encode(&mut pong_encoded);
|
||||
|
||||
// the definition of pong is 0x80 (an empty rlp string)
|
||||
let pong_raw = vec![EMPTY_STRING_CODE];
|
||||
let mut snappy_encoder = snap::raw::Encoder::new();
|
||||
let pong_compressed = snappy_encoder.compress_vec(&pong_raw).unwrap();
|
||||
let mut pong_expected = vec![P2PMessageID::Pong as u8];
|
||||
pong_expected.extend(&pong_compressed);
|
||||
|
||||
// ensure that the two encodings are equal
|
||||
assert_eq!(
|
||||
pong_expected, pong_encoded,
|
||||
"left: {:#x?}, right: {:#x?}",
|
||||
pong_expected, pong_encoded
|
||||
);
|
||||
|
||||
// also ensure that the length is correct
|
||||
assert_eq!(pong_expected.len(), P2PMessage::Pong.length());
|
||||
|
||||
// try to decode using Decodable
|
||||
let p2p_message = P2PMessage::decode(&mut &pong_expected[..]).unwrap();
|
||||
assert_eq!(p2p_message, P2PMessage::Pong);
|
||||
|
||||
// finally decode the encoded message with snappy
|
||||
let mut snappy_decoder = snap::raw::Decoder::new();
|
||||
|
||||
// the message id is not compressed, only compress the latest bits
|
||||
let decompressed = snappy_decoder.decompress_vec(&pong_encoded[1..]).unwrap();
|
||||
|
||||
assert_eq!(decompressed, pong_raw);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hello_encoding_round_trip() {
|
||||
let secret_key = SecretKey::new(&mut rand::thread_rng());
|
||||
let id = pk2id(&secret_key.public_key(SECP256K1));
|
||||
let hello = P2PMessage::Hello(HelloMessage {
|
||||
protocol_version: ProtocolVersion::V5,
|
||||
client_version: "reth/0.1.0".to_string(),
|
||||
capabilities: vec![CapabilityMessage::new(
|
||||
"eth".to_string(),
|
||||
EthVersion::Eth67 as usize,
|
||||
)],
|
||||
port: 30303,
|
||||
id,
|
||||
});
|
||||
|
||||
let mut hello_encoded = Vec::new();
|
||||
hello.encode(&mut hello_encoded);
|
||||
|
||||
let hello_decoded = P2PMessage::decode(&mut &hello_encoded[..]).unwrap();
|
||||
|
||||
assert_eq!(hello, hello_decoded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hello_encoding_length() {
|
||||
let secret_key = SecretKey::new(&mut rand::thread_rng());
|
||||
let id = pk2id(&secret_key.public_key(SECP256K1));
|
||||
let hello = P2PMessage::Hello(HelloMessage {
|
||||
protocol_version: ProtocolVersion::V5,
|
||||
client_version: "reth/0.1.0".to_string(),
|
||||
capabilities: vec![CapabilityMessage::new(
|
||||
"eth".to_string(),
|
||||
EthVersion::Eth67 as usize,
|
||||
)],
|
||||
port: 30303,
|
||||
id,
|
||||
});
|
||||
|
||||
let mut hello_encoded = Vec::new();
|
||||
hello.encode(&mut hello_encoded);
|
||||
|
||||
assert_eq!(hello_encoded.len(), hello.length());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hello_message_id_prefix() {
|
||||
// ensure that the hello message id is prefixed
|
||||
let secret_key = SecretKey::new(&mut rand::thread_rng());
|
||||
let id = pk2id(&secret_key.public_key(SECP256K1));
|
||||
let hello = P2PMessage::Hello(HelloMessage {
|
||||
protocol_version: ProtocolVersion::V5,
|
||||
client_version: "reth/0.1.0".to_string(),
|
||||
capabilities: vec![CapabilityMessage::new(
|
||||
"eth".to_string(),
|
||||
EthVersion::Eth67 as usize,
|
||||
)],
|
||||
port: 30303,
|
||||
id,
|
||||
});
|
||||
|
||||
let mut hello_encoded = Vec::new();
|
||||
hello.encode(&mut hello_encoded);
|
||||
|
||||
assert_eq!(hello_encoded[0], P2PMessageID::Hello as u8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn disconnect_round_trip() {
|
||||
let all_reasons = vec![
|
||||
DisconnectReason::DisconnectRequested,
|
||||
DisconnectReason::TcpSubsystemError,
|
||||
DisconnectReason::ProtocolBreach,
|
||||
DisconnectReason::UselessPeer,
|
||||
DisconnectReason::TooManyPeers,
|
||||
DisconnectReason::AlreadyConnected,
|
||||
DisconnectReason::IncompatibleP2PProtocolVersion,
|
||||
DisconnectReason::NullNodeIdentity,
|
||||
DisconnectReason::ClientQuitting,
|
||||
DisconnectReason::UnexpectedHandshakeIdentity,
|
||||
DisconnectReason::ConnectedToSelf,
|
||||
DisconnectReason::PingTimeout,
|
||||
DisconnectReason::SubprotocolSpecific,
|
||||
];
|
||||
|
||||
for reason in all_reasons {
|
||||
let disconnect = P2PMessage::Disconnect(reason);
|
||||
|
||||
let mut disconnect_encoded = Vec::new();
|
||||
disconnect.encode(&mut disconnect_encoded);
|
||||
|
||||
let disconnect_decoded = P2PMessage::decode(&mut &disconnect_encoded[..]).unwrap();
|
||||
|
||||
assert_eq!(disconnect, disconnect_decoded);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn disconnect_encoding_length() {
|
||||
let all_reasons = vec![
|
||||
DisconnectReason::DisconnectRequested,
|
||||
DisconnectReason::TcpSubsystemError,
|
||||
DisconnectReason::ProtocolBreach,
|
||||
DisconnectReason::UselessPeer,
|
||||
DisconnectReason::TooManyPeers,
|
||||
DisconnectReason::AlreadyConnected,
|
||||
DisconnectReason::IncompatibleP2PProtocolVersion,
|
||||
DisconnectReason::NullNodeIdentity,
|
||||
DisconnectReason::ClientQuitting,
|
||||
DisconnectReason::UnexpectedHandshakeIdentity,
|
||||
DisconnectReason::ConnectedToSelf,
|
||||
DisconnectReason::PingTimeout,
|
||||
DisconnectReason::SubprotocolSpecific,
|
||||
];
|
||||
|
||||
for reason in all_reasons {
|
||||
let disconnect = P2PMessage::Disconnect(reason);
|
||||
|
||||
let mut disconnect_encoded = Vec::new();
|
||||
disconnect.encode(&mut disconnect_encoded);
|
||||
|
||||
assert_eq!(disconnect_encoded.len(), disconnect.length());
|
||||
}
|
||||
}
|
||||
}
|
||||
581
crates/net/eth-wire/src/pinger.rs
Normal file
581
crates/net/eth-wire/src/pinger.rs
Normal file
@ -0,0 +1,581 @@
|
||||
use futures::{ready, StreamExt};
|
||||
use std::{
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::time::interval;
|
||||
use tokio_stream::{wrappers::IntervalStream, Stream};
|
||||
|
||||
use crate::error::PingerError;
|
||||
|
||||
/// This represents the possible states of the pinger.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
|
||||
pub(crate) enum PingState {
|
||||
/// There are no pings in flight, or all pings have been responded to and we are ready to send
|
||||
/// a ping at a later point.
|
||||
Ready,
|
||||
|
||||
/// We have sent a ping and are waiting for a pong, but the peer has missed n pongs.
|
||||
WaitingForPong(u8),
|
||||
|
||||
/// The peer has missed n pongs and is considered timed out.
|
||||
TimedOut(u8),
|
||||
}
|
||||
|
||||
/// The pinger is a state machine that is created with a maximum number of pongs that can be
|
||||
/// missed.
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Pinger {
|
||||
/// The maximum number of pongs that can be missed.
|
||||
max_missed: u8,
|
||||
|
||||
/// The current state of the pinger.
|
||||
state: PingState,
|
||||
}
|
||||
|
||||
impl Pinger {
|
||||
/// Create a new pinger with the given maximum number of pongs that can be missed.
|
||||
pub(crate) fn new(max_missed: u8) -> Self {
|
||||
Self { max_missed, state: PingState::Ready }
|
||||
}
|
||||
|
||||
/// Return the current state of the pinger.
|
||||
pub(crate) fn state(&self) -> &PingState {
|
||||
&self.state
|
||||
}
|
||||
|
||||
/// Check if the pinger is in the `Ready` state.
|
||||
pub(crate) fn is_ready(&self) -> bool {
|
||||
matches!(self.state, PingState::Ready)
|
||||
}
|
||||
|
||||
/// Check if the pinger is in the `WaitingForPong` state.
|
||||
pub(crate) fn is_waiting_for_pong(&self) -> bool {
|
||||
matches!(self.state, PingState::WaitingForPong(_))
|
||||
}
|
||||
|
||||
/// Check if the pinger is in the `TimedOut` state.
|
||||
pub(crate) fn is_timed_out(&self) -> bool {
|
||||
matches!(self.state, PingState::TimedOut(_))
|
||||
}
|
||||
|
||||
/// Transition the pinger to the `WaitingForPong` state if it was in the `Ready` state.
|
||||
///
|
||||
/// If the pinger is in the `WaitingForPong` state, the number of missed pongs will be
|
||||
/// incremented. If the number of missed pongs exceeds the maximum missed pongs allowed, the
|
||||
/// pinger will be transitioned to the `TimedOut` state.
|
||||
///
|
||||
/// If the pinger is in the `TimedOut` state, this method will return an error.
|
||||
pub(crate) fn next_state(&mut self) -> Result<(), PingerError> {
|
||||
match self.state {
|
||||
PingState::Ready => {
|
||||
self.state = PingState::WaitingForPong(0);
|
||||
Ok(())
|
||||
}
|
||||
PingState::WaitingForPong(missed) => {
|
||||
if missed + 1 >= self.max_missed {
|
||||
self.state = PingState::TimedOut(missed + 1);
|
||||
Ok(())
|
||||
} else {
|
||||
self.state = PingState::WaitingForPong(missed + 1);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
PingState::TimedOut(_) => Err(PingerError::PingWhileTimedOut),
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark a pong as received, and transition the pinger to the `Ready` state if it was in the
|
||||
/// `WaitingForPong` state.
|
||||
///
|
||||
/// If the pinger is in the `Ready` or `TimedOut` state, this method will return an error.
|
||||
pub(crate) fn pong_received(&mut self) -> Result<(), PingerError> {
|
||||
match self.state {
|
||||
PingState::Ready => Err(PingerError::PongWhileReady),
|
||||
PingState::WaitingForPong(_) => {
|
||||
self.state = PingState::Ready;
|
||||
Ok(())
|
||||
}
|
||||
PingState::TimedOut(_) => Err(PingerError::PongWhileTimedOut),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A Pinger that can be used as a `Stream`, which will emit
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct PingerStream {
|
||||
/// The pinger.
|
||||
pinger: Pinger,
|
||||
|
||||
/// Whether a `Timeout` event has already been sent.
|
||||
timeout_sent: bool,
|
||||
}
|
||||
|
||||
impl PingerStream {
|
||||
/// Poll the [`Pinger`] for a [`Option<PingEvent>`], which can be either a [`PingEvent::Ping`]
|
||||
/// or a final [`PingEvent::Timeout`] event, after which the stream will end and return
|
||||
/// None.
|
||||
pub(crate) fn poll(&mut self) -> Option<Result<PingerEvent, PingerError>> {
|
||||
// the stream has already sent a timeout event, so we return None
|
||||
if self.timeout_sent {
|
||||
return None
|
||||
}
|
||||
|
||||
match self.pinger.state {
|
||||
PingState::Ready => {
|
||||
// the pinger is ready, send a ping
|
||||
match self.pinger.next_state() {
|
||||
Ok(()) => Some(Ok(PingerEvent::Ping)),
|
||||
Err(e) => Some(Err(e)),
|
||||
}
|
||||
}
|
||||
PingState::WaitingForPong(_) => {
|
||||
// the peer has not timed out (yet), send another ping if the pinger does
|
||||
// not exceed the maximum number of missed pongs
|
||||
match self.pinger.next_state() {
|
||||
Ok(()) => {
|
||||
match self.pinger.state() {
|
||||
PingState::TimedOut(_) => {
|
||||
// the pinger has timed out, send a timeout event and end the
|
||||
// stream
|
||||
self.timeout_sent = true;
|
||||
Some(Ok(PingerEvent::Timeout))
|
||||
}
|
||||
_ => {
|
||||
// the pinger is still waiting for a pong, send another ping
|
||||
Some(Ok(PingerEvent::Ping))
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => Some(Err(e)),
|
||||
}
|
||||
}
|
||||
PingState::TimedOut(_) => {
|
||||
self.timeout_sent = true;
|
||||
Some(Ok(PingerEvent::Timeout))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for PingerStream {
|
||||
type Item = Result<PingerEvent, PingerError>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
if self.timeout_sent {
|
||||
return Poll::Ready(None)
|
||||
}
|
||||
|
||||
match self.pinger.state {
|
||||
PingState::Ready => {
|
||||
// the pinger is ready, send a ping
|
||||
self.pinger.next_state()?;
|
||||
Poll::Ready(Some(Ok(PingerEvent::Ping)))
|
||||
}
|
||||
PingState::WaitingForPong(_) => {
|
||||
// the peer has not timed out (yet), send another ping if the pinger does
|
||||
// not exceed the maximum number of missed pongs
|
||||
self.pinger.next_state()?;
|
||||
match self.pinger.state() {
|
||||
PingState::TimedOut(_) => {
|
||||
// the pinger has timed out, send a timeout event
|
||||
Poll::Ready(Some(Ok(PingerEvent::Timeout)))
|
||||
}
|
||||
_ => {
|
||||
// the pinger is still waiting for a pong, send another ping
|
||||
Poll::Ready(Some(Ok(PingerEvent::Ping)))
|
||||
}
|
||||
}
|
||||
}
|
||||
PingState::TimedOut(_) => {
|
||||
self.timeout_sent = true;
|
||||
Poll::Ready(Some(Ok(PingerEvent::Timeout)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The element type produced by a [`IntervalPingerStream`], representing either a new [`Ping`]
|
||||
/// message to send, or an indication that the peer should be timed out.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) enum PingerEvent {
|
||||
/// A new [`Ping`] message should be sent.
|
||||
Ping,
|
||||
|
||||
/// The peer should be timed out.
|
||||
Timeout,
|
||||
}
|
||||
|
||||
/// A type of [`Pinger`] that uses an interval and a timeout to determine when to send a ping and
|
||||
/// when to consider the peer timed out.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct IntervalTimeoutPinger {
|
||||
/// The interval pinger stream.
|
||||
interval_stream: IntervalStream,
|
||||
|
||||
/// The pinger stream we are using.
|
||||
pinger_stream: PingerStream,
|
||||
|
||||
/// The timeout duration for each ping.
|
||||
timeout: Duration,
|
||||
|
||||
/// The Interval that determines when to timeout the peer and send another ping.
|
||||
sleep: Option<IntervalStream>,
|
||||
}
|
||||
|
||||
impl IntervalTimeoutPinger {
|
||||
/// Creates a new [`IntervalTimeoutPinger`] with the given max missed pongs, interval duration,
|
||||
/// and timeout duration.
|
||||
pub(crate) fn new(
|
||||
max_missed: u8,
|
||||
interval_duration: Duration,
|
||||
timeout_duration: Duration,
|
||||
) -> Self {
|
||||
Self {
|
||||
interval_stream: IntervalStream::new(interval(interval_duration)),
|
||||
pinger_stream: PingerStream { pinger: Pinger::new(max_missed), timeout_sent: false },
|
||||
timeout: timeout_duration,
|
||||
sleep: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark a pong as received, and transition the pinger to the `Ready` state if it was in the
|
||||
/// `WaitingForPong` state. Unsets the sleep timer.
|
||||
pub(crate) fn pong_received(&mut self) -> Result<(), PingerError> {
|
||||
self.interval_stream.as_mut().reset();
|
||||
self.pinger_stream.pinger.pong_received()?;
|
||||
self.sleep = None;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Waits until the pinger sends a timeout event by exhausting the stream.
|
||||
pub(crate) async fn wait_for_timeout(&mut self) {
|
||||
while let Some(Ok(PingerEvent::Ping)) = self.next().await {}
|
||||
}
|
||||
|
||||
/// Returns the current state of the pinger.
|
||||
pub(crate) fn state(&self) -> &PingState {
|
||||
self.pinger_stream.pinger.state()
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for IntervalTimeoutPinger {
|
||||
type Item = Result<PingerEvent, PingerError>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
// if the pinger state is None, we should also return None regardless of the sleep or
|
||||
// interval state
|
||||
|
||||
// if we have a sleep timer, prefer that over the interval stream
|
||||
if let Some(inner_sleep) = this.sleep.as_mut() {
|
||||
// if the sleep is pending, we should return pending (we are waiting for a timeout)
|
||||
let pinned_sleep = Pin::new(inner_sleep);
|
||||
ready!(pinned_sleep.poll_next(cx));
|
||||
|
||||
// let's reset the interval, because the first one returns immediately when created
|
||||
// using `interval`
|
||||
let mut interval = interval(this.timeout);
|
||||
interval.reset();
|
||||
|
||||
// the sleep has elapsed, create a new sleep for the next timeout interval, then send a
|
||||
// new ping
|
||||
this.sleep = Some(IntervalStream::new(interval));
|
||||
|
||||
Pin::new(&mut this.pinger_stream).poll_next(cx)
|
||||
} else {
|
||||
// first poll the interval stream, if it is ready, send a ping
|
||||
let res = ready!(this.interval_stream.poll_next_unpin(cx));
|
||||
if res.is_none() {
|
||||
// this should never happen (the Stream impl of IntervalStream never is always Some)
|
||||
return Poll::Ready(None)
|
||||
}
|
||||
|
||||
let pinned_stream = Pin::new(&mut this.pinger_stream);
|
||||
let stream_res = ready!(pinned_stream.poll_next(cx));
|
||||
|
||||
// let's reset the interval, because the first one returns immediately when created
|
||||
// using `interval`
|
||||
let mut interval = interval(this.timeout);
|
||||
interval.reset();
|
||||
|
||||
this.sleep = Some(IntervalStream::new(interval));
|
||||
Poll::Ready(stream_res)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use tokio::select;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn send_many_pings() {
|
||||
// tests the simple pinger by sending many pings without pongs
|
||||
let mut pinger = Pinger::new(3);
|
||||
|
||||
pinger.next_state().unwrap();
|
||||
assert_eq!(*pinger.state(), PingState::WaitingForPong(0));
|
||||
|
||||
pinger.next_state().unwrap();
|
||||
assert_eq!(*pinger.state(), PingState::WaitingForPong(1));
|
||||
|
||||
pinger.next_state().unwrap();
|
||||
assert_eq!(*pinger.state(), PingState::WaitingForPong(2));
|
||||
|
||||
pinger.next_state().unwrap();
|
||||
assert_eq!(*pinger.state(), PingState::TimedOut(3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn send_many_pings_with_pongs() {
|
||||
// tests the simple pinger by sending many pings with pongs
|
||||
let mut pinger = Pinger::new(3);
|
||||
|
||||
pinger.next_state().unwrap();
|
||||
assert_eq!(*pinger.state(), PingState::WaitingForPong(0));
|
||||
|
||||
pinger.pong_received().unwrap();
|
||||
assert_eq!(*pinger.state(), PingState::Ready);
|
||||
|
||||
pinger.next_state().unwrap();
|
||||
assert_eq!(*pinger.state(), PingState::WaitingForPong(0));
|
||||
|
||||
pinger.pong_received().unwrap();
|
||||
assert_eq!(*pinger.state(), PingState::Ready);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn send_many_pings_stream() {
|
||||
let mut pinger_stream = PingerStream { pinger: Pinger::new(3), timeout_sent: false };
|
||||
|
||||
assert_eq!(pinger_stream.poll().unwrap().unwrap(), PingerEvent::Ping);
|
||||
assert_eq!(pinger_stream.poll().unwrap().unwrap(), PingerEvent::Ping);
|
||||
assert_eq!(pinger_stream.poll().unwrap().unwrap(), PingerEvent::Ping);
|
||||
assert_eq!(pinger_stream.poll().unwrap().unwrap(), PingerEvent::Timeout);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_many_pings_interval_timeout() {
|
||||
// we should wait for the interval to elapse, just like the interval-only version
|
||||
// TODO: should the timeout ever be less than the interval?
|
||||
let mut pinger =
|
||||
IntervalTimeoutPinger::new(3, Duration::from_millis(20), Duration::from_millis(10));
|
||||
|
||||
assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping);
|
||||
assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping);
|
||||
assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping);
|
||||
assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Timeout);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_many_pings_interval_timeout_with_pongs() {
|
||||
// we should wait for the interval to elapse and receive a pong before the timeout elapses
|
||||
|
||||
let mut pinger =
|
||||
IntervalTimeoutPinger::new(3, Duration::from_millis(20), Duration::from_millis(10));
|
||||
|
||||
assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping);
|
||||
assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping);
|
||||
|
||||
pinger.pong_received().unwrap();
|
||||
|
||||
assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping);
|
||||
assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping);
|
||||
assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping);
|
||||
assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Timeout);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn check_timing_over_interval() {
|
||||
// send pongs after a ping event, timing the interval between the two
|
||||
let mut pinger =
|
||||
IntervalTimeoutPinger::new(3, Duration::from_millis(20), Duration::from_millis(10));
|
||||
|
||||
assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping);
|
||||
pinger.pong_received().unwrap();
|
||||
|
||||
// wait for the interval to elapse, and compare it to the interval ping
|
||||
// to avoid flakiness let's do 25?
|
||||
let sleep = tokio::time::sleep(Duration::from_millis(25));
|
||||
let wait_for_timeout = pinger.next();
|
||||
|
||||
select! {
|
||||
_ = sleep => panic!("interval should have elapsed"),
|
||||
_ = wait_for_timeout => {}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn check_timing_under_interval() {
|
||||
// send pongs after a ping event, timing the interval between the two
|
||||
let mut pinger =
|
||||
IntervalTimeoutPinger::new(3, Duration::from_millis(20), Duration::from_millis(10));
|
||||
|
||||
assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping);
|
||||
pinger.pong_received().unwrap();
|
||||
|
||||
// wait for the interval to elapse, and compare it to the interval ping
|
||||
// to avoid flakiness let's do 15?
|
||||
let sleep = tokio::time::sleep(Duration::from_millis(15));
|
||||
let next_ping = pinger.next();
|
||||
|
||||
select! {
|
||||
_ = sleep => {}
|
||||
_ = next_ping => panic!("sleep should have elapsed first")
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn check_timing_before_timeout() {
|
||||
// send pongs after a ping event, timing the interval between the two
|
||||
let mut pinger =
|
||||
IntervalTimeoutPinger::new(3, Duration::from_millis(20), Duration::from_millis(10));
|
||||
|
||||
assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping);
|
||||
pinger.pong_received().unwrap();
|
||||
|
||||
// wait ~20ms for the next ping
|
||||
let next_ping = pinger.next().await.unwrap().unwrap();
|
||||
assert_eq!(next_ping, PingerEvent::Ping);
|
||||
|
||||
// ensure that a <10ms sleep completes first
|
||||
let sleep = tokio::time::sleep(Duration::from_millis(5));
|
||||
let next_ping = pinger.next();
|
||||
|
||||
select! {
|
||||
_ = sleep => {}
|
||||
_ = next_ping => panic!("sleep should have before re-sending a ping")
|
||||
}
|
||||
|
||||
// check that we are in the WaitingForPong(0) state (we should not have timed out the first
|
||||
// ping yet)
|
||||
let curr_state = *pinger.state();
|
||||
assert_eq!(curr_state, PingState::WaitingForPong(0));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn check_timing_after_timeout() {
|
||||
// send pongs after a ping event, timing the interval between the two
|
||||
let mut pinger =
|
||||
IntervalTimeoutPinger::new(3, Duration::from_millis(20), Duration::from_millis(10));
|
||||
|
||||
assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping);
|
||||
pinger.pong_received().unwrap();
|
||||
|
||||
// wait ~20ms for the next ping
|
||||
let next_ping = pinger.next().await.unwrap().unwrap();
|
||||
assert_eq!(next_ping, PingerEvent::Ping);
|
||||
|
||||
// ensure that the ping completes before a >10ms sleep
|
||||
let sleep = tokio::time::sleep(Duration::from_millis(15));
|
||||
let next_ping = pinger.next();
|
||||
|
||||
select! {
|
||||
_ = sleep => panic!("ping retry should have completed before sleep"),
|
||||
_ = next_ping => {}
|
||||
}
|
||||
|
||||
// check that we are in the WaitingForPong(1) state (we should have timed out the first
|
||||
// ping)
|
||||
let curr_state = *pinger.state();
|
||||
assert_eq!(curr_state, PingState::WaitingForPong(1));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn check_timing_after_second_timeout() {
|
||||
// send pongs after a ping event, timing the interval between the two
|
||||
let mut pinger =
|
||||
IntervalTimeoutPinger::new(3, Duration::from_millis(20), Duration::from_millis(10));
|
||||
|
||||
assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping);
|
||||
pinger.pong_received().unwrap();
|
||||
|
||||
// wait ~20ms for the next ping
|
||||
let next_ping = pinger.next().await.unwrap().unwrap();
|
||||
assert_eq!(next_ping, PingerEvent::Ping);
|
||||
|
||||
// wait another ~10ms for the next ping
|
||||
let next_ping = pinger.next().await.unwrap().unwrap();
|
||||
assert_eq!(next_ping, PingerEvent::Ping);
|
||||
|
||||
// ensure that the ping completes before a >10ms sleep
|
||||
let sleep = tokio::time::sleep(Duration::from_millis(15));
|
||||
let next_ping = pinger.next();
|
||||
|
||||
select! {
|
||||
_ = sleep => panic!("ping retry should have completed before sleep"),
|
||||
_ = next_ping => {}
|
||||
}
|
||||
|
||||
// check that we are in the WaitingForPong(2) state (we should have timed out the second
|
||||
// ping)
|
||||
let curr_state = *pinger.state();
|
||||
assert_eq!(curr_state, PingState::WaitingForPong(2));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn check_timing_after_last_timeout() {
|
||||
// send pongs after a ping event, timing the interval between the two
|
||||
let mut pinger =
|
||||
IntervalTimeoutPinger::new(3, Duration::from_millis(20), Duration::from_millis(10));
|
||||
|
||||
assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping);
|
||||
pinger.pong_received().unwrap();
|
||||
|
||||
// wait ~20ms for the next ping
|
||||
let next_ping = pinger.next().await.unwrap().unwrap();
|
||||
assert_eq!(next_ping, PingerEvent::Ping);
|
||||
|
||||
// wait another ~10ms for the next ping
|
||||
let next_ping = pinger.next().await.unwrap().unwrap();
|
||||
assert_eq!(next_ping, PingerEvent::Ping);
|
||||
|
||||
// wait another ~10ms for the last ping
|
||||
let next_ping = pinger.next().await.unwrap().unwrap();
|
||||
assert_eq!(next_ping, PingerEvent::Ping);
|
||||
|
||||
// ensure that the ping completes before a >10ms sleep
|
||||
let sleep = tokio::time::sleep(Duration::from_millis(15));
|
||||
let next_ping = pinger.next();
|
||||
|
||||
let ping_res = select! {
|
||||
_ = sleep => panic!("ping retry should have completed before sleep"),
|
||||
res = next_ping => {
|
||||
res.expect("stream should not be empty yet")
|
||||
}
|
||||
};
|
||||
|
||||
assert_eq!(ping_res.unwrap(), PingerEvent::Timeout);
|
||||
|
||||
// check that we are in the TimedOut(3) state (we should have timed out after the last ping)
|
||||
let curr_state = *pinger.state();
|
||||
assert_eq!(curr_state, PingState::TimedOut(3));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn timeout_with_pongs() {
|
||||
// we should wait for the interval to elapse and receive a pong before the timeout elapses
|
||||
let mut pinger =
|
||||
IntervalTimeoutPinger::new(3, Duration::from_millis(20), Duration::from_millis(10));
|
||||
|
||||
assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping);
|
||||
assert_eq!(pinger.next().await.unwrap().unwrap(), PingerEvent::Ping);
|
||||
|
||||
pinger.pong_received().unwrap();
|
||||
|
||||
// let's wait for the timeout to elapse (3 ping timeouts + interval + 10ms for flake
|
||||
// protection)
|
||||
let sleep = tokio::time::sleep(Duration::from_millis(60));
|
||||
let wait_for_timeout = pinger.wait_for_timeout();
|
||||
|
||||
select! {
|
||||
_ = sleep => panic!("timeout should have elapsed by now"),
|
||||
_ = wait_for_timeout => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3,7 +3,7 @@
|
||||
mod status;
|
||||
pub use status::Status;
|
||||
|
||||
mod version;
|
||||
pub mod version;
|
||||
pub use version::EthVersion;
|
||||
|
||||
pub mod forkid;
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
use std::str::FromStr;
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::p2pstream::CapabilityMessage;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Error)]
|
||||
#[error("Unknown eth protocol version: {0}")]
|
||||
pub struct ParseVersionError(String);
|
||||
@ -16,6 +18,19 @@ pub enum EthVersion {
|
||||
Eth67 = 67,
|
||||
}
|
||||
|
||||
impl EthVersion {
|
||||
/// Returns the total number of messages the protocol version supports.
|
||||
pub fn total_messages(&self) -> u8 {
|
||||
match self {
|
||||
EthVersion::Eth66 => 15,
|
||||
EthVersion::Eth67 => {
|
||||
// eth/67 is eth/66 minus GetNodeData and NodeData messages
|
||||
13
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Allow for converting from a `&str` to an `EthVersion`.
|
||||
///
|
||||
/// # Example
|
||||
@ -86,6 +101,13 @@ impl From<EthVersion> for &'static str {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EthVersion> for CapabilityMessage {
|
||||
#[inline]
|
||||
fn from(v: EthVersion) -> CapabilityMessage {
|
||||
CapabilityMessage { name: String::from("eth"), version: v as usize }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::{EthVersion, ParseVersionError};
|
||||
|
||||
Reference in New Issue
Block a user