fix(net): encode ping pong as snappy (#555)

This commit is contained in:
Matthias Seitz
2022-12-21 14:24:41 +01:00
committed by GitHub
parent f3c79ff61c
commit 151420df58

View File

@ -9,7 +9,7 @@ use bytes::{Buf, Bytes, BytesMut};
use futures::{Sink, SinkExt, StreamExt};
use metrics::counter;
use pin_project::pin_project;
use reth_rlp::{Decodable, DecodeError, Encodable, EMPTY_STRING_CODE};
use reth_rlp::{Decodable, DecodeError, Encodable, EMPTY_LIST_CODE};
use serde::{Deserialize, Serialize};
use std::{
collections::{BTreeSet, HashMap, VecDeque},
@ -195,6 +195,14 @@ impl<S> P2PStream<S> {
self.disconnecting
}
/// Queues in a _snappy_ encoded [`P2PMessage::Pong`] message.
fn send_pong(&mut self) {
let pong = P2PMessage::Pong;
let mut pong_bytes = BytesMut::with_capacity(pong.length());
pong.encode(&mut pong_bytes);
self.outgoing_messages.push_back(pong_bytes.into());
}
/// Starts to gracefully disconnect the connection by sending a Disconnect message and stop
/// reading new messages.
///
@ -331,17 +339,8 @@ where
let id = *bytes.first().ok_or(P2PStreamError::EmptyProtocolMessage)?;
match id {
_ if id == P2PMessageID::Ping as u8 => {
// we have received a ping, so we will send a pong
let pong = P2PMessage::Pong;
let mut pong_bytes = BytesMut::with_capacity(pong.length());
pong.encode(&mut pong_bytes);
// check if the buffer is full
if this.outgoing_messages.len() >= MAX_P2P_CAPACITY {
return Poll::Ready(Some(Err(P2PStreamError::SendBufferFull)))
}
// continue to the next message if there is one
this.outgoing_messages.push_back(pong_bytes.into());
tracing::trace!("Received Ping, Sending Pong");
this.send_pong();
}
_ if id == P2PMessageID::Disconnect as u8 => {
let reason = DisconnectReason::decode(&mut &decompress_buf[1..]).map_err(|err| {
@ -360,7 +359,6 @@ where
))))
}
_ if id == P2PMessageID::Pong as u8 => {
// TODO: do we need to decode the pong?
// if we were waiting for a pong, this will reset the pinger state
this.pinger.on_pong()?
}
@ -616,10 +614,16 @@ impl Encodable for P2PMessage {
P2PMessage::Hello(msg) => msg.encode(out),
P2PMessage::Disconnect(msg) => msg.encode(out),
P2PMessage::Ping => {
out.put_u8(EMPTY_STRING_CODE);
// Ping payload is _always_ snappy encoded
out.put_u8(0x01);
out.put_u8(0x00);
out.put_u8(EMPTY_LIST_CODE);
}
P2PMessage::Pong => {
out.put_u8(EMPTY_STRING_CODE);
// Pong payload is _always_ snappy encoded
out.put_u8(0x01);
out.put_u8(0x00);
out.put_u8(EMPTY_LIST_CODE);
}
}
}
@ -628,8 +632,9 @@ impl Encodable for P2PMessage {
let payload_len = match self {
P2PMessage::Hello(msg) => msg.length(),
P2PMessage::Disconnect(msg) => msg.length(),
P2PMessage::Ping => 1,
P2PMessage::Pong => 1,
// id + snappy encoded payload
P2PMessage::Ping => 3, // len([0x01, 0x00, 0xc0]) = 3
P2PMessage::Pong => 3, // len([0x01, 0x00, 0xc0]) = 3
};
payload_len + 1 // (1 for length of p2p message id)
}
@ -841,4 +846,26 @@ mod tests {
// make sure the server receives the message and asserts before ending the test
handle.await.unwrap();
}
#[test]
fn snappy_decode_encode_ping() {
let snappy_ping = b"\x02\x01\0\xc0";
let ping = P2PMessage::decode(&mut &snappy_ping[..]).unwrap();
assert!(matches!(ping, P2PMessage::Ping));
let mut buf = BytesMut::with_capacity(ping.length());
ping.encode(&mut buf);
assert_eq!(buf.as_ref(), &snappy_ping[..]);
}
#[test]
fn snappy_decode_encode_pong() {
let snappy_pong = b"\x03\x01\0\xc0";
let pong = P2PMessage::decode(&mut &snappy_pong[..]).unwrap();
assert!(matches!(pong, P2PMessage::Pong));
let mut buf = BytesMut::with_capacity(pong.length());
pong.encode(&mut buf);
assert_eq!(buf.as_ref(), &snappy_pong[..]);
}
}