Skip to content

Commit 46d9b44

Browse files
authored
Added mTLS support (#213)
1 parent 28694d8 commit 46d9b44

File tree

8 files changed

+127
-25
lines changed

8 files changed

+127
-25
lines changed

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,10 @@ endif(WIN32)
128128
find_package(OpenSSL)
129129
if(OpenSSL_FOUND)
130130
target_link_libraries(${PROJECT_NAME} PRIVATE OpenSSL::SSL OpenSSL::Crypto)
131+
132+
# To enable the acceptance of local issued certificates
133+
# Add the flag: ALLOW_SELF_SIGNED_CERTS
134+
# ONLY FOR TESTING PURPOSES
131135
target_compile_definitions(${PROJECT_NAME} PRIVATE USE_OPENSSL)
132136
endif()
133137

trantor/net/TcpClient.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,12 +226,13 @@ void TcpClient::enableSSL(
226226
std::string hostname,
227227
const std::vector<std::pair<std::string, std::string>> &sslConfCmds,
228228
const std::string &certPath,
229-
const std::string &keyPath)
229+
const std::string &keyPath,
230+
const std::string &caPath)
230231
{
231232
#ifdef USE_OPENSSL
232233
/* Create a new OpenSSL context */
233234
sslCtxPtr_ = newSSLClientContext(
234-
useOldTLS, validateCert, certPath, keyPath, sslConfCmds);
235+
useOldTLS, validateCert, certPath, keyPath, sslConfCmds, caPath);
235236
validateCert_ = validateCert;
236237
if (!hostname.empty())
237238
{
@@ -251,6 +252,7 @@ void TcpClient::enableSSL(
251252
(void)sslConfCmds;
252253
(void)certPath;
253254
(void)keyPath;
255+
(void)caPath;
254256

255257
LOG_FATAL << "OpenSSL is not found in your system!";
256258
throw std::runtime_error("OpenSSL is not found in your system!");

trantor/net/TcpClient.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,8 @@ class TRANTOR_EXPORT TcpClient : NonCopyable,
211211
const std::vector<std::pair<std::string, std::string>>
212212
&sslConfCmds = {},
213213
const std::string &certPath = "",
214-
const std::string &keyPath = "");
214+
const std::string &keyPath = "",
215+
const std::string &caPath = "");
215216

216217
private:
217218
/// Not thread safe, but in loop

trantor/net/TcpConnection.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ TRANTOR_EXPORT std::shared_ptr<SSLContext> newSSLServerContext(
3030
const std::string &certPath,
3131
const std::string &keyPath,
3232
bool useOldTLS = false,
33-
const std::vector<std::pair<std::string, std::string>> &sslConfCmds = {});
33+
const std::vector<std::pair<std::string, std::string>> &sslConfCmds = {},
34+
const std::string &caPath = "");
3435
/**
3536
* @brief This class represents a TCP connection.
3637
*

trantor/net/TcpServer.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,18 +206,21 @@ void TcpServer::enableSSL(
206206
const std::string &certPath,
207207
const std::string &keyPath,
208208
bool useOldTLS,
209-
const std::vector<std::pair<std::string, std::string>> &sslConfCmds)
209+
const std::vector<std::pair<std::string, std::string>> &sslConfCmds,
210+
const std::string &caPath)
210211
{
211212
#ifdef USE_OPENSSL
212213
/* Create a new OpenSSL context */
213-
sslCtxPtr_ = newSSLServerContext(certPath, keyPath, useOldTLS, sslConfCmds);
214+
sslCtxPtr_ =
215+
newSSLServerContext(certPath, keyPath, useOldTLS, sslConfCmds, caPath);
214216
#else
215217
// When not using OpenSSL, using `void` here will
216218
// work around the unused parameter warnings without overhead.
217219
(void)certPath;
218220
(void)keyPath;
219221
(void)useOldTLS;
220222
(void)sslConfCmds;
223+
(void)caPath;
221224

222225
LOG_FATAL << "OpenSSL is not found in your system!";
223226
throw std::runtime_error("OpenSSL is not found in your system!");

