Skip to content

Commit 59f857c

Browse files
authored
Add SNI support to TcpClient (#122)
1 parent e35fd04 commit 59f857c

File tree

5 files changed

+67
-17
lines changed

5 files changed

+67
-17
lines changed

trantor/net/TcpClient.cc

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@
1515
#include "Connector.h"
1616
#include "inner/TcpConnectionImpl.h"
1717
#include <trantor/net/EventLoop.h>
18+
19+
#include <functional>
20+
#include <algorithm>
21+
1822
#include "Socket.h"
1923

2024
#include <stdio.h> // snprintf
21-
#include <functional>
2225

2326
using namespace trantor;
2427
using namespace std::placeholders;
@@ -140,8 +143,13 @@ void TcpClient::newConnection(int sockfd)
140143
if (sslCtxPtr_)
141144
{
142145
#ifdef USE_OPENSSL
143-
conn = std::make_shared<TcpConnectionImpl>(
144-
loop_, sockfd, localAddr, peerAddr, sslCtxPtr_, false);
146+
conn = std::make_shared<TcpConnectionImpl>(loop_,
147+
sockfd,
148+
localAddr,
149+
peerAddr,
150+
sslCtxPtr_,
151+
false,
152+
SSLHostName_);
145153
#else
146154
LOG_FATAL << "OpenSSL is not found in your system!";
147155
abort();
@@ -187,11 +195,20 @@ void TcpClient::removeConnection(const TcpConnectionPtr &conn)
187195
}
188196
}
189197

