feat: support subscriptions over IPC (#2667)

This commit is contained in:
Matthias Seitz
2023-05-15 15:42:48 +02:00
committed by GitHub
parent ae6691dd36
commit 11cd0d4753
4 changed files with 272 additions and 47 deletions

View File

@ -16,6 +16,7 @@ futures = "0.3"
parity-tokio-ipc = "0.9.0"
tokio = { version = "1", features = ["net", "time", "rt-multi-thread"] }
tokio-util = { version = "0.7", features = ["codec"] }
tokio-stream = "0.1"
async-trait = "0.1"
pin-project = "1.0"
tower = "0.4"
@ -29,6 +30,7 @@ thiserror = "1.0.37"
[dev-dependencies]
tracing-test = "0.2"
tokio-stream = { version = "0.1", features = ["sync"] }
[features]
client = ["jsonrpsee/client", "jsonrpsee/async-client"]

View File

@ -10,9 +10,14 @@ use jsonrpsee::{
server::{
logger,
logger::{Logger, TransportProtocol},
IdProvider,
},
types::{error::ErrorCode, ErrorObject, Id, InvalidRequest, Notification, Params, Request},
MethodCallback, Methods,
types::{
error::{reject_too_many_subscriptions, ErrorCode},
ErrorObject, Id, InvalidRequest, Notification, Params, Request,
},
BoundedSubscriptions, CallOrSubscription, MethodCallback, MethodSink, Methods,
SubscriptionState,
};
use std::sync::Arc;
use tokio::sync::OwnedSemaphorePermit;
@ -32,16 +37,19 @@ pub(crate) struct CallData<'a, L: Logger> {
conn_id: usize,
logger: &'a L,
methods: &'a Methods,
id_provider: &'a dyn IdProvider,
sink: &'a MethodSink,
max_response_body_size: u32,
max_log_length: u32,
request_start: L::Instant,
bounded_subscriptions: BoundedSubscriptions,
}
// Batch responses must be sent back as a single message so we read the results from each
// request in the batch and read the results off of a new channel, `rx_batch`, and then send the
// complete batch response back to the client over `tx`.
#[instrument(name = "batch", skip(b), level = "TRACE")]
pub(crate) async fn process_batch_request<L>(b: Batch<'_, L>) -> String
pub(crate) async fn process_batch_request<L>(b: Batch<'_, L>) -> Option<String>
where
L: Logger,
{
@ -56,7 +64,9 @@ where
.into_iter()
.filter_map(|v| {
if let Ok(req) = serde_json::from_str::<Request<'_>>(v.get()) {
Some(Either::Right(execute_call(req, call.clone())))
Some(Either::Right(async {
execute_call(req, call.clone()).await.into_response()
}))
} else if let Ok(_notif) = serde_json::from_str::<Notif<'_>>(v.get()) {
// notifications should not be answered.
got_notif = true;
@ -77,31 +87,31 @@ where
while let Some(response) = pending_calls.next().await {
if let Err(too_large) = batch_response.append(&response) {
return too_large
return Some(too_large)
}
}
if got_notif && batch_response.is_empty() {
String::new()
None
} else {
batch_response.finish()
Some(batch_response.finish())
}
} else {
batch_response_error(Id::Null, ErrorObject::from(ErrorCode::ParseError))
Some(batch_response_error(Id::Null, ErrorObject::from(ErrorCode::ParseError)))
}
}
pub(crate) async fn process_single_request<L: Logger>(
data: Vec<u8>,
call: CallData<'_, L>,
) -> MethodResponse {
) -> Option<CallOrSubscription> {
if let Ok(req) = serde_json::from_slice::<Request<'_>>(&data) {
execute_call_with_tracing(req, call).await
} else if let Ok(notif) = serde_json::from_slice::<Notif<'_>>(&data) {
execute_notification(notif, call.max_log_length)
Some(execute_call_with_tracing(req, call).await)
} else if serde_json::from_slice::<Notif<'_>>(&data).is_ok() {
None
} else {
let (id, code) = prepare_error(&data);
MethodResponse::error(id, ErrorObject::from(code))
Some(CallOrSubscription::Call(MethodResponse::error(id, ErrorObject::from(code))))
}
}
@ -109,21 +119,24 @@ pub(crate) async fn process_single_request<L: Logger>(
pub(crate) async fn execute_call_with_tracing<'a, L: Logger>(
req: Request<'a>,
call: CallData<'_, L>,
) -> MethodResponse {
) -> CallOrSubscription {
execute_call(req, call).await
}
pub(crate) async fn execute_call<L: Logger>(
req: Request<'_>,
call: CallData<'_, L>,
) -> MethodResponse {
) -> CallOrSubscription {
let CallData {
methods,
logger,
max_response_body_size,
max_log_length,
conn_id,
id_provider,
sink,
logger,
request_start,
bounded_subscriptions,
} = call;
rx_log_from_json(&req, call.max_log_length);
@ -140,7 +153,8 @@ pub(crate) async fn execute_call<L: Logger>(
logger::MethodKind::Unknown,
TransportProtocol::Http,
);
MethodResponse::error(id, ErrorObject::from(ErrorCode::MethodNotFound))
let response = MethodResponse::error(id, ErrorObject::from(ErrorCode::MethodNotFound));
CallOrSubscription::Call(response)
}
Some((name, method)) => match method {
MethodCallback::Sync(callback) => {
@ -150,7 +164,8 @@ pub(crate) async fn execute_call<L: Logger>(
logger::MethodKind::MethodCall,
TransportProtocol::Http,
);
(callback)(id, params, max_response_body_size as usize)
let response = (callback)(id, params, max_response_body_size as usize);
CallOrSubscription::Call(response)
}
MethodCallback::Async(callback) => {
logger.on_call(
@ -161,23 +176,50 @@ pub(crate) async fn execute_call<L: Logger>(
);
let id = id.into_owned();
let params = params.into_owned();
(callback)(id, params, conn_id, max_response_body_size as usize).await
let response =
(callback)(id, params, conn_id, max_response_body_size as usize).await;
CallOrSubscription::Call(response)
}
MethodCallback::Subscription(_) | MethodCallback::Unsubscription(_) => {
MethodCallback::Subscription(callback) => {
if let Some(p) = bounded_subscriptions.acquire() {
let conn_state =
SubscriptionState { conn_id, id_provider, subscription_permit: p };
match callback(id, params, sink.clone(), conn_state).await {
Ok(r) => CallOrSubscription::Subscription(r),
Err(id) => {
let response = MethodResponse::error(
id,
ErrorObject::from(ErrorCode::InternalError),
);
CallOrSubscription::Call(response)
}
}
} else {
let response = MethodResponse::error(
id,
reject_too_many_subscriptions(bounded_subscriptions.max()),
);
CallOrSubscription::Call(response)
}
}
MethodCallback::Unsubscription(callback) => {
logger.on_call(
name,
params.clone(),
logger::MethodKind::Unknown,
TransportProtocol::Http,
logger::MethodKind::Unsubscription,
TransportProtocol::WebSocket,
);
tracing::error!("Subscriptions not supported on HTTP");
MethodResponse::error(id, ErrorObject::from(ErrorCode::InternalError))
// Don't adhere to any resource or subscription limits; always let unsubscribing
// happen!
let result = callback(id, params, conn_id, max_response_body_size as usize);
CallOrSubscription::Call(result)
}
},
};
tx_log_from_str(&response.result, max_log_length);
logger.on_result(name, response.success, request_start, TransportProtocol::Http);
tx_log_from_str(&response.as_response().result, max_log_length);
logger.on_result(name, response.as_response().success, request_start, TransportProtocol::Http);
response
}
@ -198,10 +240,26 @@ pub(crate) struct HandleRequest<L: Logger> {
pub(crate) batch_requests_supported: bool,
pub(crate) logger: L,
pub(crate) conn: Arc<OwnedSemaphorePermit>,
pub(crate) bounded_subscriptions: BoundedSubscriptions,
pub(crate) method_sink: MethodSink,
pub(crate) id_provider: Arc<dyn IdProvider>,
}
pub(crate) async fn handle_request<L: Logger>(request: String, input: HandleRequest<L>) -> String {
let HandleRequest { methods, max_response_body_size, max_log_length, logger, conn, .. } = input;
pub(crate) async fn handle_request<L: Logger>(
request: String,
input: HandleRequest<L>,
) -> Option<String> {
let HandleRequest {
methods,
max_response_body_size,
max_log_length,
logger,
conn,
bounded_subscriptions,
method_sink,
id_provider,
..
} = input;
enum Kind {
Single,
@ -223,14 +281,25 @@ pub(crate) async fn handle_request<L: Logger>(request: String, input: HandleRequ
conn_id: 0,
logger: &logger,
methods: &methods,
id_provider: &*id_provider,
sink: &method_sink,
max_response_body_size,
max_log_length,
request_start,
bounded_subscriptions,
};
// Single request or notification
let res = if matches!(request_kind, Kind::Single) {
let response = process_single_request(request.into_bytes(), call).await;
response.result
match response {
Some(CallOrSubscription::Call(response)) => Some(response.result),
Some(CallOrSubscription::Subscription(_)) => {
// subscription responses are sent directly over the sink, return a response here
// would lead to duplicate responses for the subscription response
None
}
None => None,
}
} else {
process_batch_request(Batch { data: request.into_bytes(), call }).await
};

View File

@ -8,7 +8,7 @@ use futures::{FutureExt, SinkExt, Stream, StreamExt};
use jsonrpsee::{
core::{Error, TEN_MB_SIZE_BYTES},
server::{logger::Logger, IdProvider, RandomIntegerIdProvider, ServerHandle},
Methods,
BoundedSubscriptions, MethodSink, Methods,
};
use std::{
future::Future,
@ -26,6 +26,8 @@ use tracing::{trace, warn};
// re-export so can be used during builder setup
pub use parity_tokio_ipc::Endpoint;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
mod connection;
mod future;
@ -104,6 +106,7 @@ impl IpcServer {
}
}
let message_buffer_capacity = self.cfg.message_buffer_capacity;
let max_request_body_size = self.cfg.max_request_body_size;
let max_response_body_size = self.cfg.max_response_body_size;
let max_log_length = self.cfg.max_log_length;
@ -141,6 +144,9 @@ impl IpcServer {
}
};
let (tx, rx) = mpsc::channel::<String>(message_buffer_capacity as usize);
let method_sink =
MethodSink::new_with_limit(tx, max_response_body_size, max_log_length);
let tower_service = TowerService {
inner: ServiceData {
methods: methods.clone(),
@ -153,11 +159,20 @@ impl IpcServer {
conn_id: id,
logger,
conn: Arc::new(conn),
bounded_subscriptions: BoundedSubscriptions::new(
max_subscriptions_per_connection,
),
method_sink,
},
};
let service = self.service_builder.service(tower_service);
connections.add(Box::pin(spawn_connection(ipc, service, stop_handle.clone())));
connections.add(Box::pin(spawn_connection(
ipc,
service,
stop_handle.clone(),
rx,
)));
id = id.wrapping_add(1);
}
@ -183,7 +198,7 @@ impl std::fmt::Debug for IpcServer {
}
}
/// Data required by the server to handle requests.
/// Data required by the server to handle requests received via an IPC connection
#[derive(Debug, Clone)]
#[allow(unused)]
pub(crate) struct ServiceData<L: Logger> {
@ -209,6 +224,12 @@ pub(crate) struct ServiceData<L: Logger> {
pub(crate) logger: L,
/// Handle to hold a `connection permit`.
pub(crate) conn: Arc<OwnedSemaphorePermit>,
/// Limits the number of subscriptions for this connection
pub(crate) bounded_subscriptions: BoundedSubscriptions,
/// Sink that is used to send back responses to the connection.
///
/// This is used for subscriptions.
pub(crate) method_sink: MethodSink,
}
/// JsonRPSee service compatible with `tower`.
@ -221,7 +242,12 @@ pub struct TowerService<L: Logger> {
}
impl<L: Logger> Service<String> for TowerService<L> {
type Response = String;
/// The response of a handled RPC call
///
/// This is an `Option` because subscriptions and call responses are handled differently.
/// This will be `Some` for calls, and `None` for subscriptions, because the subscription
/// response will be emitted via the `method_sink`.
type Response = Option<String>;
type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
@ -244,31 +270,45 @@ impl<L: Logger> Service<String> for TowerService<L> {
batch_requests_supported: true,
logger: self.inner.logger.clone(),
conn: self.inner.conn.clone(),
bounded_subscriptions: self.inner.bounded_subscriptions.clone(),
method_sink: self.inner.method_sink.clone(),
id_provider: self.inner.id_provider.clone(),
};
Box::pin(ipc::handle_request(request, data).map(Ok))
}
}
/// Spawns the connection in a new task
/// Spawns the IPC connection onto a new task
async fn spawn_connection<S, T>(
conn: IpcConn<JsonRpcStream<T>>,
mut service: S,
mut stop_handle: StopHandle,
rx: mpsc::Receiver<String>,
) where
S: Service<String, Response = String> + Send + 'static,
S: Service<String, Response = Option<String>> + Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
S::Future: Send,
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let task = tokio::task::spawn(async move {
tokio::pin!(conn);
let rx_item = ReceiverStream::new(rx);
tokio::pin!(conn, rx_item);
loop {
let request = tokio::select! {
let item = tokio::select! {
res = conn.next() => {
match res {
Some(Ok(request)) => {
request
// handle the RPC request
match service.call(request).await {
Ok(Some(resp)) => {
resp
},
Ok(None) => {
continue
},
Err(err) => err.into().to_string(),
}
},
Some(Err(e)) => {
tracing::warn!("Request failed: {:?}", e);
@ -279,19 +319,21 @@ async fn spawn_connection<S, T>(
}
}
}
item = rx_item.next() => {
match item {
Some(item) => item,
None => {
continue
}
}
}
_ = stop_handle.shutdown() => {
break
}
};
// handle the RPC request
let resp = match service.call(request).await {
Ok(resp) => resp,
Err(err) => err.into().to_string(),
};
// send back
if let Err(err) = conn.send(resp).await {
// send item over ipc
if let Err(err) = conn.send(item).await {
warn!("Failed to send response: {:?}", err);
break
}
@ -352,6 +394,8 @@ pub struct Settings {
max_connections: u32,
/// Maximum number of subscriptions per connection.
max_subscriptions_per_connection: u32,
/// Number of messages that server is allowed `buffer` until backpressure kicks in.
message_buffer_capacity: u32,
/// Custom tokio runtime to run the server on.
tokio_runtime: Option<tokio::runtime::Handle>,
}
@ -364,6 +408,7 @@ impl Default for Settings {
max_log_length: 4096,
max_connections: 100,
max_subscriptions_per_connection: 1024,
message_buffer_capacity: 1024,
tokio_runtime: None,
}
}
@ -421,6 +466,28 @@ impl<B, L> Builder<B, L> {
self
}
/// The server enforces backpressure which means that
/// `n` messages can be buffered and if the client
/// can't keep with up the server.
///
/// This `capacity` is applied per connection and
/// applies globally on the connection which implies
/// all JSON-RPC messages.
///
/// For example if a subscription produces plenty of new items
/// and the client can't keep up then no new messages are handled.
///
/// If this limit is exceeded then the server will "back-off"
/// and only accept new messages once the client reads pending messages.
///
/// # Panics
///
/// Panics if the buffer capacity is 0.
pub fn set_message_buffer_capacity(mut self, c: u32) -> Self {
self.settings.message_buffer_capacity = c;
self
}
/// Add a logger to the builder [`Logger`].
pub fn set_logger<T: Logger>(self, logger: T) -> Builder<B, T> {
Builder {
@ -514,10 +581,61 @@ impl<B, L> Builder<B, L> {
mod tests {
use super::*;
use crate::client::IpcClientBuilder;
use jsonrpsee::{core::client::ClientT, rpc_params, RpcModule};
use futures::future::{select, Either};
use jsonrpsee::{
core::client::{ClientT, Subscription, SubscriptionClientT},
rpc_params, PendingSubscriptionSink, RpcModule, SubscriptionMessage,
};
use parity_tokio_ipc::dummy_endpoint;
use tokio::sync::broadcast;
use tokio_stream::wrappers::BroadcastStream;
use tracing_test::traced_test;
async fn pipe_from_stream_with_bounded_buffer(
pending: PendingSubscriptionSink,
stream: BroadcastStream<usize>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let sink = pending.accept().await.unwrap();
let closed = sink.closed();
futures::pin_mut!(closed, stream);
loop {
match select(closed, stream.next()).await {
// subscription closed.
Either::Left((_, _)) => break Ok(()),
// received new item from the stream.
Either::Right((Some(Ok(item)), c)) => {
let notif = SubscriptionMessage::from_json(&item)?;
// NOTE: this will block until there a spot in the queue
// and you might want to do something smarter if it's
// critical that "the most recent item" must be sent when it is produced.
if sink.send(notif).await.is_err() {
break Ok(())
}
closed = c;
}
// Send back back the error.
Either::Right((Some(Err(e)), _)) => break Err(e.into()),
// Stream is closed.
Either::Right((None, _)) => break Ok(()),
}
}
}
// Naive example that broadcasts the produced values to all active subscribers.
fn produce_items(tx: broadcast::Sender<usize>) {
for c in 1..=100 {
std::thread::sleep(std::time::Duration::from_millis(1));
let _ = tx.send(c);
}
}
#[tokio::test]
#[traced_test]
async fn test_rpc_request() {
@ -533,4 +651,39 @@ mod tests {
let response: String = client.request("eth_chainId", rpc_params![]).await.unwrap();
assert_eq!(response, msg);
}
#[tokio::test(flavor = "multi_thread")]
#[traced_test]
async fn test_rpc_subscription() {
let endpoint = dummy_endpoint();
let server = Builder::default().build(&endpoint).unwrap();
let (tx, _rx) = broadcast::channel::<usize>(16);
let mut module = RpcModule::new(tx.clone());
std::thread::spawn(move || produce_items(tx));
module
.register_subscription(
"subscribe_hello",
"s_hello",
"unsubscribe_hello",
|_, pending, tx| async move {
let rx = tx.subscribe();
let stream = BroadcastStream::new(rx);
pipe_from_stream_with_bounded_buffer(pending, stream).await?;
Ok(())
},
)
.unwrap();
let handle = server.start(module).await.unwrap();
tokio::spawn(handle.stopped());
let client = IpcClientBuilder::default().build(endpoint).await.unwrap();
let sub: Subscription<usize> =
client.subscribe("subscribe_hello", rpc_params![], "unsubscribe_hello").await.unwrap();
let items = sub.take(16).collect::<Vec<_>>().await;
assert_eq!(items.len(), 16);
}
}