Skip to content

feat: Add remote JWKS decoder with improved configuration and error handling #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ repository = "https://github.com/cmackenzie1/axum-jwt-auth"
[dependencies]
axum = { version = "0.8", features = ["macros"] }
axum-extra = { version = "0.10.0", features = ["typed-header"] }
dashmap = "6.1.0"
derive_builder = "0.20.2"
jsonwebtoken = { version = "9" }
reqwest = { version = "0.12", default-features = false, features = [
"json",
Expand All @@ -27,7 +29,13 @@ tokio = { version = "1", default-features = false, features = [
"macros",
] }
serde_json = "1"
tracing-subscriber = "0.3.19"
rand = { version = "0.8.5", features = ["small_rng"] }

[[example]]
name = "local"
path = "examples/local/local.rs"

[[example]]
name = "remote"
path = "examples/remote/remote.rs"
63 changes: 48 additions & 15 deletions examples/local/local.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,34 @@
use std::sync::Arc;

use axum::{
extract::FromRef,
extract::{FromRef, State},
response::{IntoResponse, Response},
routing::{get, post},
Json, Router,
};

use axum_jwt_auth::{Claims, Decoder, JwtDecoderState, LocalDecoder};
use axum_jwt_auth::{Claims, JwtDecoderState, LocalDecoder};
use chrono::{Duration, Utc};
use jsonwebtoken::{encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};

#[derive(Clone, FromRef)]
struct AppState {
decoder: JwtDecoderState,
}

#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct MyClaims {
iat: u64,
aud: String,
exp: u64,
}

#[derive(Clone, FromRef)]
struct AppState {
decoder: JwtDecoderState<MyClaims>,
}

async fn index() -> Response {
"Hello, World!".into_response()
}

async fn user_info(Claims(claims): Claims<MyClaims>) -> Response {
#[axum::debug_handler]
async fn user_info(Claims(claims): Claims<MyClaims>, State(_state): State<AppState>) -> Response {
Json(claims).into_response()
}

Expand All @@ -52,9 +54,11 @@ async fn main() {
let keys = vec![DecodingKey::from_rsa_pem(include_bytes!("jwt.key.pub")).unwrap()];
let mut validation = Validation::new(Algorithm::RS256);
validation.set_audience(&["https://example.com"]);
let decoder: Decoder = LocalDecoder::new(keys, validation).into();
let decoder = LocalDecoder::new(keys, validation);
let state = AppState {
decoder: JwtDecoderState { decoder },
decoder: JwtDecoderState {
decoder: Arc::new(decoder),
},
};

let app = Router::new()
Expand All @@ -63,8 +67,37 @@ async fn main() {
.route("/login", post(login))
.with_state(state);

// run it on localhost:3000
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
// Create client and server
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
.await
.expect("Failed to bind");
let client = reqwest::Client::new();

// Run server in background
let server_handle = tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});

// Make requests to test endpoints
let login_resp = client
.post("http://127.0.0.1:3000/login")
.send()
.await
.expect("Login failed");
let token = login_resp.text().await.expect("Failed to get token");

let user_info = client
.get("http://127.0.0.1:3000/user_info")
.header("Authorization", format!("Bearer {}", token))
.send()
.await
.expect("User info request failed")
.json::<MyClaims>()
.await
.expect("Failed to parse claims");

println!("Successfully validated claims: {:?}", user_info);

axum::serve(listener, app).await.unwrap();
// Clean shutdown
server_handle.abort();
}
51 changes: 51 additions & 0 deletions examples/remote/jwt.key
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
-----BEGIN RSA PRIVATE KEY-----
MIIJKAIBAAKCAgEAsPWeaqsN+2KZu9rlto59XASssMoaVjIxMYXtLyifky1sXS4E
vYFnvr37X63B+lMwuZ3xACc7xsUPK+GXPe6XqZGJdj+Wgf7a3J6FieSNpnrDK4x6
CMr0iAPgIhoEYp7BUyPKzPv21vMl6A5kJvlAAdxfPm3jhk5NDWHSfiFnWiC7UESA
RgyFl0TlJ+f9H3qaArkzp3Cb+m+wlHpleewOSr9maTPLdIS+ZzZ1ZC4lDIQnetJJ
0kue+o1wAL4VmdBMY8IVxEutPAaZO+9G8eYJywZiDDkcrrqWymDvSUarcB/AOzEQ
jxN6nSSNuW6UbalfnDlGmR0kFK8fopraA4nwU4tG6fAuKTPpOmahC910IRAkedOp
6IrRU+2LmcBQ0oyzukHjXd9o9/5MES2wTDFgZBalVRZCo55vdQt5CtQDQWVUbQ1y
95dm/0EmmgZzWBgiguSKcO2QuqwYIiq5t9uikFleeVQDVnd+V6yZ5wWfnA6H0+dP
w4VTEUkxaTN8jQImQtB9gvj8iknsGX08LGF5WjWh1ewJI0L74Ey5T/ytsXME6Xpn
1qfXB2sr5tPol3KeV8pjuGrAymvaLJZz4ZqNY3f4wULfCsyVasUOdknMm8UmTgPR
+vnDlF+1ItsmN+Jl+RJ1dFkXRDcelCIJS44sMSchnxv47OwnqvBHCPbiUI8CAwEA
AQKCAgBJd+AqdxQZ/2jGNm5SqbvgHUy5JV9j0/jaj7jWcG44A47O7NEpAHXbGjMo
GRLE5A8BsVIidyd5Mc1HsaRCITG0Q+knP+Uz2WRyXhohEtPAf41SIkN0LRby9XDz
l4ukijbHVr/W9PEZct+VBYyNJcRuQVkFqUfiNdYFrUxf82xeXeKGw7nh20cHc6IU
PFu52wPgB5YreTQ4+G/+ZQaGZPvWCrrxCID6wjXu0gxQ6FuXY7KkanQdrCm36krK
9CAxuOpOLIEu+yBUIIUz/fadbZ05PlAstPV0kaETKsWNzZpVtcjwikFOtY6deVSh
3Qggs0YvrRPjc9bMA50FvHaxK26r/0XjXaowa4wCSc0x/JOERI+eWc49/Wwykuh2
H6yZ/zJAUFVaM7ic9yLKLPDS6lSFH19Q4ed3psd7nsCrjythYhV9PPfKJUwtROl1
Dbd3KL0DIlwbitnLR23k5PYzXQbEGoJpBqj5ruZnQ0bWlODEQ3F0HQXVlPrJmiGy
4J85G8EepO+j9pSADSikDiKgb06EeINFOP85a7o/E4vHdM5C1un8Z7Jym58wxptY
AjXYw6yD5tr/iXqLzMbeg2pmy7Kg3gLERrT+F7kIxVMt+w5GKPUkxTwAmyoTr4JX
756j9GgKB/1L6Z4dh0aXg5pf7IBbuRzsXEKiQaZbQFJQK12FOQKCAQEA6kWSz/Ug
qp65jw74amsxCfrkLPKWyvAtcB++GA46C8Dnz0Nz3zUe7/o5AfmCNE48Gqc3dYJH
ZG6jgjE3IVCdAnloYMl+RSavGNF9lW9SaxCgQsoN4OrkRpNxMtzFdX0KFnc0dMYo
Ps+51LdDE9XAv40MBf2QVkytJDUq+sxlvFx1rLX4hAAqaaTf8dsNOru7bq0RKKy4
HKZ7XrXAxlJrYUGaHRJNfM9fjn/MTzPUOIoAdSXjrOU0/MlopTg10Od6pKTQCHqG
rUpf2wO2F4FXZ8wPV65uZqr45lu1Kxpbv4ihCIxR1pBWVuMdfteLgcwLae/LH9Z+
YImfjpdcdJTQjQKCAQEAwV9AYUTb9ejtfrX0Af6gjtCUXh9tkI1dyJia8+O0AdEF
eYsMz+Z29Ew0qQvytS9KX70xknwOwN/86tYN+gtQP+dSXebkfCnqSgvWO4NjZMp+
Xqf47ftZ+93wTwFO6YJK4+FuPb+2hC0iGCU8dh4vYdskxtYTR9w+jsOTHNjR9IYL
vpFitnc3o3RFegrPjgUZj05vf9mH5y0qKk5PpsFETu4inFEobXwwAdEG2V3Tr+bA
NS2DQhrfg7h4zVlrUZSN+pzKcPEEcxNzoQ8+xzsGk2vSnrIvjyDvdaRdO3e77hA2
XWjNPXU4gtTaym3xGMhuSONiOSF+2RBJBUwgLpZkiwKCAQAfBGkkuXrCvFMrGrtP
M7QBc0NkpBXM9rG9Z6Z+ftu2lKrcaTzdL6ZR9Zo4pbVUgYs5qCwSldYn+PITGbsH
4Sl4m2RzdBoQw8dpDMuIzn1mCYR+c0wVHGRu57SUHGDUZmLAiLXcRCQt9MjQ3ha2
eJWVhvIxlNnYYzyFT7jKDefmYYN/A3TM3UzAQgEYf30n6pUtWSKtdPjHak9pQb0t
RNpMvSfPc43o2Xf4YPlG/0C436Sh3gtf59T1JyGAxolxiERXqi6VAMv2A6PfVoV0
ZT6SUpUxcbnSRA7CSSAafdnp6QgRHqrzMpcL1/QeyCEDZWWZeBM3uululKoYcffe
w5k9AoIBAE/LoqGA7NPZPsffBcYc8Nx+Lft5NJlF/MFeV/L0r79gJcY2Hx9blxLQ
r5pil9E0pphDVkWAdAYbaB7wHexk5sS4DEE7mmWyVkAgClOcsFNTTDp7TjnGUyeg
Oh4gCBRL8+N9jyRkDEkW5s7X7s8/PYZADDkQ9fvdYuM+yWJKBrnE5uvIytdI8ui8
fj8SXvvYFugQEerMNUysUo4KqsvBTRLVKesfgnNLn/Pf8deY5FXd/sry8QtCU6Hj
adYzZBnSF5SnRtK8Yn2qNTjtNZa5QMls4QkjtoR1rtr47JAxpJdkkUqSiL8ntB4o
//Aw1iDH9NqXGl1A+TtRgRByjYUsAmkCggEBAJ8ihjXAyveNVNRhIorAKk70Elq7
lVdcZVks150Hn+A0MmnT3HX9lt/UICjhDnhlgsAJabKdUuta8ksuErqsFzA4m9Ok
hihSN3bjQsFSB/d20tZbr20u+ao219p1XTS6KqtEObYEwlNwNjX3uz0mIstCQk/v
XefdBZVzDVgUap+cip5DshhKasEOPlMUwqSrNgjJVNjPcUIzqR7kk7+QkI/3T46D
gMxczij1/Ib9RTO42/e1AlmaIQCsPQ0lU2fO9BkzWyj6CkMoTSoUvrVVM224iO6H
H4f9Kh9zevvv3ZMPBQBdwjAvuo3RueXukNYJNaI2zjpDDaMP3gIFmO+mqyE=
-----END RSA PRIVATE KEY-----
113 changes: 113 additions & 0 deletions examples/remote/remote.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
use std::{sync::Arc, time::Duration};

use axum::{routing::get, Json, Router};
use axum_jwt_auth::{
JwtDecoder, RemoteJwksDecoder, RemoteJwksDecoderBuilder, RemoteJwksDecoderConfigBuilder,
};
use dashmap::DashMap;
use jsonwebtoken::{Algorithm, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use tokio;

#[derive(Debug, Serialize, Deserialize, Clone)]
struct CustomClaims {
sub: String,
name: String,
exp: usize,
}

// This is a sample JWKS handler. In a real application, you would fetch the JWKS from a remote source.
// For testing purposes, we randomly fail 50% of the time to simulate a remote JWKS endpoint that is not available.
async fn jwks_handler() -> Json<Value> {
// Randomly fail 50% of the time
if rand::random::<bool>() {
return Json(json!({
"error": "Internal Server Error",
"message": "Random failure for testing"
}));
}

// This is a sample JWKS. In a real application, you would generate proper keys or fetch them from a remote source
Json(json!({
"keys": [{
"kty": "RSA",
"use": "sig",
"alg": "RS256",
"kid": "1dea0016-46b8-4289-ad7b-226cfaf5305e",
"n": "sPWeaqsN-2KZu9rlto59XASssMoaVjIxMYXtLyifky1sXS4EvYFnvr37X63B-lMwuZ3xACc7xsUPK-GXPe6XqZGJdj-Wgf7a3J6FieSNpnrDK4x6CMr0iAPgIhoEYp7BUyPKzPv21vMl6A5kJvlAAdxfPm3jhk5NDWHSfiFnWiC7UESARgyFl0TlJ-f9H3qaArkzp3Cb-m-wlHpleewOSr9maTPLdIS-ZzZ1ZC4lDIQnetJJ0kue-o1wAL4VmdBMY8IVxEutPAaZO-9G8eYJywZiDDkcrrqWymDvSUarcB_AOzEQjxN6nSSNuW6UbalfnDlGmR0kFK8fopraA4nwU4tG6fAuKTPpOmahC910IRAkedOp6IrRU-2LmcBQ0oyzukHjXd9o9_5MES2wTDFgZBalVRZCo55vdQt5CtQDQWVUbQ1y95dm_0EmmgZzWBgiguSKcO2QuqwYIiq5t9uikFleeVQDVnd-V6yZ5wWfnA6H0-dPw4VTEUkxaTN8jQImQtB9gvj8iknsGX08LGF5WjWh1ewJI0L74Ey5T_ytsXME6Xpn1qfXB2sr5tPol3KeV8pjuGrAymvaLJZz4ZqNY3f4wULfCsyVasUOdknMm8UmTgPR-vnDlF-1ItsmN-Jl-RJ1dFkXRDcelCIJS44sMSchnxv47OwnqvBHCPbiUI8",
"e": "AQAB"
}]
}))
}

#[tokio::main]
async fn main() {
// Initialize tracing for logging
tracing_subscriber::fmt::init();

// Create the Axum router with the JWKS endpoint
let app = Router::new().route("/.well-known/jwks.json", get(jwks_handler));

// Spawn the server task
let server_handle = tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
.await
.expect("Failed to bind server");
println!("JWKS server listening on http://127.0.0.1:3000");
axum::serve(listener, app)
.await
.expect("Failed to start server");
});

// Create a remote JWKS decoder with custom configuration
let decoder = RemoteJwksDecoderBuilder::default()
.jwks_url("http://127.0.0.1:3000/.well-known/jwks.json".to_string())
.config(
RemoteJwksDecoderConfigBuilder::default()
.cache_duration(Duration::from_secs(1)) // Low value for testing, in a real application you should use a higher value
.retry_count(3)
.backoff(Duration::from_secs(1))
.build()
.unwrap(),
)
.validation(Validation::new(Algorithm::RS256))
.client(reqwest::Client::new())
.keys_cache(Arc::new(DashMap::new()))
.build()
.expect("Failed to build decoder");

// Spawn a task to periodically refresh the JWKS
let decoder_clone = decoder.clone();
tokio::spawn(async move {
decoder_clone.refresh_keys_periodically().await;
});

// Create a token
let token = jsonwebtoken::encode(
&Header::new(Algorithm::RS256),
&CustomClaims {
sub: "123".to_string(),
name: "John Doe".to_string(),
exp: (chrono::Utc::now().timestamp() + 60 * 60) as usize,
},
&EncodingKey::from_rsa_pem(include_bytes!("jwt.key")).unwrap(),
)
.unwrap();

// Decode the token
match <RemoteJwksDecoder as JwtDecoder<CustomClaims>>::decode(&decoder, &token) {
Ok(token_data) => {
println!("Token successfully decoded: {:?}", token_data.claims);
}
Err(err) => {
eprintln!("Failed to decode token: {:?}", err);
}
}

// Keep the main task running for a while to see the periodic refresh in action
tokio::time::sleep(Duration::from_secs(60)).await;

// Clean shutdown
server_handle.abort();
}
77 changes: 45 additions & 32 deletions src/axum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,44 +9,15 @@ use axum_extra::TypedHeader;
use serde::de::DeserializeOwned;
use serde::Deserialize;

