Skip to content

Commit a7a17be

Browse files
behoskezhuw
authored andcommitted
Support client side tls certificates reload
This allows creating a client with dynamic tls certificates. When created this way, on reconnection the client will use latest tls certificates. This allows us to have auto-reloading of refreshed certificates stored anywhere in client side. Resolves #59.
1 parent 1f3eb7a commit a7a17be

File tree

8 files changed

+785
-225
lines changed

8 files changed

+785
-225
lines changed

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ rcgen = { version = "0.14.1", features = ["default", "x509-parser"] }
6969
serial_test = "3.0.0"
7070
asyncs = { version = "0.4.0", features = ["test"] }
7171
blocking = "1.6.0"
72+
rustls-pki-types = "1.12.0"
73+
x509-parser = "0.17.0"
74+
atomic-write-file = "0.2.3"
75+
notify = "7.0.0"
7276

7377
[package.metadata.cargo-all-features]
7478
skip_optional_dependencies = true

src/session/connection.rs

Lines changed: 18 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,15 @@ use bytes::buf::BufMut;
1010
use futures::io::BufReader;
1111
use futures::prelude::*;
1212
use futures_lite::AsyncReadExt;
13+
#[cfg(feature = "tls")]
14+
pub use futures_rustls::client::TlsStream;
1315
use ignore_result::Ignore;
1416
use tracing::{debug, trace};
1517

16-
#[cfg(feature = "tls")]
17-
mod tls {
18-
pub use std::sync::Arc;
19-
20-
pub use futures_rustls::client::TlsStream;
21-
pub use futures_rustls::TlsConnector;
22-
pub use rustls::pki_types::ServerName;
23-
pub use rustls::ClientConfig;
24-
}
25-
#[cfg(feature = "tls")]
26-
use tls::*;
27-
2818
use crate::deadline::Deadline;
2919
use crate::endpoint::{EndpointRef, IterableEndpoints};
20+
#[cfg(feature = "tls")]
21+
use crate::tls::TlsClient;
3022

