From 06a2ad05b20c78a7fdfd1d0f2f359a2a7ccab1b6 Mon Sep 17 00:00:00 2001 From: Cole MacKenzie Date: Sat, 25 Jan 2025 14:46:32 -0800 Subject: [PATCH 1/2] feat: use dashmap for better concurrency --- Cargo.toml | 1 + src/remote.rs | 37 +++++++++++++++---------------------- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index be142c1..656c180 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ repository = "https://github.com/cmackenzie1/axum-jwt-auth" [dependencies] axum = { version = "0.8", features = ["macros"] } axum-extra = { version = "0.10.0", features = ["typed-header"] } +dashmap = "6.1.0" jsonwebtoken = { version = "9" } reqwest = { version = "0.12", default-features = false, features = [ "json", diff --git a/src/remote.rs b/src/remote.rs index 9b8bdde..db9e2fd 100644 --- a/src/remote.rs +++ b/src/remote.rs @@ -1,5 +1,6 @@ -use std::sync::{Arc, RwLock}; +use std::sync::Arc; +use dashmap::DashMap; use jsonwebtoken::{jwk::JwkSet, DecodingKey, TokenData, Validation}; use serde::de::DeserializeOwned; @@ -11,7 +12,7 @@ use crate::{Decoder, Error, JwtDecoder}; pub struct RemoteJwksDecoder { jwks_url: String, cache_duration: std::time::Duration, - keys_cache: RwLock, DecodingKey)>>, + keys_cache: DashMap, validation: Validation, client: reqwest::Client, retry_count: usize, @@ -29,7 +30,7 @@ impl RemoteJwksDecoder { Self { jwks_url, cache_duration: std::time::Duration::from_secs(60 * 60), - keys_cache: RwLock::new(Vec::new()), + keys_cache: DashMap::new(), validation: Validation::default(), client: reqwest::Client::new(), retry_count: 3, @@ -66,17 +67,12 @@ impl RemoteJwksDecoder { .json::() .await?; - let mut jwks_cache = self.keys_cache.write().unwrap(); - *jwks_cache = jwks - .keys - .iter() - .flat_map(|jwk| -> Result<(Option, DecodingKey), Error> { - let key_id = jwk.common.key_id.to_owned(); - let key = DecodingKey::from_jwk(jwk).map_err(Error::Jwt)?; - - Ok((key_id, key)) - }) - .collect(); + self.keys_cache.clear(); + for jwk in jwks.keys.iter() { + let key_id = jwk.common.key_id.to_owned(); + let key = DecodingKey::from_jwk(jwk).map_err(Error::Jwt)?; + self.keys_cache.insert(key_id.unwrap_or_default(), key); + } Ok(()) } @@ -112,19 +108,16 @@ where let header = jsonwebtoken::decode_header(token)?; let target_kid = header.kid; - let jwks_cache = self.keys_cache.read().unwrap(); - // Try to find the key in the cache by kid - let jwk = jwks_cache.iter().find(|(kid, _)| kid == &target_kid); - if let Some((_, key)) = jwk { - return Ok(jsonwebtoken::decode::(token, key, &self.validation)?); + if let Some(key) = self.keys_cache.get(&target_kid.unwrap_or_default()) { + return Ok(jsonwebtoken::decode::(token, key.value(), &self.validation)?); } // Otherwise, try all the keys in the cache, returning the first one that works // If none of them work, return the error from the last one let mut err: Option = None; - for (_, key) in jwks_cache.iter() { - match jsonwebtoken::decode::(token, key, &self.validation) { + for key in self.keys_cache.iter() { + match jsonwebtoken::decode::(token, key.value(), &self.validation) { Ok(token_data) => return Ok(token_data), Err(e) => err = Some(e.into()), } @@ -184,7 +177,7 @@ impl RemoteJwksDecoderBuilder { RemoteJwksDecoder { jwks_url: self.jwks_url, cache_duration: self.cache_duration, - keys_cache: RwLock::new(Vec::new()), + keys_cache: DashMap::new(), validation: self.validation, client: self.client, retry_count: self.retry_count, From b31179fa3e7f4ea10b775094acd251c9431db909 Mon Sep 17 00:00:00 2001 From: Cole MacKenzie Date: Sat, 25 Jan 2025 16:11:01 -0800 Subject: [PATCH 2/2] feat: Add remote JWKS decoder with improved configuration and error handling - Introduce RemoteJwksDecoder with builder pattern for flexible configuration - Add RemoteJwksDecoderConfig with customizable cache duration, retry count, and backoff - Enhance error handling and logging for JWKS key refresh - Update example and test code to support new decoder implementation - Add derive_builder and tracing dependencies --- Cargo.toml | 7 ++ examples/local/local.rs | 63 +++++++++++++---- examples/remote/jwt.key | 51 ++++++++++++++ examples/remote/remote.rs | 113 ++++++++++++++++++++++++++++++ src/axum.rs | 77 ++++++++++++--------- src/lib.rs | 21 ++---- src/local.rs | 11 +-- src/remote.rs | 140 +++++++++++++------------------------- tests/integration_test.rs | 9 ++- 9 files changed, 327 insertions(+), 165 deletions(-) create mode 100644 examples/remote/jwt.key create mode 100644 examples/remote/remote.rs diff --git a/Cargo.toml b/Cargo.toml index 656c180..0d5bf6e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ repository = "https://github.com/cmackenzie1/axum-jwt-auth" axum = { version = "0.8", features = ["macros"] } axum-extra = { version = "0.10.0", features = ["typed-header"] } dashmap = "6.1.0" +derive_builder = "0.20.2" jsonwebtoken = { version = "9" } reqwest = { version = "0.12", default-features = false, features = [ "json", @@ -28,7 +29,13 @@ tokio = { version = "1", default-features = false, features = [ "macros", ] } serde_json = "1" +tracing-subscriber = "0.3.19" +rand = { version = "0.8.5", features = ["small_rng"] } [[example]] name = "local" path = "examples/local/local.rs" + +[[example]] +name = "remote" +path = "examples/remote/remote.rs" diff --git a/examples/local/local.rs b/examples/local/local.rs index 88f1e2b..df07f41 100644 --- a/examples/local/local.rs +++ b/examples/local/local.rs @@ -1,32 +1,34 @@ +use std::sync::Arc; + use axum::{ - extract::FromRef, + extract::{FromRef, State}, response::{IntoResponse, Response}, routing::{get, post}, Json, Router, }; - -use axum_jwt_auth::{Claims, Decoder, JwtDecoderState, LocalDecoder}; +use axum_jwt_auth::{Claims, JwtDecoderState, LocalDecoder}; use chrono::{Duration, Utc}; use jsonwebtoken::{encode, Algorithm, DecodingKey, EncodingKey, Header, Validation}; use serde::{Deserialize, Serialize}; -#[derive(Clone, FromRef)] -struct AppState { - decoder: JwtDecoderState, -} - -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct MyClaims { iat: u64, aud: String, exp: u64, } +#[derive(Clone, FromRef)] +struct AppState { + decoder: JwtDecoderState, +} + async fn index() -> Response { "Hello, World!".into_response() } -async fn user_info(Claims(claims): Claims) -> Response { +#[axum::debug_handler] +async fn user_info(Claims(claims): Claims, State(_state): State) -> Response { Json(claims).into_response() } @@ -52,9 +54,11 @@ async fn main() { let keys = vec![DecodingKey::from_rsa_pem(include_bytes!("jwt.key.pub")).unwrap()]; let mut validation = Validation::new(Algorithm::RS256); validation.set_audience(&["https://example.com"]); - let decoder: Decoder = LocalDecoder::new(keys, validation).into(); + let decoder = LocalDecoder::new(keys, validation); let state = AppState { - decoder: JwtDecoderState { decoder }, + decoder: JwtDecoderState { + decoder: Arc::new(decoder), + }, }; let app = Router::new() @@ -63,8 +67,37 @@ async fn main() { .route("/login", post(login)) .with_state(state); - // run it on localhost:3000 - let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); + // Create client and server + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") + .await + .expect("Failed to bind"); + let client = reqwest::Client::new(); + + // Run server in background + let server_handle = tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + // Make requests to test endpoints + let login_resp = client + .post("http://127.0.0.1:3000/login") + .send() + .await + .expect("Login failed"); + let token = login_resp.text().await.expect("Failed to get token"); + + let user_info = client + .get("http://127.0.0.1:3000/user_info") + .header("Authorization", format!("Bearer {}", token)) + .send() + .await + .expect("User info request failed") + .json::() + .await + .expect("Failed to parse claims"); + + println!("Successfully validated claims: {:?}", user_info); - axum::serve(listener, app).await.unwrap(); + // Clean shutdown + server_handle.abort(); } diff --git a/examples/remote/jwt.key b/examples/remote/jwt.key new file mode 100644 index 0000000..79c4670 --- /dev/null +++ b/examples/remote/jwt.key @@ -0,0 +1,51 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIJKAIBAAKCAgEAsPWeaqsN+2KZu9rlto59XASssMoaVjIxMYXtLyifky1sXS4E +vYFnvr37X63B+lMwuZ3xACc7xsUPK+GXPe6XqZGJdj+Wgf7a3J6FieSNpnrDK4x6 +CMr0iAPgIhoEYp7BUyPKzPv21vMl6A5kJvlAAdxfPm3jhk5NDWHSfiFnWiC7UESA +RgyFl0TlJ+f9H3qaArkzp3Cb+m+wlHpleewOSr9maTPLdIS+ZzZ1ZC4lDIQnetJJ +0kue+o1wAL4VmdBMY8IVxEutPAaZO+9G8eYJywZiDDkcrrqWymDvSUarcB/AOzEQ +jxN6nSSNuW6UbalfnDlGmR0kFK8fopraA4nwU4tG6fAuKTPpOmahC910IRAkedOp +6IrRU+2LmcBQ0oyzukHjXd9o9/5MES2wTDFgZBalVRZCo55vdQt5CtQDQWVUbQ1y +95dm/0EmmgZzWBgiguSKcO2QuqwYIiq5t9uikFleeVQDVnd+V6yZ5wWfnA6H0+dP +w4VTEUkxaTN8jQImQtB9gvj8iknsGX08LGF5WjWh1ewJI0L74Ey5T/ytsXME6Xpn +1qfXB2sr5tPol3KeV8pjuGrAymvaLJZz4ZqNY3f4wULfCsyVasUOdknMm8UmTgPR ++vnDlF+1ItsmN+Jl+RJ1dFkXRDcelCIJS44sMSchnxv47OwnqvBHCPbiUI8CAwEA +AQKCAgBJd+AqdxQZ/2jGNm5SqbvgHUy5JV9j0/jaj7jWcG44A47O7NEpAHXbGjMo +GRLE5A8BsVIidyd5Mc1HsaRCITG0Q+knP+Uz2WRyXhohEtPAf41SIkN0LRby9XDz +l4ukijbHVr/W9PEZct+VBYyNJcRuQVkFqUfiNdYFrUxf82xeXeKGw7nh20cHc6IU +PFu52wPgB5YreTQ4+G/+ZQaGZPvWCrrxCID6wjXu0gxQ6FuXY7KkanQdrCm36krK +9CAxuOpOLIEu+yBUIIUz/fadbZ05PlAstPV0kaETKsWNzZpVtcjwikFOtY6deVSh +3Qggs0YvrRPjc9bMA50FvHaxK26r/0XjXaowa4wCSc0x/JOERI+eWc49/Wwykuh2 +H6yZ/zJAUFVaM7ic9yLKLPDS6lSFH19Q4ed3psd7nsCrjythYhV9PPfKJUwtROl1 +Dbd3KL0DIlwbitnLR23k5PYzXQbEGoJpBqj5ruZnQ0bWlODEQ3F0HQXVlPrJmiGy +4J85G8EepO+j9pSADSikDiKgb06EeINFOP85a7o/E4vHdM5C1un8Z7Jym58wxptY +AjXYw6yD5tr/iXqLzMbeg2pmy7Kg3gLERrT+F7kIxVMt+w5GKPUkxTwAmyoTr4JX +756j9GgKB/1L6Z4dh0aXg5pf7IBbuRzsXEKiQaZbQFJQK12FOQKCAQEA6kWSz/Ug +qp65jw74amsxCfrkLPKWyvAtcB++GA46C8Dnz0Nz3zUe7/o5AfmCNE48Gqc3dYJH +ZG6jgjE3IVCdAnloYMl+RSavGNF9lW9SaxCgQsoN4OrkRpNxMtzFdX0KFnc0dMYo +Ps+51LdDE9XAv40MBf2QVkytJDUq+sxlvFx1rLX4hAAqaaTf8dsNOru7bq0RKKy4 +HKZ7XrXAxlJrYUGaHRJNfM9fjn/MTzPUOIoAdSXjrOU0/MlopTg10Od6pKTQCHqG +rUpf2wO2F4FXZ8wPV65uZqr45lu1Kxpbv4ihCIxR1pBWVuMdfteLgcwLae/LH9Z+ +YImfjpdcdJTQjQKCAQEAwV9AYUTb9ejtfrX0Af6gjtCUXh9tkI1dyJia8+O0AdEF +eYsMz+Z29Ew0qQvytS9KX70xknwOwN/86tYN+gtQP+dSXebkfCnqSgvWO4NjZMp+ +Xqf47ftZ+93wTwFO6YJK4+FuPb+2hC0iGCU8dh4vYdskxtYTR9w+jsOTHNjR9IYL +vpFitnc3o3RFegrPjgUZj05vf9mH5y0qKk5PpsFETu4inFEobXwwAdEG2V3Tr+bA +NS2DQhrfg7h4zVlrUZSN+pzKcPEEcxNzoQ8+xzsGk2vSnrIvjyDvdaRdO3e77hA2 +XWjNPXU4gtTaym3xGMhuSONiOSF+2RBJBUwgLpZkiwKCAQAfBGkkuXrCvFMrGrtP +M7QBc0NkpBXM9rG9Z6Z+ftu2lKrcaTzdL6ZR9Zo4pbVUgYs5qCwSldYn+PITGbsH +4Sl4m2RzdBoQw8dpDMuIzn1mCYR+c0wVHGRu57SUHGDUZmLAiLXcRCQt9MjQ3ha2 +eJWVhvIxlNnYYzyFT7jKDefmYYN/A3TM3UzAQgEYf30n6pUtWSKtdPjHak9pQb0t +RNpMvSfPc43o2Xf4YPlG/0C436Sh3gtf59T1JyGAxolxiERXqi6VAMv2A6PfVoV0 +ZT6SUpUxcbnSRA7CSSAafdnp6QgRHqrzMpcL1/QeyCEDZWWZeBM3uululKoYcffe +w5k9AoIBAE/LoqGA7NPZPsffBcYc8Nx+Lft5NJlF/MFeV/L0r79gJcY2Hx9blxLQ +r5pil9E0pphDVkWAdAYbaB7wHexk5sS4DEE7mmWyVkAgClOcsFNTTDp7TjnGUyeg +Oh4gCBRL8+N9jyRkDEkW5s7X7s8/PYZADDkQ9fvdYuM+yWJKBrnE5uvIytdI8ui8 +fj8SXvvYFugQEerMNUysUo4KqsvBTRLVKesfgnNLn/Pf8deY5FXd/sry8QtCU6Hj +adYzZBnSF5SnRtK8Yn2qNTjtNZa5QMls4QkjtoR1rtr47JAxpJdkkUqSiL8ntB4o +//Aw1iDH9NqXGl1A+TtRgRByjYUsAmkCggEBAJ8ihjXAyveNVNRhIorAKk70Elq7 +lVdcZVks150Hn+A0MmnT3HX9lt/UICjhDnhlgsAJabKdUuta8ksuErqsFzA4m9Ok +hihSN3bjQsFSB/d20tZbr20u+ao219p1XTS6KqtEObYEwlNwNjX3uz0mIstCQk/v +XefdBZVzDVgUap+cip5DshhKasEOPlMUwqSrNgjJVNjPcUIzqR7kk7+QkI/3T46D +gMxczij1/Ib9RTO42/e1AlmaIQCsPQ0lU2fO9BkzWyj6CkMoTSoUvrVVM224iO6H +H4f9Kh9zevvv3ZMPBQBdwjAvuo3RueXukNYJNaI2zjpDDaMP3gIFmO+mqyE= +-----END RSA PRIVATE KEY----- diff --git a/examples/remote/remote.rs b/examples/remote/remote.rs new file mode 100644 index 0000000..19e07de --- /dev/null +++ b/examples/remote/remote.rs @@ -0,0 +1,113 @@ +use std::{sync::Arc, time::Duration}; + +use axum::{routing::get, Json, Router}; +use axum_jwt_auth::{ + JwtDecoder, RemoteJwksDecoder, RemoteJwksDecoderBuilder, RemoteJwksDecoderConfigBuilder, +}; +use dashmap::DashMap; +use jsonwebtoken::{Algorithm, EncodingKey, Header, Validation}; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use tokio; + +#[derive(Debug, Serialize, Deserialize, Clone)] +struct CustomClaims { + sub: String, + name: String, + exp: usize, +} + +// This is a sample JWKS handler. In a real application, you would fetch the JWKS from a remote source. +// For testing purposes, we randomly fail 50% of the time to simulate a remote JWKS endpoint that is not available. +async fn jwks_handler() -> Json { + // Randomly fail 50% of the time + if rand::random::() { + return Json(json!({ + "error": "Internal Server Error", + "message": "Random failure for testing" + })); + } + + // This is a sample JWKS. In a real application, you would generate proper keys or fetch them from a remote source + Json(json!({ + "keys": [{ + "kty": "RSA", + "use": "sig", + "alg": "RS256", + "kid": "1dea0016-46b8-4289-ad7b-226cfaf5305e", + "n": "sPWeaqsN-2KZu9rlto59XASssMoaVjIxMYXtLyifky1sXS4EvYFnvr37X63B-lMwuZ3xACc7xsUPK-GXPe6XqZGJdj-Wgf7a3J6FieSNpnrDK4x6CMr0iAPgIhoEYp7BUyPKzPv21vMl6A5kJvlAAdxfPm3jhk5NDWHSfiFnWiC7UESARgyFl0TlJ-f9H3qaArkzp3Cb-m-wlHpleewOSr9maTPLdIS-ZzZ1ZC4lDIQnetJJ0kue-o1wAL4VmdBMY8IVxEutPAaZO-9G8eYJywZiDDkcrrqWymDvSUarcB_AOzEQjxN6nSSNuW6UbalfnDlGmR0kFK8fopraA4nwU4tG6fAuKTPpOmahC910IRAkedOp6IrRU-2LmcBQ0oyzukHjXd9o9_5MES2wTDFgZBalVRZCo55vdQt5CtQDQWVUbQ1y95dm_0EmmgZzWBgiguSKcO2QuqwYIiq5t9uikFleeVQDVnd-V6yZ5wWfnA6H0-dPw4VTEUkxaTN8jQImQtB9gvj8iknsGX08LGF5WjWh1ewJI0L74Ey5T_ytsXME6Xpn1qfXB2sr5tPol3KeV8pjuGrAymvaLJZz4ZqNY3f4wULfCsyVasUOdknMm8UmTgPR-vnDlF-1ItsmN-Jl-RJ1dFkXRDcelCIJS44sMSchnxv47OwnqvBHCPbiUI8", + "e": "AQAB" + }] + })) +} + +#[tokio::main] +async fn main() { + // Initialize tracing for logging + tracing_subscriber::fmt::init(); + + // Create the Axum router with the JWKS endpoint + let app = Router::new().route("/.well-known/jwks.json", get(jwks_handler)); + + // Spawn the server task + let server_handle = tokio::spawn(async move { + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") + .await + .expect("Failed to bind server"); + println!("JWKS server listening on http://127.0.0.1:3000"); + axum::serve(listener, app) + .await + .expect("Failed to start server"); + }); + + // Create a remote JWKS decoder with custom configuration + let decoder = RemoteJwksDecoderBuilder::default() + .jwks_url("http://127.0.0.1:3000/.well-known/jwks.json".to_string()) + .config( + RemoteJwksDecoderConfigBuilder::default() + .cache_duration(Duration::from_secs(1)) // Low value for testing, in a real application you should use a higher value + .retry_count(3) + .backoff(Duration::from_secs(1)) + .build() + .unwrap(), + ) + .validation(Validation::new(Algorithm::RS256)) + .client(reqwest::Client::new()) + .keys_cache(Arc::new(DashMap::new())) + .build() + .expect("Failed to build decoder"); + + // Spawn a task to periodically refresh the JWKS + let decoder_clone = decoder.clone(); + tokio::spawn(async move { + decoder_clone.refresh_keys_periodically().await; + }); + + // Create a token + let token = jsonwebtoken::encode( + &Header::new(Algorithm::RS256), + &CustomClaims { + sub: "123".to_string(), + name: "John Doe".to_string(), + exp: (chrono::Utc::now().timestamp() + 60 * 60) as usize, + }, + &EncodingKey::from_rsa_pem(include_bytes!("jwt.key")).unwrap(), + ) + .unwrap(); + + // Decode the token + match >::decode(&decoder, &token) { + Ok(token_data) => { + println!("Token successfully decoded: {:?}", token_data.claims); + } + Err(err) => { + eprintln!("Failed to decode token: {:?}", err); + } + } + + // Keep the main task running for a while to see the periodic refresh in action + tokio::time::sleep(Duration::from_secs(60)).await; + + // Clean shutdown + server_handle.abort(); +} diff --git a/src/axum.rs b/src/axum.rs index ebfc2c6..a39ed8d 100644 --- a/src/axum.rs +++ b/src/axum.rs @@ -9,44 +9,15 @@ use axum_extra::TypedHeader; use serde::de::DeserializeOwned; use serde::Deserialize; -use crate::{Decoder, JwtDecoder}; +use crate::Decoder; /// A generic struct for holding the claims of a JWT token. #[derive(Debug, Deserialize)] pub struct Claims(pub T); -pub enum AuthError { - InvalidToken, - MissingToken, - ExpiredToken, - InvalidSignature, - InvalidAudience, - InternalError, -} - -impl IntoResponse for AuthError { - fn into_response(self) -> Response { - let (status, msg) = match self { - AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"), - AuthError::MissingToken => (StatusCode::UNAUTHORIZED, "Missing token"), - AuthError::ExpiredToken => (StatusCode::UNAUTHORIZED, "Expired token"), - AuthError::InvalidSignature => (StatusCode::UNAUTHORIZED, "Invalid signature"), - AuthError::InvalidAudience => (StatusCode::UNAUTHORIZED, "Invalid audience"), - AuthError::InternalError => (StatusCode::INTERNAL_SERVER_ERROR, "Internal error"), - }; - - (status, msg).into_response() - } -} - -#[derive(Clone, FromRef)] -pub struct JwtDecoderState { - pub decoder: Decoder, -} - impl axum::extract::FromRequestParts for Claims where - JwtDecoderState: FromRef, + JwtDecoderState: FromRef, S: Send + Sync, T: DeserializeOwned, { @@ -75,6 +46,48 @@ where _ => Self::Rejection::InternalError, })?; - Ok(token_data.claims) + Ok(Claims(token_data.claims)) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum AuthError { + #[error("Invalid token")] + InvalidToken, + #[error("Missing token")] + MissingToken, + #[error("Expired token")] + ExpiredToken, + #[error("Invalid signature")] + InvalidSignature, + #[error("Invalid audience")] + InvalidAudience, + #[error("Internal error")] + InternalError, +} + +impl IntoResponse for AuthError { + fn into_response(self) -> Response { + let (status, msg) = match self { + AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"), + AuthError::MissingToken => (StatusCode::UNAUTHORIZED, "Missing token"), + AuthError::ExpiredToken => (StatusCode::UNAUTHORIZED, "Expired token"), + AuthError::InvalidSignature => (StatusCode::UNAUTHORIZED, "Invalid signature"), + AuthError::InvalidAudience => (StatusCode::UNAUTHORIZED, "Invalid audience"), + AuthError::InternalError => (StatusCode::INTERNAL_SERVER_ERROR, "Internal error"), + }; + + (status, msg).into_response() + } +} + +#[derive(Clone)] +pub struct JwtDecoderState { + pub decoder: Decoder, +} + +impl FromRef> for Decoder { + fn from_ref(state: &JwtDecoderState) -> Self { + state.decoder.clone() } } diff --git a/src/lib.rs b/src/lib.rs index 68a64bd..be547fe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,7 +25,10 @@ use thiserror::Error; pub use crate::axum::{AuthError, Claims, JwtDecoderState}; pub use crate::local::LocalDecoder; -pub use crate::remote::{RemoteJwksDecoder, RemoteJwksDecoderBuilder}; +pub use crate::remote::{ + RemoteJwksDecoder, RemoteJwksDecoderBuilder, RemoteJwksDecoderConfig, + RemoteJwksDecoderConfigBuilder, +}; #[derive(Debug, thiserror::Error)] pub enum Error { @@ -49,17 +52,5 @@ where fn decode(&self, token: &str) -> Result, Error>; } -#[derive(Clone)] -pub enum Decoder { - Local(Arc), - Remote(Arc), -} - -impl JwtDecoder for Decoder { - fn decode(&self, token: &str) -> Result, Error> { - match self { - Self::Local(decoder) => decoder.decode(token), - Self::Remote(decoder) => decoder.decode(token), - } - } -} +/// A type alias for a decoder that can be used as a state in an Axum application. +pub type Decoder = Arc + Send + Sync>; diff --git a/src/local.rs b/src/local.rs index 0d36db8..4f40035 100644 --- a/src/local.rs +++ b/src/local.rs @@ -1,23 +1,16 @@ -use std::sync::Arc; - use jsonwebtoken::{DecodingKey, TokenData, Validation}; use serde::de::DeserializeOwned; -use crate::{Decoder, Error, JwtDecoder}; +use crate::{Error, JwtDecoder}; /// Local decoder /// It uses the given JWKS to decode the JWT tokens. +#[derive(Clone)] pub struct LocalDecoder { keys: Vec, validation: Validation, } -impl From for Decoder { - fn from(decoder: LocalDecoder) -> Self { - Self::Local(Arc::new(decoder)) - } -} - impl LocalDecoder { pub fn new(keys: Vec, validation: Validation) -> Self { Self { keys, validation } diff --git a/src/remote.rs b/src/remote.rs index db9e2fd..8579c6d 100644 --- a/src/remote.rs +++ b/src/remote.rs @@ -1,45 +1,55 @@ use std::sync::Arc; use dashmap::DashMap; +use derive_builder::Builder; use jsonwebtoken::{jwk::JwkSet, DecodingKey, TokenData, Validation}; use serde::de::DeserializeOwned; -use crate::{Decoder, Error, JwtDecoder}; +use crate::{Error, JwtDecoder}; + +const DEFAULT_CACHE_DURATION: std::time::Duration = std::time::Duration::from_secs(60 * 60); // 1 hour +const DEFAULT_RETRY_COUNT: usize = 3; // 3 attempts +const DEFAULT_BACKOFF: std::time::Duration = std::time::Duration::from_secs(1); // 1 second + +#[derive(Debug, Clone, Builder)] +pub struct RemoteJwksDecoderConfig { + pub cache_duration: std::time::Duration, + pub retry_count: usize, + pub backoff: std::time::Duration, +} + +impl Default for RemoteJwksDecoderConfig { + fn default() -> Self { + Self { + cache_duration: DEFAULT_CACHE_DURATION, + retry_count: DEFAULT_RETRY_COUNT, + backoff: DEFAULT_BACKOFF, + } + } +} /// Remote JWKS decoder. /// It fetches the JWKS from the given URL and caches it for the given duration. /// It uses the cached JWKS to decode the JWT tokens. +#[derive(Clone, Builder)] pub struct RemoteJwksDecoder { jwks_url: String, - cache_duration: std::time::Duration, - keys_cache: DashMap, + config: RemoteJwksDecoderConfig, + keys_cache: Arc>, validation: Validation, client: reqwest::Client, - retry_count: usize, - backoff: std::time::Duration, -} - -impl From for Decoder { - fn from(decoder: RemoteJwksDecoder) -> Self { - Self::Remote(Arc::new(decoder)) - } } impl RemoteJwksDecoder { pub fn new(jwks_url: String) -> Self { - Self { - jwks_url, - cache_duration: std::time::Duration::from_secs(60 * 60), - keys_cache: DashMap::new(), - validation: Validation::default(), - client: reqwest::Client::new(), - retry_count: 3, - backoff: std::time::Duration::from_secs(1), - } + RemoteJwksDecoderBuilder::default() + .jwks_url(jwks_url) + .build() + .unwrap() } async fn refresh_keys(&self) -> Result<(), Error> { - let max_attempts = self.retry_count; + let max_attempts = self.config.retry_count; let mut attempt = 0; let mut err = None; @@ -49,7 +59,7 @@ impl RemoteJwksDecoder { Err(e) => { err = Some(e); attempt += 1; - tokio::time::sleep(self.backoff).await; + tokio::time::sleep(self.config.backoff).await; } } } @@ -84,18 +94,19 @@ impl RemoteJwksDecoder { /// succeeds or the universe ends, whichever comes first. pub async fn refresh_keys_periodically(&self) { loop { + tracing::info!("Refreshing JWKS"); match self.refresh_keys().await { Ok(_) => {} Err(err) => { // log the error and continue with stale keys tracing::error!( "Failed to refresh JWKS after {} attempts: {:?}", - self.retry_count, + self.config.retry_count, err ); } } - tokio::time::sleep(self.cache_duration).await; + tokio::time::sleep(self.config.cache_duration).await; } } } @@ -107,81 +118,28 @@ where fn decode(&self, token: &str) -> Result, Error> { let header = jsonwebtoken::decode_header(token)?; let target_kid = header.kid; - - // Try to find the key in the cache by kid - if let Some(key) = self.keys_cache.get(&target_kid.unwrap_or_default()) { - return Ok(jsonwebtoken::decode::(token, key.value(), &self.validation)?); + if let Some(kid) = target_kid { + // Try to find the key in the cache by kid + if let Some(key) = self.keys_cache.get(&kid) { + return Ok(jsonwebtoken::decode::( + token, + key.value(), + &self.validation, + )?); + } + return Err(Error::KeyNotFound(Some(kid))); } // Otherwise, try all the keys in the cache, returning the first one that works // If none of them work, return the error from the last one - let mut err: Option = None; for key in self.keys_cache.iter() { match jsonwebtoken::decode::(token, key.value(), &self.validation) { Ok(token_data) => return Ok(token_data), - Err(e) => err = Some(e.into()), + Err(e) => { + tracing::debug!("Failed to decode token with key {}: {:?}", key.key(), e); + } } } - - Err(err.unwrap()) - } -} - -pub struct RemoteJwksDecoderBuilder { - jwks_url: String, - cache_duration: std::time::Duration, - validation: Validation, - client: reqwest::Client, - retry_count: usize, - backoff: std::time::Duration, -} - -impl RemoteJwksDecoderBuilder { - pub fn new(jwks_url: String) -> Self { - Self { - jwks_url, - cache_duration: std::time::Duration::from_secs(60 * 60), - validation: Validation::default(), - client: reqwest::Client::new(), - retry_count: 3, - backoff: std::time::Duration::from_secs(1), - } - } - - pub fn with_jwks_cache_duration(mut self, jwks_cache_duration: std::time::Duration) -> Self { - self.cache_duration = jwks_cache_duration; - self - } - - pub fn with_client(mut self, client: reqwest::Client) -> Self { - self.client = client; - self - } - - pub fn with_validation(mut self, validation: Validation) -> Self { - self.validation = validation; - self - } - - pub fn with_retry_count(mut self, retry_count: usize) -> Self { - self.retry_count = retry_count; - self - } - - pub fn with_backoff(mut self, backoff: std::time::Duration) -> Self { - self.backoff = backoff; - self - } - - pub fn build(self) -> RemoteJwksDecoder { - RemoteJwksDecoder { - jwks_url: self.jwks_url, - cache_duration: self.cache_duration, - keys_cache: DashMap::new(), - validation: self.validation, - client: self.client, - retry_count: self.retry_count, - backoff: self.backoff, - } + Err(Error::KeyNotFound(target_kid)) } } diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 02d04eb..324b142 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use axum::{ extract::FromRef, response::IntoResponse, @@ -12,10 +14,10 @@ use serde::{Deserialize, Serialize}; #[derive(Clone, FromRef)] struct AppState { - decoder: JwtDecoderState, + decoder: JwtDecoderState, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct MyClaims { iat: u64, aud: String, @@ -30,7 +32,8 @@ async fn token_is_valid() { let mut validation = Validation::new(Algorithm::RS256); validation.set_audience(&["https://example.com"]); - let decoder: Decoder = LocalDecoder::new(vec![decoding_key.to_owned()], validation).into(); + let decoder: Decoder = + Arc::new(LocalDecoder::new(vec![decoding_key.to_owned()], validation)); let state = AppState { decoder: JwtDecoderState { decoder }, };