From b8e15fa10b6da6fa8650047a13b9887ceb3d683f Mon Sep 17 00:00:00 2001 From: Matthias Seitz Date: Tue, 12 Sep 2023 22:03:20 +0200 Subject: [PATCH] fix: ensure final message is always delivered (#4569) --- crates/metrics/src/common/mpsc.rs | 13 ++++ crates/net/network/src/session/active.rs | 93 +++++++++++++++--------- crates/net/network/src/session/mod.rs | 1 + 3 files changed, 74 insertions(+), 33 deletions(-) diff --git a/crates/metrics/src/common/mpsc.rs b/crates/metrics/src/common/mpsc.rs index 6148e5a87..d1a03dd6f 100644 --- a/crates/metrics/src/common/mpsc.rs +++ b/crates/metrics/src/common/mpsc.rs @@ -11,6 +11,7 @@ use std::{ use tokio::sync::mpsc::{ self, error::{SendError, TryRecvError, TrySendError}, + OwnedPermit, }; /// Wrapper around [mpsc::unbounded_channel] that returns a new unbounded metered channel. @@ -142,6 +143,18 @@ impl MeteredSender { Self { sender, metrics: MeteredSenderMetrics::new(scope) } } + /// Tries to acquire a permit to send a message. + /// + /// See also [Sender](mpsc::Sender)'s `try_reserve_owned`. + pub fn try_reserve_owned(&self) -> Result, TrySendError>> { + self.sender.clone().try_reserve_owned() + } + + /// Returns the underlying [Sender](mpsc::Sender). + pub fn inner(&self) -> &mpsc::Sender { + &self.sender + } + /// Calls the underlying [Sender](mpsc::Sender)'s `try_send`, incrementing the appropriate /// metrics depending on the result. pub fn try_send(&self, message: T) -> Result<(), TrySendError> { diff --git a/crates/net/network/src/session/active.rs b/crates/net/network/src/session/active.rs index 180096241..64d0cf8cf 100644 --- a/crates/net/network/src/session/active.rs +++ b/crates/net/network/src/session/active.rs @@ -37,6 +37,7 @@ use tokio::{ time::Interval, }; use tokio_stream::wrappers::ReceiverStream; +use tokio_util::sync::PollSender; use tracing::{debug, info, trace}; /// Constants for timeout updating @@ -79,11 +80,11 @@ pub(crate) struct ActiveSession { pub(crate) to_session_manager: MeteredSender, /// A message that needs to be delivered to the session manager pub(crate) pending_message_to_session: Option, - /// Incoming request to send to delegate to the remote peer. + /// Incoming internal requests which are delegated to the remote peer. pub(crate) internal_request_tx: Fuse>, /// All requests sent to the remote peer we're waiting on a response pub(crate) inflight_requests: FnvHashMap, - /// All requests that were sent by the remote peer. + /// All requests that were sent by the remote peer and we're waiting on an internal response pub(crate) received_requests_from_remote: Vec, /// Buffered messages that should be handled and sent to the peer. pub(crate) queued_outgoing: VecDeque, @@ -94,6 +95,8 @@ pub(crate) struct ActiveSession { /// If an [ActiveSession] does not receive a response at all within this duration then it is /// considered a protocol violation and the session will initiate a drop. pub(crate) protocol_breach_request_timeout: Duration, + /// Used to reserve a slot to guarantee that the termination message is delivered + pub(crate) terminate_message: Option<(PollSender, ActiveSessionMessage)>, } impl ActiveSession { @@ -118,7 +121,7 @@ impl ActiveSession { /// Handle a message read from the connection. /// /// Returns an error if the message is considered to be in violation of the protocol. - fn on_incoming(&mut self, msg: EthMessage) -> OnIncomingMessageOutcome { + fn on_incoming_message(&mut self, msg: EthMessage) -> OnIncomingMessageOutcome { /// A macro that handles an incoming request /// This creates a new channel and tries to send the sender half to the session while /// storing the receiver half internally so the pending response can be polled. @@ -247,7 +250,7 @@ impl ActiveSession { } /// Handle a message received from the internal network - fn on_peer_message(&mut self, msg: PeerMessage) { + fn on_internal_peer_message(&mut self, msg: PeerMessage) { match msg { PeerMessage::NewBlockHashes(msg) => { self.queued_outgoing.push_back(EthMessage::NewBlockHashes(msg).into()); @@ -283,6 +286,8 @@ impl ActiveSession { } /// Handle a Response to the peer + /// + /// This will queue the response to be sent to the peer fn handle_outgoing_response(&mut self, id: u64, resp: PeerResponseResult) { match resp.try_into_message(id) { Ok(msg) => { @@ -355,25 +360,28 @@ impl ActiveSession { } /// Report back that this session has been closed. - fn emit_disconnect(&self) { + fn emit_disconnect(&mut self, cx: &mut Context<'_>) -> Poll<()> { trace!(target: "net::session", remote_peer_id=?self.remote_peer_id, "emitting disconnect"); - // NOTE: we clone here so there's enough capacity to deliver this message - let _ = self.to_session_manager.clone().try_send(ActiveSessionMessage::Disconnected { + let msg = ActiveSessionMessage::Disconnected { peer_id: self.remote_peer_id, remote_addr: self.remote_addr, - }); + }; + + self.terminate_message = + Some((PollSender::new(self.to_session_manager.inner().clone()).clone(), msg)); + self.poll_terminate_message(cx).expect("message is set") } /// Report back that this session has been closed due to an error - fn close_on_error(&self, error: EthStreamError) { - // NOTE: we clone here so there's enough capacity to deliver this message - let _ = self.to_session_manager.clone().try_send( - ActiveSessionMessage::ClosedOnConnectionError { - peer_id: self.remote_peer_id, - remote_addr: self.remote_addr, - error, - }, - ); + fn close_on_error(&mut self, error: EthStreamError, cx: &mut Context<'_>) -> Poll<()> { + let msg = ActiveSessionMessage::ClosedOnConnectionError { + peer_id: self.remote_peer_id, + remote_addr: self.remote_addr, + error, + }; + self.terminate_message = + Some((PollSender::new(self.to_session_manager.inner().clone()).clone(), msg)); + self.poll_terminate_message(cx).expect("message is set") } /// Starts the disconnect process @@ -391,8 +399,7 @@ impl ActiveSession { // try to close the flush out the remaining Disconnect message let _ = ready!(self.conn.poll_close_unpin(cx)); - self.emit_disconnect(); - Poll::Ready(()) + self.emit_disconnect(cx) } /// Attempts to disconnect by sending the given disconnect reason @@ -404,8 +411,7 @@ impl ActiveSession { } Err(err) => { debug!(target: "net::session", ?err, remote_peer_id=?self.remote_peer_id, "could not send disconnect"); - self.close_on_error(err); - Poll::Ready(()) + self.close_on_error(err, cx) } } } @@ -443,6 +449,25 @@ impl ActiveSession { self.internal_request_timeout.store(request_timeout.as_millis() as u64, Ordering::Relaxed); self.internal_request_timeout_interval = tokio::time::interval(request_timeout); } + + /// If a termination message is queued this will try to send it + fn poll_terminate_message(&mut self, cx: &mut Context<'_>) -> Option> { + let (mut tx, msg) = self.terminate_message.take()?; + match tx.poll_reserve(cx) { + Poll::Pending => { + self.terminate_message = Some((tx, msg)); + return Some(Poll::Pending) + } + Poll::Ready(Ok(())) => { + let _ = tx.send_item(msg); + } + Poll::Ready(Err(_)) => { + // channel closed + } + } + // terminate the task + Some(Poll::Ready(())) + } } impl Future for ActiveSession { @@ -451,6 +476,11 @@ impl Future for ActiveSession { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); + // if the session is terminate we have to send the termination message before we can close + if let Some(terminate) = this.poll_terminate_message(cx) { + return terminate + } + if this.is_disconnecting() { return this.poll_disconnect(cx) } @@ -486,7 +516,7 @@ impl Future for ActiveSession { return this.try_disconnect(reason, cx) } SessionCommand::Message(msg) => { - this.on_peer_message(msg); + this.on_internal_peer_message(msg); } } } @@ -526,8 +556,7 @@ impl Future for ActiveSession { if let Err(err) = res { debug!(target: "net::session", ?err, remote_peer_id=?this.remote_peer_id, "failed to send message"); // notify the manager - this.close_on_error(err); - return Poll::Ready(()) + return this.close_on_error(err, cx) } } else { // no more messages to send over the wire @@ -571,8 +600,7 @@ impl Future for ActiveSession { break } else { debug!(target: "net::session", remote_peer_id=?this.remote_peer_id, "eth stream completed"); - this.emit_disconnect(); - return Poll::Ready(()) + return this.emit_disconnect(cx) } } Poll::Ready(Some(res)) => { @@ -580,15 +608,14 @@ impl Future for ActiveSession { Ok(msg) => { trace!(target: "net::session", msg_id=?msg.message_id(), remote_peer_id=?this.remote_peer_id, "received eth message"); // decode and handle message - match this.on_incoming(msg) { + match this.on_incoming_message(msg) { OnIncomingMessageOutcome::Ok => { // handled successfully progress = true; } OnIncomingMessageOutcome::BadMessage { error, message } => { debug!(target: "net::session", ?error, msg=?message, remote_peer_id=?this.remote_peer_id, "received invalid protocol message"); - this.close_on_error(error); - return Poll::Ready(()) + return this.close_on_error(error, cx) } OnIncomingMessageOutcome::NoCapacity(msg) => { // failed to send due to lack of capacity @@ -599,8 +626,7 @@ impl Future for ActiveSession { } Err(err) => { debug!(target: "net::session", ?err, remote_peer_id=?this.remote_peer_id, "failed to receive message"); - this.close_on_error(err); - return Poll::Ready(()) + return this.close_on_error(err, cx) } } } @@ -612,8 +638,7 @@ impl Future for ActiveSession { } } - if this.internal_request_timeout_interval.poll_tick(cx).is_ready() { - let _ = this.internal_request_timeout_interval.poll_tick(cx); + while this.internal_request_timeout_interval.poll_tick(cx).is_ready() { // check for timed out requests if this.check_timed_out_requests(Instant::now()) { let _ = this.to_session_manager.clone().try_send( @@ -664,6 +689,7 @@ impl InflightRequest { matches!(self.request, RequestState::Waiting(_)) } + /// This will timeout the request by sending an error response to the internal channel fn timeout(&mut self) { let mut req = RequestState::TimedOut; std::mem::swap(&mut self.request, &mut req); @@ -866,6 +892,7 @@ mod tests { INITIAL_REQUEST_TIMEOUT.as_millis() as u64, )), protocol_breach_request_timeout: PROTOCOL_BREACH_REQUEST_TIMEOUT, + terminate_message: None, } } ev => { diff --git a/crates/net/network/src/session/mod.rs b/crates/net/network/src/session/mod.rs index 04d3ffe1d..b5624c879 100644 --- a/crates/net/network/src/session/mod.rs +++ b/crates/net/network/src/session/mod.rs @@ -458,6 +458,7 @@ impl SessionManager { ), internal_request_timeout: Arc::clone(&timeout), protocol_breach_request_timeout: self.protocol_breach_request_timeout, + terminate_message: None, }; self.spawn(session);