@@ -1617,6 +1617,8 @@ namespace jwt {
1617
1617
explicit rs256 (const std::string& public_key, const std::string& private_key = " " ,
1618
1618
const std::string& public_key_password = " " , const std::string& private_key_password = " " )
1619
1619
: rsa(public_key, private_key, public_key_password, private_key_password, EVP_sha256, " RS256" ) {}
1620
+
1621
+ explicit rs256 (helper::evp_pkey_handle pkey) : rsa(pkey, EVP_sha256, " RS256" ) {}
1620
1622
};
1621
1623
/* *
1622
1624
* RS384 algorithm
@@ -1632,6 +1634,8 @@ namespace jwt {
1632
1634
explicit rs384 (const std::string& public_key, const std::string& private_key = " " ,
1633
1635
const std::string& public_key_password = " " , const std::string& private_key_password = " " )
1634
1636
: rsa(public_key, private_key, public_key_password, private_key_password, EVP_sha384, " RS384" ) {}
1637
+
1638
+ explicit rs384 (helper::evp_pkey_handle pkey) : rsa(pkey, EVP_sha384, " RS384" ) {}
1635
1639
};
1636
1640
/* *
1637
1641
* RS512 algorithm
@@ -1647,6 +1651,8 @@ namespace jwt {
1647
1651
explicit rs512 (const std::string& public_key, const std::string& private_key = " " ,
1648
1652
const std::string& public_key_password = " " , const std::string& private_key_password = " " )
1649
1653
: rsa(public_key, private_key, public_key_password, private_key_password, EVP_sha512, " RS512" ) {}
1654
+
1655
+ explicit rs512 (helper::evp_pkey_handle pkey) : rsa(pkey, EVP_sha512, " RS512" ) {}
1650
1656
};
1651
1657
/* *
1652
1658
* ES256 algorithm
@@ -3126,6 +3132,12 @@ namespace jwt {
3126
3132
};
3127
3133
} // namespace verify_ops
3128
3134
3135
+ using alg_name = std::string;
3136
+ using alg_list = std::vector<alg_name>;
3137
+ using algorithms = std::unordered_map<std::string, alg_list>;
3138
+ static const algorithms supported_alg = {{" RSA" , {" RS256" , " RS384" , " RS512" , " PS256" , " PS384" , " PS512" }},
3139
+ {" EC" , {" ES256" , " ES384" , " ES512" , " ES256K" }},
3140
+ {" oct" , {" HS256" , " HS384" , " HS512" }}};
3129
3141
/* *
3130
3142
* \brief JSON Web Key
3131
3143
*
@@ -3346,6 +3358,11 @@ namespace jwt {
3346
3358
3347
3359
std::string get_oct_key () const { return key.get_symmetric_key (); }
3348
3360
3361
+ bool supports (const std::string& alg_name) const {
3362
+ const alg_list& x = supported_alg.find (get_key_type ())->second ;
3363
+ return std::find (x.begin (), x.end (), alg_name) != x.end ();
3364
+ }
3365
+
3349
3366
private:
3350
3367
class key {
3351
3368
public:
@@ -3488,6 +3505,11 @@ namespace jwt {
3488
3505
// / Supported algorithms
3489
3506
std::unordered_map<std::string, std::shared_ptr<algo_base>> algs;
3490
3507
3508
+ typedef std::vector<jwt::jwk<json_traits>> key_list;
3509
+ // / https://datatracker.ietf.org/doc/html/rfc7517#section-4.5 - kid to keys
3510
+ typedef std::unordered_map<std::string, key_list> keysets;
3511
+ keysets keys;
3512
+
3491
3513
void verify_claims (const decoded_jwt<json_traits>& jwt, std::error_code& ec) const {
3492
3514
verify_ops::verify_context<json_traits> ctx{clock.now (), jwt, default_leeway};
3493
3515
for (auto & c : claims) {
@@ -3497,6 +3519,52 @@ namespace jwt {
3497
3519
}
3498
3520
}
3499
3521
3522
+ static inline std::unique_ptr<algo_base> from_key_and_alg (const jwt::jwk<json_traits>& key,
3523
+ const std::string& alg_name, std::error_code& ec) {
3524
+ ec.clear ();
3525
+ algorithms::const_iterator it = supported_alg.find (key.get_key_type ());
3526
+ if (it == supported_alg.end ()) {
3527
+ ec = error::token_verification_error::wrong_algorithm;
3528
+ return nullptr ;
3529
+ }
3530
+
3531
+ const alg_list& supported_jwt_algorithms = it->second ;
3532
+ if (std::find (supported_jwt_algorithms.begin (), supported_jwt_algorithms.end (), alg_name) ==
3533
+ supported_jwt_algorithms.end ()) {
3534
+ ec = error::token_verification_error::wrong_algorithm;
3535
+ return nullptr ;
3536
+ }
3537
+
3538
+ if (alg_name == " RS256" ) {
3539
+ return std::make_unique<algo<jwt::algorithm::rs256>>(jwt::algorithm::rs256 (key.get_pkey ()));
3540
+ } else if (alg_name == " RS384" ) {
3541
+ return std::make_unique<algo<jwt::algorithm::rs384>>(jwt::algorithm::rs384 (key.get_pkey ()));
3542
+ } else if (alg_name == " RS512" ) {
3543
+ return std::make_unique<algo<jwt::algorithm::rs512>>(jwt::algorithm::rs512 (key.get_pkey ()));
3544
+ } else if (alg_name == " PS256" ) {
3545
+ return std::make_unique<algo<jwt::algorithm::ps256>>(jwt::algorithm::ps256 (key.get_pkey ()));
3546
+ } else if (alg_name == " PS384" ) {
3547
+ return std::make_unique<algo<jwt::algorithm::ps384>>(jwt::algorithm::ps384 (key.get_pkey ()));
3548
+ } else if (alg_name == " PS512" ) {
3549
+ return std::make_unique<algo<jwt::algorithm::ps512>>(jwt::algorithm::ps512 (key.get_pkey ()));
3550
+ } else if (alg_name == " ES256" ) {
3551
+ return std::make_unique<algo<jwt::algorithm::es256>>(jwt::algorithm::es256 (key.get_pkey ()));
3552
+ } else if (alg_name == " ES384" ) {
3553
+ return std::make_unique<algo<jwt::algorithm::es384>>(jwt::algorithm::es384 (key.get_pkey ()));
3554
+ } else if (alg_name == " ES512" ) {
3555
+ return std::make_unique<algo<jwt::algorithm::es512>>(jwt::algorithm::es512 (key.get_pkey ()));
3556
+ } else if (alg_name == " HS256" ) {
3557
+ return std::make_unique<algo<jwt::algorithm::hs256>>(jwt::algorithm::hs256 (key.get_oct_key ()));
3558
+ } else if (alg_name == " HS384" ) {
3559
+ return std::make_unique<algo<jwt::algorithm::hs384>>(jwt::algorithm::hs384 (key.get_oct_key ()));
3560
+ } else if (alg_name == " HS512" ) {
3561
+ return std::make_unique<algo<jwt::algorithm::hs512>>(jwt::algorithm::hs512 (key.get_oct_key ()));
3562
+ }
3563
+
3564
+ ec = error::token_verification_error::wrong_algorithm;
3565
+ return nullptr ;
3566
+ }
3567
+
3500
3568
public:
3501
3569
/* *
3502
3570
* Constructor for building a new verifier instance
@@ -3661,6 +3729,18 @@ namespace jwt {
3661
3729
return *this ;
3662
3730
}
3663
3731
3732
+ verifier& allow_key (const jwt::jwk<json_traits>& key) {
3733
+ std::string keyid = " " ;
3734
+ if (key.has_key_id ()) {
3735
+ keyid = key.get_key_id ();
3736
+ auto it = keys.find (keyid);
3737
+ if (it == keys.end ()) { keys[keyid] = key_list (); }
3738
+ }
3739
+
3740
+ keys[keyid].push_back (key);
3741
+ return *this ;
3742
+ }
3743
+
3664
3744
/* *
3665
3745
* Verify the given token.
3666
3746
* \param jwt Token to check
@@ -3681,13 +3761,32 @@ namespace jwt {
3681
3761
const typename json_traits::string_type data = jwt.get_header_base64 () + " ." + jwt.get_payload_base64 ();
3682
3762
const typename json_traits::string_type sig = jwt.get_signature ();
3683
3763
const std::string algo = jwt.get_algorithm ();
3684
- if (algs.count (algo) == 0 ) {
3685
- ec = error::token_verification_error::wrong_algorithm;
3686
- return ;
3764
+ std::string kid (" " );
3765
+ if (jwt.has_header_claim (" kid" )) { kid = jwt.get_header_claim (" kid" ).as_string (); }
3766
+
3767
+ typename keysets::const_iterator key_set_it = keys.find (kid);
3768
+ bool key_found = false ;
3769
+ if (key_set_it != keys.end ()) {
3770
+ const key_list& keys = key_set_it->second ;
3771
+ for (const auto & key : keys) {
3772
+ if (key.supports (algo)) {
3773
+ key_found = true ;
3774
+ auto alg = from_key_and_alg (key, algo, ec);
3775
+ alg->verify (data, sig, ec);
3776
+ break ;
3777
+ }
3778
+ }
3779
+ }
3780
+
3781
+ if (!key_found) {
3782
+ if (algs.count (algo) == 0 ) {
3783
+ ec = error::token_verification_error::wrong_algorithm;
3784
+ return ;
3785
+ }
3786
+ algs.at (algo)->verify (data, sig, ec);
3687
3787
}
3688
- algs.at (algo)->verify (data, sig, ec);
3689
- if (ec) return ;
3690
3788
3789
+ if (ec) return ;
3691
3790
verify_claims (jwt, ec);
3692
3791
}
3693
3792
};
0 commit comments