Skip to content

Commit 7a735e7

Browse files
committed
add allow_key interface
1 parent 0cadf3d commit 7a735e7

File tree

2 files changed

+123
-9
lines changed

2 files changed

+123
-9
lines changed

include/jwt-cpp/jwt.h

Lines changed: 104 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1617,6 +1617,8 @@ namespace jwt {
16171617
explicit rs256(const std::string& public_key, const std::string& private_key = "",
16181618
const std::string& public_key_password = "", const std::string& private_key_password = "")
16191619
: rsa(public_key, private_key, public_key_password, private_key_password, EVP_sha256, "RS256") {}
1620+
1621+
explicit rs256(helper::evp_pkey_handle pkey) : rsa(pkey, EVP_sha256, "RS256") {}
16201622
};
16211623
/**
16221624
* RS384 algorithm
@@ -1632,6 +1634,8 @@ namespace jwt {
16321634
explicit rs384(const std::string& public_key, const std::string& private_key = "",
16331635
const std::string& public_key_password = "", const std::string& private_key_password = "")
16341636
: rsa(public_key, private_key, public_key_password, private_key_password, EVP_sha384, "RS384") {}
1637+
1638+
explicit rs384(helper::evp_pkey_handle pkey) : rsa(pkey, EVP_sha384, "RS384") {}
16351639
};
16361640
/**
16371641
* RS512 algorithm
@@ -1647,6 +1651,8 @@ namespace jwt {
16471651
explicit rs512(const std::string& public_key, const std::string& private_key = "",
16481652
const std::string& public_key_password = "", const std::string& private_key_password = "")
16491653
: rsa(public_key, private_key, public_key_password, private_key_password, EVP_sha512, "RS512") {}
1654+
1655+
explicit rs512(helper::evp_pkey_handle pkey) : rsa(pkey, EVP_sha512, "RS512") {}
16501656
};
16511657
/**
16521658
* ES256 algorithm
@@ -3126,6 +3132,12 @@ namespace jwt {
31263132
};
31273133
} // namespace verify_ops
31283134

3135+
using alg_name = std::string;
3136+
using alg_list = std::vector<alg_name>;
3137+
using algorithms = std::unordered_map<std::string, alg_list>;
3138+
static const algorithms supported_alg = {{"RSA", {"RS256", "RS384", "RS512", "PS256", "PS384", "PS512"}},
3139+
{"EC", {"ES256", "ES384", "ES512", "ES256K"}},
3140+
{"oct", {"HS256", "HS384", "HS512"}}};
31293141
/**
31303142
* \brief JSON Web Key
31313143
*
@@ -3346,6 +3358,11 @@ namespace jwt {
33463358

33473359
std::string get_oct_key() const { return key.get_symmetric_key(); }
33483360

3361+
bool supports(const std::string& alg_name) const {
3362+
const alg_list& x = supported_alg.find(get_key_type())->second;
3363+
return std::find(x.begin(), x.end(), alg_name) != x.end();
3364+
}
3365+
33493366
private:
33503367
class key {
33513368
public:
@@ -3488,6 +3505,11 @@ namespace jwt {
34883505
/// Supported algorithms
34893506
std::unordered_map<std::string, std::shared_ptr<algo_base>> algs;
34903507

3508+
typedef std::vector<jwt::jwk<json_traits>> key_list;
3509+
/// https://datatracker.ietf.org/doc/html/rfc7517#section-4.5 - kid to keys
3510+
typedef std::unordered_map<std::string, key_list> keysets;
3511+
keysets keys;
3512+
34913513
void verify_claims(const decoded_jwt<json_traits>& jwt, std::error_code& ec) const {
34923514
verify_ops::verify_context<json_traits> ctx{clock.now(), jwt, default_leeway};
34933515
for (auto& c : claims) {
@@ -3497,6 +3519,52 @@ namespace jwt {
34973519
}
34983520
}
34993521

3522+
static inline std::unique_ptr<algo_base> from_key_and_alg(const jwt::jwk<json_traits>& key,
3523+
const std::string& alg_name, std::error_code& ec) {
3524+
ec.clear();
3525+
algorithms::const_iterator it = supported_alg.find(key.get_key_type());
3526+
if (it == supported_alg.end()) {
3527+
ec = error::token_verification_error::wrong_algorithm;
3528+
return nullptr;
3529+
}
3530+
3531+
const alg_list& supported_jwt_algorithms = it->second;
3532+
if (std::find(supported_jwt_algorithms.begin(), supported_jwt_algorithms.end(), alg_name) ==
3533+
supported_jwt_algorithms.end()) {
3534+
ec = error::token_verification_error::wrong_algorithm;
3535+
return nullptr;
3536+
}
3537+
3538+
if (alg_name == "RS256") {
3539+
return std::make_unique<algo<jwt::algorithm::rs256>>(jwt::algorithm::rs256(key.get_pkey()));
3540+
} else if (alg_name == "RS384") {
3541+
return std::make_unique<algo<jwt::algorithm::rs384>>(jwt::algorithm::rs384(key.get_pkey()));
3542+
} else if (alg_name == "RS512") {
3543+
return std::make_unique<algo<jwt::algorithm::rs512>>(jwt::algorithm::rs512(key.get_pkey()));
3544+
} else if (alg_name == "PS256") {
3545+
return std::make_unique<algo<jwt::algorithm::ps256>>(jwt::algorithm::ps256(key.get_pkey()));
3546+
} else if (alg_name == "PS384") {
3547+
return std::make_unique<algo<jwt::algorithm::ps384>>(jwt::algorithm::ps384(key.get_pkey()));
3548+
} else if (alg_name == "PS512") {
3549+
return std::make_unique<algo<jwt::algorithm::ps512>>(jwt::algorithm::ps512(key.get_pkey()));
3550+
} else if (alg_name == "ES256") {
3551+
return std::make_unique<algo<jwt::algorithm::es256>>(jwt::algorithm::es256(key.get_pkey()));
3552+
} else if (alg_name == "ES384") {
3553+
return std::make_unique<algo<jwt::algorithm::es384>>(jwt::algorithm::es384(key.get_pkey()));
3554+
} else if (alg_name == "ES512") {
3555+
return std::make_unique<algo<jwt::algorithm::es512>>(jwt::algorithm::es512(key.get_pkey()));
3556+
} else if (alg_name == "HS256") {
3557+
return std::make_unique<algo<jwt::algorithm::hs256>>(jwt::algorithm::hs256(key.get_oct_key()));
3558+
} else if (alg_name == "HS384") {
3559+
return std::make_unique<algo<jwt::algorithm::hs384>>(jwt::algorithm::hs384(key.get_oct_key()));
3560+
} else if (alg_name == "HS512") {
3561+
return std::make_unique<algo<jwt::algorithm::hs512>>(jwt::algorithm::hs512(key.get_oct_key()));
3562+
}
3563+
3564+
ec = error::token_verification_error::wrong_algorithm;
3565+
return nullptr;
3566+
}
3567+
35003568
public:
35013569
/**
35023570
* Constructor for building a new verifier instance
@@ -3661,6 +3729,18 @@ namespace jwt {
36613729
return *this;
36623730
}
36633731

3732+
verifier& allow_key(const jwt::jwk<json_traits>& key) {
3733+
std::string keyid = "";
3734+
if (key.has_key_id()) {
3735+
keyid = key.get_key_id();
3736+
auto it = keys.find(keyid);
3737+
if (it == keys.end()) { keys[keyid] = key_list(); }
3738+
}
3739+
3740+
keys[keyid].push_back(key);
3741+
return *this;
3742+
}
3743+
36643744
/**
36653745
* Verify the given token.
36663746
* \param jwt Token to check
@@ -3681,13 +3761,32 @@ namespace jwt {
36813761
const typename json_traits::string_type data = jwt.get_header_base64() + "." + jwt.get_payload_base64();
36823762
const typename json_traits::string_type sig = jwt.get_signature();
36833763
const std::string algo = jwt.get_algorithm();
3684-
if (algs.count(algo) == 0) {
3685-
ec = error::token_verification_error::wrong_algorithm;
3686-
return;
3764+
std::string kid("");
3765+
if (jwt.has_header_claim("kid")) { kid = jwt.get_header_claim("kid").as_string(); }
3766+
3767+
typename keysets::const_iterator key_set_it = keys.find(kid);
3768+
bool key_found = false;
3769+
if (key_set_it != keys.end()) {
3770+
const key_list& keys = key_set_it->second;
3771+
for (const auto& key : keys) {
3772+
if (key.supports(algo)) {
3773+
key_found = true;
3774+
auto alg = from_key_and_alg(key, algo, ec);
3775+
alg->verify(data, sig, ec);
3776+
break;
3777+
}
3778+
}
3779+
}
3780+
3781+
if (!key_found) {
3782+
if (algs.count(algo) == 0) {
3783+
ec = error::token_verification_error::wrong_algorithm;
3784+
return;
3785+
}
3786+
algs.at(algo)->verify(data, sig, ec);
36873787
}
3688-
algs.at(algo)->verify(data, sig, ec);
3689-
if (ec) return;
36903788

3789+
if (ec) return;
36913790
verify_claims(jwt, ec);
36923791
}
36933792
};

tests/JwkTest.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ vQIDAQAB
4040
-----END PUBLIC KEY-----
4141
*/
4242

