@@ -1542,6 +1542,8 @@ namespace jwt {
1542
1542
explicit rs256 (const std::string& public_key, const std::string& private_key = " " ,
1543
1543
const std::string& public_key_password = " " , const std::string& private_key_password = " " )
1544
1544
: rsa(public_key, private_key, public_key_password, private_key_password, EVP_sha256, " RS256" ) {}
1545
+
1546
+ explicit rs256 (std::shared_ptr<EVP_PKEY> pkey) : rsa(pkey, EVP_sha256, " RS256" ) {}
1545
1547
};
1546
1548
/* *
1547
1549
* RS384 algorithm
@@ -1557,6 +1559,8 @@ namespace jwt {
1557
1559
explicit rs384 (const std::string& public_key, const std::string& private_key = " " ,
1558
1560
const std::string& public_key_password = " " , const std::string& private_key_password = " " )
1559
1561
: rsa(public_key, private_key, public_key_password, private_key_password, EVP_sha384, " RS384" ) {}
1562
+
1563
+ explicit rs384 (std::shared_ptr<EVP_PKEY> pkey) : rsa(pkey, EVP_sha384, " RS384" ) {}
1560
1564
};
1561
1565
/* *
1562
1566
* RS512 algorithm
@@ -1572,6 +1576,8 @@ namespace jwt {
1572
1576
explicit rs512 (const std::string& public_key, const std::string& private_key = " " ,
1573
1577
const std::string& public_key_password = " " , const std::string& private_key_password = " " )
1574
1578
: rsa(public_key, private_key, public_key_password, private_key_password, EVP_sha512, " RS512" ) {}
1579
+
1580
+ explicit rs512 (std::shared_ptr<EVP_PKEY> pkey) : rsa(pkey, EVP_sha512, " RS512" ) {}
1575
1581
};
1576
1582
/* *
1577
1583
* ES256 algorithm
@@ -3053,6 +3059,12 @@ namespace jwt {
3053
3059
};
3054
3060
} // namespace verify_ops
3055
3061
3062
+ using alg_name = std::string;
3063
+ using alg_list = std::vector<alg_name>;
3064
+ using algorithms = std::unordered_map<std::string, alg_list>;
3065
+ static const algorithms supported_alg = {{" RSA" , {" RS256" , " RS384" , " RS512" , " PS256" , " PS384" , " PS512" }},
3066
+ {" EC" , {" ES256" , " ES384" , " ES512" , " ES256K" }},
3067
+ {" oct" , {" HS256" , " HS384" , " HS512" }}};
3056
3068
/* *
3057
3069
* \brief JSON Web Key
3058
3070
*
@@ -3273,6 +3285,11 @@ namespace jwt {
3273
3285
3274
3286
std::string get_oct_key () const { return key.get_symmetric_key (); }
3275
3287
3288
+ bool supports (const std::string& alg_name) const {
3289
+ const alg_list& x = supported_alg.find (get_key_type ())->second ;
3290
+ return std::find (x.begin (), x.end (), alg_name) != x.end ();
3291
+ }
3292
+
3276
3293
private:
3277
3294
class key {
3278
3295
public:
@@ -3415,6 +3432,11 @@ namespace jwt {
3415
3432
// / Supported algorithms
3416
3433
std::unordered_map<std::string, std::shared_ptr<algo_base>> algs;
3417
3434
3435
+ typedef std::vector<jwt::jwk<json_traits>> key_list;
3436
+ // / https://datatracker.ietf.org/doc/html/rfc7517#section-4.5 - kid to keys
3437
+ typedef std::unordered_map<std::string, key_list> keysets;
3438
+ keysets keys;
3439
+
3418
3440
void verify_claims (const decoded_jwt<json_traits>& jwt, std::error_code& ec) const {
3419
3441
verify_ops::verify_context<json_traits> ctx{clock.now (), jwt, default_leeway};
3420
3442
for (auto & c : claims) {
@@ -3424,6 +3446,52 @@ namespace jwt {
3424
3446
}
3425
3447
}
3426
3448
3449
+ static inline std::unique_ptr<algo_base> from_key_and_alg (const jwt::jwk<json_traits>& key,
3450
+ const std::string& alg_name, std::error_code& ec) {
3451
+ ec.clear ();
3452
+ algorithms::const_iterator it = supported_alg.find (key.get_key_type ());
3453
+ if (it == supported_alg.end ()) {
3454
+ ec = error::token_verification_error::wrong_algorithm;
3455
+ return nullptr ;
3456
+ }
3457
+
3458
+ const alg_list& supported_jwt_algorithms = it->second ;
3459
+ if (std::find (supported_jwt_algorithms.begin (), supported_jwt_algorithms.end (), alg_name) ==
3460
+ supported_jwt_algorithms.end ()) {
3461
+ ec = error::token_verification_error::wrong_algorithm;
3462
+ return nullptr ;
3463
+ }
3464
+
3465
+ if (alg_name == " RS256" ) {
3466
+ return std::make_unique<algo<jwt::algorithm::rs256>>(jwt::algorithm::rs256 (key.get_pkey ()));
3467
+ } else if (alg_name == " RS384" ) {
3468
+ return std::make_unique<algo<jwt::algorithm::rs384>>(jwt::algorithm::rs384 (key.get_pkey ()));
3469
+ } else if (alg_name == " RS512" ) {
3470
+ return std::make_unique<algo<jwt::algorithm::rs512>>(jwt::algorithm::rs512 (key.get_pkey ()));
3471
+ } else if (alg_name == " PS256" ) {
3472
+ return std::make_unique<algo<jwt::algorithm::ps256>>(jwt::algorithm::ps256 (key.get_pkey ()));
3473
+ } else if (alg_name == " PS384" ) {
3474
+ return std::make_unique<algo<jwt::algorithm::ps384>>(jwt::algorithm::ps384 (key.get_pkey ()));
3475
+ } else if (alg_name == " PS512" ) {
3476
+ return std::make_unique<algo<jwt::algorithm::ps512>>(jwt::algorithm::ps512 (key.get_pkey ()));
3477
+ } else if (alg_name == " ES256" ) {
3478
+ return std::make_unique<algo<jwt::algorithm::es256>>(jwt::algorithm::es256 (key.get_pkey ()));
3479
+ } else if (alg_name == " ES384" ) {
3480
+ return std::make_unique<algo<jwt::algorithm::es384>>(jwt::algorithm::es384 (key.get_pkey ()));
3481
+ } else if (alg_name == " ES512" ) {
3482
+ return std::make_unique<algo<jwt::algorithm::es512>>(jwt::algorithm::es512 (key.get_pkey ()));
3483
+ } else if (alg_name == " HS256" ) {
3484
+ return std::make_unique<algo<jwt::algorithm::hs256>>(jwt::algorithm::hs256 (key.get_oct_key ()));
3485
+ } else if (alg_name == " HS384" ) {
3486
+ return std::make_unique<algo<jwt::algorithm::hs384>>(jwt::algorithm::hs384 (key.get_oct_key ()));
3487
+ } else if (alg_name == " HS512" ) {
3488
+ return std::make_unique<algo<jwt::algorithm::hs512>>(jwt::algorithm::hs512 (key.get_oct_key ()));
3489
+ }
3490
+
3491
+ ec = error::token_verification_error::wrong_algorithm;
3492
+ return nullptr ;
3493
+ }
3494
+
3427
3495
public:
3428
3496
/* *
3429
3497
* Constructor for building a new verifier instance
@@ -3583,6 +3651,18 @@ namespace jwt {
3583
3651
return *this ;
3584
3652
}
3585
3653
3654
+ verifier& allow_key (const jwt::jwk<json_traits>& key) {
3655
+ std::string keyid = " " ;
3656
+ if (key.has_key_id ()) {
3657
+ keyid = key.get_key_id ();
3658
+ typename keysets::const_iterator it = keys.find (keyid);
3659
+ if (it == keys.end ()) { keys[keyid] = key_list (); }
3660
+ }
3661
+
3662
+ keys[keyid].push_back (key);
3663
+ return *this ;
3664
+ }
3665
+
3586
3666
/* *
3587
3667
* Verify the given token.
3588
3668
* \param jwt Token to check
@@ -3603,13 +3683,32 @@ namespace jwt {
3603
3683
const typename json_traits::string_type data = jwt.get_header_base64 () + " ." + jwt.get_payload_base64 ();
3604
3684
const typename json_traits::string_type sig = jwt.get_signature ();
3605
3685
const std::string algo = jwt.get_algorithm ();
3606
- if (algs.count (algo) == 0 ) {
3607
- ec = error::token_verification_error::wrong_algorithm;
3608
- return ;
3686
+ std::string kid (" " );
3687
+ if (jwt.has_header_claim (" kid" )) { kid = jwt.get_header_claim (" kid" ).as_string (); }
3688
+
3689
+ typename keysets::const_iterator key_set_it = keys.find (kid);
3690
+ bool key_found = false ;
3691
+ if (key_set_it != keys.end ()) {
3692
+ const key_list& keys = key_set_it->second ;
3693
+ for (const auto & key : keys) {
3694
+ if (key.supports (algo)) {
3695
+ key_found = true ;
3696
+ auto alg = from_key_and_alg (key, algo, ec);
3697
+ alg->verify (data, sig, ec);
3698
+ break ;
3699
+ }
3700
+ }
3609
3701
}
3610
- algs.at (algo)->verify (data, sig, ec);
3611
- if (ec) return ;
3612
3702
3703
+ if (!key_found) {
3704
+ if (algs.count (algo) == 0 ) {
3705
+ ec = error::token_verification_error::wrong_algorithm;
3706
+ return ;
3707
+ }
3708
+ algs.at (algo)->verify (data, sig, ec);
3709
+ }
3710
+
3711
+ if (ec) return ;
3613
3712
verify_claims (jwt, ec);
3614
3713
}
3615
3714
};
0 commit comments