diff --git a/bin/reth/src/args/rpc_server_args.rs b/bin/reth/src/args/rpc_server_args.rs index 5dc7e3697..acdfd0cad 100644 --- a/bin/reth/src/args/rpc_server_args.rs +++ b/bin/reth/src/args/rpc_server_args.rs @@ -36,6 +36,10 @@ pub struct RpcServerArgs { #[arg(long = "http.api")] pub http_api: Option, + /// Http Corsdomain to allow request from + #[arg(long = "http.corsdomain")] + pub http_corsdomain: Option, + /// Enable the WS-RPC server #[arg(long)] pub ws: bool, @@ -136,7 +140,10 @@ impl RpcServerArgs { self.http_addr.unwrap_or(IpAddr::V4(Ipv4Addr::UNSPECIFIED)), self.http_port.unwrap_or(DEFAULT_HTTP_RPC_PORT), ); - config = config.with_http_address(socket_address).with_http(ServerBuilder::new()); + config = config + .with_http_address(socket_address) + .with_http(ServerBuilder::new()) + .with_cors(self.http_corsdomain.clone().unwrap_or("".to_string())); } if self.ws { diff --git a/crates/rpc/rpc-builder/Cargo.toml b/crates/rpc/rpc-builder/Cargo.toml index e46bac036..3107e2699 100644 --- a/crates/rpc/rpc-builder/Cargo.toml +++ b/crates/rpc/rpc-builder/Cargo.toml @@ -20,6 +20,9 @@ jsonrpsee = { version = "0.16", features = ["server"] } strum = { version = "0.24", features = ["derive"] } serde = { version = "1.0", features = ["derive"] } +tower-http = { version = "0.3.4", features = ["full"] } +hyper = "0.14.24" +tower = {version = "0.4.13" , features = ["full"] } [dev-dependencies] reth-tracing = { path = "../../tracing" } diff --git a/crates/rpc/rpc-builder/src/lib.rs b/crates/rpc/rpc-builder/src/lib.rs index 3c88d94a8..1ffa73779 100644 --- a/crates/rpc/rpc-builder/src/lib.rs +++ b/crates/rpc/rpc-builder/src/lib.rs @@ -51,9 +51,10 @@ //! } //! ``` +use hyper::{http::HeaderValue, Method}; pub use jsonrpsee::server::ServerBuilder; use jsonrpsee::{ - core::{server::rpc_module::Methods, Error as RpcError}, + core::{server::host_filtering::AllowHosts, server::rpc_module::Methods, Error as RpcError}, server::{Server, ServerHandle}, RpcModule, }; @@ -72,6 +73,8 @@ use std::{ str::FromStr, }; use strum::{AsRefStr, EnumString, EnumVariantNames, ParseError, VariantNames}; +use tower::layer::util::{Identity, Stack}; +use tower_http::cors::{Any, CorsLayer}; /// The default port for the http server pub const DEFAULT_HTTP_RPC_PORT: u16 = 8545; @@ -500,6 +503,8 @@ where pub struct RpcServerConfig { /// Configs for JSON-RPC Http. http_server_config: Option, + /// Cors Domains + http_cors_domains: Option, /// Configs for WS server ws_server_config: Option, /// Address where to bind the http server to @@ -529,12 +534,16 @@ impl RpcServerConfig { pub fn ipc(config: IpcServerBuilder) -> Self { Self::default().with_ipc(config) } - /// Configures the http server pub fn with_http(mut self, config: ServerBuilder) -> Self { self.http_server_config = Some(config.http_only()); self } + /// Configure the corsdomains + pub fn with_cors(mut self, cors_domain: String) -> Self { + self.http_cors_domains = Some(cors_domain); + self + } /// Configures the ws server pub fn with_ws(mut self, config: ServerBuilder) -> Self { @@ -609,9 +618,18 @@ impl RpcServerConfig { ))); if let Some(builder) = self.http_server_config { - let http_server = builder.build(http_socket_addr).await?; - server.http_local_addr = http_server.local_addr().ok(); - server.http = Some(http_server); + let cors = Self::create_cors_layer(self.http_cors_domains).unwrap(); + if let Some(cors) = cors { + let middleware = tower::ServiceBuilder::new().layer(cors); + let http_server = + builder.set_middleware(middleware).build(http_socket_addr).await?; + server.http_local_addr = http_server.local_addr().ok(); + server.http = Some(HttpServer::WithCors(http_server)); + } else { + let http_server = builder.build(http_socket_addr).await?; + server.http_local_addr = http_server.local_addr().ok(); + server.http = Some(HttpServer::Plain(http_server)); + } } let ws_socket_addr = self.ws_addr.unwrap_or(SocketAddr::V4(SocketAddrV4::new( @@ -635,6 +653,38 @@ impl RpcServerConfig { Ok(server) } + + fn create_cors_layer(http_cors_domains: Option) -> Result, RpcError> { + let mut cors = None; + if let Some(domains) = http_cors_domains { + match domains.as_str() { + "*" => { + cors = Some( + CorsLayer::new() + .allow_methods([Method::GET, Method::POST]) + .allow_origin(Any) + .allow_headers(Any), + ); + } + "" => {} + _ => { + let origins = domains + .split(",") + .map(|domain| domain.parse::()) + .collect::, _>>(); + if let Ok(origins) = origins { + cors = Some( + CorsLayer::new() + .allow_methods([Method::GET, Method::POST]) + .allow_origin(origins) + .allow_headers(Any), + ); + } + } + } + } + Ok(cors) + } } /// Holds modules to be installed per transport type @@ -742,12 +792,19 @@ pub struct RpcServer { /// The address of the ws server ws_local_addr: Option, /// http server - http: Option, + http: Option, /// ws server ws: Option, /// ipc server ipc: Option, } +/// Http Servers Enum +pub enum HttpServer { + /// Http server + Plain(Server), + /// Http server with cors + WithCors(Server>), +} // === impl RpcServer === @@ -787,7 +844,14 @@ impl RpcServer { if let Some((server, module)) = self.http.and_then(|server| http.map(|module| (server, module))) { - handle.http = Some(server.start(module)?); + match server { + HttpServer::Plain(server) => { + handle.http = Some(server.start(module)?); + } + HttpServer::WithCors(server) => { + handle.http = Some(server.start(module)?); + } + } } if let Some((server, module)) = self.ws.and_then(|server| ws.map(|module| (server, module)))