refactor: remove futureUnordered in ipc (#7920)

This commit is contained in:
Abner Zheng
2024-04-29 23:30:42 +08:00
committed by GitHub
parent a8cd1f71a0
commit fd8fdcfd4b
4 changed files with 162 additions and 118 deletions

View File

@ -9,7 +9,7 @@ use std::{
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::Framed;
use tower::Service;
@ -18,17 +18,6 @@ pub(crate) type JsonRpcStream<T> = Framed<T, StreamCodec>;
#[pin_project::pin_project]
pub(crate) struct IpcConn<T>(#[pin] pub(crate) T);
impl<T> IpcConn<JsonRpcStream<T>>
where
T: AsyncRead + AsyncWrite + Unpin,
{
/// Create a response for when the server is busy and can't accept more requests.
pub(crate) async fn reject_connection(self) {
let mut parts = self.0.into_parts();
let _ = parts.io.write_all(b"Too many connections. Please try again later.").await;
}
}
impl<T> Stream for IpcConn<JsonRpcStream<T>>
where
T: AsyncRead + AsyncWrite,

View File

@ -27,8 +27,7 @@
//! Utilities for handling async code.
use std::sync::Arc;
use tokio::sync::{watch, OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::sync::watch;
#[derive(Debug, Clone)]
pub(crate) struct StopHandle(watch::Receiver<()>);
@ -59,27 +58,3 @@ impl ServerHandle {
self.0.closed().await
}
}
/// Limits the number of connections.
pub(crate) struct ConnectionGuard(Arc<Semaphore>);
impl ConnectionGuard {
pub(crate) fn new(limit: usize) -> Self {
Self(Arc::new(Semaphore::new(limit)))
}
pub(crate) fn try_acquire(&self) -> Option<OwnedSemaphorePermit> {
match self.0.clone().try_acquire_owned() {
Ok(guard) => Some(guard),
Err(TryAcquireError::Closed) => {
unreachable!("Semaphore::Close is never called and can't be closed")
}
Err(TryAcquireError::NoPermits) => None,
}
}
#[allow(dead_code)]
pub(crate) fn available_connections(&self) -> usize {
self.0.available_permits()
}
}

View File

@ -1,7 +1,5 @@
//! IPC request handling adapted from [`jsonrpsee`] http request handling
use std::sync::Arc;
use futures::{stream::FuturesOrdered, StreamExt};
use jsonrpsee::{
batch_response_error,
@ -17,6 +15,7 @@ use jsonrpsee::{
},
BatchResponseBuilder, MethodResponse, ResponsePayload,
};
use std::sync::Arc;
use tokio::sync::OwnedSemaphorePermit;
use tokio_util::either::Either;
use tracing::instrument;

View File

@ -2,16 +2,17 @@
use crate::server::{
connection::{IpcConn, JsonRpcStream},
future::{ConnectionGuard, StopHandle},
future::StopHandle,
};
use futures::StreamExt;
use futures_util::{future::Either, stream::FuturesUnordered};
use futures_util::{future::Either, AsyncWriteExt};
use interprocess::local_socket::tokio::{LocalSocketListener, LocalSocketStream};
use jsonrpsee::{
core::TEN_MB_SIZE_BYTES,
server::{
middleware::rpc::{RpcLoggerLayer, RpcServiceT},
AlreadyStoppedError, IdProvider, RandomIntegerIdProvider,
AlreadyStoppedError, ConnectionGuard, ConnectionPermit, IdProvider,
RandomIntegerIdProvider,
},
BoundedSubscriptions, MethodSink, Methods,
};
@ -24,10 +25,10 @@ use std::{
};
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::{oneshot, watch, OwnedSemaphorePermit},
sync::{oneshot, watch},
};
use tower::{layer::util::Identity, Layer, Service};
use tracing::{debug, trace, warn, Instrument};
use tracing::{debug, instrument, trace, warn, Instrument};
// re-export so can be used during builder setup
use crate::{
server::{
@ -150,68 +151,44 @@ where
// signal that we're ready to accept connections
on_ready.send(Ok(())).ok();
let message_buffer_capacity = self.cfg.message_buffer_capacity;
let max_request_body_size = self.cfg.max_request_body_size;
let max_response_body_size = self.cfg.max_response_body_size;
let max_log_length = self.cfg.max_log_length;
let id_provider = self.id_provider;
let max_subscriptions_per_connection = self.cfg.max_subscriptions_per_connection;
let mut id: u32 = 0;
let connection_guard = ConnectionGuard::new(self.cfg.max_connections as usize);
let mut connections = FuturesUnordered::new();
let stopped = stop_handle.clone().shutdown();
tokio::pin!(stopped);
let (drop_on_completion, mut process_connection_awaiter) = mpsc::channel::<()>(1);
trace!("accepting ipc connections");
loop {
match try_accept_conn(&listener, stopped).await {
AcceptConnection::Established { local_socket_stream, stop } => {
trace!("established new connection");
let ipc = IpcConn(tokio_util::codec::Decoder::framed(
StreamCodec::stream_incoming(),
local_socket_stream.compat(),
));
let conn = match connection_guard.try_acquire() {
Some(conn) => conn,
None => {
warn!("Too many IPC connections. Please try again later.");
connections.push(tokio::spawn(ipc.reject_connection().in_current_span()));
stopped = stop;
continue;
}
let Some(conn_permit) = connection_guard.try_acquire() else {
let (mut _reader, mut writer) = local_socket_stream.into_split();
let _ = writer.write_all(b"Too many connections. Please try again later.").await;
drop((_reader, writer));
stopped = stop;
continue;
};
let (tx, rx) = mpsc::channel::<String>(message_buffer_capacity as usize);
let method_sink = MethodSink::new_with_limit(tx, max_response_body_size);
let tower_service = TowerServiceNoHttp {
inner: ServiceData {
methods: methods.clone(),
max_request_body_size,
max_response_body_size,
max_log_length,
id_provider: id_provider.clone(),
stop_handle: stop_handle.clone(),
max_subscriptions_per_connection,
conn_id: id,
conn: Arc::new(conn),
bounded_subscriptions: BoundedSubscriptions::new(
max_subscriptions_per_connection,
),
method_sink,
},
let max_conns = connection_guard.max_connections();
let curr_conns = max_conns - connection_guard.available_connections();
trace!("Accepting new connection {}/{}", curr_conns, max_conns);
let conn_permit = Arc::new(conn_permit);
process_connection(ProcessConnection{
http_middleware: &self.http_middleware,
rpc_middleware: self.rpc_middleware.clone(),
};
let service = self.http_middleware.service(tower_service);
connections.push(tokio::spawn(process_connection(
ipc,
service,
stop_handle.clone(),
rx,
).in_current_span()));
conn_permit,
conn_id: id,
server_cfg: self.cfg.clone(),
stop_handle: stop_handle.clone(),
drop_on_completion: drop_on_completion.clone(),
methods: methods.clone(),
id_provider: self.id_provider.clone(),
local_socket_stream,
});
id = id.wrapping_add(1);
stopped = stop;
@ -224,11 +201,14 @@ where
}
}
// FuturesUnordered won't poll anything until this line but because the
// tasks are spawned (so that they can progress independently)
// then this just makes sure that all tasks are completed before
// returning from this function.
while connections.next().await.is_some() {}
// Drop the last Sender
drop(drop_on_completion);
// Once this channel is closed it is safe to assume that all connections have been gracefully shutdown
while process_connection_awaiter.recv().await.is_some() {
// Generally, messages should not be sent across this channel,
// but we'll loop here to wait for `None` just to be on the safe side
}
}
}
@ -279,30 +259,22 @@ pub struct IpcServerStartError {
pub(crate) struct ServiceData {
/// Registered server methods.
pub(crate) methods: Methods,
/// Max request body size.
pub(crate) max_request_body_size: u32,
/// Max request body size.
pub(crate) max_response_body_size: u32,
/// Max length for logging for request and response
///
/// Logs bigger than this limit will be truncated.
pub(crate) max_log_length: u32,
/// Subscription ID provider.
pub(crate) id_provider: Arc<dyn IdProvider>,
/// Stop handle.
pub(crate) stop_handle: StopHandle,
/// Max subscriptions per connection.
pub(crate) max_subscriptions_per_connection: u32,
/// Connection ID
pub(crate) conn_id: u32,
/// Handle to hold a `connection permit`.
pub(crate) conn: Arc<OwnedSemaphorePermit>,
/// Connection Permit.
pub(crate) conn_permit: Arc<ConnectionPermit>,
/// Limits the number of subscriptions for this connection
pub(crate) bounded_subscriptions: BoundedSubscriptions,
/// Sink that is used to send back responses to the connection.
///
/// This is used for subscriptions.
pub(crate) method_sink: MethodSink,
/// ServerConfig
pub(crate) server_cfg: Settings,
}
/// Similar to [`tower::ServiceBuilder`] but doesn't
@ -407,21 +379,21 @@ where
let cfg = RpcServiceCfg::CallsAndSubscriptions {
bounded_subscriptions: BoundedSubscriptions::new(
self.inner.max_subscriptions_per_connection,
self.inner.server_cfg.max_subscriptions_per_connection,
),
id_provider: self.inner.id_provider.clone(),
sink: self.inner.method_sink.clone(),
};
let max_response_body_size = self.inner.max_response_body_size as usize;
let max_request_body_size = self.inner.max_request_body_size as usize;
let max_response_body_size = self.inner.server_cfg.max_response_body_size as usize;
let max_request_body_size = self.inner.server_cfg.max_request_body_size as usize;
let conn = self.inner.conn_permit.clone();
let rpc_service = self.rpc_middleware.service(RpcService::new(
self.inner.methods.clone(),
max_response_body_size,
self.inner.conn_id as usize,
cfg,
));
let conn = self.inner.conn.clone();
// an ipc connection needs to handle read+write concurrently
// even if the underlying rpc handler spawns the actual work or is does a lot of async any
// additional overhead performed by `handle_request` can result in I/O latencies, for
@ -443,9 +415,81 @@ where
}
}
struct ProcessConnection<'a, HttpMiddleware, RpcMiddleware> {
http_middleware: &'a tower::ServiceBuilder<HttpMiddleware>,
rpc_middleware: RpcServiceBuilder<RpcMiddleware>,
conn_permit: Arc<ConnectionPermit>,
conn_id: u32,
server_cfg: Settings,
stop_handle: StopHandle,
drop_on_completion: mpsc::Sender<()>,
methods: Methods,
id_provider: Arc<dyn IdProvider>,
local_socket_stream: LocalSocketStream,
}
/// Spawns the IPC connection onto a new task
async fn process_connection<S, T>(
conn: IpcConn<JsonRpcStream<T>>,
#[instrument(name = "connection", skip_all, fields(conn_id = %params.conn_id), level = "INFO")]
fn process_connection<'b, RpcMiddleware, HttpMiddleware>(
params: ProcessConnection<'_, HttpMiddleware, RpcMiddleware>,
) where
RpcMiddleware: Layer<RpcService> + Clone + Send + 'static,
for<'a> <RpcMiddleware as Layer<RpcService>>::Service: RpcServiceT<'a>,
HttpMiddleware: Layer<TowerServiceNoHttp<RpcMiddleware>> + Send + 'static,
<HttpMiddleware as Layer<TowerServiceNoHttp<RpcMiddleware>>>::Service: Send
+ Service<
String,
Response = Option<String>,
Error = Box<dyn std::error::Error + Send + Sync + 'static>,
>,
<<HttpMiddleware as Layer<TowerServiceNoHttp<RpcMiddleware>>>::Service as Service<String>>::Future:
Send + Unpin,
{
let ProcessConnection {
http_middleware,
rpc_middleware,
conn_permit,
conn_id,
server_cfg,
stop_handle,
drop_on_completion,
id_provider,
methods,
local_socket_stream,
} = params;
let ipc = IpcConn(tokio_util::codec::Decoder::framed(
StreamCodec::stream_incoming(),
local_socket_stream.compat(),
));
let (tx, rx) = mpsc::channel::<String>(server_cfg.message_buffer_capacity as usize);
let method_sink = MethodSink::new_with_limit(tx, server_cfg.max_response_body_size);
let tower_service = TowerServiceNoHttp {
inner: ServiceData {
methods,
id_provider,
stop_handle: stop_handle.clone(),
server_cfg: server_cfg.clone(),
conn_id,
conn_permit,
bounded_subscriptions: BoundedSubscriptions::new(
server_cfg.max_subscriptions_per_connection,
),
method_sink,
},
rpc_middleware,
};
let service = http_middleware.service(tower_service);
tokio::spawn(async {
to_ipc_service(ipc, service, stop_handle, rx).in_current_span().await;
drop(drop_on_completion)
});
}
async fn to_ipc_service<S, T>(
ipc: IpcConn<JsonRpcStream<T>>,
service: S,
stop_handle: StopHandle,
rx: mpsc::Receiver<String>,
@ -457,7 +501,7 @@ async fn process_connection<S, T>(
{
let rx_item = ReceiverStream::new(rx);
let conn = IpcConnDriver {
conn,
conn: ipc,
service,
pending_calls: Default::default(),
items: Default::default(),
@ -799,6 +843,7 @@ mod tests {
types::Request,
PendingSubscriptionSink, RpcModule, SubscriptionMessage,
};
use reth_tracing::init_test_tracing;
use tokio::sync::broadcast;
use tokio_stream::wrappers::BroadcastStream;
@ -864,6 +909,7 @@ mod tests {
#[tokio::test]
async fn can_set_the_max_request_body_size() {
init_test_tracing();
let endpoint = dummy_endpoint();
let server = Builder::default().max_request_body_size(100).build(&endpoint);
let mut module = RpcModule::new(());
@ -888,8 +934,43 @@ mod tests {
assert!(response.is_err());
}
#[tokio::test]
async fn can_set_max_connections() {
init_test_tracing();
let endpoint = dummy_endpoint();
let server = Builder::default().max_connections(2).build(&endpoint);
let mut module = RpcModule::new(());
module.register_method("anything", |_, _| "succeed").unwrap();
let handle = server.start(module).await.unwrap();
tokio::spawn(handle.stopped());
let client1 = IpcClientBuilder::default().build(endpoint.clone()).await.unwrap();
let client2 = IpcClientBuilder::default().build(endpoint.clone()).await.unwrap();
let client3 = IpcClientBuilder::default().build(endpoint.clone()).await.unwrap();
let response1: Result<String, Error> = client1.request("anything", rpc_params![]).await;
let response2: Result<String, Error> = client2.request("anything", rpc_params![]).await;
let response3: Result<String, Error> = client3.request("anything", rpc_params![]).await;
assert!(response1.is_ok());
assert!(response2.is_ok());
// Third connection is rejected
assert!(response3.is_err());
// Decrement connection count
drop(client2);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Can connect again
let client4 = IpcClientBuilder::default().build(endpoint.clone()).await.unwrap();
let response4: Result<String, Error> = client4.request("anything", rpc_params![]).await;
assert!(response4.is_ok());
}
#[tokio::test]
async fn test_rpc_request() {
init_test_tracing();
let endpoint = dummy_endpoint();
let server = Builder::default().build(&endpoint);
let mut module = RpcModule::new(());