Skip to content

Commit 4791f76

Browse files
marty1885an-tao
andauthored
SSL certificate validation (#120)
Co-authored-by: an-tao <antao2002@gmail.com>
1 parent 59f857c commit 4791f76

File tree

8 files changed

+292
-22
lines changed

8 files changed

+292
-22
lines changed

trantor/net/TcpClient.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ void TcpClient::newConnection(int sockfd)
149149
peerAddr,
150150
sslCtxPtr_,
151151
false,
152+
validateCert_,
152153
SSLHostName_);
153154
#else
154155
LOG_FATAL << "OpenSSL is not found in your system!";
@@ -170,6 +171,12 @@ void TcpClient::newConnection(int sockfd)
170171
std::lock_guard<std::mutex> lock(mutex_);
171172
connection_ = conn;
172173
}
174+
conn->setSSLErrorCallback([this](SSLError err) {
175+
if (sslErrorCallback_)
176+
{
177+
sslErrorCallback_(err);
178+
}
179+
});
173180
conn->connectEstablished();
174181
}
175182

@@ -195,11 +202,14 @@ void TcpClient::removeConnection(const TcpConnectionPtr &conn)
195202
}
196203
}
197204

198-
void TcpClient::enableSSL(bool useOldTLS, std::string hostname)
205+
void TcpClient::enableSSL(bool useOldTLS,
206+
bool validateCert,
207+
std::string hostname)
199208
{
200209
#ifdef USE_OPENSSL
201210
/* Create a new OpenSSL context */
202-
sslCtxPtr_ = newSSLContext(useOldTLS);
211+
sslCtxPtr_ = newSSLContext(useOldTLS, validateCert);
212+
validateCert_ = validateCert;
203213
if (!hostname.empty())
204214
{
205215
std::transform(hostname.begin(),

trantor/net/TcpClient.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,16 +175,33 @@ class TcpClient : NonCopyable
175175
writeCompleteCallback_ = std::move(cb);
176176
}
177177

178+
/**
179+
* @brief Set the callback for errors of SSL
180+
* @param cb The callback is called when an SSL error occurs.
181+
*/
182+
void setSSLErrorCallback(const SSLErrorCallback &cb)
183+
{
184+
sslErrorCallback_ = cb;
185+
}
186+
void setSSLErrorCallback(SSLErrorCallback &&cb)
187+
{
188+
sslErrorCallback_ = std::move(cb);
189+
}
190+
178191
/**
179192
* @brief Enable SSL encryption.
180193
* @param useOldTLS If true, the TLS 1.0 and 1.1 are supported by the
181194
* client.
182195
* @param hostname The server hostname for SNI. If it is empty, the SNI is
183196
* not used.
197+
* @param validateCert If true, we try to validate if the peer's SSL cert
198+
* is valid.
184199
* @note It's well known that TLS 1.0 and 1.1 are not considered secure in
185200
* 2020. And it's a good practice to only use TLS 1.2 and above.
186201
*/
187-
void enableSSL(bool useOldTLS = false, std::string hostname = "");
202+
void enableSSL(bool useOldTLS = false,
203+
bool validateCert = true,
204+
std::string hostname = "");
188205

189206
private:
190207
/// Not thread safe, but in loop
@@ -199,12 +216,14 @@ class TcpClient : NonCopyable
199216
ConnectionErrorCallback connectionErrorCallback_;
200217
RecvMessageCallback messageCallback_;
201218
WriteCompleteCallback writeCompleteCallback_;
219+
SSLErrorCallback sslErrorCallback_;
202220
std::atomic_bool retry_; // atomic
203221
std::atomic_bool connect_; // atomic
204222
// always in loop thread
205223
mutable std::mutex mutex_;
206224
TcpConnectionPtr connection_; // @GuardedBy mutex_
207225
std::shared_ptr<SSLContext> sslCtxPtr_;
226+
bool validateCert_{false};
208227
std::string SSLHostName_;
209228
#ifndef _WIN32
210229
class IgnoreSigPipe

trantor/net/TcpConnection.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ class TcpConnection
230230
*/
231231
virtual void startClientEncryption(std::function<void()> callback,
232232
bool useOldTLS = false,
233+
bool validateCert = true,
233234
std::string hostname = "") = 0;
234235

235236
/**
@@ -242,6 +243,9 @@ class TcpConnection
242243
virtual void startServerEncryption(const std::shared_ptr<SSLContext> &ctx,
243244
std::function<void()> callback) = 0;
244245

246+
protected:
247+
bool validateCert_ = false;
248+
245249
private:
246250
std::shared_ptr<void> contextPtr_;
247251
};

trantor/net/callbacks.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,11 @@
1818
#include <memory>
1919
namespace trantor
2020
{
21-
enum TrantorError
21+
enum class SSLError
2222
{
23-
TrantorError_None,
24-
TrantorError_UnkownError
23+
kSSLHandshakeError,
24+
kSSLInvalidCertificate
2525
};
26-
2726
using TimerCallback = std::function<void()>;
2827

2928
// the data has been read to (buf, len)
@@ -39,7 +38,6 @@ using CloseCallback = std::function<void(const TcpConnectionPtr &)>;
3938
using WriteCompleteCallback = std::function<void(const TcpConnectionPtr &)>;
4039
using HighWaterMarkCallback =
4140
std::function<void(const TcpConnectionPtr &, const size_t)>;
42-
43-
using OperationCompleteCallback = std::function<void(const TrantorError)>;
41+
using SSLErrorCallback = std::function<void(SSLError)>;
4442

4543
} // namespace trantor

0 commit comments

Comments
 (0)