diff --git a/Cargo.lock b/Cargo.lock index 831dcf51b..af3eb7ed7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -539,6 +539,28 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "async-stream" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.39", +] + [[package]] name = "async-trait" version = "0.1.74" @@ -5889,6 +5911,7 @@ version = "0.1.0-alpha.13" dependencies = [ "alloy-rlp", "arbitrary", + "async-stream", "async-trait", "bytes", "derive_more", diff --git a/crates/net/eth-wire/Cargo.toml b/crates/net/eth-wire/Cargo.toml index 88af65762..29e6adcf0 100644 --- a/crates/net/eth-wire/Cargo.toml +++ b/crates/net/eth-wire/Cargo.toml @@ -53,6 +53,7 @@ secp256k1 = { workspace = true, features = ["global-context", "rand-std", "recov arbitrary = { workspace = true, features = ["derive"] } proptest.workspace = true proptest-derive.workspace = true +async-stream = "0.3" [features] default = ["serde"] diff --git a/crates/net/eth-wire/src/capability.rs b/crates/net/eth-wire/src/capability.rs index 8b64aa16a..5f799b2ac 100644 --- a/crates/net/eth-wire/src/capability.rs +++ b/crates/net/eth-wire/src/capability.rs @@ -317,6 +317,14 @@ impl SharedCapability { } } + /// Returns the eth version if it's the `eth` capability. + pub fn eth_version(&self) -> Option { + match self { + SharedCapability::Eth { version, .. } => Some(*version), + _ => None, + } + } + /// Returns the message ID offset of the current capability. /// /// This represents the message ID offset for the first message of the eth capability in the @@ -375,8 +383,8 @@ impl SharedCapabilities { /// Returns the negotiated eth version if it is shared. #[inline] - pub fn eth_version(&self) -> Result { - self.eth().map(|cap| cap.version()) + pub fn eth_version(&self) -> Result { + self.eth().map(|cap| cap.eth_version().expect("is eth; qed")) } /// Returns true if the shared capabilities contain the given capability. @@ -438,6 +446,18 @@ impl SharedCapabilities { ) -> Result<&SharedCapability, UnsupportedCapabilityError> { self.find(cap).ok_or_else(|| UnsupportedCapabilityError { capability: cap.clone() }) } + + /// Returns the number of shared capabilities. + #[inline] + pub fn len(&self) -> usize { + self.0.len() + } + + /// Returns true if there are no shared capabilities. + #[inline] + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } } /// Determines the offsets for each shared capability between the input list of peer diff --git a/crates/net/eth-wire/src/ethstream.rs b/crates/net/eth-wire/src/ethstream.rs index f1162f7ee..35ee8d276 100644 --- a/crates/net/eth-wire/src/ethstream.rs +++ b/crates/net/eth-wire/src/ethstream.rs @@ -166,6 +166,7 @@ where #[pin_project] #[derive(Debug)] pub struct EthStream { + /// Negotiated eth version. version: EthVersion, #[pin] inner: S, @@ -174,26 +175,31 @@ pub struct EthStream { impl EthStream { /// Creates a new unauthed [`EthStream`] from a provided stream. You will need /// to manually handshake a peer. + #[inline] pub fn new(version: EthVersion, inner: S) -> Self { Self { version, inner } } /// Returns the eth version. + #[inline] pub fn version(&self) -> EthVersion { self.version } /// Returns the underlying stream. + #[inline] pub fn inner(&self) -> &S { &self.inner } /// Returns mutable access to the underlying stream. + #[inline] pub fn inner_mut(&mut self) -> &mut S { &mut self.inner } /// Consumes this type and returns the wrapped stream. + #[inline] pub fn into_inner(self) -> S { self.inner } diff --git a/crates/net/eth-wire/src/hello.rs b/crates/net/eth-wire/src/hello.rs index f1e684865..e992021f4 100644 --- a/crates/net/eth-wire/src/hello.rs +++ b/crates/net/eth-wire/src/hello.rs @@ -49,6 +49,7 @@ impl HelloMessageWithProtocols { } /// Returns the raw [HelloMessage] without the additional protocol information. + #[inline] pub fn message(&self) -> HelloMessage { HelloMessage { protocol_version: self.protocol_version, @@ -69,6 +70,25 @@ impl HelloMessageWithProtocols { id: self.id, } } + + /// Returns true if the set of protocols contains the given protocol. + #[inline] + pub fn contains_protocol(&self, protocol: &Protocol) -> bool { + self.protocols.iter().any(|p| p.cap == protocol.cap) + } + + /// Adds a new protocol to the set. + /// + /// Returns an error if the protocol already exists. + #[inline] + pub fn try_add_protocol(&mut self, protocol: Protocol) -> Result<(), Protocol> { + if self.contains_protocol(&protocol) { + Err(protocol) + } else { + self.protocols.push(protocol); + Ok(()) + } + } } // TODO: determine if we should allow for the extra fields at the end like EIP-706 suggests diff --git a/crates/net/eth-wire/src/multiplex.rs b/crates/net/eth-wire/src/multiplex.rs index f76be4fbe..0cb9cd62f 100644 --- a/crates/net/eth-wire/src/multiplex.rs +++ b/crates/net/eth-wire/src/multiplex.rs @@ -16,45 +16,89 @@ use std::{ task::{ready, Context, Poll}, }; -use bytes::{Bytes, BytesMut}; -use futures::{pin_mut, Sink, SinkExt, Stream, StreamExt, TryStream, TryStreamExt}; -use tokio::sync::{mpsc, mpsc::UnboundedSender}; -use tokio_stream::wrappers::UnboundedReceiverStream; - use crate::{ capability::{Capability, SharedCapabilities, SharedCapability, UnsupportedCapabilityError}, - errors::P2PStreamError, - CanDisconnect, DisconnectReason, P2PStream, + errors::{EthStreamError, P2PStreamError}, + CanDisconnect, DisconnectReason, EthStream, P2PStream, Status, UnauthedEthStream, }; +use bytes::{Bytes, BytesMut}; +use futures::{pin_mut, Sink, SinkExt, Stream, StreamExt, TryStream, TryStreamExt}; +use reth_primitives::ForkFilter; +use tokio::sync::{mpsc, mpsc::UnboundedSender}; +use tokio_stream::wrappers::UnboundedReceiverStream; /// A Stream and Sink type that wraps a raw rlpx stream [P2PStream] and handles message ID /// multiplexing. #[derive(Debug)] pub struct RlpxProtocolMultiplexer { - /// The raw p2p stream - conn: P2PStream, - /// All the subprotocols that are multiplexed on top of the raw p2p stream - protocols: Vec, + inner: MultiplexInner, } impl RlpxProtocolMultiplexer { /// Wraps the raw p2p stream pub fn new(conn: P2PStream) -> Self { - Self { conn, protocols: Default::default() } + Self { + inner: MultiplexInner { + conn, + protocols: Default::default(), + out_buffer: Default::default(), + }, + } } - /// Installs a new protocol on top of the raw p2p stream - pub fn install_protocol( + /// Installs a new protocol on top of the raw p2p stream. + /// + /// This accepts a closure that receives a [ProtocolConnection] that will yield messages for the + /// given capability. + pub fn install_protocol( &mut self, - _cap: Capability, - _st: S, - ) -> Result<(), UnsupportedCapabilityError> { - todo!() + cap: &Capability, + f: F, + ) -> Result<(), UnsupportedCapabilityError> + where + F: FnOnce(ProtocolConnection) -> Proto, + Proto: Stream + Send + 'static, + { + self.inner.install_protocol(cap, f) } /// Returns the [SharedCapabilities] of the underlying raw p2p stream pub fn shared_capabilities(&self) -> &SharedCapabilities { - self.conn.shared_capabilities() + self.inner.shared_capabilities() + } + + /// Converts this multiplexer into a [RlpxSatelliteStream] with the given primary protocol. + pub fn into_satellite_stream( + self, + cap: &Capability, + primary: F, + ) -> Result, P2PStreamError> + where + F: FnOnce(ProtocolProxy) -> Primary, + { + let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned() + else { + return Err(P2PStreamError::CapabilityNotShared) + }; + + let (to_primary, from_wire) = mpsc::unbounded_channel(); + let (to_wire, from_primary) = mpsc::unbounded_channel(); + let proxy = ProtocolProxy { + shared_cap: shared_cap.clone(), + from_wire: UnboundedReceiverStream::new(from_wire), + to_wire, + }; + + let st = primary(proxy); + Ok(RlpxSatelliteStream { + inner: self.inner, + primary: PrimaryProtocol { + to_primary, + from_primary: UnboundedReceiverStream::new(from_primary), + st, + shared_cap, + }, + }) } /// Converts this multiplexer into a [RlpxSatelliteStream] with the given primary protocol. @@ -62,7 +106,7 @@ impl RlpxProtocolMultiplexer { /// Returns an error if the primary protocol is not supported by the remote or the handshake /// failed. pub async fn into_satellite_stream_with_handshake( - mut self, + self, cap: &Capability, handshake: F, ) -> Result, Err> @@ -71,6 +115,34 @@ impl RlpxProtocolMultiplexer { Fut: Future>, St: Stream> + Sink + Unpin, P2PStreamError: Into, + { + self.into_satellite_stream_with_tuple_handshake(cap, move |proxy| async move { + let st = handshake(proxy).await?; + Ok((st, ())) + }) + .await + .map(|(st, _)| st) + } + + /// Converts this multiplexer into a [RlpxSatelliteStream] with the given primary protocol. + /// + /// Returns an error if the primary protocol is not supported by the remote or the handshake + /// failed. + /// + /// This accepts a closure that does a handshake with the remote peer and returns a tuple of the + /// primary stream and extra data. + /// + /// See also [UnauthedEthStream::handshake] + pub async fn into_satellite_stream_with_tuple_handshake( + mut self, + cap: &Capability, + handshake: F, + ) -> Result<(RlpxSatelliteStream, Extra), Err> + where + F: FnOnce(ProtocolProxy) -> Fut, + Fut: Future>, + St: Stream> + Sink + Unpin, + P2PStreamError: Into, { let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned() else { @@ -80,7 +152,7 @@ impl RlpxProtocolMultiplexer { let (to_primary, from_wire) = mpsc::unbounded_channel(); let (to_wire, mut from_primary) = mpsc::unbounded_channel(); let proxy = ProtocolProxy { - cap: shared_cap.clone(), + shared_cap: shared_cap.clone(), from_wire: UnboundedReceiverStream::new(from_wire), to_wire, }; @@ -92,45 +164,118 @@ impl RlpxProtocolMultiplexer { // complete loop { tokio::select! { - Some(Ok(msg)) = self.conn.next() => { + Some(Ok(msg)) = self.inner.conn.next() => { // Ensure the message belongs to the primary protocol let offset = msg[0]; - if let Some(cap) = self.conn.shared_capabilities().find_by_relative_offset(offset) { - if cap == &shared_cap { + if let Some(cap) = self.shared_capabilities().find_by_relative_offset(offset).cloned() { + if cap == shared_cap { // delegate to primary let _ = to_primary.send(msg); } else { // delegate to satellite - for proto in &self.protocols { - if proto.cap == *cap { - // TODO: need some form of backpressure here so buffering can't be abused - proto.send_raw(msg); - break - } - } + self.inner.delegate_message(&cap, msg); } } else { return Err(P2PStreamError::UnknownReservedMessageId(offset).into()) } } Some(msg) = from_primary.recv() => { - self.conn.send(msg).await.map_err(Into::into)?; + self.inner.conn.send(msg).await.map_err(Into::into)?; } res = &mut f => { - let primary = res?; - return Ok(RlpxSatelliteStream { - conn: self.conn, - to_primary, - from_primary: UnboundedReceiverStream::new(from_primary), - primary, - primary_capability: shared_cap, - satellites: self.protocols, - out_buffer: Default::default(), - }) + let (st, extra) = res?; + return Ok((RlpxSatelliteStream { + inner: self.inner, + primary: PrimaryProtocol { + to_primary, + from_primary: UnboundedReceiverStream::new(from_primary), + st, + shared_cap, + } + }, extra)) } } } } + + /// Converts this multiplexer into a [RlpxSatelliteStream] with eth protocol as the given + /// primary protocol. + pub async fn into_eth_satellite_stream( + self, + status: Status, + fork_filter: ForkFilter, + ) -> Result<(RlpxSatelliteStream>, Status), EthStreamError> + where + St: Stream> + Sink + Unpin, + { + let eth_cap = self.inner.conn.shared_capabilities().eth_version()?; + self.into_satellite_stream_with_tuple_handshake( + &Capability::eth(eth_cap), + move |proxy| async move { + UnauthedEthStream::new(proxy).handshake(status, fork_filter).await + }, + ) + .await + } +} + +#[derive(Debug)] +struct MultiplexInner { + /// The raw p2p stream + conn: P2PStream, + /// All the subprotocols that are multiplexed on top of the raw p2p stream + protocols: Vec, + /// Buffer for outgoing messages on the wire. + out_buffer: VecDeque, +} + +impl MultiplexInner { + fn shared_capabilities(&self) -> &SharedCapabilities { + self.conn.shared_capabilities() + } + + /// Delegates a message to the matching protocol. + fn delegate_message(&mut self, cap: &SharedCapability, msg: BytesMut) -> bool { + for proto in &self.protocols { + if proto.shared_cap == *cap { + proto.send_raw(msg); + return true + } + } + false + } + + fn install_protocol( + &mut self, + cap: &Capability, + f: F, + ) -> Result<(), UnsupportedCapabilityError> + where + F: FnOnce(ProtocolConnection) -> Proto, + Proto: Stream + Send + 'static, + { + let shared_cap = + self.conn.shared_capabilities().ensure_matching_capability(cap).cloned()?; + let (to_satellite, rx) = mpsc::unbounded_channel(); + let proto_conn = ProtocolConnection { from_wire: UnboundedReceiverStream::new(rx) }; + let st = f(proto_conn); + let st = ProtocolStream { shared_cap, to_satellite, satellite_st: Box::pin(st) }; + self.protocols.push(st); + Ok(()) + } +} + +/// Represents a protocol in the multiplexer that is used as the primary protocol. +#[derive(Debug)] +struct PrimaryProtocol { + /// Channel to send messages to the primary protocol. + to_primary: UnboundedSender, + /// Receiver for messages from the primary protocol. + from_primary: UnboundedReceiverStream, + /// Shared capability of the primary protocol. + shared_cap: SharedCapability, + /// The primary stream. + st: Primary, } /// A Stream and Sink type that acts as a wrapper around a primary RLPx subprotocol (e.g. "eth") @@ -138,7 +283,7 @@ impl RlpxProtocolMultiplexer { /// Only emits and sends _non-empty_ messages #[derive(Debug)] pub struct ProtocolProxy { - cap: SharedCapability, + shared_cap: SharedCapability, /// Receives _non-empty_ messages from the wire from_wire: UnboundedReceiverStream, /// Sends _non-empty_ messages from the wire @@ -163,7 +308,7 @@ impl ProtocolProxy { #[inline] fn mask_msg_id(&self, msg: Bytes) -> Bytes { let mut masked_bytes = BytesMut::zeroed(msg.len()); - masked_bytes[0] = msg[0] + self.cap.relative_message_id_offset(); + masked_bytes[0] = msg[0] + self.shared_cap.relative_message_id_offset(); masked_bytes[1..].copy_from_slice(&msg[1..]); masked_bytes.freeze() } @@ -175,7 +320,7 @@ impl ProtocolProxy { /// If the message is empty. #[inline] fn unmask_id(&self, mut msg: BytesMut) -> BytesMut { - msg[0] -= self.cap.relative_message_id_offset(); + msg[0] -= self.shared_cap.relative_message_id_offset(); msg } } @@ -237,20 +382,60 @@ impl Stream for ProtocolConnection { } /// A Stream and Sink type that acts as a wrapper around a primary RLPx subprotocol (e.g. "eth") -/// [EthStream](crate::EthStream) and can also handle additional subprotocols. +/// [EthStream] and can also handle additional subprotocols. #[derive(Debug)] pub struct RlpxSatelliteStream { - /// The raw p2p stream - conn: P2PStream, - to_primary: UnboundedSender, - from_primary: UnboundedReceiverStream, - primary: Primary, - primary_capability: SharedCapability, - satellites: Vec, - out_buffer: VecDeque, + inner: MultiplexInner, + primary: PrimaryProtocol, } -impl RlpxSatelliteStream {} +impl RlpxSatelliteStream { + /// Installs a new protocol on top of the raw p2p stream. + /// + /// This accepts a closure that receives a [ProtocolConnection] that will yield messages for the + /// given capability. + pub fn install_protocol( + &mut self, + cap: &Capability, + f: F, + ) -> Result<(), UnsupportedCapabilityError> + where + F: FnOnce(ProtocolConnection) -> Proto, + Proto: Stream + Send + 'static, + { + self.inner.install_protocol(cap, f) + } + + /// Returns the primary protocol. + #[inline] + pub fn primary(&self) -> &Primary { + &self.primary.st + } + + /// Returns mutable access to the primary protocol. + #[inline] + pub fn primary_mut(&mut self) -> &mut Primary { + &mut self.primary.st + } + + /// Returns the underlying [P2PStream]. + #[inline] + pub fn inner(&self) -> &P2PStream { + &self.inner.conn + } + + /// Returns mutable access to the underlying [P2PStream]. + #[inline] + pub fn inner_mut(&mut self) -> &mut P2PStream { + &mut self.inner.conn + } + + /// Consumes this type and returns the wrapped [P2PStream]. + #[inline] + pub fn into_inner(self) -> P2PStream { + self.inner.conn + } +} impl Stream for RlpxSatelliteStream where @@ -265,16 +450,16 @@ where loop { // first drain the primary stream - if let Poll::Ready(Some(msg)) = this.primary.try_poll_next_unpin(cx) { + if let Poll::Ready(Some(msg)) = this.primary.st.try_poll_next_unpin(cx) { return Poll::Ready(Some(msg)) } - let mut out_ready = true; + let mut conn_ready = true; loop { - match this.conn.poll_ready_unpin(cx) { + match this.inner.conn.poll_ready_unpin(cx) { Poll::Ready(_) => { - if let Some(msg) = this.out_buffer.pop_front() { - if let Err(err) = this.conn.start_send_unpin(msg) { + if let Some(msg) = this.inner.out_buffer.pop_front() { + if let Err(err) = this.inner.conn.start_send_unpin(msg) { return Poll::Ready(Some(Err(err.into()))) } } else { @@ -282,7 +467,7 @@ where } } Poll::Pending => { - out_ready = false; + conn_ready = false; break } } @@ -290,9 +475,9 @@ where // advance primary out loop { - match this.from_primary.poll_next_unpin(cx) { + match this.primary.from_primary.poll_next_unpin(cx) { Poll::Ready(Some(msg)) => { - this.out_buffer.push_back(msg); + this.inner.out_buffer.push_back(msg); } Poll::Ready(None) => { // primary closed @@ -303,16 +488,16 @@ where } // advance all satellites - for idx in (0..this.satellites.len()).rev() { - let mut proto = this.satellites.swap_remove(idx); + for idx in (0..this.inner.protocols.len()).rev() { + let mut proto = this.inner.protocols.swap_remove(idx); loop { match proto.poll_next_unpin(cx) { Poll::Ready(Some(msg)) => { - this.out_buffer.push_back(msg); + this.inner.out_buffer.push_back(msg); } Poll::Ready(None) => return Poll::Ready(None), Poll::Pending => { - this.satellites.push(proto); + this.inner.protocols.push(proto); break } } @@ -322,21 +507,21 @@ where let mut delegated = false; loop { // pull messages from connection - match this.conn.poll_next_unpin(cx) { + match this.inner.conn.poll_next_unpin(cx) { Poll::Ready(Some(Ok(msg))) => { delegated = true; let offset = msg[0]; // delegate the multiplexed message to the correct protocol if let Some(cap) = - this.conn.shared_capabilities().find_by_relative_offset(offset) + this.inner.conn.shared_capabilities().find_by_relative_offset(offset) { - if cap == &this.primary_capability { + if cap == &this.primary.shared_cap { // delegate to primary - let _ = this.to_primary.send(msg); + let _ = this.primary.to_primary.send(msg); } else { - // delegate to satellite - for proto in &this.satellites { - if proto.cap == *cap { + // delegate to installed satellite if any + for proto in &this.inner.protocols { + if proto.shared_cap == *cap { proto.send_raw(msg); break } @@ -358,7 +543,7 @@ where } } - if !delegated || !out_ready || this.out_buffer.is_empty() { + if !conn_ready || (!delegated && this.inner.out_buffer.is_empty()) { return Poll::Pending } } @@ -368,41 +553,41 @@ where impl Sink for RlpxSatelliteStream where St: Stream> + Sink + Unpin, - Primary: Sink + Unpin, + Primary: Sink + Unpin, P2PStreamError: Into<>::Error>, { type Error = >::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); - if let Err(err) = ready!(this.conn.poll_ready_unpin(cx)) { + if let Err(err) = ready!(this.inner.conn.poll_ready_unpin(cx)) { return Poll::Ready(Err(err.into())) } - if let Err(err) = ready!(this.primary.poll_ready_unpin(cx)) { + if let Err(err) = ready!(this.primary.st.poll_ready_unpin(cx)) { return Poll::Ready(Err(err)) } Poll::Ready(Ok(())) } fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { - self.get_mut().primary.start_send_unpin(item) + self.get_mut().primary.st.start_send_unpin(item) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.get_mut().conn.poll_flush_unpin(cx).map_err(Into::into) + self.get_mut().inner.conn.poll_flush_unpin(cx).map_err(Into::into) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.get_mut().conn.poll_close_unpin(cx).map_err(Into::into) + self.get_mut().inner.conn.poll_close_unpin(cx).map_err(Into::into) } } /// Wraps a RLPx subprotocol and handles message ID multiplexing. struct ProtocolStream { - cap: SharedCapability, + shared_cap: SharedCapability, /// the channel shared with the satellite stream to_satellite: UnboundedSender, - satellite_st: Pin>>, + satellite_st: Pin + Send>>, } impl ProtocolStream { @@ -413,7 +598,7 @@ impl ProtocolStream { /// If the message is empty. #[inline] fn mask_msg_id(&self, mut msg: BytesMut) -> Bytes { - msg[0] += self.cap.relative_message_id_offset(); + msg[0] += self.shared_cap.relative_message_id_offset(); msg.freeze() } @@ -424,7 +609,7 @@ impl ProtocolStream { /// If the message is empty. #[inline] fn unmask_id(&self, mut msg: BytesMut) -> BytesMut { - msg[0] -= self.cap.relative_message_id_offset(); + msg[0] -= self.shared_cap.relative_message_id_offset(); msg } @@ -446,7 +631,7 @@ impl Stream for ProtocolStream { impl fmt::Debug for ProtocolStream { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("ProtocolStream").field("cap", &self.cap).finish_non_exhaustive() + f.debug_struct("ProtocolStream").field("cap", &self.shared_cap).finish_non_exhaustive() } } @@ -454,10 +639,13 @@ impl fmt::Debug for ProtocolStream { mod tests { use super::*; use crate::{ - test_utils::{connect_passthrough, eth_handshake, eth_hello}, + test_utils::{ + connect_passthrough, eth_handshake, eth_hello, + proto::{test_hello, TestProtoMessage}, + }, UnauthedEthStream, UnauthedP2PStream, }; - use tokio::net::TcpListener; + use tokio::{net::TcpListener, sync::oneshot}; use tokio_util::codec::Decoder; #[tokio::test] @@ -487,7 +675,6 @@ mod tests { let eth = conn.shared_capabilities().eth().unwrap().clone(); let multiplexer = RlpxProtocolMultiplexer::new(conn); - let _satellite = multiplexer .into_satellite_stream_with_handshake( eth.capability().as_ref(), @@ -498,4 +685,94 @@ mod tests { .await .unwrap(); } + + /// A test that install a satellite stream eth+test protocol and sends messages between them. + #[tokio::test(flavor = "multi_thread")] + async fn eth_test_protocol_satellite() { + reth_tracing::init_test_tracing(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let local_addr = listener.local_addr().unwrap(); + let (status, fork_filter) = eth_handshake(); + let other_status = status; + let other_fork_filter = fork_filter.clone(); + let _handle = tokio::spawn(async move { + let (incoming, _) = listener.accept().await.unwrap(); + let stream = crate::PassthroughCodec::default().framed(incoming); + let (server_hello, _) = test_hello(); + let (conn, _) = UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap(); + + let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn) + .into_eth_satellite_stream(other_status, other_fork_filter) + .await + .unwrap(); + + st.install_protocol(&TestProtoMessage::capability(), |mut conn| { + async_stream::stream! { + yield TestProtoMessage::ping().encoded(); + let msg = conn.next().await.unwrap(); + let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap(); + assert_eq!(msg, TestProtoMessage::pong()); + + yield TestProtoMessage::message("hello").encoded(); + let msg = conn.next().await.unwrap(); + let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap(); + assert_eq!(msg, TestProtoMessage::message("good bye!")); + + yield TestProtoMessage::message("good bye!").encoded(); + + futures::future::pending::<()>().await; + unreachable!() + } + }) + .unwrap(); + + loop { + let _ = st.next().await; + } + }); + + let conn = connect_passthrough(local_addr, test_hello().0).await; + let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn) + .into_eth_satellite_stream(status, fork_filter) + .await + .unwrap(); + + let (tx, mut rx) = oneshot::channel(); + + st.install_protocol(&TestProtoMessage::capability(), |mut conn| { + async_stream::stream! { + let msg = conn.next().await.unwrap(); + let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap(); + assert_eq!(msg, TestProtoMessage::ping()); + + yield TestProtoMessage::pong().encoded(); + + let msg = conn.next().await.unwrap(); + let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap(); + assert_eq!(msg, TestProtoMessage::message("hello")); + + yield TestProtoMessage::message("good bye!").encoded(); + + let msg = conn.next().await.unwrap(); + let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap(); + assert_eq!(msg, TestProtoMessage::message("good bye!")); + + tx.send(()).unwrap(); + + futures::future::pending::<()>().await; + unreachable!() + } + }) + .unwrap(); + + loop { + tokio::select! { + _ = &mut rx => { + break + } + _ = st.next() => { + } + } + } + } } diff --git a/crates/net/eth-wire/src/test_utils.rs b/crates/net/eth-wire/src/test_utils.rs index 01bd9a048..ceaa8206c 100644 --- a/crates/net/eth-wire/src/test_utils.rs +++ b/crates/net/eth-wire/src/test_utils.rs @@ -55,3 +55,106 @@ pub async fn connect_passthrough( p2p_stream } + +/// A Rplx subprotocol for testing +pub mod proto { + use super::*; + use crate::{capability::Capability, protocol::Protocol}; + use bytes::{Buf, BufMut, BytesMut}; + + /// Returns a new testing `HelloMessage` with eth and the test protocol + pub fn test_hello() -> (HelloMessageWithProtocols, SecretKey) { + let mut handshake = eth_hello(); + handshake.0.protocols.push(TestProtoMessage::protocol()); + handshake + } + + #[repr(u8)] + #[derive(Clone, Copy, Debug, PartialEq, Eq)] + pub enum TestProtoMessageId { + Ping = 0x00, + Pong = 0x01, + Message = 0x02, + } + + #[derive(Clone, Debug, PartialEq, Eq)] + pub enum TestProtoMessageKind { + Message(String), + Ping, + Pong, + } + + /// An `test` protocol message, containing a message ID and payload. + #[derive(Clone, Debug, PartialEq, Eq)] + pub struct TestProtoMessage { + pub message_type: TestProtoMessageId, + pub message: TestProtoMessageKind, + } + + impl TestProtoMessage { + /// Returns the capability for the `test` protocol. + pub fn capability() -> Capability { + Capability::new_static("test", 1) + } + + /// Returns the protocol for the `test` protocol. + pub fn protocol() -> Protocol { + Protocol::new(Self::capability(), 3) + } + + /// Creates a ping message + pub fn ping() -> Self { + Self { message_type: TestProtoMessageId::Ping, message: TestProtoMessageKind::Ping } + } + + /// Creates a pong message + pub fn pong() -> Self { + Self { message_type: TestProtoMessageId::Pong, message: TestProtoMessageKind::Pong } + } + + /// Creates a message + pub fn message(msg: impl Into) -> Self { + Self { + message_type: TestProtoMessageId::Message, + message: TestProtoMessageKind::Message(msg.into()), + } + } + + /// Creates a new `TestProtoMessage` with the given message ID and payload. + pub fn encoded(&self) -> BytesMut { + let mut buf = BytesMut::new(); + buf.put_u8(self.message_type as u8); + match &self.message { + TestProtoMessageKind::Ping => {} + TestProtoMessageKind::Pong => {} + TestProtoMessageKind::Message(msg) => { + buf.put(msg.as_bytes()); + } + } + buf + } + + /// Decodes a `TestProtoMessage` from the given message buffer. + pub fn decode_message(buf: &mut &[u8]) -> Option { + if buf.is_empty() { + return None; + } + let id = buf[0]; + buf.advance(1); + let message_type = match id { + 0x00 => TestProtoMessageId::Ping, + 0x01 => TestProtoMessageId::Pong, + 0x02 => TestProtoMessageId::Message, + _ => return None, + }; + let message = match message_type { + TestProtoMessageId::Ping => TestProtoMessageKind::Ping, + TestProtoMessageId::Pong => TestProtoMessageKind::Pong, + TestProtoMessageId::Message => { + TestProtoMessageKind::Message(String::from_utf8_lossy(&buf[..]).into_owned()) + } + }; + Some(Self { message_type, message }) + } + } +} diff --git a/crates/net/eth-wire/src/types/status.rs b/crates/net/eth-wire/src/types/status.rs index 5f8044944..0925b7d97 100644 --- a/crates/net/eth-wire/src/types/status.rs +++ b/crates/net/eth-wire/src/types/status.rs @@ -66,6 +66,11 @@ impl Status { Default::default() } + /// Sets the [EthVersion] for the status. + pub fn set_eth_version(&mut self, version: EthVersion) { + self.version = version as u8; + } + /// Create a [`StatusBuilder`] from the given [`ChainSpec`] and head block. /// /// Sets the `chain` and `genesis`, `blockhash`, and `forkid` fields based on the [`ChainSpec`] diff --git a/crates/net/network/src/protocol.rs b/crates/net/network/src/protocol.rs index adcfb75f2..cc9ed51c9 100644 --- a/crates/net/network/src/protocol.rs +++ b/crates/net/network/src/protocol.rs @@ -9,7 +9,12 @@ use reth_eth_wire::{ use reth_network_api::Direction; use reth_primitives::BytesMut; use reth_rpc_types::PeerId; -use std::{fmt, net::SocketAddr, pin::Pin}; +use std::{ + fmt, + net::SocketAddr, + ops::{Deref, DerefMut}, + pin::Pin, +}; /// A trait that allows to offer additional RLPx-based application-level protocols when establishing /// a peer-to-peer connection. @@ -113,6 +118,57 @@ impl RlpxSubProtocols { pub fn push(&mut self, protocol: impl IntoRlpxSubProtocol) { self.protocols.push(protocol.into_rlpx_sub_protocol()); } + + /// Returns all additional protocol handlers that should be announced to the remote during the + /// Rlpx handshake on an incoming connection. + pub(crate) fn on_incoming(&self, socket_addr: SocketAddr) -> RlpxSubProtocolHandlers { + RlpxSubProtocolHandlers( + self.protocols + .iter() + .filter_map(|protocol| protocol.0.on_incoming(socket_addr)) + .collect(), + ) + } + + /// Returns all additional protocol handlers that should be announced to the remote during the + /// Rlpx handshake on an outgoing connection. + pub(crate) fn on_outgoing( + &self, + socket_addr: SocketAddr, + peer_id: PeerId, + ) -> RlpxSubProtocolHandlers { + RlpxSubProtocolHandlers( + self.protocols + .iter() + .filter_map(|protocol| protocol.0.on_outgoing(socket_addr, peer_id)) + .collect(), + ) + } +} + +/// A set of additional RLPx-based sub-protocol connection handlers. +#[derive(Default)] +pub(crate) struct RlpxSubProtocolHandlers(Vec>); + +impl RlpxSubProtocolHandlers { + /// Returns all handlers. + pub(crate) fn into_iter(self) -> impl Iterator> { + self.0.into_iter() + } +} + +impl Deref for RlpxSubProtocolHandlers { + type Target = Vec>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for RlpxSubProtocolHandlers { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } } pub(crate) trait DynProtocolHandler: fmt::Debug + Send + Sync + 'static { @@ -156,7 +212,7 @@ pub(crate) trait DynConnectionHandler: Send + Sync + 'static { ) -> OnNotSupported; fn into_connection( - self, + self: Box, direction: Direction, peer_id: PeerId, conn: ProtocolConnection, @@ -181,11 +237,11 @@ where } fn into_connection( - self, + self: Box, direction: Direction, peer_id: PeerId, conn: ProtocolConnection, ) -> Pin + Send + 'static>> { - Box::pin(T::into_connection(self, direction, peer_id, conn)) + Box::pin(T::into_connection(*self, direction, peer_id, conn)) } } diff --git a/crates/net/network/src/session/active.rs b/crates/net/network/src/session/active.rs index 25e53a194..a8424713c 100644 --- a/crates/net/network/src/session/active.rs +++ b/crates/net/network/src/session/active.rs @@ -4,6 +4,7 @@ use crate::{ message::{NewBlockMessage, PeerMessage, PeerRequest, PeerResponse, PeerResponseResult}, session::{ config::INITIAL_REQUEST_TIMEOUT, + conn::EthRlpxConnection, handle::{ActiveSessionMessage, SessionCommand}, SessionId, }, @@ -11,16 +12,16 @@ use crate::{ use core::sync::atomic::Ordering; use fnv::FnvHashMap; use futures::{stream::Fuse, SinkExt, StreamExt}; -use reth_ecies::stream::ECIESStream; + use reth_eth_wire::{ capability::Capabilities, errors::{EthHandshakeError, EthStreamError, P2PStreamError}, message::{EthBroadcastMessage, RequestPair}, - DisconnectP2P, DisconnectReason, EthMessage, EthStream, P2PStream, + DisconnectP2P, DisconnectReason, EthMessage, }; use reth_interfaces::p2p::error::RequestError; use reth_metrics::common::mpsc::MeteredPollSender; -use reth_net_common::bandwidth_meter::MeteredStream; + use reth_primitives::PeerId; use std::{ collections::VecDeque, @@ -32,7 +33,6 @@ use std::{ time::{Duration, Instant}, }; use tokio::{ - net::TcpStream, sync::{mpsc::error::TrySendError, oneshot}, time::Interval, }; @@ -51,11 +51,6 @@ const SAMPLE_IMPACT: f64 = 0.1; /// Amount of RTTs before timeout const TIMEOUT_SCALING: u32 = 3; -/// The type of the underlying peer network connection. -// This type is boxed because the underlying stream is ~6KB, -// mostly coming from `P2PStream`'s `snap::Encoder` (2072), and `ECIESStream` (3600). -pub type PeerConnection = Box>>>>; - /// The type that advances an established session by listening for incoming messages (from local /// node or read from connection) and emitting events back to the /// [`SessionManager`](super::SessionManager). @@ -70,7 +65,7 @@ pub(crate) struct ActiveSession { /// Keeps track of request ids. pub(crate) next_id: u64, /// The underlying connection. - pub(crate) conn: PeerConnection, + pub(crate) conn: EthRlpxConnection, /// Identifier of the node we're connected to. pub(crate) remote_peer_id: PeerId, /// The address we're connected to. @@ -771,16 +766,19 @@ mod tests { handle::PendingSessionEvent, start_pending_incoming_session, }; - use reth_ecies::util::pk2id; + use reth_ecies::{stream::ECIESStream, util::pk2id}; use reth_eth_wire::{ - GetBlockBodies, HelloMessageWithProtocols, Status, StatusBuilder, UnauthedEthStream, - UnauthedP2PStream, + EthStream, GetBlockBodies, HelloMessageWithProtocols, P2PStream, Status, StatusBuilder, + UnauthedEthStream, UnauthedP2PStream, }; - use reth_net_common::bandwidth_meter::BandwidthMeter; + use reth_net_common::bandwidth_meter::{BandwidthMeter, MeteredStream}; use reth_primitives::{ForkFilter, Hardfork, MAINNET}; use secp256k1::{SecretKey, SECP256K1}; use std::time::Duration; - use tokio::{net::TcpListener, sync::mpsc}; + use tokio::{ + net::{TcpListener, TcpStream}, + sync::mpsc, + }; /// Returns a testing `HelloMessage` and new secretkey fn eth_hello(server_key: &SecretKey) -> HelloMessageWithProtocols { @@ -856,6 +854,7 @@ mod tests { self.hello.clone(), self.status, self.fork_filter.clone(), + Default::default(), )); let mut stream = ReceiverStream::new(pending_sessions_rx); diff --git a/crates/net/network/src/session/conn.rs b/crates/net/network/src/session/conn.rs new file mode 100644 index 000000000..a94b3406e --- /dev/null +++ b/crates/net/network/src/session/conn.rs @@ -0,0 +1,156 @@ +//! Connection types for a session + +use futures::{Sink, Stream}; +use reth_ecies::stream::ECIESStream; +use reth_eth_wire::{ + errors::EthStreamError, + message::EthBroadcastMessage, + multiplex::{ProtocolProxy, RlpxSatelliteStream}, + EthMessage, EthStream, EthVersion, P2PStream, +}; +use reth_net_common::bandwidth_meter::MeteredStream; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; +use tokio::net::TcpStream; + +/// The type of the underlying peer network connection. +pub type EthPeerConnection = EthStream>>>; + +/// Various connection types that at least support the ETH protocol. +pub type EthSatelliteConnection = + RlpxSatelliteStream>, EthStream>; + +/// Connection types that support the ETH protocol. +/// +/// Either a [`EthPeerConnection`] or an [`EthSatelliteConnection`]. +// This type is boxed because the underlying stream is ~6KB, +// mostly coming from `P2PStream`'s `snap::Encoder` (2072), and `ECIESStream` (3600). +#[derive(Debug)] +pub enum EthRlpxConnection { + /// A That only supports the ETH protocol. + EthOnly(Box), + /// A connection that supports the ETH protocol and __at least one other__ RLPx protocol. + Satellite(Box), +} + +impl EthRlpxConnection { + /// Returns the negotiated ETH version. + #[inline] + pub(crate) fn version(&self) -> EthVersion { + match self { + Self::EthOnly(conn) => conn.version(), + Self::Satellite(conn) => conn.primary().version(), + } + } + + /// Consumes this type and returns the wrapped [P2PStream]. + #[inline] + pub(crate) fn into_inner(self) -> P2PStream>> { + match self { + Self::EthOnly(conn) => conn.into_inner(), + Self::Satellite(conn) => conn.into_inner(), + } + } + + /// Returns mutable access to the underlying stream. + #[inline] + pub(crate) fn inner_mut(&mut self) -> &mut P2PStream>> { + match self { + Self::EthOnly(conn) => conn.inner_mut(), + Self::Satellite(conn) => conn.inner_mut(), + } + } + + /// Returns access to the underlying stream. + #[inline] + pub(crate) fn inner(&self) -> &P2PStream>> { + match self { + Self::EthOnly(conn) => conn.inner(), + Self::Satellite(conn) => conn.inner(), + } + } + + /// Same as [`Sink::start_send`] but accepts a [`EthBroadcastMessage`] instead. + #[inline] + pub fn start_send_broadcast( + &mut self, + item: EthBroadcastMessage, + ) -> Result<(), EthStreamError> { + match self { + Self::EthOnly(conn) => conn.start_send_broadcast(item), + Self::Satellite(conn) => conn.primary_mut().start_send_broadcast(item), + } + } +} + +impl From for EthRlpxConnection { + #[inline] + fn from(conn: EthPeerConnection) -> Self { + Self::EthOnly(Box::new(conn)) + } +} + +impl From for EthRlpxConnection { + #[inline] + fn from(conn: EthSatelliteConnection) -> Self { + Self::Satellite(Box::new(conn)) + } +} + +macro_rules! delegate_call { + ($self:ident.$method:ident($($args:ident),+)) => { + unsafe { + match $self.get_unchecked_mut() { + Self::EthOnly(l) => Pin::new_unchecked(l).$method($($args),+), + Self::Satellite(r) => Pin::new_unchecked(r).$method($($args),+), + } + } + } +} + +impl Stream for EthRlpxConnection { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + delegate_call!(self.poll_next(cx)) + } +} + +impl Sink for EthRlpxConnection { + type Error = EthStreamError; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + delegate_call!(self.poll_ready(cx)) + } + + fn start_send(self: Pin<&mut Self>, item: EthMessage) -> Result<(), Self::Error> { + delegate_call!(self.start_send(item)) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + delegate_call!(self.poll_flush(cx)) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + delegate_call!(self.poll_close(cx)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn assert_eth_stream() + where + St: Stream> + Sink, + { + } + + #[test] + fn test_eth_stream_variants() { + assert_eth_stream::(); + assert_eth_stream::(); + } +} diff --git a/crates/net/network/src/session/handle.rs b/crates/net/network/src/session/handle.rs index 44ca04197..9e8f4ec2a 100644 --- a/crates/net/network/src/session/handle.rs +++ b/crates/net/network/src/session/handle.rs @@ -1,9 +1,7 @@ //! Session handles. - -use super::active::PeerConnection; use crate::{ message::PeerMessage, - session::{Direction, SessionId}, + session::{conn::EthRlpxConnection, Direction, SessionId}, }; use reth_ecies::ECIESError; use reth_eth_wire::{ @@ -174,7 +172,7 @@ pub enum PendingSessionEvent { status: Arc, /// The actual connection stream which can be used to send and receive `eth` protocol /// messages - conn: PeerConnection, + conn: EthRlpxConnection, /// The direction of the session, either `Inbound` or `Outgoing` direction: Direction, /// The remote node's user agent, usually containing the client name and version diff --git a/crates/net/network/src/session/mod.rs b/crates/net/network/src/session/mod.rs index 863964ac9..98e5f9e59 100644 --- a/crates/net/network/src/session/mod.rs +++ b/crates/net/network/src/session/mod.rs @@ -40,15 +40,16 @@ use tracing::{instrument, trace}; mod active; mod config; +mod conn; mod handle; pub use crate::message::PeerRequestSender; +use crate::protocol::{IntoRlpxSubProtocol, RlpxSubProtocolHandlers, RlpxSubProtocols}; pub use config::{SessionLimits, SessionsConfig}; pub use handle::{ ActiveSessionHandle, ActiveSessionMessage, PendingSessionEvent, PendingSessionHandle, SessionCommand, }; - -use crate::protocol::{IntoRlpxSubProtocol, RlpxSubProtocols}; +use reth_eth_wire::multiplex::RlpxProtocolMultiplexer; pub use reth_network_api::{Direction, PeerInfo}; /// Internal identifier for active sessions. @@ -228,6 +229,7 @@ impl SessionManager { let hello_message = self.hello_message.clone(); let status = self.status; let fork_filter = self.fork_filter.clone(); + let extra_handlers = self.extra_protocols.on_incoming(remote_addr); self.spawn(start_pending_incoming_session( disconnect_rx, session_id, @@ -238,6 +240,7 @@ impl SessionManager { hello_message, status, fork_filter, + extra_handlers, )); let handle = PendingSessionHandle { @@ -261,6 +264,7 @@ impl SessionManager { let fork_filter = self.fork_filter.clone(); let status = self.status; let band_with_meter = self.bandwidth_meter.clone(); + let extra_handlers = self.extra_protocols.on_outgoing(remote_addr, remote_peer_id); self.spawn(start_pending_outbound_session( disconnect_rx, pending_events, @@ -272,6 +276,7 @@ impl SessionManager { status, fork_filter, band_with_meter, + extra_handlers, )); let handle = PendingSessionHandle { @@ -757,6 +762,7 @@ pub(crate) async fn start_pending_incoming_session( hello: HelloMessageWithProtocols, status: Status, fork_filter: ForkFilter, + extra_handlers: RlpxSubProtocolHandlers, ) { authenticate( disconnect_rx, @@ -769,6 +775,7 @@ pub(crate) async fn start_pending_incoming_session( hello, status, fork_filter, + extra_handlers, ) .await } @@ -787,6 +794,7 @@ async fn start_pending_outbound_session( status: Status, fork_filter: ForkFilter, bandwidth_meter: BandwidthMeter, + extra_handlers: RlpxSubProtocolHandlers, ) { let stream = match TcpStream::connect(remote_addr).await { Ok(stream) => { @@ -818,6 +826,7 @@ async fn start_pending_outbound_session( hello, status, fork_filter, + extra_handlers, ) .await } @@ -835,6 +844,7 @@ async fn authenticate( hello: HelloMessageWithProtocols, status: Status, fork_filter: ForkFilter, + extra_handlers: RlpxSubProtocolHandlers, ) { let local_addr = stream.inner().local_addr().ok(); let stream = match get_eciess_stream(stream, secret_key, direction).await { @@ -863,6 +873,7 @@ async fn authenticate( hello, status, fork_filter, + extra_handlers, ) .boxed(); @@ -900,7 +911,10 @@ async fn get_eciess_stream( /// Authenticate the stream via handshake /// -/// On Success return the authenticated stream as [`PendingSessionEvent`] +/// On Success return the authenticated stream as [`PendingSessionEvent`]. +/// +/// If additional [RlpxSubProtocolHandlers] are provided, the hello message will be updated to also +/// negotiate the additional protocols. #[allow(clippy::too_many_arguments)] async fn authenticate_stream( stream: UnauthedP2PStream>>, @@ -908,10 +922,14 @@ async fn authenticate_stream( remote_addr: SocketAddr, local_addr: Option, direction: Direction, - hello: HelloMessageWithProtocols, - status: Status, + mut hello: HelloMessageWithProtocols, + mut status: Status, fork_filter: ForkFilter, + mut extra_handlers: RlpxSubProtocolHandlers, ) -> PendingSessionEvent { + // Add extra protocols to the hello message + extra_handlers.retain(|handler| hello.try_add_protocol(handler.protocol()).is_ok()); + // conduct the p2p handshake and return the authenticated stream let (p2p_stream, their_hello) = match stream.handshake(hello).await { Ok(stream_res) => stream_res, @@ -925,8 +943,8 @@ async fn authenticate_stream( } }; - // Ensure we negotiated eth protocol - let version = match p2p_stream.shared_capabilities().eth_version() { + // Ensure we negotiated mandatory eth protocol + let eth_version = match p2p_stream.shared_capabilities().eth_version() { Ok(version) => version, Err(err) => { return PendingSessionEvent::Disconnected { @@ -938,22 +956,45 @@ async fn authenticate_stream( } }; - // if the hello handshake was successful we can try status handshake - // - // Before trying status handshake, set up the version to shared_capability - let status = Status { version, ..status }; - let eth_unauthed = UnauthedEthStream::new(p2p_stream); - let (eth_stream, their_status) = match eth_unauthed.handshake(status, fork_filter).await { - Ok(stream_res) => stream_res, - Err(err) => { - return PendingSessionEvent::Disconnected { - remote_addr, - session_id, - direction, - error: Some(err), + let (conn, their_status) = if p2p_stream.shared_capabilities().len() == 1 { + // if the hello handshake was successful we can try status handshake + // + // Before trying status handshake, set up the version to negotiated shared version + status.set_eth_version(eth_version); + let eth_unauthed = UnauthedEthStream::new(p2p_stream); + let (eth_stream, their_status) = match eth_unauthed.handshake(status, fork_filter).await { + Ok(stream_res) => stream_res, + Err(err) => { + return PendingSessionEvent::Disconnected { + remote_addr, + session_id, + direction, + error: Some(err), + } } + }; + (eth_stream.into(), their_status) + } else { + // Multiplex the stream with the extra protocols + let (mut multiplex_stream, their_status) = RlpxProtocolMultiplexer::new(p2p_stream) + .into_eth_satellite_stream(status, fork_filter) + .await + .unwrap(); + + // install additional handlers + for handler in extra_handlers.into_iter() { + let cap = handler.protocol().cap; + let remote_peer_id = their_hello.id; + multiplex_stream + .install_protocol(&cap, move |conn| { + handler.into_connection(direction, remote_peer_id, conn) + }) + .ok(); } + + (multiplex_stream.into(), their_status) }; + PendingSessionEvent::Established { session_id, remote_addr, @@ -961,7 +1002,7 @@ async fn authenticate_stream( peer_id: their_hello.id, capabilities: Arc::new(Capabilities::from(their_hello.capabilities)), status: Arc::new(their_status), - conn: Box::new(eth_stream), + conn, direction, client_id: their_hello.client_version, } diff --git a/crates/net/network/src/test_utils/testnet.rs b/crates/net/network/src/test_utils/testnet.rs index 1ac94ae33..75fc124b7 100644 --- a/crates/net/network/src/test_utils/testnet.rs +++ b/crates/net/network/src/test_utils/testnet.rs @@ -4,6 +4,7 @@ use crate::{ builder::ETH_REQUEST_CHANNEL_CAPACITY, error::NetworkError, eth_requests::EthRequestHandler, + protocol::IntoRlpxSubProtocol, transactions::{TransactionsHandle, TransactionsManager}, NetworkConfig, NetworkConfigBuilder, NetworkEvent, NetworkEvents, NetworkHandle, NetworkManager, @@ -340,6 +341,11 @@ where self.network.num_connected_peers() } + /// Adds an additional protocol handler to the peer. + pub fn add_rlpx_sub_protocol(&mut self, protocol: impl IntoRlpxSubProtocol) { + self.network.add_rlpx_sub_protocol(protocol); + } + /// Returns a handle to the peer's network. pub fn peer_handle(&self) -> PeerHandle { PeerHandle { diff --git a/crates/net/network/tests/it/main.rs b/crates/net/network/tests/it/main.rs index 1d65a90cc..a277252dc 100644 --- a/crates/net/network/tests/it/main.rs +++ b/crates/net/network/tests/it/main.rs @@ -2,6 +2,7 @@ mod big_pooled_txs_req; mod clique; mod connect; mod geth; +mod multiplex; mod requests; mod session; mod startup; diff --git a/crates/net/network/tests/it/multiplex.rs b/crates/net/network/tests/it/multiplex.rs new file mode 100644 index 000000000..3026a8efe --- /dev/null +++ b/crates/net/network/tests/it/multiplex.rs @@ -0,0 +1,330 @@ +//! Testing gossiping of transactions. + +use crate::multiplex::proto::{PingPongProtoMessage, PingPongProtoMessageKind}; +use futures::{Stream, StreamExt}; +use reth_eth_wire::{ + capability::SharedCapabilities, multiplex::ProtocolConnection, protocol::Protocol, +}; +use reth_network::{ + protocol::{ConnectionHandler, OnNotSupported, ProtocolHandler}, + test_utils::Testnet, +}; +use reth_network_api::Direction; +use reth_primitives::BytesMut; +use reth_provider::test_utils::MockEthProvider; +use reth_rpc_types::PeerId; +use std::{ + net::SocketAddr, + pin::Pin, + task::{ready, Context, Poll}, +}; +use tokio::sync::{mpsc, oneshot}; +use tokio_stream::wrappers::UnboundedReceiverStream; + +/// A simple Rplx subprotocol for +mod proto { + use super::*; + use reth_eth_wire::capability::Capability; + use reth_primitives::{Buf, BufMut}; + + #[repr(u8)] + #[derive(Clone, Copy, Debug, PartialEq, Eq)] + pub enum PingPongProtoMessageId { + Ping = 0x00, + Pong = 0x01, + PingMessage = 0x02, + PongMessage = 0x03, + } + + #[derive(Clone, Debug, PartialEq, Eq)] + pub enum PingPongProtoMessageKind { + Ping, + Pong, + PingMessage(String), + PongMessage(String), + } + + /// An protocol message, containing a message ID and payload. + #[derive(Clone, Debug, PartialEq, Eq)] + pub struct PingPongProtoMessage { + pub message_type: PingPongProtoMessageId, + pub message: PingPongProtoMessageKind, + } + + impl PingPongProtoMessage { + /// Returns the capability for the `ping` protocol. + pub fn capability() -> Capability { + Capability::new_static("ping", 1) + } + + /// Returns the protocol for the `test` protocol. + pub fn protocol() -> Protocol { + Protocol::new(Self::capability(), 4) + } + + /// Creates a ping message + pub fn ping() -> Self { + Self { + message_type: PingPongProtoMessageId::Ping, + message: PingPongProtoMessageKind::Ping, + } + } + + /// Creates a pong message + pub fn pong() -> Self { + Self { + message_type: PingPongProtoMessageId::Pong, + message: PingPongProtoMessageKind::Pong, + } + } + + /// Creates a ping message + pub fn ping_message(msg: impl Into) -> Self { + Self { + message_type: PingPongProtoMessageId::PingMessage, + message: PingPongProtoMessageKind::PingMessage(msg.into()), + } + } + /// Creates a ping message + pub fn pong_message(msg: impl Into) -> Self { + Self { + message_type: PingPongProtoMessageId::PongMessage, + message: PingPongProtoMessageKind::PongMessage(msg.into()), + } + } + + /// Creates a new `TestProtoMessage` with the given message ID and payload. + pub fn encoded(&self) -> BytesMut { + let mut buf = BytesMut::new(); + buf.put_u8(self.message_type as u8); + match &self.message { + PingPongProtoMessageKind::Ping => {} + PingPongProtoMessageKind::Pong => {} + PingPongProtoMessageKind::PingMessage(msg) => { + buf.put(msg.as_bytes()); + } + PingPongProtoMessageKind::PongMessage(msg) => { + buf.put(msg.as_bytes()); + } + } + buf + } + + /// Decodes a `TestProtoMessage` from the given message buffer. + pub fn decode_message(buf: &mut &[u8]) -> Option { + if buf.is_empty() { + return None; + } + let id = buf[0]; + buf.advance(1); + let message_type = match id { + 0x00 => PingPongProtoMessageId::Ping, + 0x01 => PingPongProtoMessageId::Pong, + 0x02 => PingPongProtoMessageId::PingMessage, + 0x03 => PingPongProtoMessageId::PongMessage, + _ => return None, + }; + let message = match message_type { + PingPongProtoMessageId::Ping => PingPongProtoMessageKind::Ping, + PingPongProtoMessageId::Pong => PingPongProtoMessageKind::Pong, + PingPongProtoMessageId::PingMessage => PingPongProtoMessageKind::PingMessage( + String::from_utf8_lossy(&buf[..]).into_owned(), + ), + PingPongProtoMessageId::PongMessage => PingPongProtoMessageKind::PongMessage( + String::from_utf8_lossy(&buf[..]).into_owned(), + ), + }; + Some(Self { message_type, message }) + } + } +} + +#[derive(Debug)] +struct PingPongProtoHandler { + state: ProtocolState, +} + +impl ProtocolHandler for PingPongProtoHandler { + type ConnectionHandler = PingPongConnectionHandler; + + fn on_incoming(&self, _socket_addr: SocketAddr) -> Option { + Some(PingPongConnectionHandler { state: self.state.clone() }) + } + + fn on_outgoing( + &self, + _socket_addr: SocketAddr, + _peer_id: PeerId, + ) -> Option { + Some(PingPongConnectionHandler { state: self.state.clone() }) + } +} + +#[derive(Clone, Debug)] +struct ProtocolState { + events: mpsc::UnboundedSender, +} + +#[derive(Debug)] +#[allow(dead_code)] +enum ProtocolEvent { + Established { + direction: Direction, + peer_id: PeerId, + to_connection: mpsc::UnboundedSender, + }, +} + +enum Command { + /// Send a ping message to the peer. + PingMessage { + msg: String, + /// The response will be sent to this channel. + response: oneshot::Sender, + }, +} + +struct PingPongConnectionHandler { + state: ProtocolState, +} + +impl ConnectionHandler for PingPongConnectionHandler { + type Connection = PingPongProtoConnection; + + fn protocol(&self) -> Protocol { + PingPongProtoMessage::protocol() + } + + fn on_unsupported_by_peer( + self, + _supported: &SharedCapabilities, + _direction: Direction, + _peer_id: PeerId, + ) -> OnNotSupported { + OnNotSupported::KeepAlive + } + + fn into_connection( + self, + direction: Direction, + _peer_id: PeerId, + conn: ProtocolConnection, + ) -> Self::Connection { + let (tx, rx) = mpsc::unbounded_channel(); + self.state + .events + .send(ProtocolEvent::Established { direction, peer_id: _peer_id, to_connection: tx }) + .ok(); + PingPongProtoConnection { + conn, + initial_ping: direction.is_outgoing().then(PingPongProtoMessage::ping), + commands: UnboundedReceiverStream::new(rx), + pending_pong: None, + } + } +} + +struct PingPongProtoConnection { + conn: ProtocolConnection, + initial_ping: Option, + commands: UnboundedReceiverStream, + pending_pong: Option>, +} + +impl Stream for PingPongProtoConnection { + type Item = BytesMut; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + if let Some(initial_ping) = this.initial_ping.take() { + return Poll::Ready(Some(initial_ping.encoded())); + } + + loop { + if let Poll::Ready(Some(cmd)) = this.commands.poll_next_unpin(cx) { + return match cmd { + Command::PingMessage { msg, response } => { + this.pending_pong = Some(response); + Poll::Ready(Some(PingPongProtoMessage::ping_message(msg).encoded())) + } + } + } + let Some(msg) = ready!(this.conn.poll_next_unpin(cx)) else { + return Poll::Ready(None); + }; + + let Some(msg) = PingPongProtoMessage::decode_message(&mut &msg[..]) else { + return Poll::Ready(None); + }; + + match msg.message { + PingPongProtoMessageKind::Ping => { + return Poll::Ready(Some(PingPongProtoMessage::pong().encoded())); + } + PingPongProtoMessageKind::Pong => {} + PingPongProtoMessageKind::PingMessage(msg) => { + return Poll::Ready(Some(PingPongProtoMessage::pong_message(msg).encoded())); + } + PingPongProtoMessageKind::PongMessage(msg) => { + if let Some(sender) = this.pending_pong.take() { + sender.send(msg).ok(); + } + continue + } + } + + return Poll::Pending; + } + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_proto_multiplex() { + reth_tracing::init_test_tracing(); + let provider = MockEthProvider::default(); + let mut net = Testnet::create_with(2, provider.clone()).await; + + let (tx, mut from_peer0) = mpsc::unbounded_channel(); + net.peers_mut()[0] + .add_rlpx_sub_protocol(PingPongProtoHandler { state: ProtocolState { events: tx } }); + + let (tx, mut from_peer1) = mpsc::unbounded_channel(); + net.peers_mut()[1] + .add_rlpx_sub_protocol(PingPongProtoHandler { state: ProtocolState { events: tx } }); + + let handle = net.spawn(); + // connect all the peers + handle.connect_peers().await; + + let peer0_to_peer1 = from_peer0.recv().await.unwrap(); + let peer0_conn = match peer0_to_peer1 { + ProtocolEvent::Established { direction: _, peer_id, to_connection } => { + assert_eq!(peer_id, *handle.peers()[1].peer_id()); + to_connection + } + }; + + let peer1_to_peer0 = from_peer1.recv().await.unwrap(); + let peer1_conn = match peer1_to_peer0 { + ProtocolEvent::Established { direction: _, peer_id, to_connection } => { + assert_eq!(peer_id, *handle.peers()[0].peer_id()); + to_connection + } + }; + + let (tx, rx) = oneshot::channel(); + // send a ping message from peer0 to peer1 + peer0_conn.send(Command::PingMessage { msg: "hello!".to_string(), response: tx }).unwrap(); + + let response = rx.await.unwrap(); + assert_eq!(response, "hello!"); + + let (tx, rx) = oneshot::channel(); + // send a ping message from peer1 to peer0 + peer1_conn + .send(Command::PingMessage { msg: "hello from peer1!".to_string(), response: tx }) + .unwrap(); + + let response = rx.await.unwrap(); + assert_eq!(response, "hello from peer1!"); +}