Skip to content

Commit 5cd1887

Browse files
authored
feat: allow users to specify a custom header not defined in the struct (#420)
* allow users to specify a custom header not defined in the struct * modify tests, docs and fix broken checks * inline header encoding and decoding + update bench tests * formatting * update bench * add verify_slices_are_equal to fix deprecation errors * remove deprecated verify_slices_are_equal * address comments * add decode test with extra headers * bug fix * fix broken tests
1 parent ab8dbb1 commit 5cd1887

File tree

8 files changed

+120
-14
lines changed

8 files changed

+120
-14
lines changed

README.md

+5
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ If you want to set the `kid` parameter or change the algorithm for example:
7979
```rust
8080
let mut header = Header::new(Algorithm::HS512);
8181
header.kid = Some("blabla".to_owned());
82+
83+
let mut extras = HashMap::with_capacity(1);
84+
extras.insert("custom".to_string(), "header".to_string());
85+
header.extras = Some(extras);
86+
8287
let token = encode(&header, &my_claims, &EncodingKey::from_secret("secret".as_ref()))?;
8388
```
8489
Look at `examples/custom_header.rs` for a full working example.

benches/jwt.rs

+14-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use criterion::{black_box, criterion_group, criterion_main, Criterion};
22
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
33
use serde::{Deserialize, Serialize};
4+
use std::collections::HashMap;
45

56
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
67
struct Claims {
@@ -17,6 +18,18 @@ fn bench_encode(c: &mut Criterion) {
1718
});
1819
}
1920

21+
fn bench_encode_custom_extra_headers(c: &mut Criterion) {
22+
let claim = Claims { sub: "b@b.com".to_owned(), company: "ACME".to_owned() };
23+
let key = EncodingKey::from_secret("secret".as_ref());
24+
let mut extras = HashMap::with_capacity(1);
25+
extras.insert("custom".to_string(), "header".to_string());
26+
let header = &Header { extras, ..Default::default() };
27+
28+
c.bench_function("bench_encode", |b| {
29+
b.iter(|| encode(black_box(header), black_box(&claim), black_box(&key)))
30+
});
31+
}
32+
2033
fn bench_decode(c: &mut Criterion) {
2134
let token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ";
2235
let key = DecodingKey::from_secret("secret".as_ref());
@@ -32,5 +45,5 @@ fn bench_decode(c: &mut Criterion) {
3245
});
3346
}
3447

35-
criterion_group!(benches, bench_encode, bench_decode);
48+
criterion_group!(benches, bench_encode, bench_encode_custom_extra_headers, bench_decode);
3649
criterion_main!(benches);

examples/custom_header.rs

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use serde::{Deserialize, Serialize};
2+
use std::collections::HashMap;
23