190-
void TcpClient::enableSSL(bool useOldTLS)
198+
void TcpClient::enableSSL(bool useOldTLS, std::string hostname)
191199
{
192200
#ifdef USE_OPENSSL
193201
/* Create a new OpenSSL context */
194202
sslCtxPtr_ = newSSLContext(useOldTLS);
203+
if (!hostname.empty())
204+
{
205+
std::transform(hostname.begin(),
206+
hostname.end(),
207+
hostname.begin(),
208+
tolower);
209+
SSLHostName_ = std::move(hostname);
210+
}
211+
195212
#else
196213
LOG_FATAL << "OpenSSL is not found in your system!";
197214
abort();

trantor/net/TcpClient.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,12 @@ class TcpClient : NonCopyable
179179
* @brief Enable SSL encryption.
180180
* @param useOldTLS If true, the TLS 1.0 and 1.1 are supported by the
181181
* client.
182+
* @param hostname The server hostname for SNI. If it is empty, the SNI is
183+
* not used.
182184
* @note It's well known that TLS 1.0 and 1.1 are not considered secure in
183185
* 2020. And it's a good practice to only use TLS 1.2 and above.
184186
*/
185-
void enableSSL(bool useOldTLS = false);
187+
void enableSSL(bool useOldTLS = false, std::string hostname = "");
186188

187189
private:
188190
/// Not thread safe, but in loop
@@ -203,6 +205,7 @@ class TcpClient : NonCopyable
203205
mutable std::mutex mutex_;
204206
TcpConnectionPtr connection_; // @GuardedBy mutex_
205207
std::shared_ptr<SSLContext> sslCtxPtr_;
208+
std::string SSLHostName_;
206209
#ifndef _WIN32
207210
class IgnoreSigPipe
208211
{

trantor/net/TcpConnection.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,12 @@ class TcpConnection
225225
*
226226
* @param callback The callback is called when the SSL connection is
227227
* established.
228+
* @param hostname The server hostname for SNI. If it is empty, the SNI is
229+
* not used.
228230
*/
229231
virtual void startClientEncryption(std::function<void()> callback,
230-
bool useOldTLS = false) = 0;
232+
bool useOldTLS = false,
233+
std::string hostname = "") = 0;
231234

232235
/**
233236
* @brief Start the SSL encryption on the connection (as a server).

trantor/net/inner/TcpConnectionImpl.cc

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ void initOpenSSL()
5454
});
5555
#endif
5656
}
57+
5758
class SSLContext
5859
{
5960
public:
@@ -62,7 +63,9 @@ class SSLContext
6263
#if (OPENSSL_VERSION_NUMBER >= 0x10100000L)
6364
ctxPtr_ = SSL_CTX_new(TLS_method());
6465
if (!useOldTLS)
66+
{
6567
SSL_CTX_set_min_proto_version(ctxPtr_, TLS1_2_VERSION);
68+
}
6669
else
6770
{
6871
LOG_WARN << "TLS 1.0/1.1 are enabled. They are considered "
@@ -73,8 +76,7 @@ class SSLContext
7376
ctxPtr_ = SSL_CTX_new(SSLv23_method());
7477
if (!useOldTLS)
7578
{
76-
SSL_CTX_set_options(ctxPtr_, SSL_OP_NO_TLSv1);
77-
SSL_CTX_set_options(ctxPtr_, SSL_OP_NO_TLSv1_1);
79+
SSL_CTX_set_options(ctxPtr_, SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1);
7880
}
7981
else
8082
{
@@ -210,7 +212,8 @@ TcpConnectionImpl::~TcpConnectionImpl()
210212
#ifdef USE_OPENSSL
211213
void TcpConnectionImpl::startClientEncryptionInLoop(
212214
std::function<void()> &&callback,
213-
bool useOldTLS)
215+
bool useOldTLS,
216+
const std::string &hostname)
214217
{
215218
loop_->assertInLoopThread();
216219
if (isEncrypted_)
@@ -223,6 +226,11 @@ void TcpConnectionImpl::startClientEncryptionInLoop(
223226
sslEncryptionPtr_->sslCtxPtr_ = newSSLContext(useOldTLS);
224227
sslEncryptionPtr_->sslPtr_ =
225228
std::make_unique<SSLConn>(sslEncryptionPtr_->sslCtxPtr_->get());
229+
if (!hostname.empty())
230+
{
231+
SSL_set_tlsext_host_name(sslEncryptionPtr_->sslPtr_->get(),
232+
hostname.data());
233+
}
226234
isEncrypted_ = true;
227235
sslEncryptionPtr_->isUpgrade_ = true;
228236
auto r = SSL_set_fd(sslEncryptionPtr_->sslPtr_->get(), socketPtr_->fd());
@@ -285,23 +293,33 @@ void TcpConnectionImpl::startServerEncryption(
285293
#endif
286294
}
287295
void TcpConnectionImpl::startClientEncryption(std::function<void()> callback,
288-
bool useOldTLS)
296+
bool useOldTLS,
297+
std::string hostname)
289298
{
290299
#ifndef USE_OPENSSL
291300
LOG_FATAL << "OpenSSL is not found in your system!";
292301
abort();
293302
#else
303+
if (!hostname.empty())
304+
{
305+
std::transform(hostname.begin(),
306+
hostname.end(),
307+
hostname.begin(),
308+
tolower);
309+
}
294310
if (loop_->isInLoopThread())
295311
{
296-
startClientEncryptionInLoop(std::move(callback), useOldTLS);
312+
startClientEncryptionInLoop(std::move(callback), useOldTLS, hostname);
297313
}
298314
else
299315
{
300316
loop_->queueInLoop([thisPtr = shared_from_this(),
301317
callback = std::move(callback),
302-
useOldTLS]() mutable {
318+
useOldTLS,
319+
hostname = std::move(hostname)]() mutable {
303320
thisPtr->startClientEncryptionInLoop(std::move(callback),
304-
useOldTLS);
321+
useOldTLS,
322+
hostname);
305323
});
306324
}
307325
#endif
@@ -1372,7 +1390,8 @@ TcpConnectionImpl::TcpConnectionImpl(EventLoop *loop,
13721390
const InetAddress &localAddr,
13731391
const InetAddress &peerAddr,
13741392
const std::shared_ptr<SSLContext> &ctxPtr,
1375-
bool isServer)
1393+
bool isServer,
1394+
const std::string &hostname)
13761395
: loop_(loop),
13771396
ioChannelPtr_(new Channel(loop, socketfd)),
13781397
socketPtr_(new Socket(socketfd)),
@@ -1395,6 +1414,11 @@ TcpConnectionImpl::TcpConnectionImpl(EventLoop *loop,
13951414
sslEncryptionPtr_ = std::make_unique<SSLEncryption>();
13961415
sslEncryptionPtr_->sslPtr_ = std::make_unique<SSLConn>(ctxPtr->get());
13971416
sslEncryptionPtr_->isServer_ = isServer;
1417+
if (!isServer && !hostname.empty())
1418+
{
1419+
SSL_set_tlsext_host_name(sslEncryptionPtr_->sslPtr_->get(),
1420+
hostname.data());
1421+
}
13981422
assert(sslEncryptionPtr_->sslPtr_);
13991423
auto r = SSL_set_fd(sslEncryptionPtr_->sslPtr_->get(), socketfd);
14001424
(void)r;

trantor/net/inner/TcpConnectionImpl.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ class TcpConnectionImpl : public TcpConnection,
9393
const InetAddress &localAddr,
9494
const InetAddress &peerAddr,
9595
const std::shared_ptr<SSLContext> &ctxPtr,
96-
bool isServer = true);
96+
bool isServer = true,
97+
const std::string &hostname = "");
9798
#endif
9899
virtual ~TcpConnectionImpl();
99100
virtual void send(const char *msg, size_t len) override;
@@ -169,7 +170,8 @@ class TcpConnectionImpl : public TcpConnection,
169170
return bytesReceived_;
170171
}
171172
virtual void startClientEncryption(std::function<void()> callback,
172-
bool useOldTLS = false) override;
173+
bool useOldTLS = false,
174+
std::string hostname = "") override;
173175
virtual void startServerEncryption(const std::shared_ptr<SSLContext> &ctx,
174176
std::function<void()> callback) override;
175177
virtual bool isSSLConnection() const override
@@ -308,7 +310,8 @@ class TcpConnectionImpl : public TcpConnection,
308310
};
309311
std::unique_ptr<SSLEncryption> sslEncryptionPtr_;
310312
void startClientEncryptionInLoop(std::function<void()> &&callback,
311-
bool useOldTLS);
313+
bool useOldTLS,
314+
const std::string &hostname);
312315
void startServerEncryptionInLoop(const std::shared_ptr<SSLContext> &ctx,
313316
std::function<void()> &&callback);
314317
#endif

0 commit comments

Comments
 (0)