trantor/net/TcpServer.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,8 @@ class TRANTOR_EXPORT TcpServer : NonCopyable
211211
const std::string &keyPath,
212212
bool useOldTLS = false,
213213
const std::vector<std::pair<std::string, std::string>>
214-
&sslConfCmds = {});
214+
&sslConfCmds = {},
215+
const std::string &caPath = "");
215216

216217
private:
217218
EventLoop *loop_;

trantor/net/inner/TcpConnectionImpl.cc

Lines changed: 104 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -290,16 +290,18 @@ class SSLContext
290290
{
291291
return ctxPtr_;
292292
}
293+
bool mtlsEnabled = false;
293294

294295
private:
295296
SSL_CTX *ctxPtr_;
296297
};
297298
class SSLConn
298299
{
299300
public:
300-
explicit SSLConn(SSL_CTX *ctx)
301+
explicit SSLConn(SSL_CTX *ctx, bool mtlsEnabled_)
301302
{
302303
SSL_ = SSL_new(ctx);
304+
mtlsEnabled = mtlsEnabled_;
303305
}
304306
~SSLConn()
305307
{
@@ -312,6 +314,7 @@ class SSLConn
312314
{
313315
return SSL_;
314316
}
317+
bool mtlsEnabled = false;
315318

316319
private:
317320
SSL *SSL_;
@@ -329,7 +332,8 @@ std::shared_ptr<SSLContext> newSSLServerContext(
329332
const std::string &certPath,
330333
const std::string &keyPath,
331334
bool useOldTLS,
332-
const std::vector<std::pair<std::string, std::string>> &sslConfCmds)
335+
const std::vector<std::pair<std::string, std::string>> &sslConfCmds,
336+
const std::string &caPath)
333337
{
334338
auto ctx = newSSLContext(useOldTLS, false, sslConfCmds);
335339
auto r = SSL_CTX_use_certificate_chain_file(ctx->get(), certPath.c_str());
@@ -356,14 +360,38 @@ std::shared_ptr<SSLContext> newSSLServerContext(
356360
LOG_FATAL << "Checking private key matches certificate: " << errbuf;
357361
throw std::runtime_error("SSL_CTX_check_private_key error");
358362
}
363+
364+
if (!caPath.empty())
365+
{
366+
auto checkCA =
367+
SSL_CTX_load_verify_locations(ctx->get(), caPath.c_str(), NULL);
368+
LOG_DEBUG << "CA CHECK LOC: " << checkCA;
369+
if (checkCA)
370+
{
371+
STACK_OF(X509_NAME) *cert_names =
372+
SSL_load_client_CA_file(caPath.c_str());
373+
if (cert_names != NULL)
374+
{
375+
SSL_CTX_set_client_CA_list(ctx->get(), cert_names);
376+
}
377+
ctx->mtlsEnabled = true;
378+
}
379+
else
380+
{
381+
LOG_FATAL << "caPath location error ";
382+
throw std::runtime_error("SSL_CTX_load_verify_locations error");
383+
}
384+
}
385+
359386
return ctx;
360387
}
361388
std::shared_ptr<SSLContext> newSSLClientContext(
362389
bool useOldTLS,
363390
bool validateCert,
364391
const std::string &certPath,
365392
const std::string &keyPath,
366-
const std::vector<std::pair<std::string, std::string>> &sslConfCmds)
393+
const std::vector<std::pair<std::string, std::string>> &sslConfCmds,
394+
const std::string &caPath)
367395
{
368396
auto ctx = newSSLContext(useOldTLS, validateCert, sslConfCmds);
369397
if (certPath.empty() || keyPath.empty())
@@ -393,6 +421,29 @@ std::shared_ptr<SSLContext> newSSLClientContext(
393421
LOG_FATAL << "Checking private key matches certificate: " << errbuf;
394422
throw std::runtime_error("SSL_CTX_check_private_key error.");
395423
}
424+
425+
if (!caPath.empty())
426+
{
427+
auto checkCA =
428+
SSL_CTX_load_verify_locations(ctx->get(), caPath.c_str(), NULL);
429+
LOG_DEBUG << "CA CHECK LOC: " << checkCA;
430+
if (checkCA)
431+
{
432+
STACK_OF(X509_NAME) *cert_names =
433+
SSL_load_client_CA_file(caPath.c_str());
434+
if (cert_names != NULL)
435+
{
436+
SSL_CTX_set_client_CA_list(ctx->get(), cert_names);
437+
}
438+
ctx->mtlsEnabled = true;
439+
}
440+
else
441+
{
442+
LOG_FATAL << "caPath location error ";
443+
throw std::runtime_error("SSL_CTX_load_verify_locations error");
444+
}
445+
}
446+
396447
return ctx;
397448
}
398449
} // namespace trantor
@@ -403,7 +454,8 @@ std::shared_ptr<SSLContext> newSSLServerContext(
403454
const std::string &,
404455
const std::string &,
405456
bool,
406-
const std::vector<std::pair<std::string, std::string>> &)
457+
const std::vector<std::pair<std::string, std::string>> &,
458+
const std::string &)
407459
{
408460
LOG_FATAL << "OpenSSL is not found in your system!";
409461
throw std::runtime_error("OpenSSL is not found in your system!");
@@ -457,11 +509,15 @@ void TcpConnectionImpl::startClientEncryptionInLoop(
457509
sslEncryptionPtr_->sslCtxPtr_ =
458510
newSSLContext(useOldTLS, validateCert_, sslConfCmds);
459511
sslEncryptionPtr_->sslPtr_ =
460-
std::make_unique<SSLConn>(sslEncryptionPtr_->sslCtxPtr_->get());
461-
if (validateCert)
512+
std::make_unique<SSLConn>(sslEncryptionPtr_->sslCtxPtr_->get(),
513+
sslEncryptionPtr_->sslCtxPtr_->mtlsEnabled);
514+
if (validateCert || sslEncryptionPtr_->sslPtr_->mtlsEnabled)
462515
{
516+
LOG_DEBUG << "MTLS: " << sslEncryptionPtr_->sslPtr_->mtlsEnabled;
463517
SSL_set_verify(sslEncryptionPtr_->sslPtr_->get(),
464-
SSL_VERIFY_NONE,
518+
sslEncryptionPtr_->sslPtr_->mtlsEnabled
519+
? SSL_VERIFY_PEER
520+
: SSL_VERIFY_NONE,
465521
nullptr);
466522
validateCert_ = validateCert;
467523
}
@@ -497,13 +553,21 @@ void TcpConnectionImpl::startServerEncryptionInLoop(
497553
sslEncryptionPtr_->sslCtxPtr_ = ctx;
498554
sslEncryptionPtr_->isServer_ = true;
499555
sslEncryptionPtr_->sslPtr_ =
500-
std::make_unique<SSLConn>(sslEncryptionPtr_->sslCtxPtr_->get());
556+
std::make_unique<SSLConn>(sslEncryptionPtr_->sslCtxPtr_->get(),
557+
sslEncryptionPtr_->sslCtxPtr_->mtlsEnabled);
501558
isEncrypted_ = true;
502559
sslEncryptionPtr_->isUpgrade_ = true;
503-
if (sslEncryptionPtr_->isServer_ == false)
560+
if (sslEncryptionPtr_->isServer_ == false ||
561+
sslEncryptionPtr_->sslPtr_->mtlsEnabled)
562+
{
563+
LOG_DEBUG << "MTLS: " << sslEncryptionPtr_->sslPtr_->mtlsEnabled;
504564
SSL_set_verify(sslEncryptionPtr_->sslPtr_->get(),
505-
SSL_VERIFY_NONE,
565+
sslEncryptionPtr_->sslPtr_->mtlsEnabled
566+
? SSL_VERIFY_PEER
567+
: SSL_VERIFY_NONE,
506568
nullptr);
569+
}
570+
507571
auto r = SSL_set_fd(sslEncryptionPtr_->sslPtr_->get(), socketPtr_->fd());
508572
(void)r;
509573
assert(r);
@@ -1895,13 +1959,20 @@ TcpConnectionImpl::TcpConnectionImpl(EventLoop *loop,
18951959
socketPtr_->setKeepAlive(true);
18961960
name_ = localAddr.toIpPort() + "--" + peerAddr.toIpPort();
18971961
sslEncryptionPtr_ = std::make_unique<SSLEncryption>();
1898-
sslEncryptionPtr_->sslPtr_ = std::make_unique<SSLConn>(ctxPtr->get());
1962+
sslEncryptionPtr_->sslPtr_ =
1963+
std::make_unique<SSLConn>(ctxPtr->get(), ctxPtr->mtlsEnabled);
18991964
sslEncryptionPtr_->isServer_ = isServer;
19001965
validateCert_ = validateCert;
1901-
if (isServer == false)
1966+
if (isServer == false || sslEncryptionPtr_->sslPtr_->mtlsEnabled)
1967+
{
1968+
LOG_DEBUG << "MTLS: " << sslEncryptionPtr_->sslPtr_->mtlsEnabled;
19021969
SSL_set_verify(sslEncryptionPtr_->sslPtr_->get(),
1903-
SSL_VERIFY_NONE,
1970+
sslEncryptionPtr_->sslPtr_->mtlsEnabled
1971+
? SSL_VERIFY_PEER
1972+
: SSL_VERIFY_NONE,
19041973
nullptr);
1974+
}
1975+
19051976
if (!isServer && !hostname.empty())
19061977
{
19071978
SSL_set_tlsext_host_name(sslEncryptionPtr_->sslPtr_->get(),
@@ -1925,12 +1996,25 @@ bool TcpConnectionImpl::validatePeerCertificate()
19251996
SSL *ssl = sslEncryptionPtr_->sslPtr_->get();
19261997

19271998
auto result = SSL_get_verify_result(ssl);
1928-
if (result != X509_V_OK)
1999+
2000+
#ifdef ALLOW_SELF_SIGNED_CERTS
2001+
if (result != X509_V_OK &&
2002+
result != X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT &&
2003+
result != X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN &&
2004+
result != X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY)
2005+
{
2006+
LOG_DEBUG << "cert error code: " << result;
2007+
LOG_ERROR << "Server certificate is not valid";
2008+
return false;
2009+
}
2010+
#else
2011+
if (result != X509_V_OK && result)
19292012
{
19302013
LOG_DEBUG << "cert error code: " << result;
19312014
LOG_ERROR << "Server certificate is not valid";
19322015
return false;
19332016
}
2017+
#endif
19342018

19352019
X509 *cert = SSL_get_peer_certificate(ssl);
19362020
if (cert == nullptr)
@@ -1944,7 +2028,10 @@ bool TcpConnectionImpl::validatePeerCertificate()
19442028
internal::verifyAltName(cert, sslEncryptionPtr_->hostname_);
19452029
X509_free(cert);
19462030

1947-
if (domainIsValid)
2031+
LOG_DEBUG << "domainIsValid: " << domainIsValid;
2032+
2033+
// if mtlsEnabled, ignore domain validation
2034+
if (sslEncryptionPtr_->sslPtr_->mtlsEnabled || domainIsValid)
19482035
{
19492036
return true;
19502037
}
@@ -1965,7 +2052,8 @@ void TcpConnectionImpl::doHandshaking()
19652052
{
19662053
// Clients don't commonly have certificates. Let's not validate
19672054
// that
1968-
if (validateCert_ && sslEncryptionPtr_->isServer_ == false)
2055+
if (validateCert_ && (!sslEncryptionPtr_->isServer_ ||
2056+
sslEncryptionPtr_->sslPtr_->mtlsEnabled))
19692057
{
19702058
if (validatePeerCertificate() == false)
19712059
{

trantor/net/inner/TcpConnectionImpl.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,15 @@ std::shared_ptr<SSLContext> newSSLServerContext(
4646
const std::string &certPath,
4747
const std::string &keyPath,
4848
bool useOldTLS,
49-
const std::vector<std::pair<std::string, std::string>> &sslConfCmds);
49+
const std::vector<std::pair<std::string, std::string>> &sslConfCmds,
50+
const std::string &caPath);
5051
std::shared_ptr<SSLContext> newSSLClientContext(
5152
bool useOldTLS,
5253
bool validateCert,
5354
const std::string &certPath = "",
5455
const std::string &keyPath = "",
55-
const std::vector<std::pair<std::string, std::string>> &sslConfCmds = {});
56+
const std::vector<std::pair<std::string, std::string>> &sslConfCmds = {},
57+
const std::string &caPath = "");
5658

5759
// void initServerSSLContext(const std::shared_ptr<SSLContext> &ctx,
5860
// const std::string &certPath,

0 commit comments

Comments
 (0)