34
use jsonwebtoken::errors::ErrorKind;
45
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
@@ -15,8 +16,15 @@ fn main() {
1516
Claims { sub: "b@b.com".to_owned(), company: "ACME".to_owned(), exp: 10000000000 };
1617
let key = b"secret";
1718

18-
let header =
19-
Header { kid: Some("signing_key".to_owned()), alg: Algorithm::HS512, ..Default::default() };
19+
let mut extras = HashMap::with_capacity(1);
20+
extras.insert("custom".to_string(), "header".to_string());
21+
22+
let header = Header {
23+
kid: Some("signing_key".to_owned()),
24+
alg: Algorithm::HS512,
25+
extras,
26+
..Default::default()
27+
};
2028

2129
let token = match encode(&header, &my_claims, &EncodingKey::from_secret(key)) {
2230
Ok(t) => t,

src/header.rs

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::collections::HashMap;
12
use std::result;
23

34
use base64::{engine::general_purpose::STANDARD, Engine};
@@ -10,7 +11,7 @@ use crate::serialization::b64_decode;
1011

1112
/// A basic JWT header, the alg defaults to HS256 and typ is automatically
1213
/// set to `JWT`. All the other fields are optional.
13-
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
14+
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1415
pub struct Header {
1516
/// The type of JWS: it can only be "JWT" here
1617
///
@@ -64,6 +65,12 @@ pub struct Header {
6465
#[serde(skip_serializing_if = "Option::is_none")]
6566
#[serde(rename = "x5t#S256")]
6667
pub x5t_s256: Option<String>,
68+
69+
/// Any additional non-standard headers not defined in [RFC7515#4.1](https://datatracker.ietf.org/doc/html/rfc7515#section-4.1).
70+
/// Once serialized, all keys will be converted to fields at the root level of the header payload
71+
/// Ex: Dict("custom" -> "header") will be converted to "{"typ": "JWT", ..., "custom": "header"}"
72+
#[serde(flatten)]
73+
pub extras: HashMap<String, String>,
6774
}
6875

6976
impl Header {
@@ -80,6 +87,7 @@ impl Header {
8087
x5c: None,
8188
x5t: None,
8289
x5t_s256: None,
90+
extras: Default::default(),
8391
}
8492
}
8593

src/jwk.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ impl<'de> Deserialize<'de> for PublicKeyUse {
4343
D: Deserializer<'de>,
4444
{
4545
struct PublicKeyUseVisitor;
46-
impl<'de> de::Visitor<'de> for PublicKeyUseVisitor {
46+
impl de::Visitor<'_> for PublicKeyUseVisitor {
4747
type Value = PublicKeyUse;
4848

4949
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
@@ -116,7 +116,7 @@ impl<'de> Deserialize<'de> for KeyOperations {
116116
D: Deserializer<'de>,
117117
{
118118
struct KeyOperationsVisitor;
119-
impl<'de> de::Visitor<'de> for KeyOperationsVisitor {
119+
impl de::Visitor<'_> for KeyOperationsVisitor {
120120
type Value = KeyOperations;
121121

122122
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {

src/serialization.rs

+2
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ use serde::{Deserialize, Serialize};
33

44
use crate::errors::Result;
55

6+
#[inline]
67
pub(crate) fn b64_encode<T: AsRef<[u8]>>(input: T) -> String {
78
URL_SAFE_NO_PAD.encode(input)
89
}
910

11+
#[inline]
1012
pub(crate) fn b64_decode<T: AsRef<[u8]>>(input: T) -> Result<Vec<u8>> {
1113
URL_SAFE_NO_PAD.decode(input).map_err(|e| e.into())
1214
}

src/validation.rs

+8-8
Original file line numberDiff line numberDiff line change
@@ -337,13 +337,20 @@ where
337337
{
338338
struct NumericType(PhantomData<fn() -> TryParse<u64>>);
339339

340-
impl<'de> Visitor<'de> for NumericType {
340+
impl Visitor<'_> for NumericType {
341341
type Value = TryParse<u64>;
342342

343343
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
344344
formatter.write_str("A NumericType that can be reasonably coerced into a u64")
345345
}
346346

347+
fn visit_u64<E>(self, value: u64) -> std::result::Result<Self::Value, E>
348+
where
349+
E: de::Error,
350+
{
351+
Ok(TryParse::Parsed(value))
352+
}
353+
347354
fn visit_f64<E>(self, value: f64) -> std::result::Result<Self::Value, E>
348355
where
349356
E: de::Error,
@@ -354,13 +361,6 @@ where
354361
Err(serde::de::Error::custom("NumericType must be representable as a u64"))
355362
}
356363
}
357-
358-
fn visit_u64<E>(self, value: u64) -> std::result::Result<Self::Value, E>
359-
where
360-
E: de::Error,
361-
{
362-
Ok(TryParse::Parsed(value))
363-
}
364364
}
365365

366366
match deserializer.deserialize_any(NumericType(PhantomData)) {

tests/hmac.rs

+70
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use jsonwebtoken::{
55
decode, decode_header, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation,
66
};
77
use serde::{Deserialize, Serialize};
8+
use std::collections::HashMap;
89
use time::OffsetDateTime;
910
use wasm_bindgen_test::wasm_bindgen_test;
1011

@@ -51,6 +52,56 @@ fn encode_with_custom_header() {
5152
.unwrap();
5253
assert_eq!(my_claims, token_data.claims);
5354
assert_eq!("kid", token_data.header.kid.unwrap());
55+
assert!(token_data.header.extras.is_empty());
56+
}
57+
58+
#[test]
59+
#[wasm_bindgen_test]
60+
fn encode_with_extra_custom_header() {
61+
let my_claims = Claims {
62+
sub: "b@b.com".to_string(),
63+
company: "ACME".to_string(),
64+
exp: OffsetDateTime::now_utc().unix_timestamp() + 10000,
65+
};
66+
let mut extras = HashMap::with_capacity(1);
67+
extras.insert("custom".to_string(), "header".to_string());
68+
let header = Header { kid: Some("kid".to_string()), extras, ..Default::default() };
69+
let token = encode(&header, &my_claims, &EncodingKey::from_secret(b"secret")).unwrap();
70+
let token_data = decode::<Claims>(
71+
&token,
72+
&DecodingKey::from_secret(b"secret"),
73+
&Validation::new(Algorithm::HS256),
74+
)
75+
.unwrap();
76+
assert_eq!(my_claims, token_data.claims);
77+
assert_eq!("kid", token_data.header.kid.unwrap());
78+
assert_eq!("header", token_data.header.extras.get("custom").unwrap().as_str());
79+
}
80+
81+
#[test]
82+
#[wasm_bindgen_test]
83+
fn encode_with_multiple_extra_custom_headers() {
84+
let my_claims = Claims {
85+
sub: "b@b.com".to_string(),
86+
company: "ACME".to_string(),
87+
exp: OffsetDateTime::now_utc().unix_timestamp() + 10000,
88+
};
89+
let mut extras = HashMap::with_capacity(2);
90+
extras.insert("custom1".to_string(), "header1".to_string());
91+
extras.insert("custom2".to_string(), "header2".to_string());
92+
let header = Header { kid: Some("kid".to_string()), extras, ..Default::default() };
93+
let token = encode(&header, &my_claims, &EncodingKey::from_secret(b"secret")).unwrap();
94+
let token_data = decode::<Claims>(
95+
&token,
96+
&DecodingKey::from_secret(b"secret"),
97+
&Validation::new(Algorithm::HS256),
98+
)
99+
.unwrap();
100+
assert_eq!(my_claims, token_data.claims);
101+
assert_eq!("kid", token_data.header.kid.unwrap());
102+
let extras = token_data.header.extras;
103+
assert_eq!("header1", extras.get("custom1").unwrap().as_str());
104+
assert_eq!("header2", extras.get("custom2").unwrap().as_str());
54105
}
55106

56107
#[test]
@@ -86,6 +137,25 @@ fn decode_token() {
86137
claims.unwrap();
87138
}
88139

140+
#[test]
141+
#[wasm_bindgen_test]
142+
fn decode_token_custom_headers() {
143+
let token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsImN1c3RvbTEiOiJoZWFkZXIxIiwiY3VzdG9tMiI6ImhlYWRlcjIifQ.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUiLCJleHAiOjI1MzI1MjQ4OTF9.FtOHsoKcNH3SriK3tnR-uWJg4UV4FkOzvq_JCfLngfU";
144+
let claims = decode::<Claims>(
145+
token,
146+
&DecodingKey::from_secret(b"secret"),
147+
&Validation::new(Algorithm::HS256),
148+
)
149+
.unwrap();
150+
let my_claims =
151+
Claims { sub: "b@b.com".to_string(), company: "ACME".to_string(), exp: 2532524891 };
152+
assert_eq!(my_claims, claims.claims);
153+
assert_eq!("kid", claims.header.kid.unwrap());
154+
let extras = claims.header.extras;
155+
assert_eq!("header1", extras.get("custom1").unwrap().as_str());
156+
assert_eq!("header2", extras.get("custom2").unwrap().as_str());
157+
}
158+
89159
#[test]
90160
#[wasm_bindgen_test]
91161
#[should_panic(expected = "InvalidToken")]

0 commit comments

Comments
 (0)