mirror of
https://github.com/hl-archive-node/nanoreth.git
synced 2025-12-06 19:09:54 +00:00
feature: Add timeouts for handshake functions (#7295)
This commit is contained in:
@ -81,6 +81,9 @@ pub enum ECIESErrorImpl {
|
||||
/// a message from the (partially filled) buffer.
|
||||
#[error("stream closed due to not being readable")]
|
||||
UnreadableStream,
|
||||
// Error when data is not recieved from peer for a prolonged period.
|
||||
#[error("never recieved data from remote peer")]
|
||||
StreamTimeout,
|
||||
}
|
||||
|
||||
impl From<ECIESErrorImpl> for ECIESError {
|
||||
|
||||
@ -15,12 +15,18 @@ use std::{
|
||||
io,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncWrite},
|
||||
time::timeout,
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_stream::{Stream, StreamExt};
|
||||
use tokio_util::codec::{Decoder, Framed};
|
||||
use tracing::{instrument, trace};
|
||||
|
||||
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
|
||||
/// `ECIES` stream over TCP exchanging raw bytes
|
||||
#[derive(Debug)]
|
||||
#[pin_project::pin_project]
|
||||
@ -40,6 +46,27 @@ where
|
||||
transport: Io,
|
||||
secret_key: SecretKey,
|
||||
remote_id: PeerId,
|
||||
) -> Result<Self, ECIESError> {
|
||||
Self::connect_with_timeout(transport, secret_key, remote_id, HANDSHAKE_TIMEOUT).await
|
||||
}
|
||||
|
||||
/// Wrapper around connect_no_timeout which enforces a timeout.
|
||||
pub async fn connect_with_timeout(
|
||||
transport: Io,
|
||||
secret_key: SecretKey,
|
||||
remote_id: PeerId,
|
||||
timeout_limit: Duration,
|
||||
) -> Result<Self, ECIESError> {
|
||||
timeout(timeout_limit, Self::connect_without_timeout(transport, secret_key, remote_id))
|
||||
.await
|
||||
.map_err(|_| ECIESError::from(ECIESErrorImpl::StreamTimeout))?
|
||||
}
|
||||
|
||||
/// Connect to an `ECIES` server with no timeout.
|
||||
pub async fn connect_without_timeout(
|
||||
transport: Io,
|
||||
secret_key: SecretKey,
|
||||
remote_id: PeerId,
|
||||
) -> Result<Self, ECIESError> {
|
||||
let ecies = ECIESCodec::new_client(secret_key, remote_id)
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "invalid handshake"))?;
|
||||
@ -180,4 +207,42 @@ mod tests {
|
||||
// make sure the server receives the message and asserts before ending the test
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connection_should_timeout() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let server_key = SecretKey::new(&mut rand::thread_rng());
|
||||
|
||||
let _handle = tokio::spawn(async move {
|
||||
// Delay accepting the connection for longer than the client's timeout period
|
||||
tokio::time::sleep(Duration::from_secs(11)).await;
|
||||
let (incoming, _) = listener.accept().await.unwrap();
|
||||
let mut stream = ECIESStream::incoming(incoming, server_key).await.unwrap();
|
||||
|
||||
// use the stream to get the next message
|
||||
let message = stream.next().await.unwrap().unwrap();
|
||||
assert_eq!(message, Bytes::from("hello"));
|
||||
});
|
||||
|
||||
// create the server pubkey
|
||||
let server_id = pk2id(&server_key.public_key(SECP256K1));
|
||||
|
||||
let client_key = SecretKey::new(&mut rand::thread_rng());
|
||||
let outgoing = TcpStream::connect(addr).await.unwrap();
|
||||
|
||||
// Attempt to connect, expecting a timeout due to the server's delayed response
|
||||
let connect_result = ECIESStream::connect_with_timeout(
|
||||
outgoing,
|
||||
client_key,
|
||||
server_id,
|
||||
Duration::from_secs(1),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Assert that a timeout error occurred
|
||||
assert!(
|
||||
matches!(connect_result, Err(e) if e.to_string() == ECIESErrorImpl::StreamTimeout.to_string())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -40,6 +40,9 @@ pub enum EthStreamError {
|
||||
/// The number of transaction sizes.
|
||||
sizes_len: usize,
|
||||
},
|
||||
/// Error when data is not recieved from peer for a prolonged period.
|
||||
#[error("never recieved data from remote peer")]
|
||||
StreamTimeout,
|
||||
}
|
||||
|
||||
// === impl EthStreamError ===
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
use crate::{
|
||||
errors::{EthHandshakeError, EthStreamError},
|
||||
message::{EthBroadcastMessage, ProtocolBroadcastMessage},
|
||||
p2pstream::HANDSHAKE_TIMEOUT,
|
||||
types::{EthMessage, ProtocolMessage, Status},
|
||||
CanDisconnect, DisconnectReason, EthVersion,
|
||||
};
|
||||
@ -13,7 +14,9 @@ use reth_primitives::{
|
||||
use std::{
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::time::timeout;
|
||||
use tokio_stream::Stream;
|
||||
use tracing::{debug, trace};
|
||||
|
||||
@ -51,6 +54,27 @@ where
|
||||
/// handshake is completed successfully. This also returns the `Status` message sent by the
|
||||
/// remote peer.
|
||||
pub async fn handshake(
|
||||
self,
|
||||
status: Status,
|
||||
fork_filter: ForkFilter,
|
||||
) -> Result<(EthStream<S>, Status), EthStreamError> {
|
||||
self.handshake_with_timeout(status, fork_filter, HANDSHAKE_TIMEOUT).await
|
||||
}
|
||||
|
||||
/// Wrapper around handshake which enforces a timeout.
|
||||
pub async fn handshake_with_timeout(
|
||||
self,
|
||||
status: Status,
|
||||
fork_filter: ForkFilter,
|
||||
timeout_limit: Duration,
|
||||
) -> Result<(EthStream<S>, Status), EthStreamError> {
|
||||
timeout(timeout_limit, Self::handshake_without_timeout(self, status, fork_filter))
|
||||
.await
|
||||
.map_err(|_| EthStreamError::StreamTimeout)?
|
||||
}
|
||||
|
||||
/// Handshake with no timeout
|
||||
pub async fn handshake_without_timeout(
|
||||
mut self,
|
||||
status: Status,
|
||||
fork_filter: ForkFilter,
|
||||
@ -321,6 +345,8 @@ where
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use super::UnauthedEthStream;
|
||||
use crate::{
|
||||
errors::{EthHandshakeError, EthStreamError},
|
||||
@ -642,4 +668,53 @@ mod tests {
|
||||
// make sure the server receives the message and asserts before ending the test
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handshake_should_timeout() {
|
||||
let genesis = B256::random();
|
||||
let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());
|
||||
|
||||
let status = Status {
|
||||
version: EthVersion::Eth67 as u8,
|
||||
chain: NamedChain::Mainnet.into(),
|
||||
total_difficulty: U256::ZERO,
|
||||
blockhash: B256::random(),
|
||||
genesis,
|
||||
// Pass the current fork id.
|
||||
forkid: fork_filter.current(),
|
||||
};
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let local_addr = listener.local_addr().unwrap();
|
||||
|
||||
let status_clone = status;
|
||||
let fork_filter_clone = fork_filter.clone();
|
||||
let _handle = tokio::spawn(async move {
|
||||
// Delay accepting the connection for longer than the client's timeout period
|
||||
tokio::time::sleep(Duration::from_secs(11)).await;
|
||||
// roughly based off of the design of tokio::net::TcpListener
|
||||
let (incoming, _) = listener.accept().await.unwrap();
|
||||
let stream = PassthroughCodec::default().framed(incoming);
|
||||
let (_, their_status) = UnauthedEthStream::new(stream)
|
||||
.handshake(status_clone, fork_filter_clone)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// just make sure it equals our status (our status is a clone of their status)
|
||||
assert_eq!(their_status, status_clone);
|
||||
});
|
||||
|
||||
let outgoing = TcpStream::connect(local_addr).await.unwrap();
|
||||
let sink = PassthroughCodec::default().framed(outgoing);
|
||||
|
||||
// try to connect
|
||||
let handshake_result = UnauthedEthStream::new(sink)
|
||||
.handshake_with_timeout(status, fork_filter, Duration::from_secs(1))
|
||||
.await;
|
||||
|
||||
// Assert that a timeout error occurred
|
||||
assert!(
|
||||
matches!(handshake_result, Err(e) if e.to_string() == EthStreamError::StreamTimeout.to_string())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -40,7 +40,7 @@ const MAX_P2P_MESSAGE_ID: u8 = P2PMessageID::Pong as u8;
|
||||
|
||||
/// [`HANDSHAKE_TIMEOUT`] determines the amount of time to wait before determining that a `p2p`
|
||||
/// handshake has timed out.
|
||||
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
pub(crate) const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
|
||||
/// [`PING_TIMEOUT`] determines the amount of time to wait before determining that a `p2p` ping has
|
||||
/// timed out.
|
||||
|
||||
Reference in New Issue
Block a user