test: add mock discovery testing (#139)

This commit is contained in:
Matthias Seitz
2022-10-26 14:33:13 +02:00
committed by GitHub
parent 61b8829bdf
commit 6c0e2753dd
4 changed files with 419 additions and 101 deletions

View File

@ -35,8 +35,12 @@ thiserror = "1.0"
url = "2.3"
hex = "0.4"
public-ip = "0.2"
rand = { version = "0.8", optional = true }
[dev-dependencies]
rand = "0.8"
tokio = { version = "1", features = ["full"] }
tracing-test = "0.2"
[features]
mock = ["rand"]

View File

@ -35,13 +35,12 @@ use secp256k1::SecretKey;
use std::{
cell::RefCell,
collections::{btree_map, hash_map::Entry, BTreeMap, HashMap, VecDeque},
future::Future,
io,
net::SocketAddr,
pin::Pin,
rc::Rc,
sync::Arc,
task::{Context, Poll},
task::{ready, Context, Poll},
time::{Duration, Instant, SystemTime, UNIX_EPOCH},
};
use tokio::{
@ -50,7 +49,7 @@ use tokio::{
task::{JoinHandle, JoinSet},
time::Interval,
};
use tokio_stream::{wrappers::ReceiverStream, StreamExt};
use tokio_stream::{wrappers::ReceiverStream, Stream, StreamExt};
use tracing::{debug, instrument, trace, warn};
pub mod bootnodes;
@ -62,6 +61,9 @@ pub use config::Discv4Config;
mod node;
pub use node::NodeRecord;
#[cfg(any(test, feature = "mock"))]
pub mod mock;
/// reexport to get public ip.
pub use public_ip;
@ -98,8 +100,8 @@ const NODE_LAST_SEEN_TIMEOUT: Duration = Duration::from_secs(24 * 60 * 60);
type EgressSender = mpsc::Sender<(Bytes, SocketAddr)>;
type EgressReceiver = mpsc::Receiver<(Bytes, SocketAddr)>;
type IngressSender = mpsc::Sender<IngressEvent>;
type IngressReceiver = mpsc::Receiver<IngressEvent>;
pub(crate) type IngressSender = mpsc::Sender<IngressEvent>;
pub(crate) type IngressReceiver = mpsc::Receiver<IngressEvent>;
type NodeRecordSender = OneshotSender<Vec<NodeRecord>>;
@ -348,10 +350,27 @@ impl Discv4Service {
}
/// Returns the address of the UDP socket
pub fn local_address(&self) -> SocketAddr {
pub fn local_addr(&self) -> SocketAddr {
self.local_address
}
/// Returns the ENR of this service.
pub fn local_enr(&self) -> NodeRecord {
self.local_enr
}
/// Returns mutable reference to ENR for testing.
#[cfg(test)]
pub fn local_enr_mut(&mut self) -> &mut NodeRecord {
&mut self.local_enr
}
/// Returns true if the given NodeId is currently in the bucket
pub fn contains_node(&self, id: NodeId) -> bool {
let key = kad_key(id);
self.kbuckets.get_index(&key).is_some()
}
/// Bootstraps the local node to join the DHT.
///
/// Bootstrapping is a multi-step operation that starts with a lookup of the local node's
@ -374,7 +393,9 @@ impl Discv4Service {
pub fn spawn(mut self) -> JoinHandle<()> {
tokio::task::spawn(async move {
self.bootstrap();
self.await
while let Some(event) = self.next().await {
trace!(?event, target = "net::disc", "processed");
}
})
}
@ -518,7 +539,7 @@ impl Discv4Service {
}
/// Encodes the packet, sends it and returns the hash.
fn send_packet(&mut self, msg: Message, to: SocketAddr) -> H256 {
pub(crate) fn send_packet(&mut self, msg: Message, to: SocketAddr) -> H256 {
let (payload, hash) = msg.encode(&self.secret_key);
trace!(r#type=?msg.msg_type(), ?to, ?hash, target = "net::disc", "sending packet");
let _ = self.egress.try_send((payload, to));
@ -555,14 +576,16 @@ impl Discv4Service {
}
if self.pending_pings.len() < MAX_NODES_PING {
self.send_ping(node, reason)
self.send_ping(node, reason);
} else {
self.queued_pings.push_back((node, reason))
self.queued_pings.push_back((node, reason));
}
}
/// Sends a ping message to the node's UDP address.
fn send_ping(&mut self, node: NodeRecord, reason: PingReason) {
///
/// Returns the echo hash of the ping message.
pub(crate) fn send_ping(&mut self, node: NodeRecord, reason: PingReason) -> H256 {
let remote_addr = node.udp_addr();
let id = node.id;
let ping =
@ -572,6 +595,7 @@ impl Discv4Service {
self.pending_pings
.insert(id, PingRequest { sent_at: Instant::now(), node, echo_hash, reason });
echo_hash
}
/// Message handler for an incoming `Pong`.
@ -820,7 +844,7 @@ impl Discv4Service {
/// To prevent traffic amplification attacks, implementations must verify that the sender of a
/// query participates in the discovery protocol. The sender of a packet is considered verified
/// if it has sent a valid Pong response with matching ping hash within the last 12 hours.
pub(crate) fn poll(&mut self, cx: &mut Context<'_>) -> Poll<()> {
pub(crate) fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Discv4Event> {
// trigger self lookup
if self.lookup_interval.poll_tick(cx).is_ready() {
let target = self.lookup_rotator.next(&self.local_enr.id);
@ -874,15 +898,19 @@ impl Discv4Service {
match msg {
Message::Ping(ping) => {
self.on_ping(ping, remote_addr, node_id, hash);
return Poll::Ready(Discv4Event::Ping)
}
Message::Pong(pong) => {
self.on_pong(pong, remote_addr, node_id);
return Poll::Ready(Discv4Event::Pong)
}
Message::FindNode(msg) => {
self.on_find_node(msg, remote_addr, node_id);
return Poll::Ready(Discv4Event::FindNode)
}
Message::Neighbours(msg) => {
self.on_neighbours(msg, remote_addr, node_id);
return Poll::Ready(Discv4Event::Neighbours)
}
}
}
@ -897,16 +925,31 @@ impl Discv4Service {
}
/// Endless future impl
impl Future for Discv4Service {
type Output = ();
impl Stream for Discv4Service {
type Item = Discv4Event;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.get_mut().poll(cx)
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Poll::Ready(Some(ready!(self.get_mut().poll(cx))))
}
}
/// The Event type the Service stream produces.
///
/// This is mainly used for testing purposes and represents messages the service processed
#[derive(Debug, Eq, PartialEq)]
pub enum Discv4Event {
/// A `Ping` message was handled.
Ping,
/// A `Pong` message was handled.
Pong,
/// A `FindNode` message was handled.
FindNode,
/// A `Neighbours` message was handled.
Neighbours,
}
/// Continuously reads new messages from the channel and writes them to the socket
async fn send_loop(udp: Arc<UdpSocket>, rx: EgressReceiver) {
pub(crate) async fn send_loop(udp: Arc<UdpSocket>, rx: EgressReceiver) {
let mut stream = ReceiverStream::new(rx);
while let Some((payload, to)) = stream.next().await {
match udp.send_to(&payload, to).await {
@ -921,7 +964,7 @@ async fn send_loop(udp: Arc<UdpSocket>, rx: EgressReceiver) {
}
/// Continuously awaits new incoming messages and sends them back through the channel.
async fn receive_loop(udp: Arc<UdpSocket>, tx: IngressSender, local_id: NodeId) {
pub(crate) async fn receive_loop(udp: Arc<UdpSocket>, tx: IngressSender, local_id: NodeId) {
loop {
let mut buf = [0; MAX_PACKET_SIZE];
let res = udp.recv_from(&mut buf).await;
@ -1204,37 +1247,18 @@ pub enum TableUpdate {
#[cfg(test)]
mod tests {
use super::*;
use crate::bootnodes::mainnet_nodes;
use rand::thread_rng;
use secp256k1::SECP256K1;
use std::str::FromStr;
use crate::{
bootnodes::mainnet_nodes,
mock::{create_discv4, create_discv4_with_config},
};
use tracing_test::traced_test;
async fn create() -> (Discv4, Discv4Service) {
create_with_config(Default::default()).await
}
async fn create_with_config(config: Discv4Config) -> (Discv4, Discv4Service) {
let mut rng = thread_rng();
let socket = SocketAddr::from_str("0.0.0.0:30303").unwrap();
let (secret_key, pk) = SECP256K1.generate_keypair(&mut rng);
let id = NodeId::from_slice(&pk.serialize_uncompressed()[1..]);
let external_addr = public_ip::addr().await.unwrap_or_else(|| socket.ip());
let local_enr = NodeRecord {
address: external_addr,
tcp_port: socket.port(),
udp_port: socket.port(),
id,
};
Discv4::bind(socket, local_enr, secret_key, config).await.unwrap()
}
#[tokio::test]
#[traced_test]
async fn test_pending_ping() {
let (_, mut service) = create().await;
let (_, mut service) = create_discv4().await;
let local_addr = service.local_address();
let local_addr = service.local_addr();
for idx in 0..MAX_NODES_PING {
let node = NodeRecord::new(local_addr, NodeId::random());
@ -1250,7 +1274,7 @@ mod tests {
async fn test_lookup() {
let all_nodes = mainnet_nodes();
let config = Discv4Config::builder().add_boot_nodes(all_nodes).build();
let (_discv4, mut service) = create_with_config(config).await;
let (_discv4, mut service) = create_discv4_with_config(config).await;
let mut updates = service.update_stream();

View File

@ -0,0 +1,343 @@
//! Mock discovery support
#![allow(missing_docs, unused)]
use crate::{
node::NodeRecord,
proto::{FindNode, Message, Neighbours, NodeEndpoint, Packet, Ping, Pong},
receive_loop, send_loop, Discv4, Discv4Config, Discv4Service, EgressSender, IngressEvent,
IngressReceiver, NodeId, SAFE_MAX_DATAGRAM_NEIGHBOUR_RECORDS,
};
use rand::{thread_rng, Rng, RngCore};
use reth_primitives::H256;
use secp256k1::{SecretKey, SECP256K1};
use std::{
collections::{HashMap, HashSet},
io,
net::{IpAddr, SocketAddr},
pin::Pin,
str::FromStr,
sync::Arc,
task::{Context, Poll},
time::{Duration, SystemTime, UNIX_EPOCH},
};
use tokio::{
net::UdpSocket,
sync::mpsc,
task::{JoinHandle, JoinSet},
};
use tokio_stream::{Stream, StreamExt};
use tracing::error;
/// Mock discovery node
pub struct MockDiscovery {
local_addr: SocketAddr,
local_enr: NodeRecord,
secret_key: SecretKey,
udp: Arc<UdpSocket>,
_tasks: JoinSet<()>,
/// Receiver for incoming messages
ingress: IngressReceiver,
/// Sender for sending outgoing messages
egress: EgressSender,
pending_pongs: HashSet<NodeId>,
pending_neighbours: HashMap<NodeId, Vec<NodeRecord>>,
command_rx: mpsc::Receiver<MockCommand>,
}
impl MockDiscovery {
/// Creates a new instance and opens a socket
pub async fn new() -> io::Result<(Self, mpsc::Sender<MockCommand>)> {
let mut rng = thread_rng();
let socket = SocketAddr::from_str("0.0.0.0:0").unwrap();
let (secret_key, pk) = SECP256K1.generate_keypair(&mut rng);
let id = NodeId::from_slice(&pk.serialize_uncompressed()[1..]);
let socket = Arc::new(UdpSocket::bind(socket).await?);
let local_addr = socket.local_addr()?;
let local_enr = NodeRecord {
address: local_addr.ip(),
tcp_port: local_addr.port(),
udp_port: local_addr.port(),
id,
};
let (ingress_tx, ingress_rx) = mpsc::channel(128);
let (egress_tx, egress_rx) = mpsc::channel(128);
let mut tasks = JoinSet::<()>::new();
let udp = Arc::clone(&socket);
tasks.spawn(async move { receive_loop(udp, ingress_tx, local_enr.id).await });
let udp = Arc::clone(&socket);
tasks.spawn(async move { send_loop(udp, egress_rx).await });
let (tx, command_rx) = mpsc::channel(128);
let this = Self {
_tasks: tasks,
ingress: ingress_rx,
egress: egress_tx,
local_addr,
local_enr,
secret_key,
udp: socket,
pending_pongs: Default::default(),
pending_neighbours: Default::default(),
command_rx,
};
Ok((this, tx))
}
/// Spawn and consume the stream.
pub fn spawn(mut self) -> JoinHandle<()> {
tokio::task::spawn(async move {
let _: Vec<_> = self.collect().await;
})
}
/// Queue a pending pong.
pub fn queue_pong(&mut self, from: NodeId) {
self.pending_pongs.insert(from);
}
/// Queue a pending Neighbours response.
pub fn queue_neighbours(&mut self, target: NodeId, nodes: Vec<NodeRecord>) {
self.pending_neighbours.insert(target, nodes);
}
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
pub fn local_enr(&self) -> NodeRecord {
self.local_enr
}
/// Encodes the packet, sends it and returns the hash.
fn send_packet(&mut self, msg: Message, to: SocketAddr) -> H256 {
let (payload, hash) = msg.encode(&self.secret_key);
let _ = self.egress.try_send((payload, to));
hash
}
fn send_neighbours_timeout(&self) -> u64 {
(SystemTime::now().duration_since(UNIX_EPOCH).unwrap() + Duration::from_secs(30)).as_secs()
}
}
impl Stream for MockDiscovery {
type Item = MockEvent;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
// process all incoming commands
while let Poll::Ready(maybe_cmd) = this.command_rx.poll_recv(cx) {
if let Some(cmd) = maybe_cmd {
match cmd {
MockCommand::MockPong { node_id } => {
this.queue_pong(node_id);
}
MockCommand::MockNeighbours { target, nodes } => {
this.queue_neighbours(target, nodes);
}
}
} else {
return Poll::Ready(None)
}
}
while let Poll::Ready(Some(event)) = this.ingress.poll_recv(cx) {
match event {
IngressEvent::RecvError(_) => {}
IngressEvent::BadPacket(from, err, data) => {
error!(?from, ?err, packet=?hex::encode(&data), target = "net::disc", "bad packet");
}
IngressEvent::Packet(remote_addr, Packet { msg, node_id, hash }) => match msg {
Message::Ping(ping) => {
if this.pending_pongs.remove(&node_id) {
let pong = Pong { to: ping.from, echo: hash, expire: ping.expire };
let msg = Message::Pong(pong.clone());
this.send_packet(msg, remote_addr);
return Poll::Ready(Some(MockEvent::Pong {
ping,
pong,
to: remote_addr,
}))
}
}
Message::Pong(_) => {}
Message::FindNode(msg) => {
if let Some(nodes) = this.pending_neighbours.remove(&msg.id) {
let msg = Message::Neighbours(Neighbours {
nodes: nodes.clone(),
expire: this.send_neighbours_timeout(),
});
this.send_packet(msg, remote_addr);
return Poll::Ready(Some(MockEvent::Neighbours {
nodes,
to: remote_addr,
}))
}
}
Message::Neighbours(_) => {}
},
}
}
Poll::Pending
}
}
/// The event type the mock service produces
pub enum MockEvent {
Pong { ping: Ping, pong: Pong, to: SocketAddr },
Neighbours { nodes: Vec<NodeRecord>, to: SocketAddr },
}
/// Command for interacting with the `MockDiscovery` service
pub enum MockCommand {
MockPong { node_id: NodeId },
MockNeighbours { target: NodeId, nodes: Vec<NodeRecord> },
}
/// Creates a new testing instance for [`Discv4`] and its service
pub async fn create_discv4() -> (Discv4, Discv4Service) {
create_discv4_with_config(Default::default()).await
}
/// Creates a new testing instance for [`Discv4`] and its service with the given config.
pub async fn create_discv4_with_config(config: Discv4Config) -> (Discv4, Discv4Service) {
let mut rng = thread_rng();
let socket = SocketAddr::from_str("0.0.0.0:0").unwrap();
let (secret_key, pk) = SECP256K1.generate_keypair(&mut rng);
let id = NodeId::from_slice(&pk.serialize_uncompressed()[1..]);
let external_addr = public_ip::addr().await.unwrap_or_else(|| socket.ip());
let local_enr =
NodeRecord { address: external_addr, tcp_port: socket.port(), udp_port: socket.port(), id };
Discv4::bind(socket, local_enr, secret_key, config).await.unwrap()
}
pub fn rng_endpoint(rng: &mut impl Rng) -> NodeEndpoint {
let address = if rng.gen() {
let mut ip = [0u8; 4];
rng.fill_bytes(&mut ip);
IpAddr::V4(ip.into())
} else {
let mut ip = [0u8; 16];
rng.fill_bytes(&mut ip);
IpAddr::V6(ip.into())
};
NodeEndpoint { address, tcp_port: rng.gen(), udp_port: rng.gen() }
}
pub fn rng_record(rng: &mut impl RngCore) -> NodeRecord {
let NodeEndpoint { address, udp_port, tcp_port } = rng_endpoint(rng);
NodeRecord { address, tcp_port, udp_port, id: NodeId::random() }
}
pub fn rng_ipv6_record(rng: &mut impl RngCore) -> NodeRecord {
let mut ip = [0u8; 16];
rng.fill_bytes(&mut ip);
let address = IpAddr::V6(ip.into());
NodeRecord { address, tcp_port: rng.gen(), udp_port: rng.gen(), id: NodeId::random() }
}
pub fn rng_ipv4_record(rng: &mut impl RngCore) -> NodeRecord {
let mut ip = [0u8; 4];
rng.fill_bytes(&mut ip);
let address = IpAddr::V4(ip.into());
NodeRecord { address, tcp_port: rng.gen(), udp_port: rng.gen(), id: NodeId::random() }
}
pub fn rng_message(rng: &mut impl RngCore) -> Message {
match rng.gen_range(1..=4) {
1 => Message::Ping(Ping {
from: rng_endpoint(rng),
to: rng_endpoint(rng),
expire: rng.gen(),
}),
2 => Message::Pong(Pong { to: rng_endpoint(rng), echo: H256::random(), expire: rng.gen() }),
3 => Message::FindNode(FindNode { id: NodeId::random(), expire: rng.gen() }),
4 => {
let num: usize = rng.gen_range(1..=SAFE_MAX_DATAGRAM_NEIGHBOUR_RECORDS);
Message::Neighbours(Neighbours {
nodes: std::iter::repeat_with(|| rng_record(rng)).take(num).collect(),
expire: rng.gen(),
})
}
_ => unreachable!(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Discv4Event, PingReason};
use std::net::{IpAddr, Ipv4Addr};
use tracing_test::traced_test;
/// This test creates two local UDP sockets. The mocked discovery service responds to specific
/// messages and we check the actual service receives answers
#[tokio::test(flavor = "multi_thread")]
#[traced_test]
async fn can_mock_discovery() {
let mut rng = thread_rng();
let (_, mut service) = create_discv4().await;
let (mut mockv4, mut cmd) = MockDiscovery::new().await.unwrap();
let mock_enr = mockv4.local_enr();
let mock_addr = mockv4.local_addr();
// we only want to test internally
service.local_enr_mut().address = IpAddr::V4(Ipv4Addr::UNSPECIFIED);
let discv_addr = service.local_addr();
let discv_enr = service.local_enr();
// make sure it responds with a Pong
mockv4.queue_pong(discv_enr.id);
// This sends a ping to the mock service
let echo_hash = service.send_ping(mock_enr, PingReason::Normal);
// process the mock pong
let event = mockv4.next().await.unwrap();
match event {
MockEvent::Pong { ping, pong, to } => {
assert_eq!(to, SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), discv_addr.port()));
assert_eq!(pong.echo, echo_hash);
}
MockEvent::Neighbours { .. } => {
unreachable!("invalid response")
}
}
// discovery service received mocked pong
let event = service.next().await.unwrap();
assert_eq!(event, Discv4Event::Pong);
assert!(service.contains_node(mock_enr.id));
let mock_nodes =
std::iter::repeat_with(|| rng_record(&mut rng)).take(5).collect::<Vec<_>>();
mockv4.queue_neighbours(discv_enr.id, mock_nodes.clone());
// start lookup
service.lookup_self();
let event = mockv4.next().await.unwrap();
match event {
MockEvent::Pong { .. } => {
unreachable!("invalid response")
}
MockEvent::Neighbours { nodes, to } => {
assert_eq!(to, SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), discv_addr.port()));
assert_eq!(nodes, mock_nodes);
}
}
// discovery service received mocked pong
let event = service.next().await.unwrap();
assert_eq!(event, Discv4Event::Neighbours);
}
}

View File

@ -386,66 +386,13 @@ impl Decodable for Octets {
#[cfg(test)]
mod tests {
use super::*;
use crate::SAFE_MAX_DATAGRAM_NEIGHBOUR_RECORDS;
use crate::{
mock::{rng_endpoint, rng_ipv4_record, rng_ipv6_record, rng_message},
SAFE_MAX_DATAGRAM_NEIGHBOUR_RECORDS,
};
use bytes::BytesMut;
use rand::{thread_rng, Rng, RngCore};
fn rng_endpoint(rng: &mut impl Rng) -> NodeEndpoint {
let address = if rng.gen() {
let mut ip = [0u8; 4];
rng.fill_bytes(&mut ip);
IpAddr::V4(ip.into())
} else {
let mut ip = [0u8; 16];
rng.fill_bytes(&mut ip);
IpAddr::V6(ip.into())
};
NodeEndpoint { address, tcp_port: rng.gen(), udp_port: rng.gen() }
}
fn rng_record(rng: &mut impl RngCore) -> NodeRecord {
let NodeEndpoint { address, udp_port, tcp_port } = rng_endpoint(rng);
NodeRecord { address, tcp_port, udp_port, id: NodeId::random() }
}
fn rng_ipv6_record(rng: &mut impl RngCore) -> NodeRecord {
let mut ip = [0u8; 16];
rng.fill_bytes(&mut ip);
let address = IpAddr::V6(ip.into());
NodeRecord { address, tcp_port: rng.gen(), udp_port: rng.gen(), id: NodeId::random() }
}
fn rng_ipv4_record(rng: &mut impl RngCore) -> NodeRecord {
let mut ip = [0u8; 4];
rng.fill_bytes(&mut ip);
let address = IpAddr::V4(ip.into());
NodeRecord { address, tcp_port: rng.gen(), udp_port: rng.gen(), id: NodeId::random() }
}
fn rng_message(rng: &mut impl RngCore) -> Message {
match rng.gen_range(1..=4) {
1 => Message::Ping(Ping {
from: rng_endpoint(rng),
to: rng_endpoint(rng),
expire: rng.gen(),
}),
2 => Message::Pong(Pong {
to: rng_endpoint(rng),
echo: H256::random(),
expire: rng.gen(),
}),
3 => Message::FindNode(FindNode { id: NodeId::random(), expire: rng.gen() }),
4 => {
let num: usize = rng.gen_range(1..=SAFE_MAX_DATAGRAM_NEIGHBOUR_RECORDS);
Message::Neighbours(Neighbours {
nodes: std::iter::repeat_with(|| rng_record(rng)).take(num).collect(),
expire: rng.gen(),
})
}
_ => unreachable!(),
}
}
#[test]
fn test_endpoint_ipv_v4() {
let mut rng = thread_rng();