Skip to content

Commit 2493902

Browse files
committed
add allow_key interface
1 parent d4a022a commit 2493902

File tree

2 files changed

+123
-9
lines changed

2 files changed

+123
-9
lines changed

include/jwt-cpp/jwt.h

Lines changed: 104 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1542,6 +1542,8 @@ namespace jwt {
15421542
explicit rs256(const std::string& public_key, const std::string& private_key = "",
15431543
const std::string& public_key_password = "", const std::string& private_key_password = "")
15441544
: rsa(public_key, private_key, public_key_password, private_key_password, EVP_sha256, "RS256") {}
1545+
1546+
explicit rs256(std::shared_ptr<EVP_PKEY> pkey) : rsa(pkey, EVP_sha256, "RS256") {}
15451547
};
15461548
/**
15471549
* RS384 algorithm
@@ -1557,6 +1559,8 @@ namespace jwt {
15571559
explicit rs384(const std::string& public_key, const std::string& private_key = "",
15581560
const std::string& public_key_password = "", const std::string& private_key_password = "")
15591561
: rsa(public_key, private_key, public_key_password, private_key_password, EVP_sha384, "RS384") {}
1562+
1563+
explicit rs384(std::shared_ptr<EVP_PKEY> pkey) : rsa(pkey, EVP_sha384, "RS384") {}
15601564
};
15611565
/**
15621566
* RS512 algorithm
@@ -1572,6 +1576,8 @@ namespace jwt {
15721576
explicit rs512(const std::string& public_key, const std::string& private_key = "",
15731577
const std::string& public_key_password = "", const std::string& private_key_password = "")
15741578
: rsa(public_key, private_key, public_key_password, private_key_password, EVP_sha512, "RS512") {}
1579+
1580+
explicit rs512(std::shared_ptr<EVP_PKEY> pkey) : rsa(pkey, EVP_sha512, "RS512") {}
15751581
};
15761582
/**
15771583
* ES256 algorithm
@@ -3053,6 +3059,12 @@ namespace jwt {
30533059
};
30543060
} // namespace verify_ops
30553061

3062+
using alg_name = std::string;
3063+
using alg_list = std::vector<alg_name>;
3064+
using algorithms = std::unordered_map<std::string, alg_list>;
3065+
static const algorithms supported_alg = {{"RSA", {"RS256", "RS384", "RS512", "PS256", "PS384", "PS512"}},
3066+
{"EC", {"ES256", "ES384", "ES512", "ES256K"}},
3067+
{"oct", {"HS256", "HS384", "HS512"}}};
30563068
/**
30573069
* \brief JSON Web Key
30583070
*
@@ -3273,6 +3285,11 @@ namespace jwt {
32733285

32743286
std::string get_oct_key() const { return key.get_symmetric_key(); }
32753287

3288+
bool supports(const std::string& alg_name) const {
3289+
const alg_list& x = supported_alg.find(get_key_type())->second;
3290+
return std::find(x.begin(), x.end(), alg_name) != x.end();
3291+
}
3292+
32763293
private:
32773294
class key {
32783295
public:
@@ -3415,6 +3432,11 @@ namespace jwt {
34153432
/// Supported algorithms
34163433
std::unordered_map<std::string, std::shared_ptr<algo_base>> algs;
34173434

3435+
typedef std::vector<jwt::jwk<json_traits>> key_list;
3436+
/// https://datatracker.ietf.org/doc/html/rfc7517#section-4.5 - kid to keys
3437+
typedef std::unordered_map<std::string, key_list> keysets;
3438+
keysets keys;
3439+
34183440
void verify_claims(const decoded_jwt<json_traits>& jwt, std::error_code& ec) const {
34193441
verify_ops::verify_context<json_traits> ctx{clock.now(), jwt, default_leeway};
34203442
for (auto& c : claims) {
@@ -3424,6 +3446,52 @@ namespace jwt {
34243446
}
34253447
}
34263448

3449+
static inline std::unique_ptr<algo_base> from_key_and_alg(const jwt::jwk<json_traits>& key,
3450+
const std::string& alg_name, std::error_code& ec) {
3451+
ec.clear();
3452+
algorithms::const_iterator it = supported_alg.find(key.get_key_type());
3453+
if (it == supported_alg.end()) {
3454+
ec = error::token_verification_error::wrong_algorithm;
3455+
return nullptr;
3456+
}
3457+
3458+
const alg_list& supported_jwt_algorithms = it->second;
3459+
if (std::find(supported_jwt_algorithms.begin(), supported_jwt_algorithms.end(), alg_name) ==
3460+
supported_jwt_algorithms.end()) {
3461+
ec = error::token_verification_error::wrong_algorithm;
3462+
return nullptr;
3463+
}
3464+
3465+
if (alg_name == "RS256") {
3466+
return std::make_unique<algo<jwt::algorithm::rs256>>(jwt::algorithm::rs256(key.get_pkey()));
3467+
} else if (alg_name == "RS384") {
3468+
return std::make_unique<algo<jwt::algorithm::rs384>>(jwt::algorithm::rs384(key.get_pkey()));
3469+
} else if (alg_name == "RS512") {
3470+
return std::make_unique<algo<jwt::algorithm::rs512>>(jwt::algorithm::rs512(key.get_pkey()));
3471+
} else if (alg_name == "PS256") {
3472+
return std::make_unique<algo<jwt::algorithm::ps256>>(jwt::algorithm::ps256(key.get_pkey()));
3473+
} else if (alg_name == "PS384") {
3474+
return std::make_unique<algo<jwt::algorithm::ps384>>(jwt::algorithm::ps384(key.get_pkey()));
3475+
} else if (alg_name == "PS512") {
3476+
return std::make_unique<algo<jwt::algorithm::ps512>>(jwt::algorithm::ps512(key.get_pkey()));
3477+
} else if (alg_name == "ES256") {
3478+
return std::make_unique<algo<jwt::algorithm::es256>>(jwt::algorithm::es256(key.get_pkey()));
3479+
} else if (alg_name == "ES384") {
3480+
return std::make_unique<algo<jwt::algorithm::es384>>(jwt::algorithm::es384(key.get_pkey()));
3481+
} else if (alg_name == "ES512") {
3482+
return std::make_unique<algo<jwt::algorithm::es512>>(jwt::algorithm::es512(key.get_pkey()));
3483+
} else if (alg_name == "HS256") {
3484+
return std::make_unique<algo<jwt::algorithm::hs256>>(jwt::algorithm::hs256(key.get_oct_key()));
3485+
} else if (alg_name == "HS384") {
3486+
return std::make_unique<algo<jwt::algorithm::hs384>>(jwt::algorithm::hs384(key.get_oct_key()));
3487+
} else if (alg_name == "HS512") {
3488+
return std::make_unique<algo<jwt::algorithm::hs512>>(jwt::algorithm::hs512(key.get_oct_key()));
3489+
}
3490+
3491+
ec = error::token_verification_error::wrong_algorithm;
3492+
return nullptr;
3493+
}
3494+
34273495
public:
34283496
/**
34293497
* Constructor for building a new verifier instance
@@ -3583,6 +3651,18 @@ namespace jwt {
35833651
return *this;
35843652
}
35853653

3654+
verifier& allow_key(const jwt::jwk<json_traits>& key) {
3655+
std::string keyid = "";
3656+
if (key.has_key_id()) {
3657+
keyid = key.get_key_id();
3658+
typename keysets::const_iterator it = keys.find(keyid);
3659+
if (it == keys.end()) { keys[keyid] = key_list(); }
3660+
}
3661+
3662+
keys[keyid].push_back(key);
3663+
return *this;
3664+
}
3665+
35863666
/**
35873667
* Verify the given token.
35883668
* \param jwt Token to check
@@ -3603,13 +3683,32 @@ namespace jwt {
36033683
const typename json_traits::string_type data = jwt.get_header_base64() + "." + jwt.get_payload_base64();
36043684
const typename json_traits::string_type sig = jwt.get_signature();
36053685
const std::string algo = jwt.get_algorithm();
3606-
if (algs.count(algo) == 0) {
3607-
ec = error::token_verification_error::wrong_algorithm;
3608-
return;
3686+
std::string kid("");
3687+
if (jwt.has_header_claim("kid")) { kid = jwt.get_header_claim("kid").as_string(); }
3688+
3689+
typename keysets::const_iterator key_set_it = keys.find(kid);
3690+
bool key_found = false;
3691+
if (key_set_it != keys.end()) {
3692+
const key_list& keys = key_set_it->second;
3693+
for (const auto& key : keys) {
3694+
if (key.supports(algo)) {
3695+
key_found = true;
3696+
auto alg = from_key_and_alg(key, algo, ec);
3697+
alg->verify(data, sig, ec);
3698+
break;
3699+
}
3700+
}
36093701
}
3610-
algs.at(algo)->verify(data, sig, ec);
3611-
if (ec) return;
36123702

3703+
if (!key_found) {
3704+
if (algs.count(algo) == 0) {
3705+
ec = error::token_verification_error::wrong_algorithm;
3706+
return;
3707+
}
3708+
algs.at(algo)->verify(data, sig, ec);
3709+
}
3710+
3711+
if (ec) return;
36133712
verify_claims(jwt, ec);
36143713
}
36153714
};

tests/JwkTest.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ vQIDAQAB
4040
-----END PUBLIC KEY-----
4141
*/
4242

