Skip to content

Commit f72bc68

Browse files
authored
Add SSL_CONF_cmd support (#148)
1 parent 8d21e95 commit f72bc68

File tree

7 files changed

+98
-43
lines changed

7 files changed

+98
-43
lines changed

trantor/net/TcpClient.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ void TcpClient::enableSSL(bool useOldTLS,
208208
{
209209
#ifdef USE_OPENSSL
210210
/* Create a new OpenSSL context */
211-
sslCtxPtr_ = newSSLContext(useOldTLS, validateCert);
211+
sslCtxPtr_ = newSSLContext(useOldTLS, validateCert, {});
212212
validateCert_ = validateCert;
213213
if (!hostname.empty())
214214
{

trantor/net/TcpConnection.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ class SSLContext;
2929
TRANTOR_EXPORT std::shared_ptr<SSLContext> newSSLServerContext(
3030
const std::string &certPath,
3131
const std::string &keyPath,
32-
bool useOldTLS = false);
32+
bool useOldTLS = false,
33+
const std::vector<std::pair<std::string, std::string>> &sslConfCmds = {});
3334
/**
3435
* @brief This class represents a TCP connection.
3536
*
@@ -230,10 +231,13 @@ class TRANTOR_EXPORT TcpConnection
230231
* @param hostname The server hostname for SNI. If it is empty, the SNI is
231232
* not used.
232233
*/
233-
virtual void startClientEncryption(std::function<void()> callback,
234-
bool useOldTLS = false,
235-
bool validateCert = true,
236-
std::string hostname = "") = 0;
234+
virtual void startClientEncryption(
235+
std::function<void()> callback,
236+
bool useOldTLS = false,
237+
bool validateCert = true,
238+
std::string hostname = "",
239+
const std::vector<std::pair<std::string, std::string>> &sslConfCmds =
240+
{}) = 0;
237241

238242
/**
239243
* @brief Start the SSL encryption on the connection (as a server).

trantor/net/TcpServer.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,15 @@ const trantor::InetAddress &TcpServer::address() const
185185
return acceptorPtr_->addr();
186186
}
187187

188-
void TcpServer::enableSSL(const std::string &certPath,
189-
const std::string &keyPath,
190-
bool useOldTLS)
188+
void TcpServer::enableSSL(
189+
const std::string &certPath,
190+
const std::string &keyPath,
191+
bool useOldTLS,
192+
const std::vector<std::pair<std::string, std::string>> &sslConfCmds)
191193
{
192194
#ifdef USE_OPENSSL
193195
/* Create a new OpenSSL context */
194-
sslCtxPtr_ = newSSLServerContext(certPath, keyPath, useOldTLS);
196+
sslCtxPtr_ = newSSLServerContext(certPath, keyPath, useOldTLS, sslConfCmds);
195197
#else
196198
LOG_FATAL << "OpenSSL is not found in your system!";
197199
abort();

trantor/net/TcpServer.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,9 @@ class TRANTOR_EXPORT TcpServer : NonCopyable
207207
*/
208208
void enableSSL(const std::string &certPath,
209209
const std::string &keyPath,
210-
bool useOldTLS = false);
210+
bool useOldTLS = false,
211+
const std::vector<std::pair<std::string, std::string>>
212+
&sslConfCmds = {});
211213

212214
private:
213215
EventLoop *loop_;

trantor/net/inner/TcpConnectionImpl.cc

Lines changed: 56 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,23 @@ void initOpenSSL()
197197
class SSLContext
198198
{
199199
public:
200-
explicit SSLContext(bool useOldTLS, bool enableValidtion)
200+
explicit SSLContext(
201+
bool useOldTLS,
202+
bool enableValidtion,
203+
const std::vector<std::pair<std::string, std::string>> &sslConfCmds)
201204
{
202205
#if (OPENSSL_VERSION_NUMBER >= 0x10100000L)
203206
ctxPtr_ = SSL_CTX_new(TLS_method());
207+
SSL_CONF_CTX *cctx = SSL_CONF_CTX_new();
208+
SSL_CONF_CTX_set_flags(cctx, SSL_CONF_FLAG_SERVER);
209+
SSL_CONF_CTX_set_flags(cctx, SSL_CONF_FLAG_CLIENT);
210+
SSL_CONF_CTX_set_flags(cctx, SSL_CONF_FLAG_CERTIFICATE);
211+
SSL_CONF_CTX_set_ssl_ctx(cctx, ctxPtr_);
212+
for (auto cmd : sslConfCmds)
213+
{
214+
SSL_CONF_cmd(cctx, cmd.first.data(), cmd.second.data());
215+
}
216+
SSL_CONF_CTX_finish(cctx);
204217
if (!useOldTLS)
205218
{
206219
SSL_CTX_set_min_proto_version(ctxPtr_, TLS1_2_VERSION);
@@ -213,6 +226,16 @@ class SSLContext
213226
}
214227
#else
215228
ctxPtr_ = SSL_CTX_new(SSLv23_method());
229+
SSL_CONF_CTX *cctx = SSL_CONF_CTX_new();
230+
SSL_CONF_CTX_set_flags(cctx, SSL_CONF_FLAG_SERVER);
231+
SSL_CONF_CTX_set_flags(cctx, SSL_CONF_FLAG_CLIENT);
232+
SSL_CONF_CTX_set_flags(cctx, SSL_CONF_FLAG_CERTIFICATE);
233+
SSL_CONF_CTX_set_ssl_ctx(cctx, ctxPtr_);
234+
for (auto cmd : sslConfCmds)
235+
{
236+
SSL_CONF_cmd(cctx, cmd.first.data(), cmd.second.data());
237+
}
238+
SSL_CONF_CTX_finish(cctx);
216239
if (!useOldTLS)
217240
{
218241
SSL_CTX_set_options(ctxPtr_, SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1);
@@ -271,16 +294,21 @@ class SSLConn
271294
SSL *SSL_;
272295
};
273296

274-
std::shared_ptr<SSLContext> newSSLContext(bool useOldTLS, bool validateCert)
297+
std::shared_ptr<SSLContext> newSSLContext(
298+
bool useOldTLS,
299+
bool validateCert,
300+
const std::vector<std::pair<std::string, std::string>> &sslConfCmds)
275301
{ // init OpenSSL
276302
initOpenSSL();
277-
return std::make_shared<SSLContext>(useOldTLS, validateCert);
303+
return std::make_shared<SSLContext>(useOldTLS, validateCert, sslConfCmds);
278304
}
279-
std::shared_ptr<SSLContext> newSSLServerContext(const std::string &certPath,
280-
const std::string &keyPath,
281-
bool useOldTLS)
305+
std::shared_ptr<SSLContext> newSSLServerContext(
306+
const std::string &certPath,
307+
const std::string &keyPath,
308+
bool useOldTLS,
309+
const std::vector<std::pair<std::string, std::string>> &sslConfCmds)
282310
{
283-
auto ctx = newSSLContext(useOldTLS, false);
311+
auto ctx = newSSLContext(useOldTLS, false, sslConfCmds);
284312
auto r = SSL_CTX_use_certificate_chain_file(ctx->get(), certPath.c_str());
285313
if (!r)
286314
{
@@ -319,9 +347,11 @@ std::shared_ptr<SSLContext> newSSLServerContext(const std::string &certPath,
319347
#else
320348
namespace trantor
321349
{
322-
std::shared_ptr<SSLContext> newSSLServerContext(const std::string &certPath,
323-
const std::string &keyPath,
324-
bool useOldTLS)
350+
std::shared_ptr<SSLContext> newSSLServerContext(
351+
const std::string &certPath,
352+
const std::string &keyPath,
353+
bool useOldTLS,
354+
const std::vector<std::pair<std::string, std::string>> &sslConfCmds)
325355
{
326356
LOG_FATAL << "OpenSSL is not found in your system!";
327357
abort();
@@ -360,7 +390,8 @@ void TcpConnectionImpl::startClientEncryptionInLoop(
360390
std::function<void()> &&callback,
361391
bool useOldTLS,
362392
bool validateCert,
363-
const std::string &hostname)
393+
const std::string &hostname,
394+
const std::vector<std::pair<std::string, std::string>> &sslConfCmds)
364395
{
365396
validateCert_ = validateCert;
366397
loop_->assertInLoopThread();
@@ -371,7 +402,8 @@ void TcpConnectionImpl::startClientEncryptionInLoop(
371402
}
372403
sslEncryptionPtr_ = std::make_unique<SSLEncryption>();
373404
sslEncryptionPtr_->upgradeCallback_ = std::move(callback);
374-
sslEncryptionPtr_->sslCtxPtr_ = newSSLContext(useOldTLS, validateCert_);
405+
sslEncryptionPtr_->sslCtxPtr_ =
406+
newSSLContext(useOldTLS, validateCert_, sslConfCmds);
375407
sslEncryptionPtr_->sslPtr_ =
376408
std::make_unique<SSLConn>(sslEncryptionPtr_->sslCtxPtr_->get());
377409
if (validateCert)
@@ -452,10 +484,12 @@ void TcpConnectionImpl::startServerEncryption(
452484

453485
#endif
454486
}
455-
void TcpConnectionImpl::startClientEncryption(std::function<void()> callback,
456-
bool useOldTLS,
457-
bool validateCert,
458-
std::string hostname)
487+
void TcpConnectionImpl::startClientEncryption(
488+
std::function<void()> callback,
489+
bool useOldTLS,
490+
bool validateCert,
491+
std::string hostname,
492+
const std::vector<std::pair<std::string, std::string>> &sslConfCmds)
459493
{
460494
#ifndef USE_OPENSSL
461495
LOG_FATAL << "OpenSSL is not found in your system!";
@@ -475,19 +509,22 @@ void TcpConnectionImpl::startClientEncryption(std::function<void()> callback,
475509
startClientEncryptionInLoop(std::move(callback),
476510
useOldTLS,
477511
validateCert,
478-
hostname);
512+
hostname,
513+
sslConfCmds);
479514
}
480515
else
481516
{
482517
loop_->queueInLoop([thisPtr = shared_from_this(),
483518
callback = std::move(callback),
484519
useOldTLS,
485520
hostname = std::move(hostname),
486-
validateCert]() mutable {
521+
validateCert,
522+
&sslConfCmds]() mutable {
487523
thisPtr->startClientEncryptionInLoop(std::move(callback),
488524
useOldTLS,
489525
validateCert,
490-
hostname);
526+
hostname,
527+
sslConfCmds);
491528
});
492529
}
493530
#endif

trantor/net/inner/TcpConnectionImpl.h

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,15 @@ enum class SSLStatus
3838
class SSLContext;
3939
class SSLConn;
4040

41-
std::shared_ptr<SSLContext> newSSLContext(bool useOldTLS, bool validateCert);
42-
std::shared_ptr<SSLContext> newSSLServerContext(const std::string &certPath,
43-
const std::string &keyPath,
44-
bool useOldTLS);
41+
std::shared_ptr<SSLContext> newSSLContext(
42+
bool useOldTLS,
43+
bool validateCert,
44+
const std::vector<std::pair<std::string, std::string>> &sslConfCmds);
45+
std::shared_ptr<SSLContext> newSSLServerContext(
46+
const std::string &certPath,
47+
const std::string &keyPath,
48+
bool useOldTLS,
49+
const std::vector<std::pair<std::string, std::string>> &sslConfCmds);
4550
// void initServerSSLContext(const std::shared_ptr<SSLContext> &ctx,
4651
// const std::string &certPath,
4752
// const std::string &keyPath);
@@ -171,10 +176,13 @@ class TcpConnectionImpl : public TcpConnection,
171176
{
172177
return bytesReceived_;
173178
}
174-
virtual void startClientEncryption(std::function<void()> callback,
175-
bool useOldTLS = false,
176-
bool validateCert = true,
177-
std::string hostname = "") override;
179+
virtual void startClientEncryption(
180+
std::function<void()> callback,
181+
bool useOldTLS = false,
182+
bool validateCert = true,
183+
std::string hostname = "",
184+
const std::vector<std::pair<std::string, std::string>> &sslConfCmds =
185+
{}) override;
178186
virtual void startServerEncryption(const std::shared_ptr<SSLContext> &ctx,
179187
std::function<void()> callback) override;
180188
virtual bool isSSLConnection() const override
@@ -320,10 +328,12 @@ class TcpConnectionImpl : public TcpConnection,
320328
std::string hostname_;
321329
};
322330
std::unique_ptr<SSLEncryption> sslEncryptionPtr_;
323-
void startClientEncryptionInLoop(std::function<void()> &&callback,
324-
bool useOldTLS,
325-
bool validateCert,
326-
const std::string &hostname);
331+
void startClientEncryptionInLoop(
332+
std::function<void()> &&callback,
333+
bool useOldTLS,
334+
bool validateCert,
335+
const std::string &hostname,
336+
const std::vector<std::pair<std::string, std::string>> &sslConfCmds);
327337
void startServerEncryptionInLoop(const std::shared_ptr<SSLContext> &ctx,
328338
std::function<void()> &&callback);
329339
#endif

trantor/tests/DelayedSSLServerTest.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ int main()
1717
InetAddress addr(8888);
1818
#endif
1919
TcpServer server(loopThread.getLoop(), addr, "test");
20-
auto ctx = newSSLServerContext("server.pem", "server.pem");
20+
auto ctx = newSSLServerContext("server.pem", "server.pem", {});
2121
LOG_INFO << "start";
2222
server.setRecvMessageCallback(
2323
[](const TcpConnectionPtr &connectionPtr, MsgBuffer *buffer) {

0 commit comments

Comments
 (0)