feat(eth-wire): RLP encode then compress (#535)

This commit is contained in:
Dan Cline
2022-12-20 07:33:47 -05:00
committed by GitHub
parent 3df86187d1
commit aab385c84a
6 changed files with 163 additions and 277 deletions

1
Cargo.lock generated
View File

@ -3468,6 +3468,7 @@ dependencies = [
"reth-ecies",
"reth-primitives",
"reth-rlp",
"reth-tracing",
"secp256k1",
"serde",
"smol_str",

View File

@ -30,10 +30,11 @@ smol_str = { version = "0.1", features = ["serde"] }
metrics = "0.20.1"
[dev-dependencies]
test-fuzz = "3.0.4"
reth-ecies = { path = "../ecies" }
reth-tracing = { path = "../../tracing" }
ethers-core = { git = "https://github.com/gakonst/ethers-rs", default-features = false }
test-fuzz = "3.0.4"
tokio-util = { version = "0.7.4", features = ["io", "codec"] }
hex-literal = "0.3"
rand = "0.8"

View File

@ -1,7 +1,7 @@
//! Disconnect
use bytes::Buf;
use reth_rlp::{Decodable, DecodeError, Encodable};
use reth_rlp::{Decodable, DecodeError, Encodable, Header};
use serde::{Deserialize, Serialize};
use std::fmt::Display;
use thiserror::Error;
@ -102,64 +102,50 @@ impl TryFrom<u8> for DisconnectReason {
}
/// The [`Encodable`](reth_rlp::Encodable) implementation for [`DisconnectReason`] encodes the
/// disconnect reason as RLP, and prepends a snappy header to the RLP bytes.
/// disconnect reason in a single-element RLP list.
impl Encodable for DisconnectReason {
fn encode(&self, out: &mut dyn bytes::BufMut) {
// disconnect reasons are snappy encoded as follows:
// [0x02, 0x04, 0xc1, rlp(reason as u8)]
// this is 4 bytes
out.put_u8(0x02);
out.put_u8(0x04);
vec![*self as u8].encode(out);
}
fn length(&self) -> usize {
// disconnect reasons are snappy encoded as follows:
// [0x02, 0x04, 0xc1, rlp(reason as u8)]
// this is 4 bytes
4
vec![*self as u8].length()
}
}
/// The [`Decodable`](reth_rlp::Decodable) implementation for [`DisconnectReason`] assumes that the
/// input is snappy compressed.
/// The [`Decodable`](reth_rlp::Decodable) implementation for [`DisconnectReason`] supports either
/// a disconnect reason encoded a single byte or a RLP list containing the disconnect reason.
impl Decodable for DisconnectReason {
fn decode(buf: &mut &[u8]) -> Result<Self, DecodeError> {
if buf.is_empty() {
return Err(DecodeError::InputTooShort)
} else if buf.len() > 2 {
return Err(DecodeError::Overflow)
}
// encoded as a single byte
let reason_byte = if buf.len() == 1 {
u8::decode(buf)?
} else if buf.len() <= 4 {
// in any disconnect encoding, headers precede and do not wrap the reason, so we should
// advance to the end of the buffer
buf.advance(buf.len() - 1);
// geth rlp encodes [`DisconnectReason::DisconnectRequested`] as 0x00 and not as empty
// string 0x80
if buf[0] == 0x00 {
DisconnectReason::DisconnectRequested as u8
} else {
// the reason is encoded at the end of the snappy encoded bytes
u8::decode(buf)?
if buf.len() > 1 {
// this should be a list, so decode the list header. this should advance the buffer so
// buf[0] is the first (and only) element of the list.
let header = Header::decode(buf)?;
if !header.list {
return Err(DecodeError::UnexpectedString)
}
} else {
return Err(DecodeError::Custom("invalid disconnect reason length"))
};
}
let reason = DisconnectReason::try_from(reason_byte)
.map_err(|_| DecodeError::Custom("unknown disconnect reason"))?;
Ok(reason)
// geth rlp encodes [`DisconnectReason::DisconnectRequested`] as 0x00 and not as empty
// string 0x80
if buf[0] == 0x00 {
buf.advance(1);
Ok(DisconnectReason::DisconnectRequested)
} else {
DisconnectReason::try_from(u8::decode(buf)?)
.map_err(|_| DecodeError::Custom("unknown disconnect reason"))
}
}
}
#[cfg(test)]
mod tests {
use crate::{
p2pstream::{P2PMessage, P2PMessageID},
DisconnectReason,
};
use crate::{p2pstream::P2PMessage, DisconnectReason};
use reth_rlp::{Decodable, Encodable};
fn all_reasons() -> Vec<DisconnectReason> {
@ -198,7 +184,12 @@ mod tests {
#[test]
fn test_reason_too_short() {
assert!(DisconnectReason::decode(&mut &[0u8][..]).is_err())
assert!(DisconnectReason::decode(&mut &[0u8; 0][..]).is_err())
}
#[test]
fn test_reason_too_long() {
assert!(DisconnectReason::decode(&mut &[0u8; 3][..]).is_err())
}
#[test]
@ -215,123 +206,18 @@ mod tests {
}
}
#[test]
fn disconnect_snappy_encoding_parity() {
// encode disconnect using our `Encodable` implementation
let disconnect = P2PMessage::Disconnect(DisconnectReason::DisconnectRequested);
let mut disconnect_encoded = Vec::new();
disconnect.encode(&mut disconnect_encoded);
let mut disconnect_raw = Vec::new();
// encode [DisconnectRequested]
// DisconnectRequested will be converted to 0x80 (encoding of 0) in Encodable::encode
Encodable::encode(&vec![0x00u8], &mut disconnect_raw);
let mut snappy_encoder = snap::raw::Encoder::new();
let disconnect_compressed = snappy_encoder.compress_vec(&disconnect_raw).unwrap();
let mut disconnect_expected = vec![P2PMessageID::Disconnect as u8];
disconnect_expected.extend(&disconnect_compressed);
// ensure that the two encodings are equal
assert_eq!(
disconnect_expected, disconnect_encoded,
"left: {disconnect_expected:#x?}, right: {disconnect_encoded:#x?}"
);
// also ensure that the length is correct
assert_eq!(
disconnect_expected.len(),
P2PMessage::Disconnect(DisconnectReason::DisconnectRequested).length()
);
}
#[test]
fn disconnect_snappy_decoding_parity() {
// encode disconnect using our `Encodable` implementation
let disconnect = P2PMessage::Disconnect(DisconnectReason::DisconnectRequested);
let mut disconnect_encoded = Vec::new();
disconnect.encode(&mut disconnect_encoded);
// try to decode using Decodable
let p2p_message = P2PMessage::decode(&mut &disconnect_encoded[..]).unwrap();
assert_eq!(p2p_message, P2PMessage::Disconnect(DisconnectReason::DisconnectRequested));
// finally decode the encoded message with snappy
let mut snappy_decoder = snap::raw::Decoder::new();
// the message id is not compressed, only compress the latest bits
let decompressed = snappy_decoder.decompress_vec(&disconnect_encoded[1..]).unwrap();
let mut disconnect_raw = Vec::new();
// encode [DisconnectRequested]
// DisconnectRequested will be converted to 0x80 (encoding of 0) in Encodable::encode
Encodable::encode(&vec![0x00u8], &mut disconnect_raw);
assert_eq!(decompressed, disconnect_raw);
}
#[test]
fn test_decode_known_reasons() {
let all_reasons = vec![
// non-snappy, encoding the disconnect reason as a single byte
"0180",
"0101",
"0102",
"0103",
"0104",
"0105",
"0106",
"0107",
"0108",
"0109",
"010a",
"010b",
"0110",
// non-snappy, encoding the disconnect reason in a list
"01c180",
"01c101",
"01c102",
"01c103",
"01c104",
"01c105",
"01c106",
"01c107",
"01c108",
"01c109",
"01c10a",
"01c10b",
"01c110",
// snappy, compressing a single byte
"010080",
"010001",
"010002",
"010003",
"010004",
"010005",
"010006",
"010007",
"010008",
"010009",
"01000a",
"01000b",
"010010",
// TODO: just saw this format once, not really sure what this format even is
"01010003",
"01010000",
// snappy, encoded the disconnect reason as a list
"010204c180",
"010204c101",
"010204c102",
"010204c103",
"010204c104",
"010204c105",
"010204c106",
"010204c107",
"010204c108",
"010204c109",
"010204c10a",
"010204c10b",
"010204c110",
// encoding the disconnect reason as a single byte
"0100", // 0x00 case
"0180", // second 0x00 case
"0101", "0102", "0103", "0104", "0105", "0106", "0107", "0108", "0109", "010a", "010b",
"0110", // encoding the disconnect reason in a list
"01c100", // 0x00 case
"01c180", // second 0x00 case
"01c101", "01c102", "01c103", "01c104", "01c105", "01c106", "01c107", "01c108",
"01c109", "01c10a", "01c10b", "01c110",
];
for reason in all_reasons {
@ -345,7 +231,7 @@ mod tests {
#[test]
fn test_decode_disconnect_requested() {
let reason = "01010000";
let reason = "0100";
let reason = hex::decode(reason).unwrap();
match P2PMessage::decode(&mut &reason[..]).unwrap() {
P2PMessage::Disconnect(DisconnectReason::DisconnectRequested) => {}

View File

@ -108,47 +108,9 @@ mod tests {
use secp256k1::{SecretKey, SECP256K1};
use crate::{
capability::Capability,
p2pstream::{P2PMessage, P2PMessageID},
EthVersion, HelloMessage, ProtocolVersion,
capability::Capability, p2pstream::P2PMessage, EthVersion, HelloMessage, ProtocolVersion,
};
#[test]
fn test_pong_snappy_encoding_parity() {
// encode pong using our `Encodable` implementation
let pong = P2PMessage::Pong;
let mut pong_encoded = Vec::new();
pong.encode(&mut pong_encoded);
// the definition of pong is 0x80 (an empty rlp string)
let pong_raw = vec![EMPTY_STRING_CODE];
let mut snappy_encoder = snap::raw::Encoder::new();
let pong_compressed = snappy_encoder.compress_vec(&pong_raw).unwrap();
let mut pong_expected = vec![P2PMessageID::Pong as u8];
pong_expected.extend(&pong_compressed);
// ensure that the two encodings are equal
assert_eq!(
pong_expected, pong_encoded,
"left: {pong_expected:#x?}, right: {pong_encoded:#x?}"
);
// also ensure that the length is correct
assert_eq!(pong_expected.len(), P2PMessage::Pong.length());
// try to decode using Decodable
let p2p_message = P2PMessage::decode(&mut &pong_expected[..]).unwrap();
assert_eq!(p2p_message, P2PMessage::Pong);
// finally decode the encoded message with snappy
let mut snappy_decoder = snap::raw::Decoder::new();
// the message id is not compressed, only compress the latest bits
let decompressed = snappy_decoder.decompress_vec(&pong_encoded[1..]).unwrap();
assert_eq!(decompressed, pong_raw);
}
#[test]
fn test_hello_encoding_round_trip() {
let secret_key = SecretKey::new(&mut rand::thread_rng());

View File

@ -199,14 +199,43 @@ impl<S> P2PStream<S> {
/// reading new messages.
///
/// Once disconnect process has started, the [`Stream`] will terminate immediately.
pub fn start_disconnect(&mut self, reason: DisconnectReason) {
///
/// # Errors
///
/// Returns an error only if the message fails to compress.
pub fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), snap::Error> {
// clear any buffered messages and queue in
self.outgoing_messages.clear();
let disconnect = P2PMessage::Disconnect(reason);
let mut buf = BytesMut::with_capacity(disconnect.length());
disconnect.encode(&mut buf);
self.outgoing_messages.push_back(buf.freeze());
tracing::trace!(
fromlen=%buf.len(),
msg=%hex::encode(&buf),
"Compressing disconnect message",
);
let mut compressed = BytesMut::zeroed(1 + snap::raw::max_compress_len(buf.len() - 1));
let compressed_size = self.encoder.compress(&buf[1..], &mut compressed[1..])?;
// truncate the compressed buffer to the actual compressed size (plus one for the message
// id)
compressed.truncate(compressed_size + 1);
// we do not add the capability offset because the disconnect message is a `p2p` reserved
// message
compressed[0] = buf[0];
tracing::trace!(
tolen=%compressed.len(),
compressed=%hex::encode(&compressed),
"Compressed disconnect message",
);
self.outgoing_messages.push_back(compressed.freeze());
self.disconnecting = true;
Ok(())
}
}
@ -216,9 +245,10 @@ where
{
/// Disconnects the connection by sending a disconnect message.
///
/// This future resolves once the disconnect message has been sent.
/// This future resolves once the disconnect message has been sent and the stream has been
/// closed.
pub async fn disconnect(mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
self.start_disconnect(reason);
self.start_disconnect(reason)?;
self.close().await
}
}
@ -257,7 +287,7 @@ where
}
_ => {
// encode the disconnect message
this.start_disconnect(DisconnectReason::PingTimeout);
this.start_disconnect(DisconnectReason::PingTimeout)?;
// End the stream after ping related error
return Poll::Ready(None)
@ -273,8 +303,32 @@ where
None => return Poll::Ready(None),
};
let id = *bytes.first().ok_or(P2PStreamError::EmptyProtocolMessage)?;
// first check that the compressed message length does not exceed the max
// payload size
let decompressed_len = snap::raw::decompress_len(&bytes[1..])?;
if decompressed_len > MAX_PAYLOAD_SIZE {
return Poll::Ready(Some(Err(P2PStreamError::MessageTooBig {
message_size: decompressed_len,
max_size: MAX_PAYLOAD_SIZE,
})))
}
// create a buffer to hold the decompressed message, adding a byte to the length for
// the message ID byte, which is the first byte in this buffer
let mut decompress_buf = BytesMut::zeroed(decompressed_len + 1);
tracing::trace!(
fromlen=%bytes.len(),
tolen=%decompress_buf.len(),
msg=%hex::encode(&bytes),
"Decompressing message",
);
// each message following a successful handshake is compressed with snappy, so we need
// to decompress the message before we can decode it.
this.decoder.decompress(&bytes[1..], &mut decompress_buf[1..])?;
let id = *bytes.first().ok_or(P2PStreamError::EmptyProtocolMessage)?;
match id {
_ if id == P2PMessageID::Ping as u8 => {
// we have received a ping, so we will send a pong
@ -290,9 +344,9 @@ where
this.outgoing_messages.push_back(pong_bytes.into());
}
_ if id == P2PMessageID::Disconnect as u8 => {
let reason = DisconnectReason::decode(&mut &bytes[1..]).map_err(|err| {
let reason = DisconnectReason::decode(&mut &decompress_buf[1..]).map_err(|err| {
tracing::warn!(
?err, msg=%hex::encode(&bytes[1..]), "Failed to decode disconnect message from peer"
?err, msg=%hex::encode(&decompress_buf[1..]), "Failed to decode disconnect message from peer"
);
err
})?;
@ -315,25 +369,30 @@ where
return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(id))))
}
_ => {
// first check that the compressed message length does not exceed the max
// message size
let decompressed_len = snap::raw::decompress_len(&bytes[1..])?;
if decompressed_len > MAX_PAYLOAD_SIZE {
return Poll::Ready(Some(Err(P2PStreamError::MessageTooBig {
message_size: decompressed_len,
max_size: MAX_PAYLOAD_SIZE,
})))
}
// we have received a message that is outside the `p2p` reserved message space,
// so it is a subprotocol message.
// then decompress the message
let mut decompress_buf = BytesMut::zeroed(decompressed_len + 1);
// we have a subprotocol message that needs to be sent in the stream.
// first, switch the message id based on offset so the next layer can decode it
// without being aware of the p2p stream's state (shared capabilities / the
// message id offset)
// Peers must be able to identify messages meant for different subprotocols
// using a single message ID byte, and those messages must be distinct from the
// lower-level `p2p` messages.
//
// To ensure that messages for subprotocols are distinct from messages meant
// for the `p2p` capability, message IDs 0x00 - 0x0f are reserved for `p2p`
// messages, so subprotocol messages must have an ID of 0x10 or higher.
//
// To ensure that messages for two different capabilities are distinct from
// each other, all shared capabilities are first ordered lexicographically.
// Message IDs are then reserved in this order, starting at 0x10, reserving a
// message ID for each message the capability supports.
//
// For example, if the shared capabilities are `eth/67` (containing 10
// messages), and "qrs/65" (containing 8 messages):
//
// * The special case of `p2p`: `p2p` is reserved message IDs 0x00 - 0x0f.
// * `eth/67` is reserved message IDs 0x10 - 0x19.
// * `qrs/65` is reserved message IDs 0x1a - 0x21.
//
decompress_buf[0] = bytes[0] - this.shared_capability.offset();
this.decoder.decompress(&bytes[1..], &mut decompress_buf[1..])?;
return Poll::Ready(Some(Ok(decompress_buf)))
}
@ -380,17 +439,22 @@ where
return Err(P2PStreamError::SendBufferFull)
}
let mut compressed = BytesMut::zeroed(1 + snap::raw::max_compress_len(item.len() - 1));
tracing::trace!(
fromlen=%item.len(),
msg=%hex::encode(&item),
"Compressing message",
);
// all messages sent in this stream are subprotocol messages, so we need to switch the
// message id based on the offset
compressed[0] = item[0] + this.shared_capability.offset();
let mut compressed = BytesMut::zeroed(1 + snap::raw::max_compress_len(item.len() - 1));
let compressed_size = this.encoder.compress(&item[1..], &mut compressed[1..])?;
// truncate the compressed buffer to the actual compressed size (plus one for the message
// id)
compressed.truncate(compressed_size + 1);
// all messages sent in this stream are subprotocol messages, so we need to switch the
// message id based on the offset
compressed[0] = item[0] + this.shared_capability.offset();
this.outgoing_messages.push_back(compressed.freeze());
Ok(())
@ -552,13 +616,9 @@ impl Encodable for P2PMessage {
P2PMessage::Hello(msg) => msg.encode(out),
P2PMessage::Disconnect(msg) => msg.encode(out),
P2PMessage::Ping => {
out.put_u8(0x01);
out.put_u8(0x00);
out.put_u8(EMPTY_STRING_CODE);
}
P2PMessage::Pong => {
out.put_u8(0x01);
out.put_u8(0x00);
out.put_u8(EMPTY_STRING_CODE);
}
}
@ -568,8 +628,8 @@ impl Encodable for P2PMessage {
let payload_len = match self {
P2PMessage::Hello(msg) => msg.length(),
P2PMessage::Disconnect(msg) => msg.length(),
P2PMessage::Ping => 3, // len([0x01, 0x00, 0x80]) = 3
P2PMessage::Pong => 3, // len([0x01, 0x00, 0x80]) = 3
P2PMessage::Ping => 1,
P2PMessage::Pong => 1,
};
payload_len + 1 // (1 for length of p2p message id)
}
@ -588,13 +648,11 @@ impl Decodable for P2PMessage {
P2PMessageID::Hello => Ok(P2PMessage::Hello(HelloMessage::decode(buf)?)),
P2PMessageID::Disconnect => Ok(P2PMessage::Disconnect(DisconnectReason::decode(buf)?)),
P2PMessageID::Ping => {
// len([0x01, 0x00, 0x80]) = 3
buf.advance(3);
buf.advance(1);
Ok(P2PMessage::Ping)
}
P2PMessageID::Pong => {
// len([0x01, 0x00, 0x80]) = 3
buf.advance(3);
buf.advance(1);
Ok(P2PMessage::Pong)
}
}
@ -683,7 +741,6 @@ mod tests {
use super::*;
use crate::{DisconnectReason, EthVersion};
use reth_ecies::util::pk2id;
use reth_rlp::EMPTY_STRING_CODE;
use secp256k1::{SecretKey, SECP256K1};
use tokio::net::{TcpListener, TcpStream};
use tokio_util::codec::Decoder;
@ -703,6 +760,7 @@ mod tests {
#[tokio::test]
async fn test_can_disconnect() {
reth_tracing::init_tracing();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
@ -732,7 +790,7 @@ mod tests {
let err = p2p_stream.next().await.unwrap().unwrap_err();
match err {
P2PStreamError::Disconnected(reason) => assert_eq!(reason, expected_disconnect),
_ => panic!("unexpected err"),
e => panic!("unexpected err: {e}"),
}
}
@ -783,40 +841,4 @@ mod tests {
// make sure the server receives the message and asserts before ending the test
handle.await.unwrap();
}
#[test]
fn test_ping_snappy_encoding_parity() {
// encode ping using our `Encodable` implementation
let ping = P2PMessage::Ping;
let mut ping_encoded = Vec::new();
ping.encode(&mut ping_encoded);
// the definition of ping is 0x80 (an empty rlp string)
let ping_raw = vec![EMPTY_STRING_CODE];
let mut snappy_encoder = snap::raw::Encoder::new();
let ping_compressed = snappy_encoder.compress_vec(&ping_raw).unwrap();
let mut ping_expected = vec![P2PMessageID::Ping as u8];
ping_expected.extend(&ping_compressed);
// ensure that the two encodings are equal
assert_eq!(
ping_expected, ping_encoded,
"left: {ping_expected:#x?}, right: {ping_encoded:#x?}"
);
// also ensure that the length is correct
assert_eq!(ping_expected.len(), P2PMessage::Ping.length());
// try to decode using Decodable
let p2p_message = P2PMessage::decode(&mut &ping_expected[..]).unwrap();
assert_eq!(p2p_message, P2PMessage::Ping);
// finally decode the encoded message with snappy
let mut snappy_decoder = snap::raw::Decoder::new();
// the message id is not compressed, only compress the latest bits
let decompressed = snappy_decoder.decompress_vec(&ping_encoded[1..]).unwrap();
assert_eq!(decompressed, ping_raw);
}
}

View File

@ -12,7 +12,7 @@ use futures::{stream::Fuse, SinkExt, StreamExt};
use reth_ecies::stream::ECIESStream;
use reth_eth_wire::{
capability::Capabilities,
error::{EthStreamError, HandshakeError},
error::{EthStreamError, HandshakeError, P2PStreamError},
message::{EthBroadcastMessage, RequestPair},
DisconnectReason, EthMessage, EthStream, P2PStream,
};
@ -290,8 +290,12 @@ impl ActiveSession {
}
/// Starts the disconnect process
fn start_disconnect(&mut self, reason: DisconnectReason) {
self.conn.inner_mut().start_disconnect(reason);
fn start_disconnect(&mut self, reason: DisconnectReason) -> Result<(), EthStreamError> {
self.conn
.inner_mut()
.start_disconnect(reason)
.map_err(P2PStreamError::from)
.map_err(Into::into)
}
/// Flushes the disconnect message and emits the corresponding message
@ -348,8 +352,18 @@ impl Future for ActiveSession {
SessionCommand::Disconnect { reason } => {
let reason =
reason.unwrap_or(DisconnectReason::DisconnectRequested);
this.start_disconnect(reason);
return this.poll_disconnect(cx)
// try to disconnect
match this.start_disconnect(reason) {
Ok(()) => {
// we're done
return this.poll_disconnect(cx)
}
Err(err) => {
error!(target: "net::session", ?err, remote_peer_id=?this.remote_peer_id, "could not send disconnect");
this.close_on_error(err);
return Poll::Ready(())
}
}
}
SessionCommand::Message(msg) => {
this.on_peer_message(msg);
@ -660,7 +674,7 @@ mod tests {
let (incoming, _) = listener.accept().await.unwrap();
let mut session = builder.connect_incoming(incoming).await;
session.start_disconnect(expected_disconnect);
session.start_disconnect(expected_disconnect).unwrap();
session.await
});