refactor(net): separate Sink and Stream (#217)

This commit is contained in:
Matthias Seitz
2022-11-16 20:10:11 +01:00
committed by GitHub
parent bb83d8a528
commit 75a6d06301
3 changed files with 74 additions and 72 deletions

View File

@ -59,8 +59,8 @@ pub enum P2PStreamError {
EmptyProtocolMessage,
#[error(transparent)]
PingerError(#[from] PingerError),
#[error("ping timed out with {0} retries")]
PingTimeout(u8),
#[error("ping timed out with")]
PingTimeout,
#[error(transparent)]
ParseVersionError(#[from] SharedCapabilityError),
#[error("mismatched protocol version in Hello message. expected: {expected:?}, got: {got:?}")]

View File

@ -203,23 +203,21 @@ where
#[cfg(test)]
mod tests {
use super::UnauthedEthStream;
use crate::{
capability::Capability,
p2pstream::{HelloMessage, ProtocolVersion, UnauthedP2PStream},
types::{broadcast::BlockHashNumber, EthMessage, Status},
types::{broadcast::BlockHashNumber, EthMessage, EthVersion, Status},
EthStream, PassthroughCodec,
};
use ethers_core::types::Chain;
use futures::{SinkExt, StreamExt};
use reth_ecies::{stream::ECIESStream, util::pk2id};
use reth_primitives::{ForkFilter, H256, U256};
use secp256k1::{SecretKey, SECP256K1};
use tokio::net::{TcpListener, TcpStream};
use tokio_util::codec::Decoder;
use crate::{capability::Capability, types::EthVersion};
use ethers_core::types::Chain;
use reth_primitives::{ForkFilter, H256, U256};
use super::UnauthedEthStream;
#[tokio::test]
async fn can_handshake() {
let genesis = H256::random();
@ -341,7 +339,7 @@ mod tests {
handle.await.unwrap();
}
#[tokio::test]
#[tokio::test(flavor = "multi_thread")]
async fn ethstream_over_p2p() {
// create a p2p stream and server, then confirm that the two are authed
// create tcpstream
@ -350,8 +348,8 @@ mod tests {
let server_key = SecretKey::new(&mut rand::thread_rng());
let test_msg = EthMessage::NewBlockHashes(
vec![
BlockHashNumber { hash: reth_primitives::H256::random(), number: 5 },
BlockHashNumber { hash: reth_primitives::H256::random(), number: 6 },
BlockHashNumber { hash: H256::random(), number: 5 },
BlockHashNumber { hash: H256::random(), number: 6 },
]
.into(),
);
@ -415,6 +413,7 @@ mod tests {
let unauthed_stream = UnauthedP2PStream::new(sink);
let (p2p_stream, _) = unauthed_stream.handshake(client_hello).await.unwrap();
let (mut client_stream, _) =
UnauthedEthStream::new(p2p_stream).handshake(status, fork_filter).await.unwrap();

View File

@ -14,7 +14,7 @@ use std::{
fmt::Display,
io,
pin::Pin,
task::{Context, Poll},
task::{ready, Context, Poll},
time::Duration,
};
use tokio_stream::Stream;
@ -186,24 +186,6 @@ where
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
// try to send any buffered outgoing messages
while let Some(message) = this.outgoing_messages.pop_front() {
// let pinned_inner = Pin::new(&mut this.inner);
match Pin::new(&mut this.inner).poll_ready(cx) {
Poll::Ready(Ok(())) => {
if let Err(e) = Pin::new(&mut this.inner).start_send(message) {
return Poll::Ready(Some(Err(P2PStreamError::Io(e))))
}
}
Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(P2PStreamError::Io(e)))),
Poll::Pending => {
// we need to buffer the message and try again later
this.outgoing_messages.push_front(message);
break
}
}
}
// poll the pinger to determine if we should send a ping
match this.pinger.poll_ping(cx) {
Poll::Pending => {}
@ -212,38 +194,24 @@ where
let mut ping_bytes = BytesMut::new();
P2PMessage::Ping.encode(&mut ping_bytes);
if Pin::new(&mut this.inner).poll_ready(cx).is_ready() {
// send the ping message
Pin::new(&mut this.inner).start_send(ping_bytes.into())?
} else {
// check if the buffer is full
if this.outgoing_messages.len() >= MAX_P2P_CAPACITY {
return Poll::Ready(Some(Err(P2PStreamError::SendBufferFull)))
}
// if the sink is not ready, buffer the message
this.outgoing_messages.push_back(ping_bytes.into());
// check if the buffer is full
if this.outgoing_messages.len() >= MAX_P2P_CAPACITY {
return Poll::Ready(Some(Err(P2PStreamError::SendBufferFull)))
}
// if the sink is not ready, buffer the message
this.outgoing_messages.push_back(ping_bytes.into());
}
_ => {
// encode the disconnect message
let mut disconnect_bytes = BytesMut::new();
P2PMessage::Disconnect(DisconnectReason::PingTimeout).encode(&mut disconnect_bytes);
if Pin::new(&mut this.inner).poll_ready(cx).is_ready() {
// send the disconnect message
Pin::new(&mut this.inner).start_send(disconnect_bytes.into())?
} else {
// check if the buffer is full
if this.outgoing_messages.len() >= MAX_P2P_CAPACITY {
return Poll::Ready(Some(Err(P2PStreamError::SendBufferFull)))
}
// clear any buffered messages so that the next message will be disconnect
this.outgoing_messages.clear();
this.outgoing_messages.push_back(disconnect_bytes.freeze());
// if the sink is not ready, buffer the message
this.outgoing_messages.push_back(disconnect_bytes.into());
}
// since the ping stream has timed out, let's send a None
// End the stream after ping related error
return Poll::Ready(None)
}
}
@ -264,18 +232,11 @@ where
let mut pong_bytes = BytesMut::new();
P2PMessage::Pong.encode(&mut pong_bytes);
if Pin::new(&mut this.inner).poll_ready(cx).is_ready() {
// send the pong message
Pin::new(&mut this.inner).start_send(pong_bytes.into())?
} else {
// check if the buffer is full
if this.outgoing_messages.len() >= MAX_P2P_CAPACITY {
return Poll::Ready(Some(Err(P2PStreamError::SendBufferFull)))
}
// if the sink is not ready, buffer the message
this.outgoing_messages.push_back(pong_bytes.into());
// check if the buffer is full
if this.outgoing_messages.len() >= MAX_P2P_CAPACITY {
return Poll::Ready(Some(Err(P2PStreamError::SendBufferFull)))
}
this.outgoing_messages.push_back(pong_bytes.into());
// continue to the next message if there is one
} else if id == P2PMessageID::Disconnect as u8 {
@ -330,13 +291,36 @@ where
{
type Error = P2PStreamError;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx).map_err(Into::into)
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let mut this = self.as_mut();
match this.inner.poll_ready_unpin(cx) {
Poll::Pending => {}
Poll::Ready(Err(err)) => return Poll::Ready(Err(P2PStreamError::Io(err))),
Poll::Ready(Ok(())) => {
let flushed = this.poll_flush(cx);
if flushed.is_ready() {
return flushed
}
}
}
if self.outgoing_messages.len() < MAX_P2P_CAPACITY {
// still has capacity
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
let this = self.project();
// ensure we have free capacity
if this.outgoing_messages.len() >= MAX_P2P_CAPACITY {
return Err(P2PStreamError::SendBufferFull)
}
let mut compressed = BytesMut::zeroed(1 + snap::raw::max_compress_len(item.len() - 1));
// all messages sent in this stream are subprotocol messages, so we need to switch the
@ -348,16 +332,35 @@ where
// id)
compressed.truncate(compressed_size + 1);
this.inner.start_send(compressed.freeze())?;
this.outgoing_messages.push_back(compressed.freeze());
Ok(())
}
/// Returns Poll::Ready(Ok(())) when no buffered items remain and the sink has been successfully
/// closed.
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(cx).map_err(Into::into)
let mut this = self.project();
loop {
match ready!(this.inner.as_mut().poll_flush(cx)) {
Err(err) => return Poll::Ready(Err(err.into())),
Ok(()) => {
if let Some(message) = this.outgoing_messages.pop_front() {
if let Err(err) = this.inner.as_mut().start_send(message) {
return Poll::Ready(Err(err.into()))
}
} else {
return Poll::Ready(Ok(()))
}
}
}
}
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_close(cx).map_err(Into::into)
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
ready!(self.as_mut().poll_flush(cx))?;
Poll::Ready(Ok(()))
}
}