diff --git a/Cargo.lock b/Cargo.lock index 3741e4d35..b02e7eccb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2376,9 +2376,9 @@ dependencies = [ [[package]] name = "hyper" -version = "0.14.23" +version = "0.14.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "034711faac9d2166cb1baf1a2fb0b60b1f277f8492fd72176c17f3515e1abd3c" +checksum = "5e011372fa0b68db8350aa7a248930ecc7839bf46d8485577d69f117a75f164c" dependencies = [ "bytes", "futures-channel", @@ -2820,6 +2820,20 @@ dependencies = [ "jsonrpsee-types", ] +[[package]] +name = "jsonwebtoken" +version = "8.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09f4f04699947111ec1733e71778d763555737579e44b85844cae8e1940a1828" +dependencies = [ + "base64 0.13.1", + "pem", + "ring", + "serde", + "serde_json", + "simple_asn1", +] + [[package]] name = "k256" version = "0.11.6" @@ -3477,6 +3491,15 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" +[[package]] +name = "pem" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8835c273a76a90455d7344889b0964598e3316e2a79ede8e36f16bdcf2228b8" +dependencies = [ + "base64 0.13.1", +] + [[package]] name = "percent-encoding" version = "2.2.0" @@ -4573,7 +4596,13 @@ version = "0.1.0" dependencies = [ "async-trait", "hex", + "http", + "http-body", + "hyper", "jsonrpsee", + "jsonwebtoken", + "pin-project", + "rand 0.8.5", "reth-interfaces", "reth-network-api", "reth-primitives", @@ -4583,11 +4612,14 @@ dependencies = [ "reth-rpc-engine-api", "reth-rpc-types", "reth-transaction-pool", - "secp256k1 0.24.3", + "secp256k1 0.26.0", "serde", "serde_json", "thiserror", "tokio", + "tokio-stream", + "tower", + "tracing", ] [[package]] @@ -5105,6 +5137,16 @@ dependencies = [ "secp256k1-sys 0.7.0", ] +[[package]] +name = "secp256k1" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4124a35fe33ae14259c490fd70fa199a32b9ce9502f2ee6bc4f81ec06fa65894" +dependencies = [ + "rand 0.8.5", + "secp256k1-sys 0.8.0", +] + [[package]] name = "secp256k1-sys" version = "0.6.1" @@ -5123,6 +5165,15 @@ dependencies = [ "cc", ] +[[package]] +name = "secp256k1-sys" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "642a62736682fdd8c71da0eb273e453c8ac74e33b9fb310e22ba5b03ec7651ff" +dependencies = [ + "cc", +] + [[package]] name = "security-framework" version = "2.8.2" @@ -5436,6 +5487,18 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "simple_asn1" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adc4e5204eb1910f40f9cfa375f6f05b68c3abac4b6fd879c8ff5e7ae8a0a085" +dependencies = [ + "num-bigint", + "num-traits", + "thiserror", + "time", +] + [[package]] name = "sketches-ddsketch" version = "0.2.0" diff --git a/crates/rpc/rpc/Cargo.toml b/crates/rpc/rpc/Cargo.toml index 01c979793..ab73d6034 100644 --- a/crates/rpc/rpc/Cargo.toml +++ b/crates/rpc/rpc/Cargo.toml @@ -15,20 +15,27 @@ reth-primitives = { path = "../../primitives" } reth-rpc-api = { path = "../rpc-api" } reth-rlp = { path = "../../rlp" } reth-rpc-types = { path = "../rpc-types" } -reth-provider = { path = "../../storage/provider" } +reth-provider = { path = "../../storage/provider", features = ["test-utils"] } reth-transaction-pool = { path = "../../transaction-pool" } reth-network-api = { path = "../../net/network-api" } reth-rpc-engine-api = { path = "../rpc-engine-api" } # rpc jsonrpsee = { version = "0.16" } +http = "0.2.8" +http-body = "0.4.5" +hyper = "0.14.24" +jsonwebtoken = "8" # async async-trait = "0.1" tokio = { version = "1", features = ["sync"] } +tower = "0.4" +tokio-stream = "0.1" +pin-project = "1.0" # misc -secp256k1 = { version = "0.24", features = [ +secp256k1 = { version = "0.26.0", features = [ "global-context", "rand-std", "recovery", @@ -37,3 +44,8 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" thiserror = "1.0" hex = "0.4" +rand = "0.8.5" +tracing = "0.1" + +[dev-dependencies] +jsonrpsee = { version = "0.16", features = ["client"]} diff --git a/crates/rpc/rpc/src/layers/auth_layer.rs b/crates/rpc/rpc/src/layers/auth_layer.rs new file mode 100644 index 000000000..be4932439 --- /dev/null +++ b/crates/rpc/rpc/src/layers/auth_layer.rs @@ -0,0 +1,296 @@ +use http::{Request, Response}; +use http_body::Body; +use pin_project::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tower::{Layer, Service}; + +use super::AuthValidator; + +/// This is an Http middleware layer that acts as an +/// interceptor for `Authorization` headers. Incoming requests are dispatched to +/// an inner [`AuthValidator`]. Invalid requests are blocked and the validator's error response is +/// returned. Valid requests are instead dispatched to the next layer along the chain. +/// +/// # How to integrate +/// ```rust +/// async fn build_layered_rpc_server() { +/// use jsonrpsee::server::ServerBuilder; +/// use reth_rpc::{AuthLayer, JwtAuthValidator, JwtSecret}; +/// use std::net::SocketAddr; +/// +/// const AUTH_PORT: u32 = 8551; +/// const AUTH_ADDR: &str = "0.0.0.0"; +/// const AUTH_SECRET: &str = "f79ae8046bc11c9927afe911db7143c51a806c4a537cc08e0d37140b0192f430"; +/// +/// let addr = format!("{AUTH_ADDR}:{AUTH_PORT}"); +/// let secret = JwtSecret::from_hex(AUTH_SECRET).unwrap(); +/// let validator = JwtAuthValidator::new(secret); +/// let layer = AuthLayer::new(validator); +/// let middleware = tower::ServiceBuilder::default().layer(layer); +/// +/// let _server = ServerBuilder::default() +/// .set_middleware(middleware) +/// .build(addr.parse::().unwrap()) +/// .await +/// .unwrap(); +/// } +/// ``` +#[allow(missing_debug_implementations)] +pub struct AuthLayer { + validator: V, +} + +impl AuthLayer +where + V: AuthValidator, + V::ResponseBody: Body, +{ + /// Creates an instance of [`AuthLayer`][crate::layers::AuthLayer]. + /// `validator` is a generic trait able to validate requests (see [`AuthValidator`]). + pub fn new(validator: V) -> Self { + Self { validator } + } +} + +impl Layer for AuthLayer +where + V: Clone, +{ + type Service = AuthService; + + fn layer(&self, inner: S) -> Self::Service { + AuthService { validator: self.validator.clone(), inner } + } +} + +/// This type is the actual implementation of +/// the middleware. It follows the [`Service`](tower::Service) +/// specification to correctly proxy Http requests +/// to its inner service after headers validation. +#[allow(missing_debug_implementations)] +pub struct AuthService { + /// Performs auth validation logics + validator: V, + /// Recipient of authorized Http requests + inner: S, +} + +impl Service> for AuthService +where + S: Service, Response = Response>, + V: AuthValidator, + ReqBody: Body, + ResBody: Body, +{ + type Response = Response; + type Error = S::Error; + type Future = ResponseFuture; + + /// If we get polled it means that we dispatched an authorized Http request to the inner layer. + /// So we just poll the inner layer ourselves. + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + /// This is the entrypoint of the service. We receive an Http request and check the validity of + /// the authorization header. + /// + /// Returns a future that wraps either: + /// - The inner service future for authorized requests + /// - An error Http response in case of authorization errors + fn call(&mut self, req: Request) -> Self::Future { + match self.validator.validate(req.headers()) { + Ok(_) => ResponseFuture::future(self.inner.call(req)), + Err(res) => ResponseFuture::invalid_auth(res), + } + } +} + +#[pin_project] +#[allow(missing_debug_implementations)] +pub struct ResponseFuture { + #[pin] + kind: Kind, +} + +impl ResponseFuture +where + B: Body, +{ + fn future(future: F) -> Self { + Self { kind: Kind::Future { future } } + } + + fn invalid_auth(err_res: Response) -> Self { + Self { kind: Kind::Error { response: Some(err_res) } } + } +} + +#[pin_project(project = KindProj)] +enum Kind { + Future { + #[pin] + future: F, + }, + Error { + response: Option>, + }, +} + +impl Future for ResponseFuture +where + F: Future, E>>, + B: Body, +{ + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project().kind.project() { + KindProj::Future { future } => future.poll(cx), + KindProj::Error { response } => { + let response = response.take().unwrap(); + Poll::Ready(Ok(response)) + } + } + } +} + +#[cfg(test)] +mod tests { + + use http::{header, Method, Request, StatusCode}; + use hyper::{body, Body}; + use jsonrpsee::{ + server::{RandomStringIdProvider, ServerBuilder, ServerHandle}, + RpcModule, + }; + use std::{ + net::SocketAddr, + time::{SystemTime, UNIX_EPOCH}, + }; + + use super::AuthLayer; + use crate::{layers::jwt_secret::Claims, JwtAuthValidator, JwtError, JwtSecret}; + + const AUTH_PORT: u32 = 8551; + const AUTH_ADDR: &str = "0.0.0.0"; + const SECRET: &str = "f79ae8046bc11c9927afe911db7143c51a806c4a537cc08e0d37140b0192f430"; + + #[tokio::test] + async fn test_jwt_layer() { + // We group all tests into one to avoid individual #[tokio::test] + // to concurrently spawn a server on the same port. + valid_jwt().await; + missing_jwt_error().await; + wrong_jwt_signature_error().await; + invalid_issuance_timestamp_error().await; + jwt_decode_error().await; + } + + async fn valid_jwt() { + let claims = Claims { iat: to_u64(SystemTime::now()), exp: 10000000000 }; + let secret = JwtSecret::from_hex(SECRET).unwrap(); // Same secret as the server + let jwt = secret.encode(&claims).unwrap(); + let (status, _) = send_request(Some(jwt)).await; + assert_eq!(status, StatusCode::OK); + } + + async fn missing_jwt_error() { + let (status, body) = send_request(None).await; + let expected = JwtError::MissingOrInvalidAuthorizationHeader; + assert_eq!(status, StatusCode::UNAUTHORIZED); + assert_eq!(body, expected.to_string()); + } + + async fn wrong_jwt_signature_error() { + // This secret is different from the server. This will generate a + // different signature + let secret = JwtSecret::random(); + let claims = Claims { iat: to_u64(SystemTime::now()), exp: 10000000000 }; + let jwt = secret.encode(&claims).unwrap(); + + let (status, body) = send_request(Some(jwt)).await; + let expected = JwtError::InvalidSignature; + assert_eq!(status, StatusCode::UNAUTHORIZED); + assert_eq!(body, expected.to_string()); + } + + async fn invalid_issuance_timestamp_error() { + let secret = JwtSecret::from_hex(SECRET).unwrap(); // Same secret as the server + + let iat = to_u64(SystemTime::now()) + 1000; + let claims = Claims { iat, exp: 10000000000 }; + let jwt = secret.encode(&claims).unwrap(); + + let (status, body) = send_request(Some(jwt)).await; + let expected = JwtError::InvalidIssuanceTimestamp; + assert_eq!(status, StatusCode::UNAUTHORIZED); + assert_eq!(body, expected.to_string()); + } + + async fn jwt_decode_error() { + let jwt = "this jwt has serious encoding problems".to_string(); + let (status, body) = send_request(Some(jwt)).await; + assert_eq!(status, StatusCode::UNAUTHORIZED); + assert_eq!(body, "JWT decoding error Error(InvalidToken)".to_string()); + } + + async fn send_request(jwt: Option) -> (StatusCode, String) { + let server = spawn_server().await; + let client = hyper::Client::new(); + + let jwt = jwt.unwrap_or("".into()); + let address = format!("http://{AUTH_ADDR}:{AUTH_PORT}"); + let bearer = format!("Bearer {jwt}"); + let body = r#"{"jsonrpc": "2.0", "method": "greet_melkor", "params": [], "id": 1}"#; + + let req = Request::builder() + .method(Method::POST) + .header(header::AUTHORIZATION, bearer) + .header(header::CONTENT_TYPE, "application/json") + .uri(address) + .body(Body::from(body)) + .unwrap(); + + let res = client.request(req).await.unwrap(); + let status = res.status(); + let body_bytes = body::to_bytes(res.into_body()).await.unwrap(); + let body = String::from_utf8(body_bytes.to_vec()).expect("response was not valid utf-8"); + + server.stop().unwrap(); + server.stopped().await; + + (status, body) + } + + /// Spawn a new RPC server equipped with a JwtLayer auth middleware. + async fn spawn_server() -> ServerHandle { + let secret = JwtSecret::from_hex(SECRET).unwrap(); + let addr = format!("{AUTH_ADDR}:{AUTH_PORT}"); + let validator = JwtAuthValidator::new(secret); + let layer = AuthLayer::new(validator); + let middleware = tower::ServiceBuilder::default().layer(layer); + + // Create a layered server + let server = ServerBuilder::default() + .set_id_provider(RandomStringIdProvider::new(16)) + .set_middleware(middleware) + .build(addr.parse::().unwrap()) + .await + .unwrap(); + + // Create a mock rpc module + let mut module = RpcModule::new(()); + module.register_method("greet_melkor", |_, _| Ok("You are the dark lord")).unwrap(); + + server.start(module).unwrap() + } + + fn to_u64(time: SystemTime) -> u64 { + time.duration_since(UNIX_EPOCH).unwrap().as_secs() + } +} diff --git a/crates/rpc/rpc/src/layers/jwt_secret.rs b/crates/rpc/rpc/src/layers/jwt_secret.rs new file mode 100644 index 000000000..f9340b6cc --- /dev/null +++ b/crates/rpc/rpc/src/layers/jwt_secret.rs @@ -0,0 +1,270 @@ +use hex::encode as hex_encode; +use jsonwebtoken::{decode, errors::ErrorKind, Algorithm, DecodingKey, Validation}; +use rand::Rng; +use serde::{Deserialize, Serialize}; +use std::{ + collections::hash_map::DefaultHasher, + hash::{Hash, Hasher}, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; +use thiserror::Error; + +/// Errors returned by the [`JwtSecret`][crate::layers::JwtSecret] +#[derive(Error, Debug)] +#[allow(missing_docs)] +pub enum JwtError { + #[error(transparent)] + JwtSecretHexDecodeError(#[from] hex::FromHexError), + #[error("JWT key is expected to have a length of {0} digits. {1} digits key provided")] + InvalidLength(usize, usize), + #[error("Unsupported signature algorithm. Only HS256 is supported")] + UnsupportedSignatureAlgorithm, + #[error("The provided signature is invalid")] + InvalidSignature, + #[error("The iat (issued-at) claim is not within +-60 seconds from the current time")] + InvalidIssuanceTimestamp, + #[error("Autorization header is missing or invalid")] + MissingOrInvalidAuthorizationHeader, + #[error("JWT decoding error {0}")] + JwtDecodingError(String), +} + +/// Length of the hex-encoded 256 bit secret key. +/// A 256-bit encoded string in Rust has a length of 64 digits because each digit represents 4 bits +/// of data. In hexadecimal representation, each digit can have 16 possible values (0-9 and A-F), so +/// 4 bits can be represented using a single hex digit. Therefore, to represent a 256-bit string, +/// we need 64 hexadecimal digits (256 bits รท 4 bits per digit = 64 digits). +const JWT_SECRET_LEN: usize = 64; + +/// The JWT `iat` (issued-at) claim cannot exceed +-60 seconds from the current time. +const JWT_MAX_IAT_DIFF: Duration = Duration::from_secs(60); + +/// The execution layer client MUST support at least the following alg HMAC + SHA256 (HS256) +const JWT_SIGNATURE_ALGO: Algorithm = Algorithm::HS256; + +/// Value-object holding a reference to an hex-encoded 256-bit secret key. +/// A JWT secret key is used to secure JWT-based authentication. The secret key is +/// a shared secret between the server and the client and is used to calculate a digital signature +/// for the JWT, which is included in the JWT along with its payload. +/// +/// See also: [Secret key - Engine API specs](https://github.com/ethereum/execution-apis/blob/main/src/engine/authentication.md#key-distribution) +#[derive(Clone)] +pub struct JwtSecret([u8; 32]); + +impl JwtSecret { + /// Creates an instance of [`JwtSecret`][crate::layers::JwtSecret]. + /// + /// Returns an error if one of the following applies: + /// - `hex` is not a valid hexadecimal string + /// - `hex` argument length is less than `JWT_SECRET_LEN` + pub fn from_hex>(hex: S) -> Result { + let hex: &str = hex.as_ref().trim(); + if hex.len() != JWT_SECRET_LEN { + Err(JwtError::InvalidLength(JWT_SECRET_LEN, hex.len())) + } else { + let hex_bytes = hex::decode(hex)?; + // is 32bytes, see length check + let bytes = hex_bytes.try_into().expect("is expected len"); + Ok(JwtSecret(bytes)) + } + } +} + +impl std::fmt::Debug for JwtSecret { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut hasher = DefaultHasher::new(); + let bytes = &self.0; + bytes.hash(&mut hasher); + let hash = format!("{}", hasher.finish()); + f.debug_tuple("JwtSecretHash").field(&hex::encode(hash)).finish() + } +} + +impl JwtSecret { + /// Validates a JWT token along the following rules: + /// - The JWT signature is valid. + /// - The JWT is signed with the `HMAC + SHA256 (HS256)` algorithm. + /// - The JWT `iat` (issued-at) claim is a timestamp within +-60 seconds from the current time. + /// + /// See also: [JWT Claims - Engine API specs](https://github.com/ethereum/execution-apis/blob/main/src/engine/authentication.md#jwt-claims) + pub fn validate(&self, jwt: String) -> Result<(), JwtError> { + let validation = Validation::new(JWT_SIGNATURE_ALGO); + let bytes = &self.0; + + match decode::(&jwt, &DecodingKey::from_secret(bytes), &validation) { + Ok(token) => { + if !token.claims.is_within_time_window() { + Err(JwtError::InvalidIssuanceTimestamp)? + } + } + Err(err) => match *err.kind() { + ErrorKind::InvalidSignature => Err(JwtError::InvalidSignature)?, + ErrorKind::InvalidAlgorithm => Err(JwtError::UnsupportedSignatureAlgorithm)?, + _ => { + let detail = format!("{err:?}"); + Err(JwtError::JwtDecodingError(detail))? + } + }, + }; + + Ok(()) + } + + /// Generates a random [`JwtSecret`][crate::layers::JwtSecret] + /// containing a hex-encoded 256 bit secret key. + pub fn random() -> Self { + let random_bytes: [u8; 32] = rand::thread_rng().gen(); + let secret = hex_encode(random_bytes); + JwtSecret::from_hex(secret).unwrap() + } + + #[cfg(test)] + pub(crate) fn encode(&self, claims: &Claims) -> Result> { + let bytes = &self.0; + let key = jsonwebtoken::EncodingKey::from_secret(bytes); + let algo = jsonwebtoken::Header::new(Algorithm::HS256); + Ok(jsonwebtoken::encode(&algo, claims, &key)?) + } +} + +/// Claims in JWT are used to represent a set of information about an entity. +/// Claims are essentially key-value pairs that are encoded as JSON objects and included in the +/// payload of a JWT. They are used to transmit information such as the identity of the entity, the +/// time the JWT was issued, and the expiration time of the JWT, among others. +/// +/// The Engine API spec requires that just the `iat` (issued-at) claim is provided. +/// It ignores claims that are optional or additional for this specification. +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct Claims { + /// The "iat" value MUST be a number containing a NumericDate value. + /// According to the RFC A NumericDate represents the number of seconds since + /// the UNIX_EPOCH. + /// - [`RFC-7519 - Spec`](https://www.rfc-editor.org/rfc/rfc7519#section-4.1.6) + /// - [`RFC-7519 - Notations`](https://www.rfc-editor.org/rfc/rfc7519#section-2) + pub(crate) iat: u64, + pub(crate) exp: u64, +} + +impl Claims { + fn is_within_time_window(&self) -> bool { + let now = SystemTime::now(); + let now_secs = now.duration_since(UNIX_EPOCH).unwrap().as_secs(); + now_secs.abs_diff(self.iat) <= JWT_MAX_IAT_DIFF.as_secs() + } +} + +#[cfg(test)] +mod tests { + use super::{Claims, JwtError, JwtSecret}; + use crate::layers::jwt_secret::JWT_MAX_IAT_DIFF; + use jsonwebtoken::{encode, Algorithm, EncodingKey, Header}; + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + + #[test] + fn from_hex() { + let key = "f79ae8046bc11c9927afe911db7143c51a806c4a537cc08e0d37140b0192f430"; + let secret: Result = JwtSecret::from_hex(key); + assert!(matches!(secret, Ok(_))); + + let secret: Result = JwtSecret::from_hex(key); + assert!(matches!(secret, Ok(_))); + } + + #[test] + fn original_key_integrity_across_transformations() { + let original = "f79ae8046bc11c9927afe911db7143c51a806c4a537cc08e0d37140b0192f430"; + let secret = JwtSecret::from_hex(original).unwrap(); + let bytes = &secret.0; + let computed = hex::encode(bytes); + assert_eq!(original, computed); + } + + #[test] + fn secret_has_64_hex_digits() { + let expected_len = 64; + let secret = JwtSecret::random(); + let hex = hex::encode(secret.0); + assert_eq!(hex.len(), expected_len); + } + + #[test] + fn creation_error_wrong_len() { + let hex = "f79ae8046"; + let result = JwtSecret::from_hex(hex); + assert!(matches!(result, Err(JwtError::InvalidLength(_, _)))); + } + + #[test] + fn creation_error_wrong_hex_string() { + let hex: String = "This__________Is__________Not_______An____Hex_____________String".into(); + let result = JwtSecret::from_hex(hex); + assert!(matches!(result, Err(JwtError::JwtSecretHexDecodeError(_)))); + } + + #[test] + fn validation_ok() { + let secret = JwtSecret::random(); + let claims = Claims { iat: to_u64(SystemTime::now()), exp: 10000000000 }; + let jwt: String = secret.encode(&claims).unwrap(); + + let result = secret.validate(jwt); + + assert!(matches!(result, Ok(()))); + } + + #[test] + fn validation_error_iat_out_of_window() { + let secret = JwtSecret::random(); + + // Check past 'iat' claim more than 60 secs + let offset = Duration::from_secs(JWT_MAX_IAT_DIFF.as_secs() + 1); + let out_of_window_time = SystemTime::now().checked_sub(offset).unwrap(); + let claims = Claims { iat: to_u64(out_of_window_time), exp: 10000000000 }; + let jwt: String = secret.encode(&claims).unwrap(); + + let result = secret.validate(jwt); + + assert!(matches!(result, Err(JwtError::InvalidIssuanceTimestamp))); + + // Check future 'iat' claim more than 60 secs + let offset = Duration::from_secs(JWT_MAX_IAT_DIFF.as_secs() + 1); + let out_of_window_time = SystemTime::now().checked_add(offset).unwrap(); + let claims = Claims { iat: to_u64(out_of_window_time), exp: 10000000000 }; + let jwt: String = secret.encode(&claims).unwrap(); + + let result = secret.validate(jwt); + + assert!(matches!(result, Err(JwtError::InvalidIssuanceTimestamp))); + } + + #[test] + fn validation_error_wrong_signature() { + let secret_1 = JwtSecret::random(); + let claims = Claims { iat: to_u64(SystemTime::now()), exp: 10000000000 }; + let jwt: String = secret_1.encode(&claims).unwrap(); + + // A different secret will generate a different signature. + let secret_2 = JwtSecret::random(); + let result = secret_2.validate(jwt); + assert!(matches!(result, Err(JwtError::InvalidSignature))); + } + + #[test] + fn validation_error_unsupported_algorithm() { + let secret = JwtSecret::random(); + let bytes = &secret.0; + + let key = EncodingKey::from_secret(bytes); + let unsupported_algo = Header::new(Algorithm::HS384); + + let claims = Claims { iat: to_u64(SystemTime::now()), exp: 10000000000 }; + let jwt: String = encode(&unsupported_algo, &claims, &key).unwrap(); + let result = secret.validate(jwt); + + assert!(matches!(result, Err(JwtError::UnsupportedSignatureAlgorithm))); + } + + fn to_u64(time: SystemTime) -> u64 { + time.duration_since(UNIX_EPOCH).unwrap().as_secs() + } +} diff --git a/crates/rpc/rpc/src/layers/jwt_validator.rs b/crates/rpc/rpc/src/layers/jwt_validator.rs new file mode 100644 index 000000000..0626b37d9 --- /dev/null +++ b/crates/rpc/rpc/src/layers/jwt_validator.rs @@ -0,0 +1,101 @@ +use http::{header, HeaderMap, Response, StatusCode}; +use tracing::error; + +use crate::{AuthValidator, JwtError, JwtSecret}; + +/// Implements JWT validation logics and integrates +/// to an Http [`AuthLayer`][crate::layers::AuthLayer] +/// by implementing the [`AuthValidator`] trait. +#[derive(Clone)] +#[allow(missing_debug_implementations)] +pub struct JwtAuthValidator { + secret: JwtSecret, +} + +impl JwtAuthValidator { + /// Creates a new instance of [`JwtAuthValidator`]. + /// Validation logics are implemnted by the `secret` + /// argument (see [`JwtSecret`]). + pub fn new(secret: JwtSecret) -> Self { + Self { secret } + } +} + +impl AuthValidator for JwtAuthValidator { + type ResponseBody = hyper::Body; + + fn validate(&self, headers: &HeaderMap) -> Result<(), Response> { + match get_bearer(headers) { + Some(jwt) => match self.secret.validate(jwt) { + Ok(_) => Ok(()), + Err(e) => { + error!(target = "engine::jwt-validator", "{e}"); + let response = err_response(e); + Err(response) + } + }, + None => { + let e = JwtError::MissingOrInvalidAuthorizationHeader; + error!(target = "engine::jwt-validator", "{e}"); + let response = err_response(e); + Err(response) + } + } + } +} + +/// This is an utility function that retrieves a bearer +/// token from an authorization Http header. +fn get_bearer(headers: &HeaderMap) -> Option { + let header = headers.get(header::AUTHORIZATION)?; + let auth: &str = header.to_str().ok()?; + let prefix = "Bearer "; + let index = auth.find(prefix)?; + let token: &str = &auth[index + prefix.len()..]; + Some(token.into()) +} + +fn err_response(err: JwtError) -> Response { + let body = hyper::Body::from(err.to_string()); + // We build a response from an error message. + // We don't cope with headers or other structured fields. + // Then we are safe to "expect" on the result. + Response::builder() + .status(StatusCode::UNAUTHORIZED) + .body(body) + .expect("This should never happen") +} + +#[cfg(test)] +mod tests { + use http::{header, HeaderMap}; + + use crate::layers::jwt_validator::get_bearer; + + #[test] + fn auth_header_available() { + let jwt = "foo"; + let bearer = format!("Bearer {jwt}"); + let mut headers = HeaderMap::new(); + headers.insert(header::AUTHORIZATION, bearer.parse().unwrap()); + let token = get_bearer(&headers).unwrap(); + assert_eq!(token, jwt); + } + + #[test] + fn auth_header_not_available() { + let headers = HeaderMap::new(); + let token = get_bearer(&headers); + assert!(matches!(token, None)); + } + + #[test] + fn auth_header_malformed() { + let jwt = "foo"; + let bearer = format!("Bea___rer {jwt}"); + let mut headers = HeaderMap::new(); + headers.insert(header::AUTHORIZATION, bearer.parse().unwrap()); + let token = get_bearer(&headers); + assert!(matches!(token, None)); + } +} diff --git a/crates/rpc/rpc/src/layers/mod.rs b/crates/rpc/rpc/src/layers/mod.rs new file mode 100644 index 000000000..f464e6f5d --- /dev/null +++ b/crates/rpc/rpc/src/layers/mod.rs @@ -0,0 +1,21 @@ +use http::{HeaderMap, Response}; + +mod auth_layer; +mod jwt_secret; +mod jwt_validator; +pub use auth_layer::AuthLayer; +pub use jwt_secret::{JwtError, JwtSecret}; +pub use jwt_validator::JwtAuthValidator; + +/// General purpose trait to validate Http Authorization +/// headers. It's supposed to be integrated as a validator +/// trait into an [`AuthLayer`][crate::layers::AuthLayer]. +pub trait AuthValidator { + /// Body type of the error response + type ResponseBody; + + /// This function is invoked by the [`AuthLayer`][crate::layers::AuthLayer] + /// to perform validation on Http headers. + /// The result conveys validation errors in the form of an Http response. + fn validate(&self, headers: &HeaderMap) -> Result<(), Response>; +} diff --git a/crates/rpc/rpc/src/lib.rs b/crates/rpc/rpc/src/lib.rs index 3e9412fb8..be0c627b7 100644 --- a/crates/rpc/rpc/src/lib.rs +++ b/crates/rpc/rpc/src/lib.rs @@ -15,6 +15,7 @@ mod admin; mod debug; mod engine; mod eth; +mod layers; mod net; mod trace; mod web3; @@ -23,6 +24,7 @@ pub use admin::AdminApi; pub use debug::DebugApi; pub use engine::EngineApi; pub use eth::{EthApi, EthApiSpec, EthPubSub}; +pub use layers::{AuthLayer, AuthValidator, JwtAuthValidator, JwtError, JwtSecret}; pub use net::NetApi; pub use trace::TraceApi; pub use web3::Web3Api;