mirror of
https://github.com/hl-archive-node/nanoreth.git
synced 2025-12-06 10:59:55 +00:00
feat: add --http.corsdomain (#1305)
This commit is contained in:
@ -20,6 +20,9 @@ jsonrpsee = { version = "0.16", features = ["server"] }
|
||||
|
||||
strum = { version = "0.24", features = ["derive"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
tower-http = { version = "0.3.4", features = ["full"] }
|
||||
hyper = "0.14.24"
|
||||
tower = {version = "0.4.13" , features = ["full"] }
|
||||
|
||||
[dev-dependencies]
|
||||
reth-tracing = { path = "../../tracing" }
|
||||
|
||||
@ -51,9 +51,10 @@
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
use hyper::{http::HeaderValue, Method};
|
||||
pub use jsonrpsee::server::ServerBuilder;
|
||||
use jsonrpsee::{
|
||||
core::{server::rpc_module::Methods, Error as RpcError},
|
||||
core::{server::host_filtering::AllowHosts, server::rpc_module::Methods, Error as RpcError},
|
||||
server::{Server, ServerHandle},
|
||||
RpcModule,
|
||||
};
|
||||
@ -72,6 +73,8 @@ use std::{
|
||||
str::FromStr,
|
||||
};
|
||||
use strum::{AsRefStr, EnumString, EnumVariantNames, ParseError, VariantNames};
|
||||
use tower::layer::util::{Identity, Stack};
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
|
||||
/// The default port for the http server
|
||||
pub const DEFAULT_HTTP_RPC_PORT: u16 = 8545;
|
||||
@ -500,6 +503,8 @@ where
|
||||
pub struct RpcServerConfig {
|
||||
/// Configs for JSON-RPC Http.
|
||||
http_server_config: Option<ServerBuilder>,
|
||||
/// Cors Domains
|
||||
http_cors_domains: Option<String>,
|
||||
/// Configs for WS server
|
||||
ws_server_config: Option<ServerBuilder>,
|
||||
/// Address where to bind the http server to
|
||||
@ -529,12 +534,16 @@ impl RpcServerConfig {
|
||||
pub fn ipc(config: IpcServerBuilder) -> Self {
|
||||
Self::default().with_ipc(config)
|
||||
}
|
||||
|
||||
/// Configures the http server
|
||||
pub fn with_http(mut self, config: ServerBuilder) -> Self {
|
||||
self.http_server_config = Some(config.http_only());
|
||||
self
|
||||
}
|
||||
/// Configure the corsdomains
|
||||
pub fn with_cors(mut self, cors_domain: String) -> Self {
|
||||
self.http_cors_domains = Some(cors_domain);
|
||||
self
|
||||
}
|
||||
|
||||
/// Configures the ws server
|
||||
pub fn with_ws(mut self, config: ServerBuilder) -> Self {
|
||||
@ -609,9 +618,18 @@ impl RpcServerConfig {
|
||||
)));
|
||||
|
||||
if let Some(builder) = self.http_server_config {
|
||||
let http_server = builder.build(http_socket_addr).await?;
|
||||
server.http_local_addr = http_server.local_addr().ok();
|
||||
server.http = Some(http_server);
|
||||
let cors = Self::create_cors_layer(self.http_cors_domains).unwrap();
|
||||
if let Some(cors) = cors {
|
||||
let middleware = tower::ServiceBuilder::new().layer(cors);
|
||||
let http_server =
|
||||
builder.set_middleware(middleware).build(http_socket_addr).await?;
|
||||
server.http_local_addr = http_server.local_addr().ok();
|
||||
server.http = Some(HttpServer::WithCors(http_server));
|
||||
} else {
|
||||
let http_server = builder.build(http_socket_addr).await?;
|
||||
server.http_local_addr = http_server.local_addr().ok();
|
||||
server.http = Some(HttpServer::Plain(http_server));
|
||||
}
|
||||
}
|
||||
|
||||
let ws_socket_addr = self.ws_addr.unwrap_or(SocketAddr::V4(SocketAddrV4::new(
|
||||
@ -635,6 +653,38 @@ impl RpcServerConfig {
|
||||
|
||||
Ok(server)
|
||||
}
|
||||
|
||||
fn create_cors_layer(http_cors_domains: Option<String>) -> Result<Option<CorsLayer>, RpcError> {
|
||||
let mut cors = None;
|
||||
if let Some(domains) = http_cors_domains {
|
||||
match domains.as_str() {
|
||||
"*" => {
|
||||
cors = Some(
|
||||
CorsLayer::new()
|
||||
.allow_methods([Method::GET, Method::POST])
|
||||
.allow_origin(Any)
|
||||
.allow_headers(Any),
|
||||
);
|
||||
}
|
||||
"" => {}
|
||||
_ => {
|
||||
let origins = domains
|
||||
.split(",")
|
||||
.map(|domain| domain.parse::<HeaderValue>())
|
||||
.collect::<Result<Vec<HeaderValue>, _>>();
|
||||
if let Ok(origins) = origins {
|
||||
cors = Some(
|
||||
CorsLayer::new()
|
||||
.allow_methods([Method::GET, Method::POST])
|
||||
.allow_origin(origins)
|
||||
.allow_headers(Any),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(cors)
|
||||
}
|
||||
}
|
||||
|
||||
/// Holds modules to be installed per transport type
|
||||
@ -742,12 +792,19 @@ pub struct RpcServer {
|
||||
/// The address of the ws server
|
||||
ws_local_addr: Option<SocketAddr>,
|
||||
/// http server
|
||||
http: Option<Server>,
|
||||
http: Option<HttpServer>,
|
||||
/// ws server
|
||||
ws: Option<Server>,
|
||||
/// ipc server
|
||||
ipc: Option<IpcServer>,
|
||||
}
|
||||
/// Http Servers Enum
|
||||
pub enum HttpServer {
|
||||
/// Http server
|
||||
Plain(Server),
|
||||
/// Http server with cors
|
||||
WithCors(Server<Stack<CorsLayer, Identity>>),
|
||||
}
|
||||
|
||||
// === impl RpcServer ===
|
||||
|
||||
@ -787,7 +844,14 @@ impl RpcServer {
|
||||
if let Some((server, module)) =
|
||||
self.http.and_then(|server| http.map(|module| (server, module)))
|
||||
{
|
||||
handle.http = Some(server.start(module)?);
|
||||
match server {
|
||||
HttpServer::Plain(server) => {
|
||||
handle.http = Some(server.start(module)?);
|
||||
}
|
||||
HttpServer::WithCors(server) => {
|
||||
handle.http = Some(server.start(module)?);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some((server, module)) = self.ws.and_then(|server| ws.map(|module| (server, module)))
|
||||
|
||||
Reference in New Issue
Block a user