43-
TEST(JwkTest, ParseKey) {
43+
TEST(JwkTest, RsaKey) {
4444
std::string token =
4545
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXUyJ9.eyJpc3MiOiJhdXRoMCJ9.IzM0dgbhU1CsRbjmwyPHXkc8LagqFtsZD6p1ls_"
4646
"WBugkEKNfFmZmhOM1YYiFg59xId_KtzNdp4puzGIafut15U06DL2ZGH_H4xE7ONy6WLA_i5z5H8gPxD3ui2W4nHEEf-mvqKSn-"
@@ -55,9 +55,24 @@ TEST(JwkTest, ParseKey) {
5555

5656
auto jwk = jwt::parse_jwk(public_key);
5757
ASSERT_EQ("RSA", jwk.get_key_type());
58-
auto alg = jwt::algorithm::rsa(jwk.get_pkey(), EVP_sha256, "RS256");
59-
auto verify = jwt::verify();
58+
auto verifier = jwt::verify();
59+
verifier.allow_key(jwk);
6060
auto decoded_token = jwt::decode(token);
61+
ASSERT_NO_THROW(verifier.verify(decoded_token));
62+
}
6163

62-
ASSERT_NO_THROW(verify.do_verify(jwk, decoded_token));
64+
TEST(JwkTest, HmacKey) {
65+
std::string token =
66+
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXUyJ9.eyJpc3MiOiJhdXRoMCJ9.AbIJTDMFc7yUa5MhvcP03nJPyCPzZtQcGEp-zWfOkEE";
67+
std::string secret_key = R"({
68+
"kty": "oct",
69+
"k": "c2VjcmV0"
70+
})";
71+
72+
auto jwk = jwt::parse_jwk(secret_key);
73+
ASSERT_EQ("oct", jwk.get_key_type());
74+
auto verifier = jwt::verify();
75+
verifier.allow_key(jwk);
76+
auto decoded_token = jwt::decode(token);
77+
ASSERT_NO_THROW(verifier.verify(decoded_token));
6378
}

0 commit comments

Comments
 (0)