refactor: remove WsHttpServerKind enum and simplify server launch (#7531)

Co-authored-by: Matthias Seitz <matthias.seitz@outlook.de>
This commit is contained in:
Sean Matt
2024-04-26 09:47:12 -04:00
committed by GitHub
parent b6b2cf816e
commit 7f0e81e476
3 changed files with 91 additions and 153 deletions

View File

@ -3,7 +3,7 @@ use tower_http::cors::{AllowOrigin, Any, CorsLayer};
/// Error thrown when parsing cors domains went wrong
#[derive(Debug, thiserror::Error)]
pub(crate) enum CorsDomainError {
pub enum CorsDomainError {
#[error("{domain} is an invalid header value")]
InvalidHeader { domain: String },
#[error("wildcard origin (`*`) cannot be passed as part of a list: {input}")]

View File

@ -1,4 +1,4 @@
use crate::RethRpcModule;
use crate::{cors::CorsDomainError, RethRpcModule};
use reth_ipc::server::IpcServerStartError;
use std::{io, io::ErrorKind, net::SocketAddr};
@ -57,6 +57,9 @@ pub enum RpcError {
/// IO error.
error: io::Error,
},
/// Cors parsing error.
#[error(transparent)]
Cors(#[from] CorsDomainError),
/// Http and WS server configured on the same port but with conflicting settings.
#[error(transparent)]
WsHttpSamePortError(#[from] WsHttpSamePortError),

View File

@ -156,8 +156,8 @@
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
use crate::{
auth::AuthRpcModule, error::WsHttpSamePortError, metrics::RpcRequestMetrics,
RpcModuleSelection::Selection,
auth::AuthRpcModule, cors::CorsDomainError, error::WsHttpSamePortError,
metrics::RpcRequestMetrics, RpcModuleSelection::Selection,
};
use constants::*;
use error::{RpcError, ServerKind};
@ -1623,6 +1623,16 @@ impl RpcServerConfig {
self.build(&modules).await?.start(modules).await
}
/// Creates the [CorsLayer] if any
fn maybe_cors_layer(cors: Option<String>) -> Result<Option<CorsLayer>, CorsDomainError> {
cors.as_deref().map(cors::create_cors_layer).transpose()
}
/// Creates the [AuthLayer] if any
fn maybe_jwt_layer(&self) -> Option<AuthLayer<JwtAuthValidator>> {
self.jwt_secret.clone().map(|secret| AuthLayer::new(JwtAuthValidator::new(secret)))
}
/// Builds the ws and http server(s).
///
/// If both are on the same port, they are combined into one server.
@ -1634,7 +1644,6 @@ impl RpcServerConfig {
Ipv4Addr::LOCALHOST,
DEFAULT_HTTP_RPC_PORT,
)));
let jwt_secret = self.jwt_secret.clone();
let ws_socket_addr = self
.ws_addr
@ -1660,33 +1669,39 @@ impl RpcServerConfig {
}
.cloned();
let secret = self.jwt_secret.clone();
// we merge this into one server using the http setup
self.ws_server_config.take();
modules.config.ensure_ws_http_identical()?;
let builder = self.http_server_config.take().expect("http_server_config is Some");
let (server, addr) = WsHttpServerKind::build(
builder,
http_socket_addr,
cors,
secret,
ServerKind::WsHttp(http_socket_addr),
modules
.http
.as_ref()
.or(modules.ws.as_ref())
.map(RpcRequestMetrics::same_port)
.unwrap_or_default(),
)
.await?;
let server = builder
.set_http_middleware(
tower::ServiceBuilder::new()
.option_layer(Self::maybe_cors_layer(cors)?)
.option_layer(self.maybe_jwt_layer()),
)
.set_rpc_middleware(
RpcServiceBuilder::new().layer(
modules
.http
.as_ref()
.or(modules.ws.as_ref())
.map(RpcRequestMetrics::same_port)
.unwrap_or_default(),
),
)
.build(http_socket_addr)
.await
.map_err(|err| RpcError::server_error(err, ServerKind::WsHttp(http_socket_addr)))?;
let addr = server
.local_addr()
.map_err(|err| RpcError::server_error(err, ServerKind::WsHttp(http_socket_addr)))?;
return Ok(WsHttpServer {
http_local_addr: Some(addr),
ws_local_addr: Some(addr),
server: WsHttpServers::SamePort(server),
jwt_secret,
jwt_secret: self.jwt_secret.clone(),
})
}
@ -1696,32 +1711,48 @@ impl RpcServerConfig {
let mut ws_local_addr = None;
let mut ws_server = None;
if let Some(builder) = self.ws_server_config.take() {
let builder = builder.ws_only();
let (server, addr) = WsHttpServerKind::build(
builder,
ws_socket_addr,
self.ws_cors_domains.take(),
self.jwt_secret.clone(),
ServerKind::WS(ws_socket_addr),
modules.ws.as_ref().map(RpcRequestMetrics::ws).unwrap_or_default(),
)
.await?;
let server = builder
.ws_only()
.set_http_middleware(
tower::ServiceBuilder::new()
.option_layer(Self::maybe_cors_layer(self.ws_cors_domains.clone())?)
.option_layer(self.maybe_jwt_layer()),
)
.set_rpc_middleware(
RpcServiceBuilder::new()
.layer(modules.ws.as_ref().map(RpcRequestMetrics::ws).unwrap_or_default()),
)
.build(ws_socket_addr)
.await
.map_err(|err| RpcError::server_error(err, ServerKind::WS(ws_socket_addr)))?;
let addr = server
.local_addr()
.map_err(|err| RpcError::server_error(err, ServerKind::WS(ws_socket_addr)))?;
ws_local_addr = Some(addr);
ws_server = Some(server);
}
if let Some(builder) = self.http_server_config.take() {
let builder = builder.http_only();
let (server, addr) = WsHttpServerKind::build(
builder,
http_socket_addr,
self.http_cors_domains.take(),
self.jwt_secret.clone(),
ServerKind::Http(http_socket_addr),
modules.http.as_ref().map(RpcRequestMetrics::http).unwrap_or_default(),
)
.await?;
http_local_addr = Some(addr);
let server = builder
.http_only()
.set_http_middleware(
tower::ServiceBuilder::new()
.option_layer(Self::maybe_cors_layer(self.http_cors_domains.clone())?)
.option_layer(self.maybe_jwt_layer()),
)
.set_rpc_middleware(
RpcServiceBuilder::new().layer(
modules.http.as_ref().map(RpcRequestMetrics::http).unwrap_or_default(),
),
)
.build(http_socket_addr)
.await
.map_err(|err| RpcError::server_error(err, ServerKind::Http(http_socket_addr)))?;
let local_addr = server
.local_addr()
.map_err(|err| RpcError::server_error(err, ServerKind::Http(http_socket_addr)))?;
http_local_addr = Some(local_addr);
http_server = Some(server);
}
@ -1729,7 +1760,7 @@ impl RpcServerConfig {
http_local_addr,
ws_local_addr,
server: WsHttpServers::DifferentPort { http: http_server, ws: ws_server },
jwt_secret,
jwt_secret: self.jwt_secret.clone(),
})
}
@ -1945,6 +1976,15 @@ struct WsHttpServer {
jwt_secret: Option<JwtSecret>,
}
// Define the type alias with detailed type complexity
type WsHttpServerKind = Server<
Stack<
tower::util::Either<AuthLayer<JwtAuthValidator>, Identity>,
Stack<tower::util::Either<CorsLayer, Identity>, Identity>,
>,
Stack<RpcRequestMetrics, Identity>,
>;
/// Enum for holding the http and ws servers in all possible combinations.
enum WsHttpServers {
/// Both servers are on the same port
@ -1966,13 +2006,13 @@ impl WsHttpServers {
let mut http_handle = None;
let mut ws_handle = None;
match self {
WsHttpServers::SamePort(both) => {
WsHttpServers::SamePort(server) => {
// Make sure http and ws modules are identical, since we currently can't run
// different modules on same server
config.ensure_ws_http_identical()?;
if let Some(module) = http_module.or(ws_module) {
let handle = both.start(module).await;
let handle = server.start(module);
http_handle = Some(handle.clone());
ws_handle = Some(handle);
}
@ -1981,12 +2021,12 @@ impl WsHttpServers {
if let Some((server, module)) =
http.and_then(|server| http_module.map(|module| (server, module)))
{
http_handle = Some(server.start(module).await);
http_handle = Some(server.start(module));
}
if let Some((server, module)) =
ws.and_then(|server| ws_module.map(|module| (server, module)))
{
ws_handle = Some(server.start(module).await);
ws_handle = Some(server.start(module));
}
}
}
@ -2001,111 +2041,6 @@ impl Default for WsHttpServers {
}
}
/// Http Servers Enum
#[allow(clippy::type_complexity)]
enum WsHttpServerKind {
/// Http server
Plain(Server<Identity, Stack<RpcRequestMetrics, Identity>>),
/// Http server with cors
WithCors(Server<Stack<CorsLayer, Identity>, Stack<RpcRequestMetrics, Identity>>),
/// Http server with auth
WithAuth(
Server<Stack<AuthLayer<JwtAuthValidator>, Identity>, Stack<RpcRequestMetrics, Identity>>,
),
/// Http server with cors and auth
WithCorsAuth(
Server<
Stack<AuthLayer<JwtAuthValidator>, Stack<CorsLayer, Identity>>,
Stack<RpcRequestMetrics, Identity>,
>,
),
}
// === impl WsHttpServerKind ===
impl WsHttpServerKind {
/// Starts the server and returns the handle
async fn start(self, module: RpcModule<()>) -> ServerHandle {
match self {
WsHttpServerKind::Plain(server) => server.start(module),
WsHttpServerKind::WithCors(server) => server.start(module),
WsHttpServerKind::WithAuth(server) => server.start(module),
WsHttpServerKind::WithCorsAuth(server) => server.start(module),
}
}
/// Builds the server according to the given config parameters.
///
/// Returns the address of the started server.
async fn build(
builder: ServerBuilder<Identity, Identity>,
socket_addr: SocketAddr,
cors_domains: Option<String>,
jwt_secret: Option<JwtSecret>,
server_kind: ServerKind,
metrics: RpcRequestMetrics,
) -> Result<(Self, SocketAddr), RpcError> {
if let Some(cors) = cors_domains.as_deref().map(cors::create_cors_layer) {
let cors = cors.map_err(|err| RpcError::Custom(err.to_string()))?;
if let Some(secret) = jwt_secret {
// stack cors and auth layers
let middleware = tower::ServiceBuilder::new()
.layer(cors)
.layer(AuthLayer::new(JwtAuthValidator::new(secret.clone())));
let server = builder
.set_http_middleware(middleware)
.set_rpc_middleware(RpcServiceBuilder::new().layer(metrics))
.build(socket_addr)
.await
.map_err(|err| RpcError::server_error(err, server_kind))?;
let local_addr =
server.local_addr().map_err(|err| RpcError::server_error(err, server_kind))?;
let server = WsHttpServerKind::WithCorsAuth(server);
Ok((server, local_addr))
} else {
let middleware = tower::ServiceBuilder::new().layer(cors);
let server = builder
.set_http_middleware(middleware)
.set_rpc_middleware(RpcServiceBuilder::new().layer(metrics))
.build(socket_addr)
.await
.map_err(|err| RpcError::server_error(err, server_kind))?;
let local_addr =
server.local_addr().map_err(|err| RpcError::server_error(err, server_kind))?;
let server = WsHttpServerKind::WithCors(server);
Ok((server, local_addr))
}
} else if let Some(secret) = jwt_secret {
// jwt auth layered service
let middleware = tower::ServiceBuilder::new()
.layer(AuthLayer::new(JwtAuthValidator::new(secret.clone())));
let server = builder
.set_http_middleware(middleware)
.set_rpc_middleware(RpcServiceBuilder::new().layer(metrics))
.build(socket_addr)
.await
.map_err(|err| RpcError::server_error(err, ServerKind::Auth(socket_addr)))?;
let local_addr =
server.local_addr().map_err(|err| RpcError::server_error(err, server_kind))?;
let server = WsHttpServerKind::WithAuth(server);
Ok((server, local_addr))
} else {
// plain server without any middleware
let server = builder
.set_rpc_middleware(RpcServiceBuilder::new().layer(metrics))
.build(socket_addr)
.await
.map_err(|err| RpcError::server_error(err, server_kind))?;
let local_addr =
server.local_addr().map_err(|err| RpcError::server_error(err, server_kind))?;
let server = WsHttpServerKind::Plain(server);
Ok((server, local_addr))
}
}
}
/// Container type for each transport ie. http, ws, and ipc server
pub struct RpcServer {
/// Configured ws,http servers