Skip to content

Commit 3ed4ff9

Browse files
authored
Use EVP_PKEY_up_ref if available (#238)
1 parent c9a511f commit 3ed4ff9

File tree

3 files changed

+129
-68
lines changed

3 files changed

+129
-68
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,8 @@ test
307307
*.o
308308
*.o.d
309309
.vscode
310+
# ClangD cache files
311+
.cache
310312

311313
doxy/
312314
doxygen-awesome*.css

include/jwt-cpp/jwt.h

Lines changed: 121 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,83 @@ namespace jwt {
388388
* you maybe need to extract the modulus and exponent of an RSA Public Key.
389389
*/
390390
namespace helper {
391+
/**
392+
* \brief Handle class for EVP_PKEY structures
393+
*
394+
* Starting from OpenSSL 1.1.0, EVP_PKEY has internal reference counting. This handle class allows
395+
* jwt-cpp to leverage that and thus safe an allocation for the control block in std::shared_ptr.
396+
* The handle uses shared_ptr as a fallback on older versions. The behaviour should be identical between both.
397+
*/
398+
class evp_pkey_handle {
399+
public:
400+
constexpr evp_pkey_handle() noexcept = default;
401+
#ifdef JWT_OPENSSL_1_0_0
402+
/**
403+
* \brief Contruct a new handle. The handle takes ownership of the key.
404+
* \param key The key to store
405+
*/
406+
explicit evp_pkey_handle(EVP_PKEY* key) { m_key = std::shared_ptr<EVP_PKEY>(key, EVP_PKEY_free); }
407+
408+
EVP_PKEY* get() const noexcept { return m_key.get(); }
409+
bool operator!() const noexcept { return m_key == nullptr; }
410+
explicit operator bool() const noexcept { return m_key != nullptr; }
411+
412+
private:
413+
std::shared_ptr<EVP_PKEY> m_key{nullptr};
414+
#else
415+
/**
416+
* \brief Contruct a new handle. The handle takes ownership of the key.
417+
* \param key The key to store
418+
*/
419+
explicit constexpr evp_pkey_handle(EVP_PKEY* key) noexcept : m_key{key} {}
420+
evp_pkey_handle(const evp_pkey_handle& other) : m_key{other.m_key} {
421+
if (m_key != nullptr && EVP_PKEY_up_ref(m_key) != 1) throw std::runtime_error("EVP_PKEY_up_ref failed");
422+
}
423+
// C++11 requires the body of a constexpr constructor to be empty
424+
#if __cplusplus >= 201402L
425+
constexpr
426+
#endif
427+
evp_pkey_handle(evp_pkey_handle&& other) noexcept
428+
: m_key{other.m_key} {
429+
other.m_key = nullptr;
430+
}
431+
evp_pkey_handle& operator=(const evp_pkey_handle& other) {
432+
if (&other == this) return *this;
433+
decrement_ref_count(m_key);
434+
m_key = other.m_key;
435+
increment_ref_count(m_key);
436+
return *this;
437+
}
438+
evp_pkey_handle& operator=(evp_pkey_handle&& other) noexcept {
439+
if (&other == this) return *this;
440+
decrement_ref_count(m_key);
441+
m_key = other.m_key;
442+
other.m_key = nullptr;
443+
return *this;
444+
}
445+
evp_pkey_handle& operator=(EVP_PKEY* key) {
446+
decrement_ref_count(m_key);
447+
m_key = key;
448+
increment_ref_count(m_key);
449+
return *this;
450+
}
451+
~evp_pkey_handle() noexcept { decrement_ref_count(m_key); }
452+
453+
EVP_PKEY* get() const noexcept { return m_key; }
454+
bool operator!() const noexcept { return m_key == nullptr; }
455+
explicit operator bool() const noexcept { return m_key != nullptr; }
456+
457+
private:
458+
EVP_PKEY* m_key{nullptr};
459+
460+
static void increment_ref_count(EVP_PKEY* key) {
461+
if (key != nullptr && EVP_PKEY_up_ref(key) != 1) throw std::runtime_error("EVP_PKEY_up_ref failed");
462+
}
463+
static void decrement_ref_count(EVP_PKEY* key) noexcept {
464+
if (key != nullptr) EVP_PKEY_free(key);
465+
}
466+
#endif
467+
};
391468
/**
392469
* \brief Extract the public key of a pem certificate
393470
*
@@ -556,38 +633,34 @@ namespace jwt {
556633
* \param password Password used to decrypt certificate (leave empty if not encrypted)
557634
* \param ec error_code for error_detection (gets cleared if no error occures)
558635
*/
559-
inline std::shared_ptr<EVP_PKEY> load_public_key_from_string(const std::string& key,
560-
const std::string& password, std::error_code& ec) {
636+
inline evp_pkey_handle load_public_key_from_string(const std::string& key, const std::string& password,
637+
std::error_code& ec) {
561638
ec.clear();
562639
std::unique_ptr<BIO, decltype(&BIO_free_all)> pubkey_bio(BIO_new(BIO_s_mem()), BIO_free_all);
563640
if (!pubkey_bio) {
564641
ec = error::rsa_error::create_mem_bio_failed;
565-
return nullptr;
642+
return {};
566643
}
567644
if (key.substr(0, 27) == "-----BEGIN CERTIFICATE-----") {
568645
auto epkey = helper::extract_pubkey_from_cert(key, password, ec);
569-
if (ec) return nullptr;
646+
if (ec) return {};
570647
const int len = static_cast<int>(epkey.size());
571648
if (BIO_write(pubkey_bio.get(), epkey.data(), len) != len) {
572649
ec = error::rsa_error::load_key_bio_write;
573-
return nullptr;
650+
return {};
574651
}
575652
} else {
576653
const int len = static_cast<int>(key.size());
577654
if (BIO_write(pubkey_bio.get(), key.data(), len) != len) {
578655
ec = error::rsa_error::load_key_bio_write;
579-
return nullptr;
656+
return {};
580657
}
581658
}
582659

583-
std::shared_ptr<EVP_PKEY> pkey(
584-
PEM_read_bio_PUBKEY(pubkey_bio.get(), nullptr, nullptr,
585-
(void*)password.data()), // NOLINT(google-readability-casting) requires `const_cast`
586-
EVP_PKEY_free);
587-
if (!pkey) {
588-
ec = error::rsa_error::load_key_bio_read;
589-
return nullptr;
590-
}
660+
evp_pkey_handle pkey(PEM_read_bio_PUBKEY(
661+
pubkey_bio.get(), nullptr, nullptr,
662+
(void*)password.data())); // NOLINT(google-readability-casting) requires `const_cast`
663+
if (!pkey) ec = error::rsa_error::load_key_bio_read;
591664
return pkey;
592665
}
593666

@@ -600,8 +673,7 @@ namespace jwt {
600673
* \param password Password used to decrypt certificate or key (leave empty if not encrypted)
601674
* \throw rsa_exception if an error occurred
602675
*/
603-
inline std::shared_ptr<EVP_PKEY> load_public_key_from_string(const std::string& key,
604-
const std::string& password = "") {
676+
inline evp_pkey_handle load_public_key_from_string(const std::string& key, const std::string& password = "") {
605677
std::error_code ec;
606678
auto res = load_public_key_from_string(key, password, ec);
607679
error::throw_if_error(ec);
@@ -615,25 +687,21 @@ namespace jwt {
615687
* \param password Password used to decrypt key (leave empty if not encrypted)
616688
* \param ec error_code for error_detection (gets cleared if no error occures)
617689
*/
618-
inline std::shared_ptr<EVP_PKEY>
619-
load_private_key_from_string(const std::string& key, const std::string& password, std::error_code& ec) {
690+
inline evp_pkey_handle load_private_key_from_string(const std::string& key, const std::string& password,
691+
std::error_code& ec) {
620692
std::unique_ptr<BIO, decltype(&BIO_free_all)> privkey_bio(BIO_new(BIO_s_mem()), BIO_free_all);
621693
if (!privkey_bio) {
622694
ec = error::rsa_error::create_mem_bio_failed;
623-
return nullptr;
695+
return {};
624696
}
625697
const int len = static_cast<int>(key.size());
626698
if (BIO_write(privkey_bio.get(), key.data(), len) != len) {
627699
ec = error::rsa_error::load_key_bio_write;
628-
return nullptr;
629-
}
630-
std::shared_ptr<EVP_PKEY> pkey(
631-
PEM_read_bio_PrivateKey(privkey_bio.get(), nullptr, nullptr, const_cast<char*>(password.c_str())),
632-
EVP_PKEY_free);
633-
if (!pkey) {
634-
ec = error::rsa_error::load_key_bio_read;
635-
return nullptr;
700+
return {};
636701
}
702+
evp_pkey_handle pkey(
703+
PEM_read_bio_PrivateKey(privkey_bio.get(), nullptr, nullptr, const_cast<char*>(password.c_str())));
704+
if (!pkey) ec = error::rsa_error::load_key_bio_read;
637705
return pkey;
638706
}
639707

@@ -644,8 +712,7 @@ namespace jwt {
644712
* \param password Password used to decrypt key (leave empty if not encrypted)
645713
* \throw rsa_exception if an error occurred
646714
*/
647-
inline std::shared_ptr<EVP_PKEY> load_private_key_from_string(const std::string& key,
648-
const std::string& password = "") {
715+
inline evp_pkey_handle load_private_key_from_string(const std::string& key, const std::string& password = "") {
649716
std::error_code ec;
650717
auto res = load_private_key_from_string(key, password, ec);
651718
error::throw_if_error(ec);
@@ -661,38 +728,34 @@ namespace jwt {
661728
* \param password Password used to decrypt certificate (leave empty if not encrypted)
662729
* \param ec error_code for error_detection (gets cleared if no error occures)
663730
*/
664-
inline std::shared_ptr<EVP_PKEY>
665-
load_public_ec_key_from_string(const std::string& key, const std::string& password, std::error_code& ec) {
731+
inline evp_pkey_handle load_public_ec_key_from_string(const std::string& key, const std::string& password,
732+
std::error_code& ec) {
666733
ec.clear();
667734
std::unique_ptr<BIO, decltype(&BIO_free_all)> pubkey_bio(BIO_new(BIO_s_mem()), BIO_free_all);
668735
if (!pubkey_bio) {
669736
ec = error::ecdsa_error::create_mem_bio_failed;
670-
return nullptr;
737+
return {};
671738
}
672739
if (key.substr(0, 27) == "-----BEGIN CERTIFICATE-----") {
673740
auto epkey = helper::extract_pubkey_from_cert(key, password, ec);
674-
if (ec) return nullptr;
741+
if (ec) return {};
675742
const int len = static_cast<int>(epkey.size());
676743
if (BIO_write(pubkey_bio.get(), epkey.data(), len) != len) {
677744
ec = error::ecdsa_error::load_key_bio_write;
678-
return nullptr;
745+
return {};
679746
}
680747
} else {
681748
const int len = static_cast<int>(key.size());
682749
if (BIO_write(pubkey_bio.get(), key.data(), len) != len) {
683750
ec = error::ecdsa_error::load_key_bio_write;
684-
return nullptr;
751+
return {};
685752
}
686753
}
687754

688-
std::shared_ptr<EVP_PKEY> pkey(
689-
PEM_read_bio_PUBKEY(pubkey_bio.get(), nullptr, nullptr,
690-
(void*)password.data()), // NOLINT(google-readability-casting) requires `const_cast`
691-
EVP_PKEY_free);
692-
if (!pkey) {
693-
ec = error::ecdsa_error::load_key_bio_read;
694-
return nullptr;
695-
}
755+
evp_pkey_handle pkey(PEM_read_bio_PUBKEY(
756+
pubkey_bio.get(), nullptr, nullptr,
757+
(void*)password.data())); // NOLINT(google-readability-casting) requires `const_cast`
758+
if (!pkey) ec = error::ecdsa_error::load_key_bio_read;
696759
return pkey;
697760
}
698761

@@ -705,8 +768,8 @@ namespace jwt {
705768
* \param password Password used to decrypt certificate or key (leave empty if not encrypted)
706769
* \throw ecdsa_exception if an error occurred
707770
*/
708-
inline std::shared_ptr<EVP_PKEY> load_public_ec_key_from_string(const std::string& key,
709-
const std::string& password = "") {
771+
inline evp_pkey_handle load_public_ec_key_from_string(const std::string& key,
772+
const std::string& password = "") {
710773
std::error_code ec;
711774
auto res = load_public_ec_key_from_string(key, password, ec);
712775
error::throw_if_error(ec);
@@ -720,25 +783,21 @@ namespace jwt {
720783
* \param password Password used to decrypt key (leave empty if not encrypted)
721784
* \param ec error_code for error_detection (gets cleared if no error occures)
722785
*/
723-
inline std::shared_ptr<EVP_PKEY>
724-
load_private_ec_key_from_string(const std::string& key, const std::string& password, std::error_code& ec) {
786+
inline evp_pkey_handle load_private_ec_key_from_string(const std::string& key, const std::string& password,
787+
std::error_code& ec) {
725788
std::unique_ptr<BIO, decltype(&BIO_free_all)> privkey_bio(BIO_new(BIO_s_mem()), BIO_free_all);
726789
if (!privkey_bio) {
727790
ec = error::ecdsa_error::create_mem_bio_failed;
728-
return nullptr;
791+
return {};
729792
}
730793
const int len = static_cast<int>(key.size());
731794
if (BIO_write(privkey_bio.get(), key.data(), len) != len) {
732795
ec = error::ecdsa_error::load_key_bio_write;
733-
return nullptr;
734-
}
735-
std::shared_ptr<EVP_PKEY> pkey(
736-
PEM_read_bio_PrivateKey(privkey_bio.get(), nullptr, nullptr, const_cast<char*>(password.c_str())),
737-
EVP_PKEY_free);
738-
if (!pkey) {
739-
ec = error::ecdsa_error::load_key_bio_read;
740-
return nullptr;
796+
return {};
741797
}
798+
evp_pkey_handle pkey(
799+
PEM_read_bio_PrivateKey(privkey_bio.get(), nullptr, nullptr, const_cast<char*>(password.c_str())));
800+
if (!pkey) ec = error::ecdsa_error::load_key_bio_read;
742801
return pkey;
743802
}
744803

@@ -749,8 +808,8 @@ namespace jwt {
749808
* \param password Password used to decrypt key (leave empty if not encrypted)
750809
* \throw ecdsa_exception if an error occurred
751810
*/
752-
inline std::shared_ptr<EVP_PKEY> load_private_ec_key_from_string(const std::string& key,
753-
const std::string& password = "") {
811+
inline evp_pkey_handle load_private_ec_key_from_string(const std::string& key,
812+
const std::string& password = "") {
754813
std::error_code ec;
755814
auto res = load_private_ec_key_from_string(key, password, ec);
756815
error::throw_if_error(ec);
@@ -990,7 +1049,7 @@ namespace jwt {
9901049

9911050
private:
9921051
/// OpenSSL structure containing converted keys
993-
std::shared_ptr<EVP_PKEY> pkey;
1052+
helper::evp_pkey_handle pkey;
9941053
/// Hash generator
9951054
const EVP_MD* (*md)();
9961055
/// algorithm's name
@@ -1214,7 +1273,7 @@ namespace jwt {
12141273
}
12151274

12161275
/// OpenSSL struct containing keys
1217-
std::shared_ptr<EVP_PKEY> pkey;
1276+
helper::evp_pkey_handle pkey;
12181277
/// Hash generator function
12191278
const EVP_MD* (*md)();
12201279
/// algorithm's name
@@ -1360,7 +1419,7 @@ namespace jwt {
13601419

13611420
private:
13621421
/// OpenSSL struct containing keys
1363-
std::shared_ptr<EVP_PKEY> pkey;
1422+
helper::evp_pkey_handle pkey;
13641423
/// algorithm's name
13651424
const std::string alg_name;
13661425
};
@@ -1496,7 +1555,7 @@ namespace jwt {
14961555

14971556
private:
14981557
/// OpenSSL structure containing keys
1499-
std::shared_ptr<EVP_PKEY> pkey;
1558+
helper::evp_pkey_handle pkey;
15001559
/// Hash generator function
15011560
const EVP_MD* (*md)();
15021561
/// algorithm's name

tests/OpenSSLErrorTest.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ TEST(OpenSSLErrorTest, ConvertCertBase64DerToPemErrorCode) {
533533

534534
TEST(OpenSSLErrorTest, LoadPublicKeyFromStringReference) {
535535
auto res = jwt::helper::load_public_key_from_string(rsa_pub_key, "");
536-
ASSERT_NE(res, nullptr);
536+
ASSERT_TRUE(res);
537537
}
538538

539539
TEST(OpenSSLErrorTest, LoadPublicKeyFromString) {
@@ -556,13 +556,13 @@ TEST(OpenSSLErrorTest, LoadPublicKeyFromStringErrorCode) {
556556

557557
run_multitest(mapping, [](std::error_code& ec) {
558558
auto res = jwt::helper::load_public_key_from_string(rsa_pub_key, "", ec);
559-
ASSERT_EQ(res, nullptr);
559+
ASSERT_FALSE(res);
560560
});
561561
}
562562

563563
TEST(OpenSSLErrorTest, LoadPublicKeyCertFromStringReference) {
564564
auto res = jwt::helper::load_public_key_from_string(sample_cert, "");
565-
ASSERT_NE(res, nullptr);
565+
ASSERT_TRUE(res);
566566
}
567567

568568
TEST(OpenSSLErrorTest, LoadPublicKeyCertFromString) {
@@ -601,13 +601,13 @@ TEST(OpenSSLErrorTest, LoadPublicKeyCertFromStringErrorCode) {
601601

602602
run_multitest(mapping, [](std::error_code& ec) {
603603
auto res = jwt::helper::load_public_key_from_string(sample_cert, "", ec);
604-
ASSERT_EQ(res, nullptr);
604+
ASSERT_FALSE(res);
605605
});
606606
}
607607

608608
TEST(OpenSSLErrorTest, LoadPrivateKeyFromStringReference) {
609609
auto res = jwt::helper::load_private_key_from_string(rsa_priv_key, "");
610-
ASSERT_NE(res, nullptr);
610+
ASSERT_TRUE(res);
611611
}
612612

613613
TEST(OpenSSLErrorTest, LoadPrivateKeyFromString) {
@@ -630,7 +630,7 @@ TEST(OpenSSLErrorTest, LoadPrivateKeyFromStringErrorCode) {
630630

631631
run_multitest(mapping, [](std::error_code& ec) {
632632
auto res = jwt::helper::load_private_key_from_string(rsa_priv_key, "", ec);
633-
ASSERT_EQ(res, nullptr);
633+
ASSERT_FALSE(res);
634634
});
635635
}
636636

0 commit comments

Comments
 (0)