perf/refactor: use tokio_util::sync::PollSender for ActiveSession -> SessionManager messages (#4603)

This commit is contained in:
Collin
2023-09-22 17:09:39 -05:00
committed by GitHub
parent 3018054772
commit 675c6bfc39
5 changed files with 101 additions and 36 deletions

1
Cargo.lock generated
View File

@ -5748,6 +5748,7 @@ dependencies = [
"metrics", "metrics",
"reth-metrics-derive", "reth-metrics-derive",
"tokio", "tokio",
"tokio-util",
] ]
[[package]] [[package]]

View File

@ -18,6 +18,7 @@ metrics.workspace = true
# async # async
tokio = { workspace = true, features = ["full"], optional = true } tokio = { workspace = true, features = ["full"], optional = true }
futures = { workspace = true, optional = true } futures = { workspace = true, optional = true }
tokio-util = { workspace = true, optional = true }
[features] [features]
common = ["tokio", "futures"] common = ["tokio", "futures", "tokio-util"]

View File

@ -13,6 +13,7 @@ use tokio::sync::mpsc::{
error::{SendError, TryRecvError, TrySendError}, error::{SendError, TryRecvError, TrySendError},
OwnedPermit, OwnedPermit,
}; };
use tokio_util::sync::{PollSendError, PollSender};
/// Wrapper around [mpsc::unbounded_channel] that returns a new unbounded metered channel. /// Wrapper around [mpsc::unbounded_channel] that returns a new unbounded metered channel.
pub fn metered_unbounded_channel<T>( pub fn metered_unbounded_channel<T>(
@ -265,3 +266,65 @@ struct MeteredReceiverMetrics {
/// Number of messages received /// Number of messages received
messages_received: Counter, messages_received: Counter,
} }
/// A wrapper type around [PollSender](PollSender) that updates metrics on send.
#[derive(Debug)]
pub struct MeteredPollSender<T> {
/// The [PollSender](PollSender) that this wraps around
sender: PollSender<T>,
/// Holds metrics for this type
metrics: MeteredPollSenderMetrics,
}
impl<T: Send + 'static> MeteredPollSender<T> {
/// Creates a new [`MeteredPollSender`] wrapping around the provided [PollSender](PollSender)
pub fn new(sender: PollSender<T>, scope: &'static str) -> Self {
Self { sender, metrics: MeteredPollSenderMetrics::new(scope) }
}
/// Returns the underlying [PollSender](PollSender).
pub fn inner(&self) -> &PollSender<T> {
&self.sender
}
/// Calls the underlying [PollSender](PollSender)'s `poll_reserve`, incrementing the appropriate
/// metrics depending on the result.
pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), PollSendError<T>>> {
match self.sender.poll_reserve(cx) {
Poll::Ready(Ok(permit)) => Poll::Ready(Ok(permit)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => {
self.metrics.back_pressure.increment(1);
Poll::Pending
}
}
}
/// Calls the underlying [PollSender](PollSender)'s `send_item`, incrementing the appropriate
/// metrics depending on the result.
pub fn send_item(&mut self, item: T) -> Result<(), PollSendError<T>> {
match self.sender.send_item(item) {
Ok(()) => {
self.metrics.messages_sent.increment(1);
Ok(())
}
Err(error) => Err(error),
}
}
}
impl<T> Clone for MeteredPollSender<T> {
fn clone(&self) -> Self {
Self { sender: self.sender.clone(), metrics: self.metrics.clone() }
}
}
/// Throughput metrics for [MeteredPollSender]
#[derive(Clone, Metrics)]
#[metrics(dynamic = true)]
struct MeteredPollSenderMetrics {
/// Number of messages sent
messages_sent: Counter,
/// Number of delayed message deliveries caused by a full channel
back_pressure: Counter,
}

View File

