feat: add --http.corsdomain (#1305)

This commit is contained in:
Tirth Patel
2023-02-21 08:50:23 -03:30
committed by GitHub
parent 20aceb750c
commit 42b1fc1f5b
3 changed files with 82 additions and 8 deletions

View File

@ -20,6 +20,9 @@ jsonrpsee = { version = "0.16", features = ["server"] }
strum = { version = "0.24", features = ["derive"] }
serde = { version = "1.0", features = ["derive"] }
tower-http = { version = "0.3.4", features = ["full"] }
hyper = "0.14.24"
tower = {version = "0.4.13" , features = ["full"] }
[dev-dependencies]
reth-tracing = { path = "../../tracing" }

View File

@ -51,9 +51,10 @@
//! }
//! ```
use hyper::{http::HeaderValue, Method};
pub use jsonrpsee::server::ServerBuilder;
use jsonrpsee::{
core::{server::rpc_module::Methods, Error as RpcError},
core::{server::host_filtering::AllowHosts, server::rpc_module::Methods, Error as RpcError},
server::{Server, ServerHandle},
RpcModule,
};
@ -72,6 +73,8 @@ use std::{
str::FromStr,
};
use strum::{AsRefStr, EnumString, EnumVariantNames, ParseError, VariantNames};
use tower::layer::util::{Identity, Stack};
use tower_http::cors::{Any, CorsLayer};
/// The default port for the http server
pub const DEFAULT_HTTP_RPC_PORT: u16 = 8545;
@ -500,6 +503,8 @@ where
pub struct RpcServerConfig {
/// Configs for JSON-RPC Http.
http_server_config: Option<ServerBuilder>,
/// Cors Domains
http_cors_domains: Option<String>,
/// Configs for WS server
ws_server_config: Option<ServerBuilder>,
/// Address where to bind the http server to
@ -529,12 +534,16 @@ impl RpcServerConfig {
pub fn ipc(config: IpcServerBuilder) -> Self {
Self::default().with_ipc(config)
}
/// Configures the http server
pub fn with_http(mut self, config: ServerBuilder) -> Self {
self.http_server_config = Some(config.http_only());
self
}
/// Configure the corsdomains
pub fn with_cors(mut self, cors_domain: String) -> Self {
self.http_cors_domains = Some(cors_domain);
self
}
/// Configures the ws server
pub fn with_ws(mut self, config: ServerBuilder) -> Self {
@ -609,9 +618,18 @@ impl RpcServerConfig {
)));
if let Some(builder) = self.http_server_config {
let http_server = builder.build(http_socket_addr).await?;
server.http_local_addr = http_server.local_addr().ok();
server.http = Some(http_server);
let cors = Self::create_cors_layer(self.http_cors_domains).unwrap();
if let Some(cors) = cors {
let middleware = tower::ServiceBuilder::new().layer(cors);
let http_server =
builder.set_middleware(middleware).build(http_socket_addr).await?;
server.http_local_addr = http_server.local_addr().ok();
server.http = Some(HttpServer::WithCors(http_server));
} else {
let http_server = builder.build(http_socket_addr).await?;
server.http_local_addr = http_server.local_addr().ok();
server.http = Some(HttpServer::Plain(http_server));
}
}
let ws_socket_addr = self.ws_addr.unwrap_or(SocketAddr::V4(SocketAddrV4::new(
@ -635,6 +653,38 @@ impl RpcServerConfig {
Ok(server)
}
fn create_cors_layer(http_cors_domains: Option<String>) -> Result<Option<CorsLayer>, RpcError> {
let mut cors = None;
if let Some(domains) = http_cors_domains {
match domains.as_str() {
"*" => {
cors = Some(
CorsLayer::new()
.allow_methods([Method::GET, Method::POST])
.allow_origin(Any)
.allow_headers(Any),
);
}
"" => {}
_ => {
let origins = domains
.split(",")
.map(|domain| domain.parse::<HeaderValue>())
.collect::<Result<Vec<HeaderValue>, _>>();
if let Ok(origins) = origins {
cors = Some(
CorsLayer::new()
.allow_methods([Method::GET, Method::POST])
.allow_origin(origins)
.allow_headers(Any),
);
}
}
}
}
Ok(cors)
}
}
/// Holds modules to be installed per transport type
@ -742,12 +792,19 @@ pub struct RpcServer {
/// The address of the ws server
ws_local_addr: Option<SocketAddr>,
/// http server
http: Option<Server>,
http: Option<HttpServer>,
/// ws server
ws: Option<Server>,
/// ipc server
ipc: Option<IpcServer>,
}
/// Http Servers Enum
pub enum HttpServer {
/// Http server
Plain(Server),
/// Http server with cors
WithCors(Server<Stack<CorsLayer, Identity>>),
}
// === impl RpcServer ===
@ -787,7 +844,14 @@ impl RpcServer {
if let Some((server, module)) =
self.http.and_then(|server| http.map(|module| (server, module)))
{
handle.http = Some(server.start(module)?);
match server {
HttpServer::Plain(server) => {
handle.http = Some(server.start(module)?);
}
HttpServer::WithCors(server) => {
handle.http = Some(server.start(module)?);
}
}
}
if let Some((server, module)) = self.ws.and_then(|server| ws.map(|module| (server, module)))