diff --git a/Cargo.lock b/Cargo.lock index 0ee5b47fc..3bf50b5dc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5327,7 +5327,6 @@ dependencies = [ "tokio-util", "tower", "tracing", - "tracing-test", ] [[package]] @@ -7400,29 +7399,6 @@ dependencies = [ "tracing-log", ] -[[package]] -name = "tracing-test" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a2c0ff408fe918a94c428a3f2ad04e4afd5c95bbc08fcf868eff750c15728a4" -dependencies = [ - "lazy_static", - "tracing-core", - "tracing-subscriber", - "tracing-test-macro", -] - -[[package]] -name = "tracing-test-macro" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "258bc1c4f8e2e73a977812ab339d503e6feeb92700f6d07a6de4d321522d5c08" -dependencies = [ - "lazy_static", - "quote 1.0.28", - "syn 1.0.109", -] - [[package]] name = "triehash" version = "0.8.4" diff --git a/crates/rpc/ipc/Cargo.toml b/crates/rpc/ipc/Cargo.toml index 417dfd823..3c6d3832b 100644 --- a/crates/rpc/ipc/Cargo.toml +++ b/crates/rpc/ipc/Cargo.toml @@ -30,5 +30,4 @@ bytes = { workspace = true } thiserror = { workspace = true } [dev-dependencies] -tracing-test = "0.2" tokio-stream = { workspace = true, features = ["sync"] } diff --git a/crates/rpc/ipc/src/server/connection.rs b/crates/rpc/ipc/src/server/connection.rs index ff0bd4c00..e502a27de 100644 --- a/crates/rpc/ipc/src/server/connection.rs +++ b/crates/rpc/ipc/src/server/connection.rs @@ -1,8 +1,10 @@ //! A IPC connection. use crate::stream_codec::StreamCodec; -use futures::{ready, Sink, Stream, StreamExt}; +use futures::{ready, stream::FuturesUnordered, Sink, Stream, StreamExt}; use std::{ + collections::VecDeque, + future::Future, io, marker::PhantomData, pin::Pin, @@ -10,6 +12,7 @@ use std::{ }; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio_util::codec::Framed; +use tower::Service; pub(crate) type JsonRpcStream = Framed; @@ -113,3 +116,80 @@ where self.project().0.poll_close(cx) } } + +/// Drives an [IpcConn] forward. +/// +/// This forwards received requests from the connection to the service and sends responses to the +/// connection. +/// +/// This future terminates when the connection is closed. +#[pin_project::pin_project] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub(crate) struct IpcConnDriver { + #[pin] + pub(crate) conn: IpcConn>, + pub(crate) service: S, + #[pin] + pub(crate) pending_calls: FuturesUnordered, + pub(crate) items: VecDeque, +} + +impl IpcConnDriver { + /// Add a new item to the send queue. + pub(crate) fn push_back(&mut self, item: String) { + self.items.push_back(item); + } +} + +impl Future for IpcConnDriver +where + S: Service> + Send + 'static, + S::Error: Into>, + S::Future: Send, + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + loop { + // process calls + if !this.pending_calls.is_empty() { + while let Poll::Ready(Some(res)) = this.pending_calls.as_mut().poll_next(cx) { + let item = match res { + Ok(Some(resp)) => resp, + Ok(None) => continue, + Err(err) => err.into().to_string(), + }; + this.items.push_back(item); + } + } + + // write to the sink + while this.conn.as_mut().poll_ready(cx).is_ready() { + if let Some(item) = this.items.pop_front() { + if let Err(err) = this.conn.as_mut().start_send(item) { + tracing::warn!("IPC response failed: {:?}", err); + return Poll::Ready(()) + } + } else { + break + } + } + + // read from the stream + match ready!(this.conn.as_mut().poll_next(cx)) { + Some(Ok(item)) => { + let call = this.service.call(item); + this.pending_calls.push(call); + } + Some(Err(err)) => { + tracing::warn!("IPC request failed: {:?}", err); + return Poll::Ready(()) + } + None => return Poll::Ready(()), + } + } + } +} diff --git a/crates/rpc/ipc/src/server/mod.rs b/crates/rpc/ipc/src/server/mod.rs index d30a5ae76..7911f037d 100644 --- a/crates/rpc/ipc/src/server/mod.rs +++ b/crates/rpc/ipc/src/server/mod.rs @@ -4,7 +4,7 @@ use crate::server::{ connection::{Incoming, IpcConn, JsonRpcStream}, future::{ConnectionGuard, FutureDriver, StopHandle}, }; -use futures::{FutureExt, SinkExt, Stream, StreamExt}; +use futures::{FutureExt, Stream, StreamExt}; use jsonrpsee::{ core::{Error, TEN_MB_SIZE_BYTES}, server::{logger::Logger, IdProvider, RandomIntegerIdProvider, ServerHandle}, @@ -25,6 +25,7 @@ use tower::{layer::util::Identity, Service}; use tracing::{debug, trace, warn}; // re-export so can be used during builder setup +use crate::server::connection::IpcConnDriver; pub use parity_tokio_ipc::Endpoint; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; @@ -285,7 +286,7 @@ impl Service for TowerService { /// Spawns the IPC connection onto a new task async fn spawn_connection( conn: IpcConn>, - mut service: S, + service: S, mut stop_handle: StopHandle, rx: mpsc::Receiver, ) where @@ -296,51 +297,29 @@ async fn spawn_connection( { let task = tokio::task::spawn(async move { let rx_item = ReceiverStream::new(rx); + let conn = IpcConnDriver { + conn, + service, + pending_calls: Default::default(), + items: Default::default(), + }; tokio::pin!(conn, rx_item); loop { - let item = tokio::select! { - res = conn.next() => { - match res { - Some(Ok(request)) => { - // handle the RPC request - match service.call(request).await { - Ok(Some(resp)) => { - resp - }, - Ok(None) => { - continue - }, - Err(err) => err.into().to_string(), - } - }, - Some(Err(e)) => { - tracing::warn!("IPC request failed: {:?}", e); - break - } - None => { - return - } - } + tokio::select! { + _ = &mut conn => { + break } item = rx_item.next() => { - match item { - Some(item) => item, - None => { - continue - } + if let Some(item) = item { + conn.push_back(item); } } _ = stop_handle.shutdown() => { + // shutdown break } }; - - // send item over ipc - if let Err(err) = conn.send(item).await { - warn!("Failed to send IPC response: {:?}", err); - break - } } }); @@ -593,7 +572,6 @@ mod tests { use parity_tokio_ipc::dummy_endpoint; use tokio::sync::broadcast; use tokio_stream::wrappers::BroadcastStream; - use tracing_test::traced_test; async fn pipe_from_stream_with_bounded_buffer( pending: PendingSubscriptionSink, @@ -641,7 +619,6 @@ mod tests { } #[tokio::test] - #[traced_test] async fn test_rpc_request() { let endpoint = dummy_endpoint(); let server = Builder::default().build(&endpoint).unwrap(); @@ -672,7 +649,6 @@ mod tests { } #[tokio::test(flavor = "multi_thread")] - #[traced_test] async fn test_rpc_subscription() { let endpoint = dummy_endpoint(); let server = Builder::default().build(&endpoint).unwrap();