Cap mux simple (#5577)

Signed-off-by: Emilia Hane <elsaemiliaevahane@gmail.com>
This commit is contained in:
Emilia Hane
2023-12-08 09:21:01 +01:00
committed by GitHub
parent 27da72cd57
commit cd4d6c52b0
12 changed files with 703 additions and 25 deletions

View File

@ -21,6 +21,7 @@ reth-metrics.workspace = true
metrics.workspace = true
bytes.workspace = true
derive_more = "0.99.17"
thiserror.workspace = true
serde = { workspace = true, optional = true }
tokio = { workspace = true, features = ["full"] }
@ -38,10 +39,12 @@ proptest = { workspace = true, optional = true }
proptest-derive = { workspace = true, optional = true }
[dev-dependencies]
reth-net-common.workspace = true
reth-primitives = { workspace = true, features = ["arbitrary"] }
reth-tracing.workspace = true
ethers-core = { workspace = true, default-features = false }
test-fuzz = "4"
tokio-util = { workspace = true, features = ["io", "codec"] }
rand.workspace = true

View File

@ -8,6 +8,7 @@ use crate::{
EthMessage, EthMessageID, EthVersion,
};
use alloy_rlp::{Decodable, Encodable, RlpDecodable, RlpEncodable};
use derive_more::{Deref, DerefMut};
use reth_codecs::add_arbitrary_tests;
use reth_primitives::bytes::{BufMut, Bytes};
#[cfg(feature = "serde")]
@ -249,14 +250,23 @@ pub enum SharedCapability {
/// This represents the message ID offset for the first message of the eth capability in
/// the message id space.
offset: u8,
/// The number of messages of this capability. Needed to calculate range of message IDs in
/// demuxing.
messages: u8,
},
}
impl SharedCapability {
/// Creates a new [`SharedCapability`] based on the given name, offset, and version.
/// Creates a new [`SharedCapability`] based on the given name, offset, version (and messages
/// if the capability is custom).
///
/// Returns an error if the offset is equal or less than [`MAX_RESERVED_MESSAGE_ID`].
pub(crate) fn new(name: &str, version: u8, offset: u8) -> Result<Self, SharedCapabilityError> {
pub(crate) fn new(
name: &str,
version: u8,
offset: u8,
messages: u8,
) -> Result<Self, SharedCapabilityError> {
if offset <= MAX_RESERVED_MESSAGE_ID {
return Err(SharedCapabilityError::ReservedMessageIdOffset(offset))
}
@ -266,6 +276,7 @@ impl SharedCapability {
_ => Ok(Self::UnknownCapability {
cap: Capability::new(name.to_string(), version as usize),
offset,
messages,
}),
}
}
@ -324,10 +335,10 @@ impl SharedCapability {
}
/// Returns the number of protocol messages supported by this capability.
pub fn num_messages(&self) -> Result<u8, SharedCapabilityError> {
pub fn num_messages(&self) -> u8 {
match self {
SharedCapability::Eth { version: _version, .. } => Ok(EthMessageID::max() + 1),
_ => Err(SharedCapabilityError::UnknownCapability),
SharedCapability::Eth { version: _version, .. } => EthMessageID::max() + 1,
SharedCapability::UnknownCapability { messages, .. } => *messages,
}
}
}
@ -335,7 +346,7 @@ impl SharedCapability {
/// Non-empty,ordered list of recognized shared capabilities.
///
/// Shared capabilities are ordered alphabetically by case sensitive name.
#[derive(Debug)]
#[derive(Debug, Clone, Deref, DerefMut, PartialEq, Eq)]
pub struct SharedCapabilities(Vec<SharedCapability>);
impl SharedCapabilities {
@ -500,9 +511,14 @@ pub fn shared_capability_offsets(
for name in shared_capability_names {
let proto_version = shared_capabilities.get(&name).expect("shared; qed");
let shared_capability = SharedCapability::new(&name, proto_version.version as u8, offset)?;
let shared_capability = SharedCapability::new(
&name,
proto_version.version as u8,
offset,
proto_version.messages,
)?;
offset += proto_version.messages;
offset += shared_capability.num_messages();
shared_with_offsets.push(shared_capability);
}
@ -519,9 +535,6 @@ pub enum SharedCapabilityError {
/// Unsupported `eth` version.
#[error(transparent)]
UnsupportedVersion(#[from] ParseVersionError),
/// Cannot determine the number of messages for unknown capabilities.
#[error("cannot determine the number of messages for unknown capabilities")]
UnknownCapability,
/// Thrown when the message id for a [SharedCapability] overlaps with the reserved p2p message
/// id space [`MAX_RESERVED_MESSAGE_ID`].
#[error("message id offset `{0}` is reserved")]
@ -541,7 +554,7 @@ mod tests {
#[test]
fn from_eth_68() {
let capability = SharedCapability::new("eth", 68, MAX_RESERVED_MESSAGE_ID + 1).unwrap();
let capability = SharedCapability::new("eth", 68, MAX_RESERVED_MESSAGE_ID + 1, 13).unwrap();
assert_eq!(capability.name(), "eth");
assert_eq!(capability.version(), 68);
@ -556,7 +569,7 @@ mod tests {
#[test]
fn from_eth_67() {
let capability = SharedCapability::new("eth", 67, MAX_RESERVED_MESSAGE_ID + 1).unwrap();
let capability = SharedCapability::new("eth", 67, MAX_RESERVED_MESSAGE_ID + 1, 13).unwrap();
assert_eq!(capability.name(), "eth");
assert_eq!(capability.version(), 67);
@ -571,7 +584,7 @@ mod tests {
#[test]
fn from_eth_66() {
let capability = SharedCapability::new("eth", 66, MAX_RESERVED_MESSAGE_ID + 1).unwrap();
let capability = SharedCapability::new("eth", 66, MAX_RESERVED_MESSAGE_ID + 1, 15).unwrap();
assert_eq!(capability.name(), "eth");
assert_eq!(capability.version(), 66);

View File

@ -150,7 +150,7 @@ impl Decodable for DisconnectReason {
/// lower-level disconnect functions (such as those that exist in the `p2p` protocol) if the
/// underlying stream supports it.
#[async_trait::async_trait]
pub trait CanDisconnect<T>: Sink<T> + Unpin + Sized {
pub trait CanDisconnect<T>: Sink<T> + Unpin {
/// Disconnects from the underlying stream, using a [`DisconnectReason`] as disconnect
/// information if the stream implements a protocol that can carry the additional disconnect
/// metadata.

View File

@ -1,6 +1,8 @@
//! Error handling for (`EthStream`)[crate::EthStream]
use crate::{
errors::P2PStreamError, version::ParseVersionError, DisconnectReason, EthMessageID, EthVersion,
errors::{MuxDemuxError, P2PStreamError},
version::ParseVersionError,
DisconnectReason, EthMessageID, EthVersion,
};
use reth_primitives::{Chain, GotExpected, GotExpectedBoxed, ValidationError, B256};
use std::io;
@ -13,6 +15,9 @@ pub enum EthStreamError {
/// Error of the underlying P2P connection.
P2PStreamError(#[from] P2PStreamError),
#[error(transparent)]
/// Error of the underlying de-/muxed P2P connection.
MuxDemuxError(#[from] MuxDemuxError),
#[error(transparent)]
/// Failed to parse peer's version.
ParseVersionError(#[from] ParseVersionError),
#[error(transparent)]
@ -43,6 +48,8 @@ impl EthStreamError {
pub fn as_disconnected(&self) -> Option<DisconnectReason> {
if let EthStreamError::P2PStreamError(err) = self {
err.as_disconnected()
} else if let EthStreamError::MuxDemuxError(MuxDemuxError::P2PStreamError(err)) = self {
err.as_disconnected()
} else {
None
}

View File

@ -1,7 +1,9 @@
//! Error types for stream variants
mod eth;
mod muxdemux;
mod p2p;
pub use eth::*;
pub use muxdemux::*;
pub use p2p::*;

View File

@ -0,0 +1,47 @@
use thiserror::Error;
use crate::capability::{SharedCapabilityError, UnsupportedCapabilityError};
use super::P2PStreamError;
/// Errors thrown by de-/muxing.
#[derive(Error, Debug)]
pub enum MuxDemuxError {
/// Error of the underlying P2P connection.
#[error(transparent)]
P2PStreamError(#[from] P2PStreamError),
/// Stream is in use by secondary stream impeding disconnect.
#[error("secondary streams are still running")]
StreamInUse,
/// Stream has already been set up for this capability stream type.
#[error("stream already init for stream type")]
StreamAlreadyExists,
/// Capability stream type is not shared with peer on underlying p2p connection.
#[error("stream type is not shared on this p2p connection")]
CapabilityNotShared,
/// Capability stream type has not been configured in [`crate::muxdemux::MuxDemuxer`].
#[error("stream type is not configured")]
CapabilityNotConfigured,
/// Capability stream type has not been configured for
/// [`crate::capability::SharedCapabilities`] type.
#[error("stream type is not recognized")]
CapabilityNotRecognized,
/// Message ID is out of range.
#[error("message id out of range, {0}")]
MessageIdOutOfRange(u8),
/// Demux channel failed.
#[error("sending demuxed bytes to secondary stream failed")]
SendIngressBytesFailed,
/// Mux channel failed.
#[error("sending bytes from secondary stream to mux failed")]
SendEgressBytesFailed,
/// Attempt to disconnect the p2p stream via a stream clone.
#[error("secondary stream cannot disconnect p2p stream")]
CannotDisconnectP2PStream,
/// Shared capability error.
#[error(transparent)]
SharedCapabilityError(#[from] SharedCapabilityError),
/// Capability not supported on the p2p connection.
#[error(transparent)]
UnsupportedCapabilityError(#[from] UnsupportedCapabilityError),
}

View File

@ -283,7 +283,7 @@ where
// start_disconnect, which would ideally be a part of the CanDisconnect trait, or at
// least similar.
//
// Other parts of reth do not need traits like CanDisconnect because they work
// Other parts of reth do not yet need traits like CanDisconnect because atm they work
// exclusively with EthStream<P2PStream<S>>, where the inner P2PStream is accessible,
// allowing for its start_disconnect method to be called.
//

View File

@ -21,6 +21,7 @@ pub mod errors;
mod ethstream;
mod hello;
pub mod multiplex;
pub mod muxdemux;
mod p2pstream;
mod pinger;
pub mod protocol;
@ -37,11 +38,14 @@ pub use tokio_util::codec::{
};
pub use crate::{
capability::Capability,
disconnect::{CanDisconnect, DisconnectReason},
ethstream::{EthStream, UnauthedEthStream, MAX_MESSAGE_SIZE},
hello::{HelloMessage, HelloMessageBuilder, HelloMessageWithProtocols},
muxdemux::{MuxDemuxStream, StreamClone},
p2pstream::{
P2PMessage, P2PMessageID, P2PStream, ProtocolVersion, UnauthedP2PStream,
DisconnectP2P, P2PMessage, P2PMessageID, P2PStream, ProtocolVersion, UnauthedP2PStream,
MAX_RESERVED_MESSAGE_ID,
},
types::EthVersion,
};

View File

@ -0,0 +1,592 @@
//! [`MuxDemuxer`] allows for multiple capability streams to share the same p2p connection. De-/
//! muxing the connection offers two stream types [`MuxDemuxStream`] and [`StreamClone`].
//! [`MuxDemuxStream`] is the main stream that wraps the p2p connection, only this stream can
//! advance transfer across the network. One [`MuxDemuxStream`] can have many [`StreamClone`]s,
//! these are weak clones of the stream and depend on advancing the [`MuxDemuxStream`] to make
//! progress.
//!
//! [`MuxDemuxer`] filters bytes according to message ID offset. The message ID offset is
//! negotiated upon start of the p2p connection. Bytes received by polling the [`MuxDemuxStream`]
//! or a [`StreamClone`] are specific to the capability stream wrapping it. When received the
//! message IDs are unmasked so that all message IDs start at 0x0. [`MuxDemuxStream`] and
//! [`StreamClone`] mask message IDs before sinking bytes to the [`MuxDemuxer`].
//!
//! For example, `EthStream<MuxDemuxStream<P2PStream<S>>>` is the main capability stream.
//! Subsequent capability streams clone the p2p connection via EthStream.
//!
//! When [`MuxDemuxStream`] is polled, [`MuxDemuxer`] receives bytes from the network. If these
//! bytes belong to the capability stream wrapping the [`MuxDemuxStream`] then they are passed up
//! directly. If these bytes however belong to another capability stream, then they are buffered
//! on a channel. When [`StreamClone`] is polled, bytes are read from this buffer. Similarly
//! [`StreamClone`] buffers egress bytes for [`MuxDemuxer`] that are read and sent to the network
//! when [`MuxDemuxStream`] is polled.
use std::{
collections::HashMap,
pin::Pin,
task::{ready, Context, Poll},
};
use derive_more::{Deref, DerefMut};
use futures::{Sink, SinkExt, StreamExt};
use reth_primitives::bytes::{Bytes, BytesMut};
use tokio::sync::mpsc;
use tokio_stream::Stream;
use crate::{
capability::{Capability, SharedCapabilities, SharedCapability},
errors::MuxDemuxError,
CanDisconnect, DisconnectP2P, DisconnectReason,
};
use MuxDemuxError::*;
/// Stream MUX/DEMUX acts like a regular stream and sink for the owning stream, and handles bytes
/// belonging to other streams over their respective channels.
#[derive(Debug)]
pub struct MuxDemuxer<S> {
// receive and send muxed p2p outputs
inner: S,
// owner of the stream. stores message id offset for this capability.
owner: SharedCapability,
// receive muxed p2p inputs from stream clones
mux: mpsc::UnboundedReceiver<Bytes>,
// send demuxed p2p outputs to app
demux: HashMap<SharedCapability, mpsc::UnboundedSender<BytesMut>>,
// sender to mux stored to make new stream clones
mux_tx: mpsc::UnboundedSender<Bytes>,
// capabilities supported by underlying p2p stream (makes testing easier to store here too).
shared_capabilities: SharedCapabilities,
}
/// The main stream on top of the p2p stream. Wraps [`MuxDemuxer`] and enforces it can't be dropped
/// before all secondary streams are dropped (stream clones).
#[derive(Debug, Deref, DerefMut)]
pub struct MuxDemuxStream<S>(MuxDemuxer<S>);
impl<S> MuxDemuxStream<S> {
/// Creates a new [`MuxDemuxer`].
pub fn try_new(
inner: S,
cap: Capability,
shared_capabilities: SharedCapabilities,
) -> Result<Self, MuxDemuxError> {
let owner = Self::shared_cap(&cap, &shared_capabilities)?.clone();
let demux = HashMap::new();
let (mux_tx, mux) = mpsc::unbounded_channel();
Ok(Self(MuxDemuxer { inner, owner, mux, demux, mux_tx, shared_capabilities }))
}
/// Clones the stream if the given capability stream type is shared on the underlying p2p
/// connection.
pub fn try_clone_stream(&mut self, cap: &Capability) -> Result<StreamClone, MuxDemuxError> {
let cap = self.shared_capabilities.ensure_matching_capability(cap)?.clone();
let ingress = self.reg_new_ingress_buffer(&cap)?;
let mux_tx = self.mux_tx.clone();
Ok(StreamClone { stream: ingress, sink: mux_tx, cap })
}
/// Starts a graceful disconnect.
pub fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), MuxDemuxError>
where
S: DisconnectP2P,
{
if !self.can_drop() {
return Err(StreamInUse)
}
self.inner.start_disconnect(reason).map_err(|e| e.into())
}
/// Returns `true` if the connection is about to disconnect.
pub fn is_disconnecting(&self) -> bool
where
S: DisconnectP2P,
{
self.inner.is_disconnecting()
}
/// Shared capabilities of underlying p2p connection as negotiated by peers at connection
/// open.
pub fn shared_capabilities(&self) -> &SharedCapabilities {
&self.shared_capabilities
}
fn shared_cap<'a>(
cap: &Capability,
shared_capabilities: &'a SharedCapabilities,
) -> Result<&'a SharedCapability, MuxDemuxError> {
for shared_cap in shared_capabilities.iter_caps() {
match shared_cap {
SharedCapability::Eth { .. } if cap.is_eth() => return Ok(shared_cap),
SharedCapability::UnknownCapability { cap: unknown_cap, .. }
if cap == unknown_cap =>
{
return Ok(shared_cap)
}
_ => continue,
}
}
Err(CapabilityNotShared)
}
fn reg_new_ingress_buffer(
&mut self,
cap: &SharedCapability,
) -> Result<mpsc::UnboundedReceiver<BytesMut>, MuxDemuxError> {
if let Some(tx) = self.demux.get(cap) {
if !tx.is_closed() {
return Err(StreamAlreadyExists)
}
}
let (ingress_tx, ingress) = mpsc::unbounded_channel();
self.demux.insert(cap.clone(), ingress_tx);
Ok(ingress)
}
fn unmask_msg_id(&self, id: &mut u8) -> Result<&SharedCapability, MuxDemuxError> {
for cap in self.shared_capabilities.iter_caps() {
let offset = cap.relative_message_id_offset();
let next_offset = offset + cap.num_messages();
if *id < next_offset {
*id -= offset;
return Ok(cap)
}
}
Err(MessageIdOutOfRange(*id))
}
/// Masks message id with offset relative to the message id suffix reserved for capability
/// message ids. The p2p stream further masks the message id (todo: mask whole message id at
/// once to avoid copying message to mutate id byte or sink BytesMut).
fn mask_msg_id(&self, msg: Bytes) -> Bytes {
let mut masked_bytes = BytesMut::zeroed(msg.len());
masked_bytes[0] = msg[0] + self.owner.relative_message_id_offset();
masked_bytes[1..].copy_from_slice(&msg[1..]);
masked_bytes.freeze()
}
/// Checks if all clones of this shared stream have been dropped, if true then returns //
/// function to drop the stream.
fn can_drop(&mut self) -> bool {
for tx in self.demux.values() {
if !tx.is_closed() {
return false
}
}
true
}
}
impl<S, E> Stream for MuxDemuxStream<S>
where
S: Stream<Item = Result<BytesMut, E>> + CanDisconnect<Bytes> + Unpin,
MuxDemuxError: From<E> + From<<S as Sink<Bytes>>::Error>,
{
type Item = Result<BytesMut, MuxDemuxError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut send_count = 0;
let mut mux_exhausted = false;
loop {
// send buffered bytes from `StreamClone`s. try send at least as many messages as
// there are stream clones.
if self.inner.poll_ready_unpin(cx).is_ready() {
if let Poll::Ready(Some(item)) = self.mux.poll_recv(cx) {
self.inner.start_send_unpin(item)?;
if send_count < self.demux.len() {
send_count += 1;
continue
}
} else {
mux_exhausted = true;
}
}
// advances the wire and either yields message for the owner or delegates message to a
// stream clone
let res = self.inner.poll_next_unpin(cx);
if res.is_pending() {
// no message is received. continue to send messages from stream clones as long as
// there are messages left to send.
if !mux_exhausted && self.inner.poll_ready_unpin(cx).is_ready() {
continue
}
// flush before returning pending
_ = self.inner.poll_flush_unpin(cx)?;
}
let mut bytes = match ready!(res) {
Some(Ok(bytes)) => bytes,
Some(Err(err)) => {
_ = self.inner.poll_flush_unpin(cx)?;
return Poll::Ready(Some(Err(err.into())))
}
None => {
_ = self.inner.poll_flush_unpin(cx)?;
return Poll::Ready(None)
}
};
// normalize message id suffix for capability
let cap = self.unmask_msg_id(&mut bytes[0])?;
// yield message for main stream
if *cap == self.owner {
_ = self.inner.poll_flush_unpin(cx)?;
return Poll::Ready(Some(Ok(bytes)))
}
// delegate message for stream clone
let tx = self.demux.get(cap).ok_or(CapabilityNotConfigured)?;
tx.send(bytes).map_err(|_| SendIngressBytesFailed)?;
}
}
}
impl<S, E> Sink<Bytes> for MuxDemuxStream<S>
where
S: Sink<Bytes, Error = E> + CanDisconnect<Bytes> + Unpin,
MuxDemuxError: From<E>,
{
type Error = MuxDemuxError;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready_unpin(cx).map_err(Into::into)
}
fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
let item = self.mask_msg_id(item);
self.inner.start_send_unpin(item).map_err(|e| e.into())
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_flush_unpin(cx).map_err(Into::into)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
while let Ok(item) = self.mux.try_recv() {
self.inner.start_send_unpin(item)?;
}
_ = self.inner.poll_flush_unpin(cx)?;
self.inner.poll_close_unpin(cx).map_err(Into::into)
}
}
#[async_trait::async_trait]
impl<S, E> CanDisconnect<Bytes> for MuxDemuxStream<S>
where
S: Sink<Bytes, Error = E> + CanDisconnect<Bytes> + Unpin + Send + Sync,
MuxDemuxError: From<E>,
{
async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), MuxDemuxError> {
if self.can_drop() {
return self.inner.disconnect(reason).await.map_err(Into::into)
}
Err(StreamInUse)
}
}
/// More or less a weak clone of the stream wrapped in [`MuxDemuxer`] but the bytes belonging to
/// other capabilities have been filtered out.
#[derive(Debug)]
pub struct StreamClone {
// receive bytes from de-/muxer
stream: mpsc::UnboundedReceiver<BytesMut>,
// send bytes to de-/muxer
sink: mpsc::UnboundedSender<Bytes>,
// message id offset for capability holding this clone
cap: SharedCapability,
}
impl StreamClone {
fn mask_msg_id(&self, msg: Bytes) -> Bytes {
let mut masked_bytes = BytesMut::zeroed(msg.len());
masked_bytes[0] = msg[0] + self.cap.relative_message_id_offset();
masked_bytes[1..].copy_from_slice(&msg[1..]);
masked_bytes.freeze()
}
}
impl Stream for StreamClone {
type Item = BytesMut;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.stream.poll_recv(cx)
}
}
impl Sink<Bytes> for StreamClone {
type Error = MuxDemuxError;
fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
let item = self.mask_msg_id(item);
self.sink.send(item).map_err(|_| SendEgressBytesFailed)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
}
#[async_trait::async_trait]
impl CanDisconnect<Bytes> for StreamClone {
async fn disconnect(&mut self, _reason: DisconnectReason) -> Result<(), MuxDemuxError> {
Err(CannotDisconnectP2PStream)
}
}
#[cfg(test)]
mod test {
use std::{net::SocketAddr, pin::Pin};
use futures::{Future, SinkExt, StreamExt};
use reth_ecies::util::pk2id;
use reth_primitives::{
bytes::{BufMut, Bytes, BytesMut},
ForkFilter, Hardfork, MAINNET,
};
use secp256k1::{SecretKey, SECP256K1};
use tokio::{
net::{TcpListener, TcpStream},
task::JoinHandle,
};
use tokio_util::codec::{Decoder, Framed, LengthDelimitedCodec};
use crate::{
capability::{Capability, SharedCapabilities},
muxdemux::MuxDemuxStream,
protocol::Protocol,
EthVersion, HelloMessageWithProtocols, Status, StatusBuilder, StreamClone,
UnauthedEthStream, UnauthedP2PStream,
};
const ETH_68_CAP: Capability = Capability::eth(EthVersion::Eth68);
const ETH_68_PROTOCOL: Protocol = Protocol::new(ETH_68_CAP, 13);
const CUSTOM_CAP: Capability = Capability::new_static("snap", 1);
const CUSTOM_CAP_PROTOCOL: Protocol = Protocol::new(CUSTOM_CAP, 10);
// message IDs `0x00` and `0x01` are normalized for the custom protocol stream
const CUSTOM_REQUEST: [u8; 5] = [0x00, 0x00, 0x01, 0x0, 0xc0];
const CUSTOM_RESPONSE: [u8; 5] = [0x01, 0x00, 0x01, 0x0, 0xc0];
fn shared_caps_eth68() -> SharedCapabilities {
let local_capabilities: Vec<Protocol> = vec![ETH_68_PROTOCOL];
let peer_capabilities: Vec<Capability> = vec![ETH_68_CAP];
SharedCapabilities::try_new(local_capabilities, peer_capabilities).unwrap()
}
fn shared_caps_eth68_and_custom() -> SharedCapabilities {
let local_capabilities: Vec<Protocol> = vec![ETH_68_PROTOCOL, CUSTOM_CAP_PROTOCOL];
let peer_capabilities: Vec<Capability> = vec![ETH_68_CAP, CUSTOM_CAP];
SharedCapabilities::try_new(local_capabilities, peer_capabilities).unwrap()
}
struct ConnectionBuilder {
local_addr: SocketAddr,
local_hello: HelloMessageWithProtocols,
status: Status,
fork_filter: ForkFilter,
}
impl ConnectionBuilder {
fn new() -> Self {
let (_secret_key, pk) = SECP256K1.generate_keypair(&mut rand::thread_rng());
let hello = HelloMessageWithProtocols::builder(pk2id(&pk))
.protocol(ETH_68_PROTOCOL)
.protocol(CUSTOM_CAP_PROTOCOL)
.build();
let local_addr = "127.0.0.1:30303".parse().unwrap();
Self {
local_hello: hello,
local_addr,
status: StatusBuilder::default().build(),
fork_filter: MAINNET
.hardfork_fork_filter(Hardfork::Frontier)
.expect("The Frontier fork filter should exist on mainnet"),
}
}
/// Connects a custom sub protocol stream and executes the given closure with that
/// established stream (main stream is eth).
fn with_connect_custom_protocol<F, G>(
self,
f_local: F,
f_remote: G,
) -> (JoinHandle<BytesMut>, JoinHandle<BytesMut>)
where
F: FnOnce(StreamClone) -> Pin<Box<(dyn Future<Output = BytesMut> + Send)>>
+ Send
+ Sync
+ Send
+ 'static,
G: FnOnce(StreamClone) -> Pin<Box<(dyn Future<Output = BytesMut> + Send)>>
+ Send
+ Sync
+ Send
+ 'static,
{
let local_addr = self.local_addr;
let local_hello = self.local_hello.clone();
let status = self.status;
let fork_filter = self.fork_filter.clone();
let local_handle = tokio::spawn(async move {
let local_listener = TcpListener::bind(local_addr).await.unwrap();
let (incoming, _) = local_listener.accept().await.unwrap();
let stream = crate::PassthroughCodec::default().framed(incoming);
let protocol_proxy =
connect_protocol(stream, local_hello, status, fork_filter).await;
f_local(protocol_proxy).await
});
let remote_key = SecretKey::new(&mut rand::thread_rng());
let remote_id = pk2id(&remote_key.public_key(SECP256K1));
let mut remote_hello = self.local_hello.clone();
remote_hello.id = remote_id;
let fork_filter = self.fork_filter.clone();
let remote_handle = tokio::spawn(async move {
let outgoing = TcpStream::connect(local_addr).await.unwrap();
let stream = crate::PassthroughCodec::default().framed(outgoing);
let protocol_proxy =
connect_protocol(stream, remote_hello, status, fork_filter).await;
f_remote(protocol_proxy).await
});
(local_handle, remote_handle)
}
}
async fn connect_protocol(
stream: Framed<TcpStream, LengthDelimitedCodec>,
hello: HelloMessageWithProtocols,
status: Status,
fork_filter: ForkFilter,
) -> StreamClone {
let unauthed_stream = UnauthedP2PStream::new(stream);
let (p2p_stream, _) = unauthed_stream.handshake(hello).await.unwrap();
// ensure that the two share capabilities
assert_eq!(*p2p_stream.shared_capabilities(), shared_caps_eth68_and_custom(),);
let shared_caps = p2p_stream.shared_capabilities().clone();
let main_cap = shared_caps.eth().unwrap();
let proxy_server =
MuxDemuxStream::try_new(p2p_stream, main_cap.capability().into_owned(), shared_caps)
.expect("should start mxdmx stream");
let (mut main_stream, _) =
UnauthedEthStream::new(proxy_server).handshake(status, fork_filter).await.unwrap();
let protocol_proxy =
main_stream.inner_mut().try_clone_stream(&CUSTOM_CAP).expect("should clone stream");
tokio::spawn(async move {
loop {
_ = main_stream.next().await.unwrap()
}
});
protocol_proxy
}
#[test]
fn test_unmask_msg_id() {
let mut msg = BytesMut::with_capacity(1);
msg.put_u8(0x07); // eth msg id
let mxdmx_stream =
MuxDemuxStream::try_new((), Capability::eth(EthVersion::Eth67), shared_caps_eth68())
.unwrap();
_ = mxdmx_stream.unmask_msg_id(&mut msg[0]).unwrap();
assert_eq!(msg.as_ref(), &[0x07]);
}
#[test]
fn test_mask_msg_id() {
let mut msg = BytesMut::with_capacity(2);
msg.put_u8(0x10); // eth msg id
msg.put_u8(0x20); // some msg data
let mxdmx_stream =
MuxDemuxStream::try_new((), Capability::eth(EthVersion::Eth66), shared_caps_eth68())
.unwrap();
let egress_bytes = mxdmx_stream.mask_msg_id(msg.freeze());
assert_eq!(egress_bytes.as_ref(), &[0x10, 0x20]);
}
#[test]
fn test_unmask_msg_id_cap_not_in_shared_range() {
let mut msg = BytesMut::with_capacity(1);
msg.put_u8(0x11);
let mxdmx_stream =
MuxDemuxStream::try_new((), Capability::eth(EthVersion::Eth68), shared_caps_eth68())
.unwrap();
assert!(mxdmx_stream.unmask_msg_id(&mut msg[0]).is_err());
}
#[tokio::test(flavor = "multi_thread")]
async fn test_mux_demux() {
let builder = ConnectionBuilder::new();
let request = Bytes::from(&CUSTOM_REQUEST[..]);
let response = Bytes::from(&CUSTOM_RESPONSE[..]);
let expected_request = request.clone();
let expected_response = response.clone();
let (local_handle, remote_handle) = builder.with_connect_custom_protocol(
// send request from local addr
|mut protocol_proxy| {
Box::pin(async move {
protocol_proxy.send(request).await.unwrap();
protocol_proxy.next().await.unwrap()
})
},
// respond from remote addr
|mut protocol_proxy| {
Box::pin(async move {
let request = protocol_proxy.next().await.unwrap();
protocol_proxy.send(response).await.unwrap();
request
})
},
);
let (local_res, remote_res) = tokio::join!(local_handle, remote_handle);
// remote address receives request
assert_eq!(expected_request, remote_res.unwrap().freeze());
// local address receives response
assert_eq!(expected_response, local_res.unwrap().freeze());
}
}

View File

@ -301,11 +301,6 @@ impl<S> P2PStream<S> {
&self.shared_capabilities
}
/// Returns `true` if the connection is about to disconnect.
pub fn is_disconnecting(&self) -> bool {
self.disconnecting
}
/// Returns `true` if the stream has outgoing capacity.
fn has_outgoing_capacity(&self) -> bool {
self.outgoing_messages.len() < self.outgoing_message_buffer_capacity
@ -326,7 +321,16 @@ impl<S> P2PStream<S> {
ping.encode(&mut ping_bytes);
self.outgoing_messages.push_back(ping_bytes.freeze());
}
}
pub trait DisconnectP2P {
/// Starts to gracefully disconnect.
fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError>;
/// Returns `true` if the connection is about to disconnect.
fn is_disconnecting(&self) -> bool;
}
impl<S> DisconnectP2P for P2PStream<S> {
/// Starts to gracefully disconnect the connection by sending a Disconnect message and stop
/// reading new messages.
///
@ -335,7 +339,7 @@ impl<S> P2PStream<S> {
/// # Errors
///
/// Returns an error only if the message fails to compress.
pub fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), snap::Error> {
fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
// clear any buffered messages and queue in
self.outgoing_messages.clear();
let disconnect = P2PMessage::Disconnect(reason);
@ -365,6 +369,10 @@ impl<S> P2PStream<S> {
self.disconnecting = true;
Ok(())
}
fn is_disconnecting(&self) -> bool {
self.disconnecting
}
}
impl<S> P2PStream<S>