test(net): add active session tests (#230)

* test(net): add active session tests

* more tests
This commit is contained in:
Matthias Seitz
2022-11-22 06:22:43 +01:00
committed by GitHub
parent f7c6ae5858
commit 46e4ad9744
5 changed files with 317 additions and 7 deletions

View File

@ -21,6 +21,19 @@ pub enum EthStreamError {
MessageTooBig(usize), MessageTooBig(usize),
} }
// === impl EthStreamError ===
impl EthStreamError {
/// Returns the [`DisconnectReason`] if the error is a disconnect message
pub fn as_disconnected(&self) -> Option<DisconnectReason> {
if let EthStreamError::P2PStreamError(err) = self {
err.as_disconnected()
} else {
None
}
}
}
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
#[allow(missing_docs)] #[allow(missing_docs)]
pub enum HandshakeError { pub enum HandshakeError {
@ -73,6 +86,19 @@ pub enum P2PStreamError {
Disconnected(DisconnectReason), Disconnected(DisconnectReason),
} }
// === impl P2PStreamError ===
impl P2PStreamError {
/// Returns the [`DisconnectReason`] if it is the `Disconnected` variant.
pub fn as_disconnected(&self) -> Option<DisconnectReason> {
if let P2PStreamError::Disconnected(reason) = self {
Some(*reason)
} else {
None
}
}
}
/// Errors when conducting a p2p handshake /// Errors when conducting a p2p handshake
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
pub enum P2PHandshakeError { pub enum P2PHandshakeError {

View File

@ -141,6 +141,11 @@ impl<S> EthStream<S> {
pub fn inner_mut(&mut self) -> &mut S { pub fn inner_mut(&mut self) -> &mut S {
&mut self.inner &mut self.inner
} }
/// Consumes this type and returns the wrapped stream.
pub fn into_inner(self) -> S {
self.inner
}
} }
impl<S, E> EthStream<S> impl<S, E> EthStream<S>

View File

@ -23,5 +23,5 @@ pub use types::*;
pub use crate::{ pub use crate::{
ethstream::{EthStream, UnauthedEthStream}, ethstream::{EthStream, UnauthedEthStream},
p2pstream::{DisconnectReason, HelloMessage, P2PStream, UnauthedP2PStream}, p2pstream::{DisconnectReason, HelloMessage, P2PStream, ProtocolVersion, UnauthedP2PStream},
}; };

View File

@ -16,7 +16,6 @@ use reth_eth_wire::{
message::{EthBroadcastMessage, RequestPair}, message::{EthBroadcastMessage, RequestPair},
DisconnectReason, EthMessage, EthStream, P2PStream, DisconnectReason, EthMessage, EthStream, P2PStream,
}; };
use reth_primitives::PeerId; use reth_primitives::PeerId;
use std::{ use std::{
collections::VecDeque, collections::VecDeque,
@ -243,7 +242,7 @@ impl ActiveSession {
} }
/// Report back that this session has been closed. /// Report back that this session has been closed.
fn disconnect(&self) { fn emit_disconnect(&self) {
// NOTE: we clone here so there's enough capacity to deliver this message // NOTE: we clone here so there's enough capacity to deliver this message
let _ = self.to_session.clone().try_send(ActiveSessionMessage::Disconnected { let _ = self.to_session.clone().try_send(ActiveSessionMessage::Disconnected {
peer_id: self.remote_peer_id, peer_id: self.remote_peer_id,
@ -260,6 +259,11 @@ impl ActiveSession {
error, error,
}); });
} }
/// Starts the disconnect process
fn start_disconnect(&mut self, reason: DisconnectReason) {
self.conn.inner_mut().start_disconnect(reason);
}
} }
impl Future for ActiveSession { impl Future for ActiveSession {
@ -271,7 +275,7 @@ impl Future for ActiveSession {
if this.is_disconnecting() { if this.is_disconnecting() {
// try to close the flush out the remaining Disconnect message // try to close the flush out the remaining Disconnect message
let _ = ready!(this.conn.poll_close_unpin(cx)); let _ = ready!(this.conn.poll_close_unpin(cx));
this.disconnect(); this.emit_disconnect();
return Poll::Ready(()) return Poll::Ready(())
} }
@ -293,7 +297,7 @@ impl Future for ActiveSession {
SessionCommand::Disconnect { reason } => { SessionCommand::Disconnect { reason } => {
let reason = let reason =
reason.unwrap_or(DisconnectReason::DisconnectRequested); reason.unwrap_or(DisconnectReason::DisconnectRequested);
this.conn.inner_mut().start_disconnect(reason); this.start_disconnect(reason);
} }
SessionCommand::Message(msg) => { SessionCommand::Message(msg) => {
this.on_peer_message(msg); this.on_peer_message(msg);
@ -345,7 +349,14 @@ impl Future for ActiveSession {
loop { loop {
match this.conn.poll_next_unpin(cx) { match this.conn.poll_next_unpin(cx) {
Poll::Pending => break, Poll::Pending => break,
Poll::Ready(None) => return Poll::Pending, Poll::Ready(None) => {
if this.is_disconnecting() {
break
} else {
this.emit_disconnect();
return Poll::Ready(())
}
}
Poll::Ready(Some(res)) => { Poll::Ready(Some(res)) => {
progress = true; progress = true;
match res { match res {
@ -401,3 +412,271 @@ impl From<EthBroadcastMessage> for OutgoingMessage {
OutgoingMessage::Broadcast(value) OutgoingMessage::Broadcast(value)
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use crate::session::{handle::PendingSessionEvent, start_pending_incoming_session};
use reth_ecies::util::pk2id;
use reth_eth_wire::{
EthVersion, HelloMessage, NewPooledTransactionHashes, ProtocolVersion, Status,
StatusBuilder, UnauthedEthStream, UnauthedP2PStream,
};
use reth_primitives::{ForkFilter, Hardfork};
use secp256k1::{SecretKey, SECP256K1};
use std::time::Duration;
use tokio::net::TcpListener;
/// Returns a testing `HelloMessage` and new secretkey
fn eth_hello(server_key: &SecretKey) -> HelloMessage {
HelloMessage {
protocol_version: ProtocolVersion::V5,
client_version: "reth/1.0.0".to_string(),
capabilities: vec![EthVersion::Eth67.into()],
port: 30303,
id: pk2id(&server_key.public_key(SECP256K1)),
}
}
struct SessionBuilder {
remote_capabilities: Arc<Capabilities>,
active_session_tx: mpsc::Sender<ActiveSessionMessage>,
active_session_rx: ReceiverStream<ActiveSessionMessage>,
to_sessions: Vec<mpsc::Sender<SessionCommand>>,
secret_key: SecretKey,
local_peer_id: PeerId,
hello: HelloMessage,
status: Status,
fork_filter: ForkFilter,
next_id: usize,
}
impl SessionBuilder {
fn next_id(&mut self) -> SessionId {
let id = self.next_id;
self.next_id += 1;
SessionId(id)
}
/// Connects a new Eth stream and executes the given closure with that established stream
fn with_client_stream<F, O>(
&self,
local_addr: SocketAddr,
f: F,
) -> Pin<Box<dyn Future<Output = ()> + Send + Sync>>
where
F: FnOnce(EthStream<P2PStream<ECIESStream<TcpStream>>>) -> O + Send + Sync + 'static,
O: Future<Output = ()> + Send + Sync,
{
let status = self.status;
let fork_filter = self.fork_filter.clone();
let local_peer_id = self.local_peer_id;
let mut hello = self.hello.clone();
let key = SecretKey::new(&mut rand::thread_rng());
hello.id = pk2id(&key.public_key(SECP256K1));
Box::pin(async move {
let outgoing = TcpStream::connect(local_addr).await.unwrap();
let sink = ECIESStream::connect(outgoing, key, local_peer_id).await.unwrap();
let (p2p_stream, _) = UnauthedP2PStream::new(sink).handshake(hello).await.unwrap();
let (client_stream, _) = UnauthedEthStream::new(p2p_stream)
.handshake(status, fork_filter)
.await
.unwrap();
f(client_stream).await
})
}
async fn connect_incoming(&mut self, stream: TcpStream) -> ActiveSession {
let remote_addr = stream.local_addr().unwrap();
let session_id = self.next_id();
let (_disconnect_tx, disconnect_rx) = oneshot::channel();
let (pending_sessions_tx, pending_sessions_rx) = mpsc::channel(1);
tokio::task::spawn(start_pending_incoming_session(
disconnect_rx,
session_id,
stream,
pending_sessions_tx,
remote_addr,
self.secret_key,
self.hello.clone(),
self.status,
self.fork_filter.clone(),
));
let mut stream = ReceiverStream::new(pending_sessions_rx);
match stream.next().await.unwrap() {
PendingSessionEvent::Established {
session_id,
remote_addr,
peer_id,
capabilities,
status: _,
conn,
} => {
let (_to_session_tx, messages_rx) = mpsc::channel(10);
let (commands_to_session, commands_rx) = mpsc::channel(10);
self.to_sessions.push(commands_to_session);
ActiveSession {
next_id: 0,
remote_peer_id: peer_id,
remote_addr,
remote_capabilities: Arc::clone(&capabilities),
session_id,
commands_rx: ReceiverStream::new(commands_rx),
to_session: self.active_session_tx.clone(),
request_tx: ReceiverStream::new(messages_rx).fuse(),
inflight_requests: Default::default(),
conn,
queued_outgoing: Default::default(),
received_requests: Default::default(),
}
}
_ => {
panic!("unexpected message")
}
}
}
}
impl Default for SessionBuilder {
fn default() -> Self {
let (active_session_tx, active_session_rx) = mpsc::channel(100);
let (secret_key, pk) = SECP256K1.generate_keypair(&mut rand::thread_rng());
let local_peer_id = pk2id(&pk);
Self {
next_id: 0,
remote_capabilities: Arc::new(Capabilities::from(vec![])),
active_session_tx,
active_session_rx: ReceiverStream::new(active_session_rx),
to_sessions: vec![],
hello: eth_hello(&secret_key),
secret_key,
local_peer_id,
status: StatusBuilder::default().build(),
fork_filter: Hardfork::Frontier.fork_filter(),
}
}
}
#[tokio::test(flavor = "multi_thread")]
async fn test_disconnect() {
let mut builder = SessionBuilder::default();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let expected_disconnect = DisconnectReason::UselessPeer;
let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
let msg = client_stream.next().await.unwrap().unwrap_err();
assert_eq!(msg.as_disconnected().unwrap(), expected_disconnect);
});
tokio::task::spawn(async move {
let (incoming, _) = listener.accept().await.unwrap();
let mut session = builder.connect_incoming(incoming).await;
session.start_disconnect(expected_disconnect);
session.await
});
fut.await;
}
#[tokio::test(flavor = "multi_thread")]
async fn handle_dropped_stream() {
let mut builder = SessionBuilder::default();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let fut = builder.with_client_stream(local_addr, move |client_stream| async move {
drop(client_stream);
tokio::time::sleep(Duration::from_secs(1)).await
});
let (tx, rx) = oneshot::channel();
tokio::task::spawn(async move {
let (incoming, _) = listener.accept().await.unwrap();
let session = builder.connect_incoming(incoming).await;
session.await;
tx.send(()).unwrap();
});
tokio::task::spawn(fut);
rx.await.unwrap();
}
#[tokio::test(flavor = "multi_thread")]
async fn test_send_many_messages() {
let mut builder = SessionBuilder::default();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let num_messages = 10_000;
let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
for _ in 0..num_messages {
client_stream
.send(EthMessage::NewPooledTransactionHashes(NewPooledTransactionHashes(
vec![],
)))
.await
.unwrap();
}
});
let (tx, rx) = oneshot::channel();
tokio::task::spawn(async move {
let (incoming, _) = listener.accept().await.unwrap();
let session = builder.connect_incoming(incoming).await;
session.await;
tx.send(()).unwrap();
});
tokio::task::spawn(fut);
rx.await.unwrap();
}
#[tokio::test(flavor = "multi_thread")]
async fn test_keep_alive() {
let mut builder = SessionBuilder::default();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let fut = builder.with_client_stream(local_addr, move |mut client_stream| async move {
let _ = tokio::time::timeout(Duration::from_secs(60), client_stream.next()).await;
client_stream.into_inner().disconnect(DisconnectReason::UselessPeer).await.unwrap();
});
let (tx, rx) = oneshot::channel();
tokio::task::spawn(async move {
let (incoming, _) = listener.accept().await.unwrap();
let session = builder.connect_incoming(incoming).await;
session.await;
tx.send(()).unwrap();
});
tokio::task::spawn(fut);
rx.await.unwrap();
}
}

View File

@ -478,7 +478,7 @@ pub struct ExceedsSessionLimit(usize);
/// Starts the authentication process for a connection initiated by a remote peer. /// Starts the authentication process for a connection initiated by a remote peer.
/// ///
/// This will wait for the _incoming_ handshake request and answer it. /// This will wait for the _incoming_ handshake request and answer it.
async fn start_pending_incoming_session( pub(crate) async fn start_pending_incoming_session(
disconnect_rx: oneshot::Receiver<()>, disconnect_rx: oneshot::Receiver<()>,
session_id: SessionId, session_id: SessionId,
stream: TcpStream, stream: TcpStream,