43-
TEST(JwkTest, ParseKey) {
43+
TEST(JwkTest, RsaKey) {
4444
std::string token =
4545
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXUyJ9.eyJpc3MiOiJhdXRoMCJ9.IzM0dgbhU1CsRbjmwyPHXkc8LagqFtsZD6p1ls_"
4646
"WBugkEKNfFmZmhOM1YYiFg59xId_KtzNdp4puzGIafut15U06DL2ZGH_H4xE7ONy6WLA_i5z5H8gPxD3ui2W4nHEEf-mvqKSn-"
@@ -55,9 +55,24 @@ TEST(JwkTest, ParseKey) {
5555

5656
auto jwk = jwt::parse_jwk(public_key);
5757
ASSERT_EQ("RSA", jwk.get_key_type());
58-
auto alg = jwt::algorithm::rsa(jwk.get_pkey(), EVP_sha256, "RS256");
59-
auto verify = jwt::verify();
58+
auto verifier = jwt::verify();
59+
verifier.allow_key(jwk);
6060
auto decoded_token = jwt::decode(token);
61+
ASSERT_NO_THROW(verifier.verify(decoded_token));
62+
}
6163

62-
ASSERT_NO_THROW(verify.do_verify(jwk, decoded_token));
64+
TEST(JwkTest, HmacKey) {
65+
std::string token =
66+
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXUyJ9.eyJpc3MiOiJhdXRoMCJ9.AbIJTDMFc7yUa5MhvcP03nJPyCPzZtQcGEp-zWfOkEE";
67+
std::string secret_key = R"({
68+
"kty": "oct",
69+
"k": "c2VjcmV0"
70+
})";
71+
72+
auto jwk = jwt::parse_jwk(secret_key);
73+
ASSERT_EQ("oct", jwk.get_key_type());
74+
auto verifier = jwt::verify();
75+
verifier.allow_key(jwk);
76+
auto decoded_token = jwt::decode(token);
77+
ASSERT_NO_THROW(verifier.verify(decoded_token));
6378
}

0 commit comments

Comments
 (0)