use crate::{Decoder, JwtDecoder};
use crate::Decoder;

/// A generic struct for holding the claims of a JWT token.
#[derive(Debug, Deserialize)]
pub struct Claims<T>(pub T);

pub enum AuthError {
InvalidToken,
MissingToken,
ExpiredToken,
InvalidSignature,
InvalidAudience,
InternalError,
}

impl IntoResponse for AuthError {
fn into_response(self) -> Response {
let (status, msg) = match self {
AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"),
AuthError::MissingToken => (StatusCode::UNAUTHORIZED, "Missing token"),
AuthError::ExpiredToken => (StatusCode::UNAUTHORIZED, "Expired token"),
AuthError::InvalidSignature => (StatusCode::UNAUTHORIZED, "Invalid signature"),
AuthError::InvalidAudience => (StatusCode::UNAUTHORIZED, "Invalid audience"),
AuthError::InternalError => (StatusCode::INTERNAL_SERVER_ERROR, "Internal error"),
};

(status, msg).into_response()
}
}

#[derive(Clone, FromRef)]
pub struct JwtDecoderState {
pub decoder: Decoder,
}

impl<S, T> axum::extract::FromRequestParts<S> for Claims<T>
where
JwtDecoderState: FromRef<S>,
JwtDecoderState<T>: FromRef<S>,
S: Send + Sync,
T: DeserializeOwned,
{
Expand Down Expand Up @@ -75,6 +46,48 @@ where
_ => Self::Rejection::InternalError,
})?;