@ -19,7 +19,7 @@ use reth_eth_wire::{
DisconnectReason, EthMessage, EthStream, P2PStream, DisconnectReason, EthMessage, EthStream, P2PStream,
}; };
use reth_interfaces::p2p::error::RequestError; use reth_interfaces::p2p::error::RequestError;
use reth_metrics::common::mpsc::MeteredSender; use reth_metrics::common::mpsc::MeteredPollSender;
use reth_net_common::bandwidth_meter::MeteredStream; use reth_net_common::bandwidth_meter::MeteredStream;
use reth_primitives::PeerId; use reth_primitives::PeerId;
use std::{ use std::{
@ -77,7 +77,7 @@ pub(crate) struct ActiveSession {
/// Incoming commands from the manager /// Incoming commands from the manager
pub(crate) commands_rx: ReceiverStream<SessionCommand>, pub(crate) commands_rx: ReceiverStream<SessionCommand>,
/// Sink to send messages to the [`SessionManager`](super::SessionManager). /// Sink to send messages to the [`SessionManager`](super::SessionManager).
pub(crate) to_session_manager: MeteredSender<ActiveSessionMessage>, pub(crate) to_session_manager: MeteredPollSender<ActiveSessionMessage>,
/// A message that needs to be delivered to the session manager /// A message that needs to be delivered to the session manager
pub(crate) pending_message_to_session: Option<ActiveSessionMessage>, pub(crate) pending_message_to_session: Option<ActiveSessionMessage>,
/// Incoming internal requests which are delegated to the remote peer. /// Incoming internal requests which are delegated to the remote peer.
@ -304,8 +304,9 @@ impl ActiveSession {
/// Returns the message if the bounded channel is currently unable to handle this message. /// Returns the message if the bounded channel is currently unable to handle this message.
#[allow(clippy::result_large_err)] #[allow(clippy::result_large_err)]
fn try_emit_broadcast(&self, message: PeerMessage) -> Result<(), ActiveSessionMessage> { fn try_emit_broadcast(&self, message: PeerMessage) -> Result<(), ActiveSessionMessage> {
match self let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };
.to_session_manager
match sender
.try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message }) .try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
{ {
Ok(_) => Ok(()), Ok(_) => Ok(()),
@ -329,8 +330,9 @@ impl ActiveSession {
/// Returns the message if the bounded channel is currently unable to handle this message. /// Returns the message if the bounded channel is currently unable to handle this message.
#[allow(clippy::result_large_err)] #[allow(clippy::result_large_err)]
fn try_emit_request(&self, message: PeerMessage) -> Result<(), ActiveSessionMessage> { fn try_emit_request(&self, message: PeerMessage) -> Result<(), ActiveSessionMessage> {
match self let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };
.to_session_manager
match sender
.try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message }) .try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
{ {
Ok(_) => Ok(()), Ok(_) => Ok(()),
@ -354,9 +356,8 @@ impl ActiveSession {
/// Notify the manager that the peer sent a bad message /// Notify the manager that the peer sent a bad message
fn on_bad_message(&self) { fn on_bad_message(&self) {
let _ = self let Some(sender) = self.to_session_manager.inner().get_ref() else { return };
.to_session_manager let _ = sender.try_send(ActiveSessionMessage::BadMessage { peer_id: self.remote_peer_id });
.try_send(ActiveSessionMessage::BadMessage { peer_id: self.remote_peer_id });
} }
/// Report back that this session has been closed. /// Report back that this session has been closed.
@ -367,8 +368,7 @@ impl ActiveSession {
remote_addr: self.remote_addr, remote_addr: self.remote_addr,
}; };
self.terminate_message = self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
Some((PollSender::new(self.to_session_manager.inner().clone()).clone(), msg));
self.poll_terminate_message(cx).expect("message is set") self.poll_terminate_message(cx).expect("message is set")
} }
@ -379,8 +379,7 @@ impl ActiveSession {
remote_addr: self.remote_addr, remote_addr: self.remote_addr,
error, error,
}; };
self.terminate_message = self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
Some((PollSender::new(self.to_session_manager.inner().clone()).clone(), msg));
self.poll_terminate_message(cx).expect("message is set") self.poll_terminate_message(cx).expect("message is set")
} }
@ -575,22 +574,19 @@ impl Future for ActiveSession {
} }
// try to resend the pending message that we could not send because the channel was // try to resend the pending message that we could not send because the channel was
// full. // full. [`PollSender`] will ensure that we're woken up again when the channel is
// ready to receive the message, and will only error if the channel is closed.
if let Some(msg) = this.pending_message_to_session.take() { if let Some(msg) = this.pending_message_to_session.take() {
match this.to_session_manager.try_send(msg) { match this.to_session_manager.poll_reserve(cx) {
Ok(_) => {} Poll::Ready(Ok(_)) => {
Err(err) => { let _ = this.to_session_manager.send_item(msg);
match err {
TrySendError::Full(msg) => {
this.pending_message_to_session = Some(msg);
// ensure we're woken up again
cx.waker().wake_by_ref();
break 'receive
}
TrySendError::Closed(_) => {}
}
} }
} Poll::Ready(Err(_)) => return Poll::Ready(()),
Poll::Pending => {
this.pending_message_to_session = Some(msg);
break 'receive
}
};
} }
match this.conn.poll_next_unpin(cx) { match this.conn.poll_next_unpin(cx) {
@ -641,9 +637,10 @@ impl Future for ActiveSession {
while this.internal_request_timeout_interval.poll_tick(cx).is_ready() { while this.internal_request_timeout_interval.poll_tick(cx).is_ready() {
// check for timed out requests // check for timed out requests
if this.check_timed_out_requests(Instant::now()) { if this.check_timed_out_requests(Instant::now()) {
let _ = this.to_session_manager.clone().try_send( if let Poll::Ready(Ok(_)) = this.to_session_manager.poll_reserve(cx) {
ActiveSessionMessage::ProtocolBreach { peer_id: this.remote_peer_id }, let msg = ActiveSessionMessage::ProtocolBreach { peer_id: this.remote_peer_id };
); this.pending_message_to_session = Some(msg);
}
} }
} }
@ -865,6 +862,7 @@ mod tests {
} => { } => {
let (_to_session_tx, messages_rx) = mpsc::channel(10); let (_to_session_tx, messages_rx) = mpsc::channel(10);
let (commands_to_session, commands_rx) = mpsc::channel(10); let (commands_to_session, commands_rx) = mpsc::channel(10);
let poll_sender = PollSender::new(self.active_session_tx.clone());
self.to_sessions.push(commands_to_session); self.to_sessions.push(commands_to_session);
@ -875,8 +873,8 @@ mod tests {
remote_capabilities: Arc::clone(&capabilities), remote_capabilities: Arc::clone(&capabilities),
session_id, session_id,
commands_rx: ReceiverStream::new(commands_rx), commands_rx: ReceiverStream::new(commands_rx),
to_session_manager: MeteredSender::new( to_session_manager: MeteredPollSender::new(
self.active_session_tx.clone(), poll_sender,
"network_active_session", "network_active_session",
), ),
pending_message_to_session: None, pending_message_to_session: None,

View File

@ -12,7 +12,7 @@ use reth_eth_wire::{
errors::EthStreamError, errors::EthStreamError,
DisconnectReason, EthVersion, HelloMessage, Status, UnauthedEthStream, UnauthedP2PStream, DisconnectReason, EthVersion, HelloMessage, Status, UnauthedEthStream, UnauthedP2PStream,
}; };
use reth_metrics::common::mpsc::MeteredSender; use reth_metrics::common::mpsc::MeteredPollSender;
use reth_net_common::{ use reth_net_common::{
bandwidth_meter::{BandwidthMeter, MeteredStream}, bandwidth_meter::{BandwidthMeter, MeteredStream},
stream::HasRemoteAddr, stream::HasRemoteAddr,
@ -34,6 +34,7 @@ use tokio::{
sync::{mpsc, oneshot}, sync::{mpsc, oneshot},
}; };
use tokio_stream::wrappers::ReceiverStream; use tokio_stream::wrappers::ReceiverStream;
use tokio_util::sync::PollSender;
use tracing::{instrument, trace}; use tracing::{instrument, trace};
mod active; mod active;
@ -95,7 +96,7 @@ pub struct SessionManager {
/// ///
/// When active session state is reached, the corresponding [`ActiveSessionHandle`] will get a /// When active session state is reached, the corresponding [`ActiveSessionHandle`] will get a
/// clone of this sender half. /// clone of this sender half.
active_session_tx: MeteredSender<ActiveSessionMessage>, active_session_tx: MeteredPollSender<ActiveSessionMessage>,
/// Receiver half that listens for [`ActiveSessionMessage`] produced by pending sessions. /// Receiver half that listens for [`ActiveSessionMessage`] produced by pending sessions.
active_session_rx: ReceiverStream<ActiveSessionMessage>, active_session_rx: ReceiverStream<ActiveSessionMessage>,
/// Used to measure inbound & outbound bandwidth across all managed streams /// Used to measure inbound & outbound bandwidth across all managed streams
@ -119,6 +120,7 @@ impl SessionManager {
) -> Self { ) -> Self {
let (pending_sessions_tx, pending_sessions_rx) = mpsc::channel(config.session_event_buffer); let (pending_sessions_tx, pending_sessions_rx) = mpsc::channel(config.session_event_buffer);
let (active_session_tx, active_session_rx) = mpsc::channel(config.session_event_buffer); let (active_session_tx, active_session_rx) = mpsc::channel(config.session_event_buffer);
let active_session_tx = PollSender::new(active_session_tx);
Self { Self {
next_id: 0, next_id: 0,
@ -135,7 +137,7 @@ impl SessionManager {
active_sessions: Default::default(), active_sessions: Default::default(),
pending_sessions_tx, pending_sessions_tx,
pending_session_rx: ReceiverStream::new(pending_sessions_rx), pending_session_rx: ReceiverStream::new(pending_sessions_rx),
active_session_tx: MeteredSender::new(active_session_tx, "network_active_session"), active_session_tx: MeteredPollSender::new(active_session_tx, "network_active_session"),
active_session_rx: ReceiverStream::new(active_session_rx), active_session_rx: ReceiverStream::new(active_session_rx),
bandwidth_meter, bandwidth_meter,
metrics: Default::default(), metrics: Default::default(),