fix: handle both compressed and uncompressed disconnect reason decoding (#6862)

This commit is contained in:
Dan Cline
2024-02-29 10:31:11 -05:00
committed by GitHub
parent 9468527aad
commit 7d36206dfe

View File

@ -418,6 +418,28 @@ where
return Poll::Ready(Some(Err(P2PStreamError::EmptyProtocolMessage))) return Poll::Ready(Some(Err(P2PStreamError::EmptyProtocolMessage)))
} }
// first decode disconnect reasons, because they can be encoded in a variety of forms
// over the wire, in both snappy compressed and uncompressed forms.
//
// see: [crate::disconnect::tests::test_decode_known_reasons]
let id = bytes[0];
if id == P2PMessageID::Disconnect as u8 {
// We can't handle the error here because disconnect reasons are encoded as both:
// * snappy compressed, AND
// * uncompressed
// over the network.
//
// If the decoding succeeds, we already checked the id and know this is a
// disconnect message, so we can return with the reason.
//
// If the decoding fails, we continue, and will attempt to decode it again if the
// message is snappy compressed. Failure handling in that step is the primary point
// where an error is returned if the disconnect reason is malformed.
if let Ok(reason) = DisconnectReason::decode(&mut &bytes[1..]) {
return Poll::Ready(Some(Err(P2PStreamError::Disconnected(reason))))
}
}
// first check that the compressed message length does not exceed the max // first check that the compressed message length does not exceed the max
// payload size // payload size
let decompressed_len = snap::raw::decompress_len(&bytes[1..])?; let decompressed_len = snap::raw::decompress_len(&bytes[1..])?;
@ -443,7 +465,6 @@ where
err err
})?; })?;
let id = bytes[0];
match id { match id {
_ if id == P2PMessageID::Ping as u8 => { _ if id == P2PMessageID::Ping as u8 => {
trace!("Received Ping, Sending Pong"); trace!("Received Ping, Sending Pong");
@ -452,15 +473,6 @@ where
// that happens, the pong will never be sent. // that happens, the pong will never be sent.
cx.waker().wake_by_ref(); cx.waker().wake_by_ref();
} }
_ if id == P2PMessageID::Disconnect as u8 => {
let reason = DisconnectReason::decode(&mut &decompress_buf[1..]).map_err(|err| {
debug!(
%err, msg=%hex::encode(&decompress_buf[1..]), "Failed to decode disconnect message from peer"
);
err
})?;
return Poll::Ready(Some(Err(P2PStreamError::Disconnected(reason))))
}
_ if id == P2PMessageID::Hello as u8 => { _ if id == P2PMessageID::Hello as u8 => {
// we have received a hello message outside of the handshake, so we will return // we have received a hello message outside of the handshake, so we will return
// an error // an error
@ -472,6 +484,20 @@ where
// if we were waiting for a pong, this will reset the pinger state // if we were waiting for a pong, this will reset the pinger state
this.pinger.on_pong()? this.pinger.on_pong()?
} }
_ if id == P2PMessageID::Disconnect as u8 => {
// At this point, the `decempres_buf` contains the snappy decompressed
// disconnect message.
//
// It's possible we already tried to RLP decode this, but it was snappy
// compressed, so we need to RLP decode it again.
let reason = DisconnectReason::decode(&mut &decompress_buf[1..]).map_err(|err| {
debug!(
%err, msg=%hex::encode(&decompress_buf[1..]), "Failed to decode disconnect message from peer"
);
err
})?;
return Poll::Ready(Some(Err(P2PStreamError::Disconnected(reason))))
}
_ if id > MAX_P2P_MESSAGE_ID && id <= MAX_RESERVED_MESSAGE_ID => { _ if id > MAX_P2P_MESSAGE_ID && id <= MAX_RESERVED_MESSAGE_ID => {
// we have received an unknown reserved message // we have received an unknown reserved message
return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(id)))) return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(id))))
@ -850,6 +876,52 @@ mod tests {
handle.await.unwrap(); handle.await.unwrap();
} }
#[tokio::test]
async fn test_can_disconnect_weird_disconnect_encoding() {
reth_tracing::init_test_tracing();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let expected_disconnect = DisconnectReason::SubprotocolSpecific;
let handle = tokio::spawn(async move {
// roughly based off of the design of tokio::net::TcpListener
let (incoming, _) = listener.accept().await.unwrap();
let stream = crate::PassthroughCodec::default().framed(incoming);
let (server_hello, _) = eth_hello();
let (mut p2p_stream, _) =
UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
// Unrolled `disconnect` method, without compression
p2p_stream.outgoing_messages.clear();
let disconnect = P2PMessage::Disconnect(DisconnectReason::SubprotocolSpecific);
let mut buf = BytesMut::with_capacity(disconnect.length());
disconnect.encode(&mut buf);
p2p_stream.outgoing_messages.push_back(buf.freeze());
p2p_stream.disconnecting = true;
p2p_stream.close().await.unwrap();
});
let outgoing = TcpStream::connect(local_addr).await.unwrap();
let sink = crate::PassthroughCodec::default().framed(outgoing);
let (client_hello, _) = eth_hello();
let (mut p2p_stream, _) =
UnauthedP2PStream::new(sink).handshake(client_hello).await.unwrap();
let err = p2p_stream.next().await.unwrap().unwrap_err();
match err {
P2PStreamError::Disconnected(reason) => assert_eq!(reason, expected_disconnect),
e => panic!("unexpected err: {e}"),
}
handle.await.unwrap();
}
#[tokio::test] #[tokio::test]
async fn test_handshake_passthrough() { async fn test_handshake_passthrough() {
// create a p2p stream and server, then confirm that the two are authed // create a p2p stream and server, then confirm that the two are authed