Ok(token_data.claims)
Ok(Claims(token_data.claims))
}
}

#[derive(Debug, thiserror::Error)]
pub enum AuthError {
#[error("Invalid token")]
InvalidToken,
#[error("Missing token")]
MissingToken,
#[error("Expired token")]
ExpiredToken,
#[error("Invalid signature")]
InvalidSignature,
#[error("Invalid audience")]
InvalidAudience,
#[error("Internal error")]
InternalError,
}

impl IntoResponse for AuthError {
fn into_response(self) -> Response {
let (status, msg) = match self {
AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"),
AuthError::MissingToken => (StatusCode::UNAUTHORIZED, "Missing token"),
AuthError::ExpiredToken => (StatusCode::UNAUTHORIZED, "Expired token"),
AuthError::InvalidSignature => (StatusCode::UNAUTHORIZED, "Invalid signature"),
AuthError::InvalidAudience => (StatusCode::UNAUTHORIZED, "Invalid audience"),
AuthError::InternalError => (StatusCode::INTERNAL_SERVER_ERROR, "Internal error"),
};

(status, msg).into_response()
}
}

#[derive(Clone)]
pub struct JwtDecoderState<T> {
pub decoder: Decoder<T>,
}

impl<T> FromRef<JwtDecoderState<T>> for Decoder<T> {
fn from_ref(state: &JwtDecoderState<T>) -> Self {
state.decoder.clone()
}
}
Loading
Loading