mirror of
https://github.com/hl-archive-node/nanoreth.git
synced 2025-12-06 10:59:55 +00:00
chore: some multiplex followup (#5553)
This commit is contained in:
@ -376,17 +376,24 @@ impl SharedCapabilities {
|
||||
|
||||
/// Returns the matching shared capability for the given capability offset.
|
||||
///
|
||||
/// `offset` is the multiplexed message id offset of the capability relative to
|
||||
/// [`MAX_RESERVED_MESSAGE_ID`].
|
||||
/// `offset` is the multiplexed message id offset of the capability relative to the reserved
|
||||
/// message id space. In other words, counting starts at [`MAX_RESERVED_MESSAGE_ID`] + 1, which
|
||||
/// corresponds to the first non-reserved message id.
|
||||
///
|
||||
/// For example: `offset == 0` corresponds to the first shared message across the shared
|
||||
/// capabilities and will return the first shared capability that supports messages.
|
||||
#[inline]
|
||||
pub fn find_by_relative_offset(&self, offset: u8) -> Option<&SharedCapability> {
|
||||
self.find_by_offset(offset.saturating_add(MAX_RESERVED_MESSAGE_ID))
|
||||
self.find_by_offset(offset.saturating_add(MAX_RESERVED_MESSAGE_ID + 1))
|
||||
}
|
||||
|
||||
/// Returns the matching shared capability for the given capability offset.
|
||||
///
|
||||
/// `offset` is the multiplexed message id offset of the capability that includes the reserved
|
||||
/// message id space.
|
||||
///
|
||||
/// This will always return None if `offset` is less than or equal to
|
||||
/// [`MAX_RESERVED_MESSAGE_ID`] because the reserved message id space is not shared.
|
||||
#[inline]
|
||||
pub fn find_by_offset(&self, offset: u8) -> Option<&SharedCapability> {
|
||||
let mut iter = self.0.iter();
|
||||
@ -637,12 +644,14 @@ mod tests {
|
||||
|
||||
let shared = SharedCapabilities::try_new(local_capabilities, peer_capabilities).unwrap();
|
||||
|
||||
assert!(shared.find_by_relative_offset(0).is_none());
|
||||
let shared_eth = shared.find_by_relative_offset(1).unwrap();
|
||||
let shared_eth = shared.find_by_relative_offset(0).unwrap();
|
||||
assert_eq!(shared_eth.name(), "eth");
|
||||
|
||||
let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 1).unwrap();
|
||||
assert_eq!(shared_eth.name(), "eth");
|
||||
|
||||
// reserved message id space
|
||||
assert!(shared.find_by_offset(MAX_RESERVED_MESSAGE_ID).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -654,15 +663,14 @@ mod tests {
|
||||
|
||||
let shared = SharedCapabilities::try_new(local_capabilities, peer_capabilities).unwrap();
|
||||
|
||||
assert!(shared.find_by_relative_offset(0).is_none());
|
||||
let shared_eth = shared.find_by_relative_offset(1).unwrap();
|
||||
let shared_eth = shared.find_by_relative_offset(0).unwrap();
|
||||
assert_eq!(shared_eth.name(), proto.cap.name);
|
||||
|
||||
let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 1).unwrap();
|
||||
assert_eq!(shared_eth.name(), proto.cap.name);
|
||||
|
||||
// the 5th shared message is the last message of the aaa capability
|
||||
let shared_eth = shared.find_by_relative_offset(5).unwrap();
|
||||
// the 5th shared message (0,1,2,3,4) is the last message of the aaa capability
|
||||
let shared_eth = shared.find_by_relative_offset(4).unwrap();
|
||||
assert_eq!(shared_eth.name(), proto.cap.name);
|
||||
let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 5).unwrap();
|
||||
assert_eq!(shared_eth.name(), proto.cap.name);
|
||||
|
||||
@ -65,15 +65,16 @@ impl<St> RlpxProtocolMultiplexer<St> {
|
||||
mut self,
|
||||
cap: &Capability,
|
||||
handshake: F,
|
||||
) -> Result<RlpxSatelliteStream<St, Primary>, Self>
|
||||
) -> Result<RlpxSatelliteStream<St, Primary>, Err>
|
||||
where
|
||||
F: FnOnce(ProtocolProxy) -> Fut,
|
||||
Fut: Future<Output = Result<Primary, Err>>,
|
||||
St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
|
||||
P2PStreamError: Into<Err>,
|
||||
{
|
||||
let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned()
|
||||
else {
|
||||
return Err(self)
|
||||
return Err(P2PStreamError::CapabilityNotShared.into())
|
||||
};
|
||||
|
||||
let (to_primary, from_wire) = mpsc::unbounded_channel();
|
||||
@ -87,20 +88,36 @@ impl<St> RlpxProtocolMultiplexer<St> {
|
||||
let f = handshake(proxy);
|
||||
pin_mut!(f);
|
||||
|
||||
// handle messages until the handshake is complete
|
||||
// this polls the connection and the primary stream concurrently until the handshake is
|
||||
// complete
|
||||
loop {
|
||||
// TODO error handling
|
||||
tokio::select! {
|
||||
Some(Ok(msg)) = self.conn.next() => {
|
||||
// TODO handle multiplex
|
||||
let _ = to_primary.send(msg);
|
||||
// Ensure the message belongs to the primary protocol
|
||||
let offset = msg[0];
|
||||
if let Some(cap) = self.conn.shared_capabilities().find_by_relative_offset(offset) {
|
||||
if cap == &shared_cap {
|
||||
// delegate to primary
|
||||
let _ = to_primary.send(msg);
|
||||
} else {
|
||||
// delegate to satellite
|
||||
for proto in &self.protocols {
|
||||
if proto.cap == *cap {
|
||||
// TODO: need some form of backpressure here so buffering can't be abused
|
||||
proto.send_raw(msg);
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Err(P2PStreamError::UnknownReservedMessageId(offset).into())
|
||||
}
|
||||
}
|
||||
Some(msg) = from_primary.recv() => {
|
||||
// TODO error handling
|
||||
self.conn.send(msg).await.unwrap();
|
||||
self.conn.send(msg).await.map_err(Into::into)?;
|
||||
}
|
||||
res = &mut f => {
|
||||
let Ok(primary) = res else { return Err(self) };
|
||||
let primary = res?;
|
||||
return Ok(RlpxSatelliteStream {
|
||||
conn: self.conn,
|
||||
to_primary,
|
||||
@ -117,24 +134,47 @@ impl<St> RlpxProtocolMultiplexer<St> {
|
||||
}
|
||||
|
||||
/// A Stream and Sink type that acts as a wrapper around a primary RLPx subprotocol (e.g. "eth")
|
||||
///
|
||||
/// Only emits and sends _non-empty_ messages
|
||||
#[derive(Debug)]
|
||||
pub struct ProtocolProxy {
|
||||
cap: SharedCapability,
|
||||
/// Receives _non-empty_ messages from the wire
|
||||
from_wire: UnboundedReceiverStream<BytesMut>,
|
||||
/// Sends _non-empty_ messages from the wire
|
||||
to_wire: UnboundedSender<Bytes>,
|
||||
}
|
||||
|
||||
impl ProtocolProxy {
|
||||
/// Sends a _non-empty_ message on the wire.
|
||||
fn try_send(&self, msg: Bytes) -> Result<(), io::Error> {
|
||||
if msg.is_empty() {
|
||||
// message must not be empty
|
||||
return Err(io::ErrorKind::InvalidInput.into())
|
||||
}
|
||||
self.to_wire.send(self.mask_msg_id(msg)).map_err(|_| io::ErrorKind::BrokenPipe.into())
|
||||
}
|
||||
|
||||
/// Masks the message ID of a message to be sent on the wire.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the message is empty.
|
||||
#[inline]
|
||||
fn mask_msg_id(&self, msg: Bytes) -> Bytes {
|
||||
// TODO handle empty messages
|
||||
let mut masked_bytes = BytesMut::zeroed(msg.len());
|
||||
masked_bytes[0] = msg[0] + self.cap.relative_message_id_offset();
|
||||
masked_bytes[1..].copy_from_slice(&msg[1..]);
|
||||
masked_bytes.freeze()
|
||||
}
|
||||
|
||||
/// Unmasks the message ID of a message received from the wire.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the message is empty.
|
||||
#[inline]
|
||||
fn unmask_id(&self, mut msg: BytesMut) -> BytesMut {
|
||||
// TODO handle empty messages
|
||||
msg[0] -= self.cap.relative_message_id_offset();
|
||||
msg
|
||||
}
|
||||
@ -157,8 +197,7 @@ impl Sink<Bytes> for ProtocolProxy {
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
|
||||
let msg = self.mask_msg_id(item);
|
||||
self.to_wire.send(msg).map_err(|_| io::ErrorKind::BrokenPipe.into())
|
||||
self.get_mut().try_send(item)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
@ -181,7 +220,7 @@ impl CanDisconnect<Bytes> for ProtocolProxy {
|
||||
}
|
||||
}
|
||||
|
||||
/// A connection channel to receive messages for the negotiated protocol.
|
||||
/// A connection channel to receive _non_empty_ messages for the negotiated protocol.
|
||||
///
|
||||
/// This is a [Stream] that returns raw bytes of the received messages for this protocol.
|
||||
#[derive(Debug)]
|
||||
@ -287,34 +326,28 @@ where
|
||||
Poll::Ready(Some(Ok(msg))) => {
|
||||
delegated = true;
|
||||
let offset = msg[0];
|
||||
// find the protocol that matches the offset
|
||||
// TODO optimize this by keeping a better index
|
||||
let mut lowest_satellite = None;
|
||||
// find the protocol with the lowest offset that is greater than the message
|
||||
// offset
|
||||
for (i, proto) in this.satellites.iter().enumerate() {
|
||||
let proto_offset = proto.cap.relative_message_id_offset();
|
||||
if proto_offset >= offset {
|
||||
if let Some((_, lowest_offset)) = lowest_satellite {
|
||||
if proto_offset < lowest_offset {
|
||||
lowest_satellite = Some((i, proto_offset));
|
||||
// delegate the multiplexed message to the correct protocol
|
||||
if let Some(cap) =
|
||||
this.conn.shared_capabilities().find_by_relative_offset(offset)
|
||||
{
|
||||
if cap == &this.primary_capability {
|
||||
// delegate to primary
|
||||
let _ = this.to_primary.send(msg);
|
||||
} else {
|
||||
// delegate to satellite
|
||||
for proto in &this.satellites {
|
||||
if proto.cap == *cap {
|
||||
proto.send_raw(msg);
|
||||
break
|
||||
}
|
||||
} else {
|
||||
lowest_satellite = Some((i, proto_offset));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(
|
||||
offset,
|
||||
)
|
||||
.into())))
|
||||
}
|
||||
|
||||
if let Some((idx, lowest_offset)) = lowest_satellite {
|
||||
if lowest_offset < this.primary_capability.relative_message_id_offset()
|
||||
{
|
||||
// delegate to satellite
|
||||
this.satellites[idx].send_raw(msg);
|
||||
continue
|
||||
}
|
||||
}
|
||||
// delegate to primary
|
||||
let _ = this.to_primary.send(msg);
|
||||
}
|
||||
Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err.into()))),
|
||||
Poll::Ready(None) => {
|
||||
@ -373,18 +406,29 @@ struct ProtocolStream {
|
||||
}
|
||||
|
||||
impl ProtocolStream {
|
||||
/// Masks the message ID of a message to be sent on the wire.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the message is empty.
|
||||
#[inline]
|
||||
fn mask_msg_id(&self, mut msg: BytesMut) -> Bytes {
|
||||
// TODO handle empty messages
|
||||
msg[0] += self.cap.relative_message_id_offset();
|
||||
msg.freeze()
|
||||
}
|
||||
|
||||
/// Unmasks the message ID of a message received from the wire.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the message is empty.
|
||||
#[inline]
|
||||
fn unmask_id(&self, mut msg: BytesMut) -> BytesMut {
|
||||
// TODO handle empty messages
|
||||
msg[0] -= self.cap.relative_message_id_offset();
|
||||
msg
|
||||
}
|
||||
|
||||
/// Sends the message to the satellite stream.
|
||||
fn send_raw(&self, msg: BytesMut) {
|
||||
let _ = self.to_satellite.send(self.unmask_id(msg));
|
||||
}
|
||||
@ -396,7 +440,7 @@ impl Stream for ProtocolStream {
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.get_mut();
|
||||
let msg = ready!(this.satellite_st.as_mut().poll_next(cx));
|
||||
Poll::Ready(msg.map(|msg| this.mask_msg_id(msg)))
|
||||
Poll::Ready(msg.filter(|msg| !msg.is_empty()).map(|msg| this.mask_msg_id(msg)))
|
||||
}
|
||||
}
|
||||
|
||||
@ -408,15 +452,13 @@ impl fmt::Debug for ProtocolStream {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_util::codec::Decoder;
|
||||
|
||||
use super::*;
|
||||
use crate::{
|
||||
test_utils::{connect_passthrough, eth_handshake, eth_hello},
|
||||
UnauthedEthStream, UnauthedP2PStream,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_util::codec::Decoder;
|
||||
|
||||
#[tokio::test]
|
||||
async fn eth_satellite() {
|
||||
|
||||
@ -228,9 +228,10 @@ where
|
||||
///
|
||||
/// See also <https://github.com/ethereum/devp2p/blob/master/rlpx.md#message-id-based-multiplexing>
|
||||
///
|
||||
/// This stream emits Bytes that start with the normalized message id, so that the first byte of
|
||||
/// each message starts from 0. If this stream only supports a single capability, for example `eth`
|
||||
/// then the first byte of each message will match [EthMessageID](crate::types::EthMessageID).
|
||||
/// This stream emits _non-empty_ Bytes that start with the normalized message id, so that the first
|
||||
/// byte of each message starts from 0. If this stream only supports a single capability, for
|
||||
/// example `eth` then the first byte of each message will match
|
||||
/// [EthMessageID](crate::types::EthMessageID).
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct P2PStream<S> {
|
||||
@ -405,6 +406,11 @@ where
|
||||
None => return Poll::Ready(None),
|
||||
};
|
||||
|
||||
if bytes.is_empty() {
|
||||
// empty messages are not allowed
|
||||
return Poll::Ready(Some(Err(P2PStreamError::EmptyProtocolMessage)))
|
||||
}
|
||||
|
||||
// first check that the compressed message length does not exceed the max
|
||||
// payload size
|
||||
let decompressed_len = snap::raw::decompress_len(&bytes[1..])?;
|
||||
@ -430,7 +436,7 @@ where
|
||||
err
|
||||
})?;
|
||||
|
||||
let id = *bytes.first().ok_or(P2PStreamError::EmptyProtocolMessage)?;
|
||||
let id = bytes[0];
|
||||
match id {
|
||||
_ if id == P2PMessageID::Ping as u8 => {
|
||||
trace!("Received Ping, Sending Pong");
|
||||
|
||||
Reference in New Issue
Block a user