refactor(net): misc P2Pstream refactor (#336)

* refactor(net): misc P2Pstream refactor

* update note
This commit is contained in:
Matthias Seitz
2022-12-05 20:49:22 +01:00
committed by Georgios Konstantopoulos
parent 074e69cafb
commit b40546b999
3 changed files with 71 additions and 67 deletions

View File

@ -40,7 +40,7 @@ impl<S> UnauthedEthStream<S> {
impl<S, E> UnauthedEthStream<S>
where
S: Stream<Item = Result<bytes::BytesMut, E>> + Sink<bytes::Bytes, Error = E> + Unpin,
S: Stream<Item = Result<BytesMut, E>> + Sink<Bytes, Error = E> + Unpin,
EthStreamError: From<E>,
{
/// Consumes the [`UnauthedEthStream`] and returns an [`EthStream`] after the `Status`
@ -54,7 +54,8 @@ where
tracing::trace!("sending eth status ...");
// we need to encode and decode here on our own because we don't have an `EthStream` yet
let mut our_status_bytes = BytesMut::new();
// The max length for a status with TTD is: <msg id = 1 byte> + <rlp(status) = 88 byte>
let mut our_status_bytes = BytesMut::with_capacity(1 + 88);
ProtocolMessage::from(EthMessage::Status(status)).encode(&mut our_status_bytes);
let our_status_bytes = our_status_bytes.freeze();
self.inner.send(our_status_bytes).await?;
@ -277,7 +278,7 @@ mod tests {
let handle = tokio::spawn(async move {
// roughly based off of the design of tokio::net::TcpListener
let (incoming, _) = listener.accept().await.unwrap();
let stream = crate::PassthroughCodec::default().framed(incoming);
let stream = PassthroughCodec::default().framed(incoming);
let (_, their_status) = UnauthedEthStream::new(stream)
.handshake(status_clone, fork_filter_clone)
.await
@ -288,7 +289,7 @@ mod tests {
});
let outgoing = TcpStream::connect(local_addr).await.unwrap();
let sink = crate::PassthroughCodec::default().framed(outgoing);
let sink = PassthroughCodec::default().framed(outgoing);
// try to connect
let (_, their_status) =
@ -307,8 +308,8 @@ mod tests {
let local_addr = listener.local_addr().unwrap();
let test_msg = EthMessage::NewBlockHashes(
vec![
BlockHashNumber { hash: reth_primitives::H256::random(), number: 5 },
BlockHashNumber { hash: reth_primitives::H256::random(), number: 6 },
BlockHashNumber { hash: H256::random(), number: 5 },
BlockHashNumber { hash: H256::random(), number: 6 },
]
.into(),
);
@ -342,8 +343,8 @@ mod tests {
let server_key = SecretKey::new(&mut rand::thread_rng());
let test_msg = EthMessage::NewBlockHashes(
vec![
BlockHashNumber { hash: reth_primitives::H256::random(), number: 5 },
BlockHashNumber { hash: reth_primitives::H256::random(), number: 6 },
BlockHashNumber { hash: H256::random(), number: 5 },
BlockHashNumber { hash: H256::random(), number: 6 },
]
.into(),
);

View File

@ -67,7 +67,7 @@ impl<S> UnauthedP2PStream<S> {
impl<S> UnauthedP2PStream<S>
where
S: Stream<Item = Result<BytesMut, io::Error>> + Sink<Bytes, Error = io::Error> + Unpin,
S: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
{
/// Consumes the `UnauthedP2PStream` and returns a `P2PStream` after the `Hello` handshake is
/// completed successfully. This also returns the `Hello` message sent by the remote peer.
@ -218,8 +218,9 @@ impl<S> P2PStream<S> {
pub fn start_disconnect(&mut self, reason: DisconnectReason) {
// clear any buffered messages and queue in
self.outgoing_messages.clear();
let mut buf = BytesMut::new();
P2PMessage::Disconnect(reason).encode(&mut buf);
let disconnect = P2PMessage::Disconnect(reason);
let mut buf = BytesMut::with_capacity(disconnect.length());
disconnect.encode(&mut buf);
self.outgoing_messages.push_back(buf.freeze());
}
}
@ -241,14 +242,14 @@ where
// protocol
impl<S> Stream for P2PStream<S>
where
S: Stream<Item = Result<BytesMut, io::Error>> + Sink<Bytes, Error = io::Error> + Unpin,
S: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
{
type Item = Result<BytesMut, P2PStreamError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
let this = self.get_mut();
if *this.disconnecting {
if this.disconnecting {
// if disconnecting, stop reading messages
return Poll::Ready(None)
}
@ -271,12 +272,7 @@ where
}
_ => {
// encode the disconnect message
let mut disconnect_bytes = BytesMut::new();
P2PMessage::Disconnect(DisconnectReason::PingTimeout).encode(&mut disconnect_bytes);
// clear any buffered messages so that the next message will be disconnect
this.outgoing_messages.clear();
this.outgoing_messages.push_back(disconnect_bytes.freeze());
this.start_disconnect(DisconnectReason::PingTimeout);
// End the stream after ping related error
return Poll::Ready(None)
@ -285,7 +281,7 @@ where
// we should loop here to ensure we don't return Poll::Pending if we have a message to
// return behind any pings we need to respond to
while let Poll::Ready(res) = this.inner.as_mut().poll_next(cx) {
while let Poll::Ready(res) = this.inner.poll_next_unpin(cx) {
let bytes = match res {
Some(Ok(bytes)) => bytes,
Some(Err(err)) => return Poll::Ready(Some(Err(err.into()))),
@ -293,57 +289,64 @@ where
};
let id = *bytes.first().ok_or(P2PStreamError::EmptyProtocolMessage)?;
if id == P2PMessageID::Ping as u8 {
// TODO: do we need to decode the ping?
// we have received a ping, so we will send a pong
let mut pong_bytes = BytesMut::new();
P2PMessage::Pong.encode(&mut pong_bytes);
// check if the buffer is full
if this.outgoing_messages.len() >= MAX_P2P_CAPACITY {
return Poll::Ready(Some(Err(P2PStreamError::SendBufferFull)))
match id {
_ if id == P2PMessageID::Ping as u8 => {
// we have received a ping, so we will send a pong
let pong = P2PMessage::Pong;
let mut pong_bytes = BytesMut::with_capacity(pong.length());
pong.encode(&mut pong_bytes);
// check if the buffer is full
if this.outgoing_messages.len() >= MAX_P2P_CAPACITY {
return Poll::Ready(Some(Err(P2PStreamError::SendBufferFull)))
}
// continue to the next message if there is one
this.outgoing_messages.push_back(pong_bytes.into());
}
this.outgoing_messages.push_back(pong_bytes.into());
// continue to the next message if there is one
} else if id == P2PMessageID::Disconnect as u8 {
let reason = DisconnectReason::decode(&mut &bytes[1..])?;
return Poll::Ready(Some(Err(P2PStreamError::Disconnected(reason))))
} else if id == P2PMessageID::Hello as u8 {
// we have received a hello message outside of the handshake, so we will return an
// error
return Poll::Ready(Some(Err(P2PStreamError::HandshakeError(
P2PHandshakeError::HelloNotInHandshake,
))))
} else if id == P2PMessageID::Pong as u8 {
// TODO: do we need to decode the pong?
// if we were waiting for a pong, this will reset the pinger state
this.pinger.on_pong()?
} else if id > MAX_P2P_MESSAGE_ID && id <= MAX_RESERVED_MESSAGE_ID {
// we have received an unknown reserved message
return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(id))))
} else {
// first check that the compressed message length does not exceed the max message
// size
let decompressed_len = snap::raw::decompress_len(&bytes[1..])?;
if decompressed_len > MAX_PAYLOAD_SIZE {
return Poll::Ready(Some(Err(P2PStreamError::MessageTooBig {
message_size: decompressed_len,
max_size: MAX_PAYLOAD_SIZE,
})))
_ if id == P2PMessageID::Disconnect as u8 => {
let reason = DisconnectReason::decode(&mut &bytes[1..])?;
return Poll::Ready(Some(Err(P2PStreamError::Disconnected(reason))))
}
_ if id == P2PMessageID::Hello as u8 => {
// we have received a hello message outside of the handshake, so we will return
// an error
return Poll::Ready(Some(Err(P2PStreamError::HandshakeError(
P2PHandshakeError::HelloNotInHandshake,
))))
}
_ if id == P2PMessageID::Pong as u8 => {
// TODO: do we need to decode the pong?
// if we were waiting for a pong, this will reset the pinger state
this.pinger.on_pong()?
}
_ if id > MAX_P2P_MESSAGE_ID && id <= MAX_RESERVED_MESSAGE_ID => {
// we have received an unknown reserved message
return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(id))))
}
_ => {
// first check that the compressed message length does not exceed the max
// message size
let decompressed_len = snap::raw::decompress_len(&bytes[1..])?;
if decompressed_len > MAX_PAYLOAD_SIZE {
return Poll::Ready(Some(Err(P2PStreamError::MessageTooBig {
message_size: decompressed_len,
max_size: MAX_PAYLOAD_SIZE,
})))
}
// then decompress the message
let mut decompress_buf = BytesMut::zeroed(decompressed_len + 1);
// then decompress the message
let mut decompress_buf = BytesMut::zeroed(decompressed_len + 1);
// we have a subprotocol message that needs to be sent in the stream.
// first, switch the message id based on offset so the next layer can decode it
// without being aware of the p2p stream's state (shared capabilities / the message
// id offset)
decompress_buf[0] = bytes[0] - this.shared_capability.offset();
this.decoder.decompress(&bytes[1..], &mut decompress_buf[1..])?;
// we have a subprotocol message that needs to be sent in the stream.
// first, switch the message id based on offset so the next layer can decode it
// without being aware of the p2p stream's state (shared capabilities / the
// message id offset)
decompress_buf[0] = bytes[0] - this.shared_capability.offset();
this.decoder.decompress(&bytes[1..], &mut decompress_buf[1..])?;
return Poll::Ready(Some(Ok(decompress_buf)))
return Poll::Ready(Some(Ok(decompress_buf)))
}
}
}

View File

@ -289,7 +289,7 @@ pub enum EthMessageID {
}
impl Encodable for EthMessageID {
fn encode(&self, out: &mut dyn bytes::BufMut) {
fn encode(&self, out: &mut dyn BufMut) {
out.put_u8(*self as u8);
}
fn length(&self) -> usize {