3123
#[derive(Debug)]
3224
pub enum Connection {
@@ -170,7 +162,7 @@ impl Connection {
170162
#[derive(Clone)]
171163
pub struct Connector {
172164
#[cfg(feature = "tls")]
173-
tls: Option<TlsConnector>,
165+
tls: Option<TlsClient>,
174166
timeout: Duration,
175167
}
176168

@@ -186,15 +178,8 @@ impl Connector {
186178
}
187179

188180
#[cfg(feature = "tls")]
189-
pub fn with_tls(config: ClientConfig) -> Self {
190-
Self { tls: Some(TlsConnector::from(Arc::new(config))), timeout: Duration::from_secs(10) }
191-
}
192-
193-
#[cfg(feature = "tls")]
194-
async fn connect_tls(&self, stream: TcpStream, host: &str) -> Result<Connection> {
195-
let domain = ServerName::try_from(host).unwrap().to_owned();
196-
let stream = self.tls.as_ref().unwrap().connect(domain, stream).await?;
197-
Ok(Connection::new_tls(stream))
181+
pub fn with_tls(client: TlsClient) -> Self {
182+
Self { tls: Some(client), timeout: Duration::from_secs(10) }
198183
}
199184

200185
pub fn timeout(&self) -> Duration {
@@ -205,34 +190,25 @@ impl Connector {
205190
self.timeout = timeout;
206191
}
207192

208-
pub async fn connect(&self, endpoint: EndpointRef<'_>, deadline: &mut Deadline) -> Result<Connection> {
193+
async fn connect_endpoint(&self, endpoint: EndpointRef<'_>) -> Result<Connection> {
209194
if endpoint.tls {
210195
#[cfg(feature = "tls")]
211-
if self.tls.is_none() {
212-
return Err(Error::new(ErrorKind::Unsupported, "tls not configured"));
213-
}
196+
return match self.tls.as_ref() {
197+
None => return Err(Error::new(ErrorKind::Unsupported, "tls not configured")),
198+
Some(client) => client.connect(endpoint.host, endpoint.port).await.map(Connection::new_tls),
199+
};
214200
#[cfg(not(feature = "tls"))]
215201
return Err(Error::new(ErrorKind::Unsupported, "tls not supported"));
216202
}
203+
TcpStream::connect((endpoint.host, endpoint.port)).await.map(Connection::new_raw)
204+
}
205+
206+
pub async fn connect(&self, endpoint: EndpointRef<'_>, deadline: &mut Deadline) -> Result<Connection> {
217207
select! {
208+
biased;
209+
r = self.connect_endpoint(endpoint) => r,
218210
_ = unsafe { Pin::new_unchecked(deadline) } => Err(Error::new(ErrorKind::TimedOut, "deadline exceed")),
219211
_ = Timer::after(self.timeout) => Err(Error::new(ErrorKind::TimedOut, format!("connection timeout{:?} exceed", self.timeout))),
220-
r = TcpStream::connect((endpoint.host, endpoint.port)) => {
221-
match r {
222-
Err(err) => Err(err),
223-
Ok(sock) => {
224-
let connection = if endpoint.tls {
225-
#[cfg(not(feature = "tls"))]
226-
unreachable!("tls not supported");
227-
#[cfg(feature = "tls")]
228-
self.connect_tls(sock, endpoint.host).await?
229-
} else {
230-
Connection::new_raw(sock)
231-
};
232-
Ok(connection)
233-
},
234-
}
235-
},
236212
}
237213
}
238214

src/session/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ impl Builder {
131131
}
132132
#[cfg(feature = "tls")]
133133
let connector = match self.tls {
134-
Some(options) => Connector::with_tls(options.into_config()?),
134+
Some(options) => Connector::with_tls(options.into_client()?),
135135
None => Connector::new(),
136136
};
137137
#[cfg(not(feature = "tls"))]

src/tls.rs

Lines changed: 0 additions & 182 deletions
This file was deleted.

src/tls/cert.rs

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
2+
use rustls::crypto::{CryptoProvider, WebPkiSupportedAlgorithms};
3+
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
4+
use rustls::server::ParsedCertificate;
5+
use rustls::{DigitallySignedStruct, Error as TlsError, RootCertStore, SignatureScheme};
6+
7+
use crate::client::Result;
8+
use crate::Error;
9+
10+
// Rustls tends to make disable of hostname verification verbose since it exposes man-in-the-middle
11+
// attacks. Though, there are still attempts to disable hostname verification in rustls, but no got
12+
// merged until now.
13+
// * Allow disabling Hostname Verification: https://github.com/rustls/rustls/issues/578
14+
// * Dangerous verifiers API proposal: https://github.com/rustls/rustls/pull/1197
15+
#[derive(Debug)]
16+
pub(super) struct NoHostnameVerificationServerCertVerifier {
17+
roots: RootCertStore,
18+
supported: WebPkiSupportedAlgorithms,
19+
hostname_verification: bool,
20+
}
21+
22+
impl NoHostnameVerificationServerCertVerifier {
23+
pub unsafe fn new(roots: RootCertStore, hostname_verification: bool) -> Self {
24+
Self {
25+
roots,
26+
supported: CryptoProvider::get_default().unwrap().signature_verification_algorithms,
27+
hostname_verification,
28+
}
29+
}
30+
}
31+
32+
impl ServerCertVerifier for NoHostnameVerificationServerCertVerifier {
33+
fn verify_server_cert(
34+
&self,
35+
end_entity: &CertificateDer<'_>,
36+
intermediates: &[CertificateDer<'_>],
37+
server_name: &ServerName<'_>,
38+
_ocsp_response: &[u8],
39+
now: UnixTime,
40+
) -> Result<ServerCertVerified, TlsError> {
41+
let cert = ParsedCertificate::try_from(end_entity)?;
42+
rustls::client::verify_server_cert_signed_by_trust_anchor(
43+
&cert,
44+
&self.roots,
45+
intermediates,
46+
now,
47+
self.supported.all,
48+
)?;
49+
50+
if self.hostname_verification {
51+
rustls::client::verify_server_name(&cert, server_name)?;
52+
}
53+
Ok(ServerCertVerified::assertion())
54+
}
55+
56+
fn verify_tls12_signature(
57+
&self,
58+
message: &[u8],
59+
cert: &CertificateDer<'_>,
60+
dss: &DigitallySignedStruct,
61+
) -> Result<HandshakeSignatureValid, TlsError> {
62+
rustls::crypto::verify_tls12_signature(message, cert, dss, &self.supported)
63+
}
64+
65+
fn verify_tls13_signature(
66+
&self,
67+
message: &[u8],
68+
cert: &CertificateDer<'_>,
69+
dss: &DigitallySignedStruct,
70+
) -> Result<HandshakeSignatureValid, TlsError> {
71+
rustls::crypto::verify_tls12_signature(message, cert, dss, &self.supported)
72+
}
73+
74+
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
75+
self.supported.supported_schemes()
76+
}
77+
}
78+
79+
/// Helper function to parse certificate and key content from strings
80+
pub(super) fn parse_pem_identity(
81+
cert: &str,
82+
key: &str,
83+
) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
84+
let r: std::result::Result<Vec<_>, _> = rustls_pemfile::certs(&mut cert.as_bytes()).collect();
85+
let certs = match r {
86+
Err(err) => return Err(Error::with_other("fail to read cert", err)),
87+
Ok(certs) => certs,
88+
};
89+
let key = match rustls_pemfile::private_key(&mut key.as_bytes()) {
90+
Err(err) => return Err(Error::with_other("fail to read client private key", err)),
91+
Ok(None) => return Err(Error::BadArguments(&"no client private key")),
92+
Ok(Some(key)) => key,
93+
};
94+
Ok((certs, key))
95+
}

0 commit comments

Comments
 (0)