mirror of
https://github.com/hl-archive-node/nanoreth.git
synced 2025-12-06 10:59:55 +00:00
refactor: remove futureUnordered in ipc (#7920)
This commit is contained in:
@ -9,7 +9,7 @@ use std::{
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::codec::Framed;
|
||||
use tower::Service;
|
||||
|
||||
@ -18,17 +18,6 @@ pub(crate) type JsonRpcStream<T> = Framed<T, StreamCodec>;
|
||||
#[pin_project::pin_project]
|
||||
pub(crate) struct IpcConn<T>(#[pin] pub(crate) T);
|
||||
|
||||
impl<T> IpcConn<JsonRpcStream<T>>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
/// Create a response for when the server is busy and can't accept more requests.
|
||||
pub(crate) async fn reject_connection(self) {
|
||||
let mut parts = self.0.into_parts();
|
||||
let _ = parts.io.write_all(b"Too many connections. Please try again later.").await;
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Stream for IpcConn<JsonRpcStream<T>>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite,
|
||||
|
||||
@ -27,8 +27,7 @@
|
||||
//! Utilities for handling async code.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::sync::{watch, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||
use tokio::sync::watch;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct StopHandle(watch::Receiver<()>);
|
||||
@ -59,27 +58,3 @@ impl ServerHandle {
|
||||
self.0.closed().await
|
||||
}
|
||||
}
|
||||
|
||||
/// Limits the number of connections.
|
||||
pub(crate) struct ConnectionGuard(Arc<Semaphore>);
|
||||
|
||||
impl ConnectionGuard {
|
||||
pub(crate) fn new(limit: usize) -> Self {
|
||||
Self(Arc::new(Semaphore::new(limit)))
|
||||
}
|
||||
|
||||
pub(crate) fn try_acquire(&self) -> Option<OwnedSemaphorePermit> {
|
||||
match self.0.clone().try_acquire_owned() {
|
||||
Ok(guard) => Some(guard),
|
||||
Err(TryAcquireError::Closed) => {
|
||||
unreachable!("Semaphore::Close is never called and can't be closed")
|
||||
}
|
||||
Err(TryAcquireError::NoPermits) => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn available_connections(&self) -> usize {
|
||||
self.0.available_permits()
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
//! IPC request handling adapted from [`jsonrpsee`] http request handling
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::{stream::FuturesOrdered, StreamExt};
|
||||
use jsonrpsee::{
|
||||
batch_response_error,
|
||||
@ -17,6 +15,7 @@ use jsonrpsee::{
|
||||
},
|
||||
BatchResponseBuilder, MethodResponse, ResponsePayload,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::OwnedSemaphorePermit;
|
||||
use tokio_util::either::Either;
|
||||
use tracing::instrument;
|
||||
|
||||
@ -2,16 +2,17 @@
|
||||
|
||||
use crate::server::{
|
||||
connection::{IpcConn, JsonRpcStream},
|
||||
future::{ConnectionGuard, StopHandle},
|
||||
future::StopHandle,
|
||||
};
|
||||
use futures::StreamExt;
|
||||
use futures_util::{future::Either, stream::FuturesUnordered};
|
||||
use futures_util::{future::Either, AsyncWriteExt};
|
||||
use interprocess::local_socket::tokio::{LocalSocketListener, LocalSocketStream};
|
||||
use jsonrpsee::{
|
||||
core::TEN_MB_SIZE_BYTES,
|
||||
server::{
|
||||
middleware::rpc::{RpcLoggerLayer, RpcServiceT},
|
||||
AlreadyStoppedError, IdProvider, RandomIntegerIdProvider,
|
||||
AlreadyStoppedError, ConnectionGuard, ConnectionPermit, IdProvider,
|
||||
RandomIntegerIdProvider,
|
||||
},
|
||||
BoundedSubscriptions, MethodSink, Methods,
|
||||
};
|
||||
@ -24,10 +25,10 @@ use std::{
|
||||
};
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncWrite},
|
||||
sync::{oneshot, watch, OwnedSemaphorePermit},
|
||||
sync::{oneshot, watch},
|
||||
};
|
||||
use tower::{layer::util::Identity, Layer, Service};
|
||||
use tracing::{debug, trace, warn, Instrument};
|
||||
use tracing::{debug, instrument, trace, warn, Instrument};
|
||||
// re-export so can be used during builder setup
|
||||
use crate::{
|
||||
server::{
|
||||
@ -150,68 +151,44 @@ where
|
||||
// signal that we're ready to accept connections
|
||||
on_ready.send(Ok(())).ok();
|
||||
|
||||
let message_buffer_capacity = self.cfg.message_buffer_capacity;
|
||||
let max_request_body_size = self.cfg.max_request_body_size;
|
||||
let max_response_body_size = self.cfg.max_response_body_size;
|
||||
let max_log_length = self.cfg.max_log_length;
|
||||
let id_provider = self.id_provider;
|
||||
let max_subscriptions_per_connection = self.cfg.max_subscriptions_per_connection;
|
||||
|
||||
let mut id: u32 = 0;
|
||||
let connection_guard = ConnectionGuard::new(self.cfg.max_connections as usize);
|
||||
|
||||
let mut connections = FuturesUnordered::new();
|
||||
let stopped = stop_handle.clone().shutdown();
|
||||
tokio::pin!(stopped);
|
||||
|
||||
let (drop_on_completion, mut process_connection_awaiter) = mpsc::channel::<()>(1);
|
||||
|
||||
trace!("accepting ipc connections");
|
||||
loop {
|
||||
match try_accept_conn(&listener, stopped).await {
|
||||
AcceptConnection::Established { local_socket_stream, stop } => {
|
||||
trace!("established new connection");
|
||||
let ipc = IpcConn(tokio_util::codec::Decoder::framed(
|
||||
StreamCodec::stream_incoming(),
|
||||
local_socket_stream.compat(),
|
||||
));
|
||||
|
||||
let conn = match connection_guard.try_acquire() {
|
||||
Some(conn) => conn,
|
||||
None => {
|
||||
warn!("Too many IPC connections. Please try again later.");
|
||||
connections.push(tokio::spawn(ipc.reject_connection().in_current_span()));
|
||||
stopped = stop;
|
||||
continue;
|
||||
}
|
||||
let Some(conn_permit) = connection_guard.try_acquire() else {
|
||||
let (mut _reader, mut writer) = local_socket_stream.into_split();
|
||||
let _ = writer.write_all(b"Too many connections. Please try again later.").await;
|
||||
drop((_reader, writer));
|
||||
stopped = stop;
|
||||
continue;
|
||||
};
|
||||
|
||||
let (tx, rx) = mpsc::channel::<String>(message_buffer_capacity as usize);
|
||||
let method_sink = MethodSink::new_with_limit(tx, max_response_body_size);
|
||||
let tower_service = TowerServiceNoHttp {
|
||||
inner: ServiceData {
|
||||
methods: methods.clone(),
|
||||
max_request_body_size,
|
||||
max_response_body_size,
|
||||
max_log_length,
|
||||
id_provider: id_provider.clone(),
|
||||
stop_handle: stop_handle.clone(),
|
||||
max_subscriptions_per_connection,
|
||||
conn_id: id,
|
||||
conn: Arc::new(conn),
|
||||
bounded_subscriptions: BoundedSubscriptions::new(
|
||||
max_subscriptions_per_connection,
|
||||
),
|
||||
method_sink,
|
||||
},
|
||||
let max_conns = connection_guard.max_connections();
|
||||
let curr_conns = max_conns - connection_guard.available_connections();
|
||||
trace!("Accepting new connection {}/{}", curr_conns, max_conns);
|
||||
|
||||
let conn_permit = Arc::new(conn_permit);
|
||||
|
||||
process_connection(ProcessConnection{
|
||||
http_middleware: &self.http_middleware,
|
||||
rpc_middleware: self.rpc_middleware.clone(),
|
||||
};
|
||||
|
||||
let service = self.http_middleware.service(tower_service);
|
||||
connections.push(tokio::spawn(process_connection(
|
||||
ipc,
|
||||
service,
|
||||
stop_handle.clone(),
|
||||
rx,
|
||||
).in_current_span()));
|
||||
conn_permit,
|
||||
conn_id: id,
|
||||
server_cfg: self.cfg.clone(),
|
||||
stop_handle: stop_handle.clone(),
|
||||
drop_on_completion: drop_on_completion.clone(),
|
||||
methods: methods.clone(),
|
||||
id_provider: self.id_provider.clone(),
|
||||
local_socket_stream,
|
||||
});
|
||||
|
||||
id = id.wrapping_add(1);
|
||||
stopped = stop;
|
||||
@ -224,11 +201,14 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
// FuturesUnordered won't poll anything until this line but because the
|
||||
// tasks are spawned (so that they can progress independently)
|
||||
// then this just makes sure that all tasks are completed before
|
||||
// returning from this function.
|
||||
while connections.next().await.is_some() {}
|
||||
// Drop the last Sender
|
||||
drop(drop_on_completion);
|
||||
|
||||
// Once this channel is closed it is safe to assume that all connections have been gracefully shutdown
|
||||
while process_connection_awaiter.recv().await.is_some() {
|
||||
// Generally, messages should not be sent across this channel,
|
||||
// but we'll loop here to wait for `None` just to be on the safe side
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -279,30 +259,22 @@ pub struct IpcServerStartError {
|
||||
pub(crate) struct ServiceData {
|
||||
/// Registered server methods.
|
||||
pub(crate) methods: Methods,
|
||||
/// Max request body size.
|
||||
pub(crate) max_request_body_size: u32,
|
||||
/// Max request body size.
|
||||
pub(crate) max_response_body_size: u32,
|
||||
/// Max length for logging for request and response
|
||||
///
|
||||
/// Logs bigger than this limit will be truncated.
|
||||
pub(crate) max_log_length: u32,
|
||||
/// Subscription ID provider.
|
||||
pub(crate) id_provider: Arc<dyn IdProvider>,
|
||||
/// Stop handle.
|
||||
pub(crate) stop_handle: StopHandle,
|
||||
/// Max subscriptions per connection.
|
||||
pub(crate) max_subscriptions_per_connection: u32,
|
||||
/// Connection ID
|
||||
pub(crate) conn_id: u32,
|
||||
/// Handle to hold a `connection permit`.
|
||||
pub(crate) conn: Arc<OwnedSemaphorePermit>,
|
||||
/// Connection Permit.
|
||||
pub(crate) conn_permit: Arc<ConnectionPermit>,
|
||||
/// Limits the number of subscriptions for this connection
|
||||
pub(crate) bounded_subscriptions: BoundedSubscriptions,
|
||||
/// Sink that is used to send back responses to the connection.
|
||||
///
|
||||
/// This is used for subscriptions.
|
||||
pub(crate) method_sink: MethodSink,
|
||||
/// ServerConfig
|
||||
pub(crate) server_cfg: Settings,
|
||||
}
|
||||
|
||||
/// Similar to [`tower::ServiceBuilder`] but doesn't
|
||||
@ -407,21 +379,21 @@ where
|
||||
|
||||
let cfg = RpcServiceCfg::CallsAndSubscriptions {
|
||||
bounded_subscriptions: BoundedSubscriptions::new(
|
||||
self.inner.max_subscriptions_per_connection,
|
||||
self.inner.server_cfg.max_subscriptions_per_connection,
|
||||
),
|
||||
id_provider: self.inner.id_provider.clone(),
|
||||
sink: self.inner.method_sink.clone(),
|
||||
};
|
||||
|
||||
let max_response_body_size = self.inner.max_response_body_size as usize;
|
||||
let max_request_body_size = self.inner.max_request_body_size as usize;
|
||||
let max_response_body_size = self.inner.server_cfg.max_response_body_size as usize;
|
||||
let max_request_body_size = self.inner.server_cfg.max_request_body_size as usize;
|
||||
let conn = self.inner.conn_permit.clone();
|
||||
let rpc_service = self.rpc_middleware.service(RpcService::new(
|
||||
self.inner.methods.clone(),
|
||||
max_response_body_size,
|
||||
self.inner.conn_id as usize,
|
||||
cfg,
|
||||
));
|
||||
let conn = self.inner.conn.clone();
|
||||
// an ipc connection needs to handle read+write concurrently
|
||||
// even if the underlying rpc handler spawns the actual work or is does a lot of async any
|
||||
// additional overhead performed by `handle_request` can result in I/O latencies, for
|
||||
@ -443,9 +415,81 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
struct ProcessConnection<'a, HttpMiddleware, RpcMiddleware> {
|
||||
http_middleware: &'a tower::ServiceBuilder<HttpMiddleware>,
|
||||
rpc_middleware: RpcServiceBuilder<RpcMiddleware>,
|
||||
conn_permit: Arc<ConnectionPermit>,
|
||||
conn_id: u32,
|
||||
server_cfg: Settings,
|
||||
stop_handle: StopHandle,
|
||||
drop_on_completion: mpsc::Sender<()>,
|
||||
methods: Methods,
|
||||
id_provider: Arc<dyn IdProvider>,
|
||||
local_socket_stream: LocalSocketStream,
|
||||
}
|
||||
|
||||
/// Spawns the IPC connection onto a new task
|
||||
async fn process_connection<S, T>(
|
||||
conn: IpcConn<JsonRpcStream<T>>,
|
||||
#[instrument(name = "connection", skip_all, fields(conn_id = %params.conn_id), level = "INFO")]
|
||||
fn process_connection<'b, RpcMiddleware, HttpMiddleware>(
|
||||
params: ProcessConnection<'_, HttpMiddleware, RpcMiddleware>,
|
||||
) where
|
||||
RpcMiddleware: Layer<RpcService> + Clone + Send + 'static,
|
||||
for<'a> <RpcMiddleware as Layer<RpcService>>::Service: RpcServiceT<'a>,
|
||||
HttpMiddleware: Layer<TowerServiceNoHttp<RpcMiddleware>> + Send + 'static,
|
||||
<HttpMiddleware as Layer<TowerServiceNoHttp<RpcMiddleware>>>::Service: Send
|
||||
+ Service<
|
||||
String,
|
||||
Response = Option<String>,
|
||||
Error = Box<dyn std::error::Error + Send + Sync + 'static>,
|
||||
>,
|
||||
<<HttpMiddleware as Layer<TowerServiceNoHttp<RpcMiddleware>>>::Service as Service<String>>::Future:
|
||||
Send + Unpin,
|
||||
{
|
||||
let ProcessConnection {
|
||||
http_middleware,
|
||||
rpc_middleware,
|
||||
conn_permit,
|
||||
conn_id,
|
||||
server_cfg,
|
||||
stop_handle,
|
||||
drop_on_completion,
|
||||
id_provider,
|
||||
methods,
|
||||
local_socket_stream,
|
||||
} = params;
|
||||
|
||||
let ipc = IpcConn(tokio_util::codec::Decoder::framed(
|
||||
StreamCodec::stream_incoming(),
|
||||
local_socket_stream.compat(),
|
||||
));
|
||||
|
||||
let (tx, rx) = mpsc::channel::<String>(server_cfg.message_buffer_capacity as usize);
|
||||
let method_sink = MethodSink::new_with_limit(tx, server_cfg.max_response_body_size);
|
||||
let tower_service = TowerServiceNoHttp {
|
||||
inner: ServiceData {
|
||||
methods,
|
||||
id_provider,
|
||||
stop_handle: stop_handle.clone(),
|
||||
server_cfg: server_cfg.clone(),
|
||||
conn_id,
|
||||
conn_permit,
|
||||
bounded_subscriptions: BoundedSubscriptions::new(
|
||||
server_cfg.max_subscriptions_per_connection,
|
||||
),
|
||||
method_sink,
|
||||
},
|
||||
rpc_middleware,
|
||||
};
|
||||
|
||||
let service = http_middleware.service(tower_service);
|
||||
tokio::spawn(async {
|
||||
to_ipc_service(ipc, service, stop_handle, rx).in_current_span().await;
|
||||
drop(drop_on_completion)
|
||||
});
|
||||
}
|
||||
|
||||
async fn to_ipc_service<S, T>(
|
||||
ipc: IpcConn<JsonRpcStream<T>>,
|
||||
service: S,
|
||||
stop_handle: StopHandle,
|
||||
rx: mpsc::Receiver<String>,
|
||||
@ -457,7 +501,7 @@ async fn process_connection<S, T>(
|
||||
{
|
||||
let rx_item = ReceiverStream::new(rx);
|
||||
let conn = IpcConnDriver {
|
||||
conn,
|
||||
conn: ipc,
|
||||
service,
|
||||
pending_calls: Default::default(),
|
||||
items: Default::default(),
|
||||
@ -799,6 +843,7 @@ mod tests {
|
||||
types::Request,
|
||||
PendingSubscriptionSink, RpcModule, SubscriptionMessage,
|
||||
};
|
||||
use reth_tracing::init_test_tracing;
|
||||
use tokio::sync::broadcast;
|
||||
use tokio_stream::wrappers::BroadcastStream;
|
||||
|
||||
@ -864,6 +909,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn can_set_the_max_request_body_size() {
|
||||
init_test_tracing();
|
||||
let endpoint = dummy_endpoint();
|
||||
let server = Builder::default().max_request_body_size(100).build(&endpoint);
|
||||
let mut module = RpcModule::new(());
|
||||
@ -888,8 +934,43 @@ mod tests {
|
||||
assert!(response.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn can_set_max_connections() {
|
||||
init_test_tracing();
|
||||
|
||||
let endpoint = dummy_endpoint();
|
||||
let server = Builder::default().max_connections(2).build(&endpoint);
|
||||
let mut module = RpcModule::new(());
|
||||
module.register_method("anything", |_, _| "succeed").unwrap();
|
||||
let handle = server.start(module).await.unwrap();
|
||||
tokio::spawn(handle.stopped());
|
||||
|
||||
let client1 = IpcClientBuilder::default().build(endpoint.clone()).await.unwrap();
|
||||
let client2 = IpcClientBuilder::default().build(endpoint.clone()).await.unwrap();
|
||||
let client3 = IpcClientBuilder::default().build(endpoint.clone()).await.unwrap();
|
||||
|
||||
let response1: Result<String, Error> = client1.request("anything", rpc_params![]).await;
|
||||
let response2: Result<String, Error> = client2.request("anything", rpc_params![]).await;
|
||||
let response3: Result<String, Error> = client3.request("anything", rpc_params![]).await;
|
||||
|
||||
assert!(response1.is_ok());
|
||||
assert!(response2.is_ok());
|
||||
// Third connection is rejected
|
||||
assert!(response3.is_err());
|
||||
|
||||
// Decrement connection count
|
||||
drop(client2);
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
|
||||
// Can connect again
|
||||
let client4 = IpcClientBuilder::default().build(endpoint.clone()).await.unwrap();
|
||||
let response4: Result<String, Error> = client4.request("anything", rpc_params![]).await;
|
||||
assert!(response4.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rpc_request() {
|
||||
init_test_tracing();
|
||||
let endpoint = dummy_endpoint();
|
||||
let server = Builder::default().build(&endpoint);
|
||||
let mut module = RpcModule::new(());
|
||||
|
||||
Reference in New Issue
Block a user