mirror of
https://github.com/hl-archive-node/nanoreth.git
synced 2025-12-06 10:59:55 +00:00
fix: ensure final message is always delivered (#4569)
This commit is contained in:
@ -11,6 +11,7 @@ use std::{
|
||||
use tokio::sync::mpsc::{
|
||||
self,
|
||||
error::{SendError, TryRecvError, TrySendError},
|
||||
OwnedPermit,
|
||||
};
|
||||
|
||||
/// Wrapper around [mpsc::unbounded_channel] that returns a new unbounded metered channel.
|
||||
@ -142,6 +143,18 @@ impl<T> MeteredSender<T> {
|
||||
Self { sender, metrics: MeteredSenderMetrics::new(scope) }
|
||||
}
|
||||
|
||||
/// Tries to acquire a permit to send a message.
|
||||
///
|
||||
/// See also [Sender](mpsc::Sender)'s `try_reserve_owned`.
|
||||
pub fn try_reserve_owned(&self) -> Result<OwnedPermit<T>, TrySendError<mpsc::Sender<T>>> {
|
||||
self.sender.clone().try_reserve_owned()
|
||||
}
|
||||
|
||||
/// Returns the underlying [Sender](mpsc::Sender).
|
||||
pub fn inner(&self) -> &mpsc::Sender<T> {
|
||||
&self.sender
|
||||
}
|
||||
|
||||
/// Calls the underlying [Sender](mpsc::Sender)'s `try_send`, incrementing the appropriate
|
||||
/// metrics depending on the result.
|
||||
pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> {
|
||||
|
||||
@ -37,6 +37,7 @@ use tokio::{
|
||||
time::Interval,
|
||||
};
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_util::sync::PollSender;
|
||||
use tracing::{debug, info, trace};
|
||||
|
||||
/// Constants for timeout updating
|
||||
@ -79,11 +80,11 @@ pub(crate) struct ActiveSession {
|
||||
pub(crate) to_session_manager: MeteredSender<ActiveSessionMessage>,
|
||||
/// A message that needs to be delivered to the session manager
|
||||
pub(crate) pending_message_to_session: Option<ActiveSessionMessage>,
|
||||
/// Incoming request to send to delegate to the remote peer.
|
||||
/// Incoming internal requests which are delegated to the remote peer.
|
||||
pub(crate) internal_request_tx: Fuse<ReceiverStream<PeerRequest>>,
|
||||
/// All requests sent to the remote peer we're waiting on a response
|
||||
pub(crate) inflight_requests: FnvHashMap<u64, InflightRequest>,
|
||||
/// All requests that were sent by the remote peer.
|
||||
/// All requests that were sent by the remote peer and we're waiting on an internal response
|
||||
pub(crate) received_requests_from_remote: Vec<ReceivedRequest>,
|
||||
/// Buffered messages that should be handled and sent to the peer.
|
||||
pub(crate) queued_outgoing: VecDeque<OutgoingMessage>,
|
||||
@ -94,6 +95,8 @@ pub(crate) struct ActiveSession {
|
||||
/// If an [ActiveSession] does not receive a response at all within this duration then it is
|
||||
/// considered a protocol violation and the session will initiate a drop.
|
||||
pub(crate) protocol_breach_request_timeout: Duration,
|
||||
/// Used to reserve a slot to guarantee that the termination message is delivered
|
||||
pub(crate) terminate_message: Option<(PollSender<ActiveSessionMessage>, ActiveSessionMessage)>,
|
||||
}
|
||||
|
||||
impl ActiveSession {
|
||||
@ -118,7 +121,7 @@ impl ActiveSession {
|
||||
/// Handle a message read from the connection.
|
||||
///
|
||||
/// Returns an error if the message is considered to be in violation of the protocol.
|
||||
fn on_incoming(&mut self, msg: EthMessage) -> OnIncomingMessageOutcome {
|
||||
fn on_incoming_message(&mut self, msg: EthMessage) -> OnIncomingMessageOutcome {
|
||||
/// A macro that handles an incoming request
|
||||
/// This creates a new channel and tries to send the sender half to the session while
|
||||
/// storing the receiver half internally so the pending response can be polled.
|
||||
@ -247,7 +250,7 @@ impl ActiveSession {
|
||||
}
|
||||
|
||||
/// Handle a message received from the internal network
|
||||
fn on_peer_message(&mut self, msg: PeerMessage) {
|
||||
fn on_internal_peer_message(&mut self, msg: PeerMessage) {
|
||||
match msg {
|
||||
PeerMessage::NewBlockHashes(msg) => {
|
||||
self.queued_outgoing.push_back(EthMessage::NewBlockHashes(msg).into());
|
||||
@ -283,6 +286,8 @@ impl ActiveSession {
|
||||
}
|
||||
|
||||
/// Handle a Response to the peer
|
||||
///
|
||||
/// This will queue the response to be sent to the peer
|
||||
fn handle_outgoing_response(&mut self, id: u64, resp: PeerResponseResult) {
|
||||
match resp.try_into_message(id) {
|
||||
Ok(msg) => {
|
||||
@ -355,25 +360,28 @@ impl ActiveSession {
|
||||
}
|
||||
|
||||
/// Report back that this session has been closed.
|
||||
fn emit_disconnect(&self) {
|
||||
fn emit_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<()> {
|
||||
trace!(target: "net::session", remote_peer_id=?self.remote_peer_id, "emitting disconnect");
|
||||
// NOTE: we clone here so there's enough capacity to deliver this message
|
||||
let _ = self.to_session_manager.clone().try_send(ActiveSessionMessage::Disconnected {
|
||||
let msg = ActiveSessionMessage::Disconnected {
|
||||
peer_id: self.remote_peer_id,
|
||||
remote_addr: self.remote_addr,
|
||||
});
|
||||
};
|
||||
|
||||
self.terminate_message =
|
||||
Some((PollSender::new(self.to_session_manager.inner().clone()).clone(), msg));
|
||||
self.poll_terminate_message(cx).expect("message is set")
|
||||
}
|
||||
|
||||
/// Report back that this session has been closed due to an error
|
||||
fn close_on_error(&self, error: EthStreamError) {
|
||||
// NOTE: we clone here so there's enough capacity to deliver this message
|
||||
let _ = self.to_session_manager.clone().try_send(
|
||||
ActiveSessionMessage::ClosedOnConnectionError {
|
||||
fn close_on_error(&mut self, error: EthStreamError, cx: &mut Context<'_>) -> Poll<()> {
|
||||
let msg = ActiveSessionMessage::ClosedOnConnectionError {
|
||||
peer_id: self.remote_peer_id,
|
||||
remote_addr: self.remote_addr,
|
||||
error,
|
||||
},
|
||||
);
|
||||
};
|
||||
self.terminate_message =
|
||||
Some((PollSender::new(self.to_session_manager.inner().clone()).clone(), msg));
|
||||
self.poll_terminate_message(cx).expect("message is set")
|
||||
}
|
||||
|
||||
/// Starts the disconnect process
|
||||
@ -391,8 +399,7 @@ impl ActiveSession {
|
||||
|
||||
// try to close the flush out the remaining Disconnect message
|
||||
let _ = ready!(self.conn.poll_close_unpin(cx));
|
||||
self.emit_disconnect();
|
||||
Poll::Ready(())
|
||||
self.emit_disconnect(cx)
|
||||
}
|
||||
|
||||
/// Attempts to disconnect by sending the given disconnect reason
|
||||
@ -404,8 +411,7 @@ impl ActiveSession {
|
||||
}
|
||||
Err(err) => {
|
||||
debug!(target: "net::session", ?err, remote_peer_id=?self.remote_peer_id, "could not send disconnect");
|
||||
self.close_on_error(err);
|
||||
Poll::Ready(())
|
||||
self.close_on_error(err, cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -443,6 +449,25 @@ impl ActiveSession {
|
||||
self.internal_request_timeout.store(request_timeout.as_millis() as u64, Ordering::Relaxed);
|
||||
self.internal_request_timeout_interval = tokio::time::interval(request_timeout);
|
||||
}
|
||||
|
||||
/// If a termination message is queued this will try to send it
|
||||
fn poll_terminate_message(&mut self, cx: &mut Context<'_>) -> Option<Poll<()>> {
|
||||
let (mut tx, msg) = self.terminate_message.take()?;
|
||||
match tx.poll_reserve(cx) {
|
||||
Poll::Pending => {
|
||||
self.terminate_message = Some((tx, msg));
|
||||
return Some(Poll::Pending)
|
||||
}
|
||||
Poll::Ready(Ok(())) => {
|
||||
let _ = tx.send_item(msg);
|
||||
}
|
||||
Poll::Ready(Err(_)) => {
|
||||
// channel closed
|
||||
}
|
||||
}
|
||||
// terminate the task
|
||||
Some(Poll::Ready(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Future for ActiveSession {
|
||||
@ -451,6 +476,11 @@ impl Future for ActiveSession {
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.get_mut();
|
||||
|
||||
// if the session is terminate we have to send the termination message before we can close
|
||||
if let Some(terminate) = this.poll_terminate_message(cx) {
|
||||
return terminate
|
||||
}
|
||||
|
||||
if this.is_disconnecting() {
|
||||
return this.poll_disconnect(cx)
|
||||
}
|
||||
@ -486,7 +516,7 @@ impl Future for ActiveSession {
|
||||
return this.try_disconnect(reason, cx)
|
||||
}
|
||||
SessionCommand::Message(msg) => {
|
||||
this.on_peer_message(msg);
|
||||
this.on_internal_peer_message(msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -526,8 +556,7 @@ impl Future for ActiveSession {
|
||||
if let Err(err) = res {
|
||||
debug!(target: "net::session", ?err, remote_peer_id=?this.remote_peer_id, "failed to send message");
|
||||
// notify the manager
|
||||
this.close_on_error(err);
|
||||
return Poll::Ready(())
|
||||
return this.close_on_error(err, cx)
|
||||
}
|
||||
} else {
|
||||
// no more messages to send over the wire
|
||||
@ -571,8 +600,7 @@ impl Future for ActiveSession {
|
||||
break
|
||||
} else {
|
||||
debug!(target: "net::session", remote_peer_id=?this.remote_peer_id, "eth stream completed");
|
||||
this.emit_disconnect();
|
||||
return Poll::Ready(())
|
||||
return this.emit_disconnect(cx)
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(res)) => {
|
||||
@ -580,15 +608,14 @@ impl Future for ActiveSession {
|
||||
Ok(msg) => {
|
||||
trace!(target: "net::session", msg_id=?msg.message_id(), remote_peer_id=?this.remote_peer_id, "received eth message");
|
||||
// decode and handle message
|
||||
match this.on_incoming(msg) {
|
||||
match this.on_incoming_message(msg) {
|
||||
OnIncomingMessageOutcome::Ok => {
|
||||
// handled successfully
|
||||
progress = true;
|
||||
}
|
||||
OnIncomingMessageOutcome::BadMessage { error, message } => {
|
||||
debug!(target: "net::session", ?error, msg=?message, remote_peer_id=?this.remote_peer_id, "received invalid protocol message");
|
||||
this.close_on_error(error);
|
||||
return Poll::Ready(())
|
||||
return this.close_on_error(error, cx)
|
||||
}
|
||||
OnIncomingMessageOutcome::NoCapacity(msg) => {
|
||||
// failed to send due to lack of capacity
|
||||
@ -599,8 +626,7 @@ impl Future for ActiveSession {
|
||||
}
|
||||
Err(err) => {
|
||||
debug!(target: "net::session", ?err, remote_peer_id=?this.remote_peer_id, "failed to receive message");
|
||||
this.close_on_error(err);
|
||||
return Poll::Ready(())
|
||||
return this.close_on_error(err, cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -612,8 +638,7 @@ impl Future for ActiveSession {
|
||||
}
|
||||
}
|
||||
|
||||
if this.internal_request_timeout_interval.poll_tick(cx).is_ready() {
|
||||
let _ = this.internal_request_timeout_interval.poll_tick(cx);
|
||||
while this.internal_request_timeout_interval.poll_tick(cx).is_ready() {
|
||||
// check for timed out requests
|
||||
if this.check_timed_out_requests(Instant::now()) {
|
||||
let _ = this.to_session_manager.clone().try_send(
|
||||
@ -664,6 +689,7 @@ impl InflightRequest {
|
||||
matches!(self.request, RequestState::Waiting(_))
|
||||
}
|
||||
|
||||
/// This will timeout the request by sending an error response to the internal channel
|
||||
fn timeout(&mut self) {
|
||||
let mut req = RequestState::TimedOut;
|
||||
std::mem::swap(&mut self.request, &mut req);
|
||||
@ -866,6 +892,7 @@ mod tests {
|
||||
INITIAL_REQUEST_TIMEOUT.as_millis() as u64,
|
||||
)),
|
||||
protocol_breach_request_timeout: PROTOCOL_BREACH_REQUEST_TIMEOUT,
|
||||
terminate_message: None,
|
||||
}
|
||||
}
|
||||
ev => {
|
||||
|
||||
@ -458,6 +458,7 @@ impl SessionManager {
|
||||
),
|
||||
internal_request_timeout: Arc::clone(&timeout),
|
||||
protocol_breach_request_timeout: self.protocol_breach_request_timeout,
|
||||
terminate_message: None,
|
||||
};
|
||||
|
||||
self.spawn(session);
|
||||
|
||||
Reference in New Issue
Block a user