diff --git a/Cargo.toml b/Cargo.toml index b12a732..14ab219 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ rust-version = "1.76" [features] default = [] -tls = ["rustls", "rustls-pemfile", "futures-rustls"] +tls = ["rustls", "rustls-webpki", "rustls-pemfile", "futures-rustls"] sasl = ["sasl-gssapi", "sasl-digest-md5"] sasl-digest-md5 = ["rsasl/unstable_custom_mechanism", "md5", "linkme", "hex"] sasl-gssapi = ["rsasl/gssapi"] @@ -38,6 +38,7 @@ hashlink = "0.8.0" either = "1.9.0" uuid = { version = "1.4.1", features = ["v4"] } rustls = { version = "0.23.2", optional = true } +rustls-webpki = { version = "0.103.4", optional = true } rustls-pemfile = { version = "2", optional = true } derive-where = "1.2.7" fastrand = "2.0.2" @@ -67,6 +68,10 @@ rcgen = { version = "0.14.1", features = ["default", "x509-parser"] } serial_test = "3.0.0" asyncs = { version = "0.4.0", features = ["test"] } blocking = "1.6.0" +rustls-pki-types = "1.12.0" +x509-parser = "0.17.0" +atomic-write-file = "0.2.3" +notify = "7.0.0" [package.metadata.cargo-all-features] skip_optional_dependencies = true diff --git a/src/lib.rs b/src/lib.rs index 39f3d01..4a18a46 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -45,7 +45,7 @@ mod util; pub use self::acl::{Acl, Acls, AuthId, AuthUser, Permission}; pub use self::error::Error; #[cfg(feature = "tls")] -pub use self::tls::TlsOptions; +pub use self::tls::{TlsCa, TlsCerts, TlsCertsBuilder, TlsCertsOptions, TlsDynamicCerts, TlsIdentity, TlsOptions}; pub use crate::client::*; #[cfg(feature = "sasl-digest-md5")] pub use crate::sasl::DigestMd5SaslOptions; diff --git a/src/session/connection.rs b/src/session/connection.rs index 5e8d36d..73397f3 100644 --- a/src/session/connection.rs +++ b/src/session/connection.rs @@ -10,23 +10,15 @@ use bytes::buf::BufMut; use futures::io::BufReader; use futures::prelude::*; use futures_lite::AsyncReadExt; +#[cfg(feature = "tls")] +pub use futures_rustls::client::TlsStream; use ignore_result::Ignore; use tracing::{debug, trace}; -#[cfg(feature = "tls")] -mod tls { - pub use std::sync::Arc; - - pub use futures_rustls::client::TlsStream; - pub use futures_rustls::TlsConnector; - pub use rustls::pki_types::ServerName; - pub use rustls::ClientConfig; -} -#[cfg(feature = "tls")] -use tls::*; - use crate::deadline::Deadline; use crate::endpoint::{EndpointRef, IterableEndpoints}; +#[cfg(feature = "tls")] +use crate::tls::TlsClient; #[derive(Debug)] pub enum Connection { @@ -170,31 +162,22 @@ impl Connection { #[derive(Clone)] pub struct Connector { #[cfg(feature = "tls")] - tls: Option, + tls: Option, timeout: Duration, } impl Connector { - #[cfg(feature = "tls")] - pub fn new() -> Self { - Self { tls: None, timeout: Duration::from_secs(10) } - } - - #[cfg(not(feature = "tls"))] pub fn new() -> Self { - Self { timeout: Duration::from_secs(10) } - } - - #[cfg(feature = "tls")] - pub fn with_tls(config: ClientConfig) -> Self { - Self { tls: Some(TlsConnector::from(Arc::new(config))), timeout: Duration::from_secs(10) } + Self { + #[cfg(feature = "tls")] + tls: None, + timeout: Duration::from_secs(10), + } } #[cfg(feature = "tls")] - async fn connect_tls(&self, stream: TcpStream, host: &str) -> Result { - let domain = ServerName::try_from(host).unwrap().to_owned(); - let stream = self.tls.as_ref().unwrap().connect(domain, stream).await?; - Ok(Connection::new_tls(stream)) + pub fn with_tls(client: TlsClient) -> Self { + Self { tls: Some(client), timeout: Duration::from_secs(10) } } pub fn timeout(&self) -> Duration { @@ -205,34 +188,25 @@ impl Connector { self.timeout = timeout; } - pub async fn connect(&self, endpoint: EndpointRef<'_>, deadline: &mut Deadline) -> Result { + async fn connect_endpoint(&self, endpoint: EndpointRef<'_>) -> Result { if endpoint.tls { #[cfg(feature = "tls")] - if self.tls.is_none() { - return Err(Error::new(ErrorKind::Unsupported, "tls not configured")); - } + return match self.tls.as_ref() { + None => return Err(Error::new(ErrorKind::Unsupported, "tls not configured")), + Some(client) => client.connect(endpoint.host, endpoint.port).await.map(Connection::new_tls), + }; #[cfg(not(feature = "tls"))] return Err(Error::new(ErrorKind::Unsupported, "tls not supported")); } + TcpStream::connect((endpoint.host, endpoint.port)).await.map(Connection::new_raw) + } + + pub async fn connect(&self, endpoint: EndpointRef<'_>, deadline: &mut Deadline) -> Result { select! { + biased; + r = self.connect_endpoint(endpoint) => r, _ = unsafe { Pin::new_unchecked(deadline) } => Err(Error::new(ErrorKind::TimedOut, "deadline exceed")), _ = Timer::after(self.timeout) => Err(Error::new(ErrorKind::TimedOut, format!("connection timeout{:?} exceed", self.timeout))), - r = TcpStream::connect((endpoint.host, endpoint.port)) => { - match r { - Err(err) => Err(err), - Ok(sock) => { - let connection = if endpoint.tls { - #[cfg(not(feature = "tls"))] - unreachable!("tls not supported"); - #[cfg(feature = "tls")] - self.connect_tls(sock, endpoint.host).await? - } else { - Connection::new_raw(sock) - }; - Ok(connection) - }, - } - }, } } diff --git a/src/session/mod.rs b/src/session/mod.rs index 4c25f93..01f1ea0 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -131,7 +131,7 @@ impl Builder { } #[cfg(feature = "tls")] let connector = match self.tls { - Some(options) => Connector::with_tls(options.into_config()?), + Some(options) => Connector::with_tls(options.into_client()?), None => Connector::new(), }; #[cfg(not(feature = "tls"))] diff --git a/src/tls.rs b/src/tls.rs deleted file mode 100644 index c7f4e40..0000000 --- a/src/tls.rs +++ /dev/null @@ -1,173 +0,0 @@ -use std::sync::Arc; - -use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; -use rustls::crypto::{CryptoProvider, WebPkiSupportedAlgorithms}; -use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime}; -use rustls::server::ParsedCertificate; -use rustls::{ClientConfig, DigitallySignedStruct, Error as TlsError, RootCertStore, SignatureScheme}; - -use crate::client::Result; -use crate::Error; - -/// Options for tls connection. -#[derive(Debug)] -pub struct TlsOptions { - identity: Option<(Vec>, PrivateKeyDer<'static>)>, - ca_certs: RootCertStore, - hostname_verification: bool, -} - -impl Clone for TlsOptions { - fn clone(&self) -> Self { - Self { - identity: self.identity.as_ref().map(|id| (id.0.clone(), id.1.clone_key())), - ca_certs: self.ca_certs.clone(), - hostname_verification: self.hostname_verification, - } - } -} - -impl Default for TlsOptions { - /// Same as [Self::new]. - fn default() -> Self { - Self::new() - } -} - -// Rustls tends to make disable of hostname verification verbose since it exposes man-in-the-middle -// attacks. Though, there are still attempts to disable hostname verification in rustls, but no got -// merged until now. -// * Allow disabling Hostname Verification: https://github.com/rustls/rustls/issues/578 -// * Dangerous verifiers API proposal: https://github.com/rustls/rustls/pull/1197 -#[derive(Debug)] -struct NoHostnameVerificationServerCertVerifier { - roots: RootCertStore, - supported: WebPkiSupportedAlgorithms, -} - -impl NoHostnameVerificationServerCertVerifier { - unsafe fn new(roots: RootCertStore) -> Self { - Self { roots, supported: CryptoProvider::get_default().unwrap().signature_verification_algorithms } - } -} - -impl ServerCertVerifier for NoHostnameVerificationServerCertVerifier { - fn verify_server_cert( - &self, - end_entity: &CertificateDer<'_>, - intermediates: &[CertificateDer<'_>], - _server_name: &ServerName<'_>, - _ocsp_response: &[u8], - now: UnixTime, - ) -> Result { - let cert = ParsedCertificate::try_from(end_entity)?; - rustls::client::verify_server_cert_signed_by_trust_anchor( - &cert, - &self.roots, - intermediates, - now, - self.supported.all, - )?; - - Ok(ServerCertVerified::assertion()) - } - - fn verify_tls12_signature( - &self, - message: &[u8], - cert: &CertificateDer<'_>, - dss: &DigitallySignedStruct, - ) -> Result { - rustls::crypto::verify_tls12_signature(message, cert, dss, &self.supported) - } - - fn verify_tls13_signature( - &self, - message: &[u8], - cert: &CertificateDer<'_>, - dss: &DigitallySignedStruct, - ) -> Result { - rustls::crypto::verify_tls12_signature(message, cert, dss, &self.supported) - } - - fn supported_verify_schemes(&self) -> Vec { - self.supported.supported_schemes() - } -} - -impl TlsOptions { - /// Tls options with no ca certificates. - #[deprecated(since = "0.10.0", note = "use TlsOptions::new instead")] - pub fn no_ca() -> Self { - Self::new() - } - - /// Tls options with no ca certificates. - pub fn new() -> Self { - Self { ca_certs: RootCertStore::empty(), identity: None, hostname_verification: true } - } - - /// Disables hostname verification in tls handshake. - /// - /// # Safety - /// This exposes risk to man-in-the-middle attacks. - pub unsafe fn with_no_hostname_verification(mut self) -> Self { - self.hostname_verification = false; - self - } - - /// Adds new ca certificates. - pub fn with_pem_ca_certs(mut self, certs: &str) -> Result { - for r in rustls_pemfile::certs(&mut certs.as_bytes()) { - let cert = match r { - Ok(cert) => cert, - Err(err) => return Err(Error::with_other("fail to read cert", err)), - }; - if let Err(err) = self.ca_certs.add(cert) { - return Err(Error::with_other("fail to add cert", err)); - } - } - Ok(self) - } - - /// Specifies client identity for server to authenticate. - pub fn with_pem_identity(mut self, cert: &str, key: &str) -> Result { - let r: std::result::Result, _> = rustls_pemfile::certs(&mut cert.as_bytes()).collect(); - let certs = match r { - Err(err) => return Err(Error::with_other("fail to read cert", err)), - Ok(certs) => certs, - }; - let key = match rustls_pemfile::private_key(&mut key.as_bytes()) { - Err(err) => return Err(Error::with_other("fail to read client private key", err)), - Ok(None) => return Err(Error::BadArguments(&"no client private key")), - Ok(Some(key)) => key, - }; - self.identity = Some((certs, key)); - Ok(self) - } - - fn take_roots(&mut self) -> RootCertStore { - std::mem::replace(&mut self.ca_certs, RootCertStore::empty()) - } - - pub(crate) fn into_config(mut self) -> Result { - let roots = self.take_roots(); - // This has to be called before server cert verifier to install default crypto provider. - let builder = ClientConfig::builder(); - let builder = match self.hostname_verification { - true => builder.with_root_certificates(roots), - false => unsafe { - let verifier = NoHostnameVerificationServerCertVerifier::new(roots); - builder.dangerous().with_custom_certificate_verifier(Arc::new(verifier)) - }, - }; - if let Some((client_cert, client_key)) = self.identity.take() { - match builder.with_client_auth_cert(client_cert, client_key) { - Ok(config) => Ok(config), - Err(err) => Err(Error::with_other("invalid client private key", err)), - } - } else { - Ok(builder.with_no_client_auth()) - } - } -} diff --git a/src/tls/client.rs b/src/tls/client.rs new file mode 100644 index 0000000..bd65a26 --- /dev/null +++ b/src/tls/client.rs @@ -0,0 +1,551 @@ +use std::sync::{Arc, RwLock}; + +use async_net::TcpStream; +use futures_rustls::client::TlsStream; +use futures_rustls::TlsConnector; +use rustls::client::WebPkiServerVerifier; +use rustls::pki_types::ServerName; +use rustls::ClientConfig; +use tracing::warn; + +use super::{NoHostnameVerificationServerCertVerifier, TlsCerts, TlsDynamicCerts}; +use crate::client::Result; +use crate::error::Error; + +pub(crate) struct TlsDynamicConnector { + config: RwLock<(u64, Arc)>, + dynamic_certs: TlsDynamicCerts, + hostname_verification: bool, +} + +impl TlsDynamicConnector { + pub fn new(dynamic_certs: TlsDynamicCerts, hostname_verification: bool) -> Result> { + let (version, certs) = dynamic_certs.get_versioned(); + let config = TlsClient::create_config((*certs).clone(), hostname_verification)?; + Ok(Arc::new(Self { config: RwLock::new((version, config)), dynamic_certs, hostname_verification })) + } + + pub fn get(&self) -> TlsConnector { + let (version, mut config) = self.config.read().unwrap().clone(); + if let Some((updated_version, certs)) = self.dynamic_certs.get_updated(version) { + config = match TlsClient::create_config((*certs).clone(), self.hostname_verification) { + Ok(config) => self.update_config(updated_version, config), + Err(err) => { + if self.skip_version(version, updated_version) { + warn!("fail to create tls config for updated certs: {:?}", err); + } + config + }, + }; + } + TlsConnector::from(config) + } + + fn skip_version(&self, expected_version: u64, updated_version: u64) -> bool { + let mut locked = self.config.write().unwrap(); + let update = expected_version == locked.0; + if update { + locked.0 = updated_version; + } + update + } + + fn update_config(&self, version: u64, config: Arc) -> Arc { + let mut locked = self.config.write().unwrap(); + if version > locked.0 { + *locked = (version, config); + } + locked.1.clone() + } +} + +#[derive(Clone)] +pub(crate) enum TlsClient { + Static(TlsConnector), + Dynamic(Arc), +} + +impl TlsClient { + pub(super) fn new_static(certs: TlsCerts, hostname_verification: bool) -> Result { + let config = Self::create_config(certs, hostname_verification)?; + Ok(Self::Static(TlsConnector::from(config))) + } + + pub(super) fn new_dynamic(dynamic_certs: TlsDynamicCerts, hostname_verification: bool) -> Result { + TlsDynamicConnector::new(dynamic_certs, hostname_verification).map(TlsClient::Dynamic) + } + + fn create_config(certs: TlsCerts, hostname_verification: bool) -> Result> { + // This has to be called before server cert verifier to install default crypto provider. + let builder = ClientConfig::builder(); + let builder = match hostname_verification { + true => { + let verifier = WebPkiServerVerifier::builder_with_provider( + certs.ca.roots.into(), + builder.crypto_provider().clone(), + ) + .with_crls(certs.ca.crls) + .build() + .map_err(|err| Error::with_other("fail to create tls server cert verifier", err))?; + builder.with_webpki_verifier(verifier) + }, + false => unsafe { + let verifier = NoHostnameVerificationServerCertVerifier::new(certs.ca.roots, certs.ca.crls); + builder.dangerous().with_custom_certificate_verifier(Arc::new(verifier)) + }, + }; + match certs.identity { + Some(identity) => match builder.with_client_auth_cert(identity.cert, identity.key) { + Ok(config) => Ok(config.into()), + Err(err) => Err(Error::with_other("invalid client private key", err)), + }, + None => Ok(builder.with_no_client_auth().into()), + } + } + + async fn connect_tls( + &self, + domain: ServerName<'static>, + stream: TcpStream, + ) -> std::io::Result> { + match self { + Self::Static(connector) => connector.connect(domain, stream).await, + Self::Dynamic(connector) => { + let connector = connector.get(); + connector.connect(domain, stream).await + }, + } + } + + pub async fn connect(&self, host: &str, port: u16) -> std::io::Result> { + let stream = TcpStream::connect((host, port)).await?; + let domain = ServerName::try_from(host).unwrap().to_owned(); + self.connect_tls(domain, stream).await + } +} + +#[cfg(test)] +mod tests { + use std::io::Write; + use std::path::Path; + use std::sync::Arc; + use std::time::{Duration, SystemTime}; + + use async_net::{TcpListener, TcpStream}; + use asyncs::task::TaskHandle; + use atomic_write_file::AtomicWriteFile; + use futures::channel::mpsc; + use futures::channel::mpsc::{UnboundedReceiver, UnboundedSender}; + use futures::join; + use futures::stream::StreamExt; + use futures_rustls::{TlsAcceptor, TlsStream}; + use notify::{Event, RecursiveMode, Watcher}; + use rcgen::{ + Certificate, + CertificateParams, + CertificateRevocationList, + CertificateRevocationListParams, + CertifiedKey, + Issuer, + KeyIdMethod, + KeyPair, + RevokedCertParams, + SerialNumber, + }; + use rustls::server::{ServerConfig, WebPkiClientVerifier}; + use rustls::RootCertStore; + use rustls_pki_types::PrivatePkcs8KeyDer; + use tempfile::TempDir; + use x509_parser::prelude::*; + + use crate::tls::{TlsCa, TlsCerts, TlsClient, TlsDynamicCerts, TlsIdentity, TlsOptions}; + + const HOSTNAME: &str = "127.0.0.1"; + const MISTMATCH_HOSTNAME: &str = "localhost"; + + struct Ca { + pub cert: Certificate, + pub crls: Vec, + pub issuer: Issuer<'static, KeyPair>, + } + + impl Ca { + pub fn new(cert: Certificate, key: KeyPair) -> Self { + let issuer = Issuer::from_ca_cert_der(cert.der(), key).unwrap(); + Self { cert, crls: vec![], issuer } + } + + pub fn revoke(&mut self, serial_number: SerialNumber) { + let revoked_params = RevokedCertParams { + serial_number, + revocation_time: SystemTime::now().into(), + reason_code: None, + invalidity_date: None, + }; + + let crl_params = CertificateRevocationListParams { + this_update: revoked_params.revocation_time, + next_update: revoked_params.revocation_time + Duration::from_secs(3600), + crl_number: SerialNumber::from(128), + issuing_distribution_point: None, + revoked_certs: vec![revoked_params], + key_identifier_method: KeyIdMethod::Sha256, + }; + self.crls.push(crl_params.signed_by(&self.issuer).unwrap()); + } + + pub fn pem(&self) -> String { + let mut pem = self.cert.pem(); + for crl in self.crls.iter() { + pem += crl.pem().unwrap().as_str(); + } + pem + } + } + + struct ServerCert { + pub cert: Certificate, + pub signing_key: KeyPair, + pub serial_number: SerialNumber, + } + + impl Clone for ServerCert { + fn clone(&self) -> Self { + Self { + cert: self.cert.clone(), + signing_key: KeyPair::try_from(self.signing_key.serialize_der()).unwrap(), + serial_number: self.serial_number.clone(), + } + } + } + + fn generate_ca_cert() -> Ca { + let mut params = CertificateParams::default(); + params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained); + params.distinguished_name.push(rcgen::DnType::CommonName, "ca"); + params.key_usages = vec![ + rcgen::KeyUsagePurpose::KeyCertSign, + rcgen::KeyUsagePurpose::DigitalSignature, + rcgen::KeyUsagePurpose::CrlSign, + ]; + let key = KeyPair::generate().unwrap(); + let ca_cert = params.self_signed(&key).unwrap(); + Ca::new(ca_cert, key) + } + + fn generate_server_cert(issuer: &Issuer<'_, KeyPair>) -> ServerCert { + let serial_number = SerialNumber::from_slice(uuid::Uuid::new_v4().as_bytes().as_slice()); + let mut params = CertificateParams::new(vec![HOSTNAME.to_string()]).unwrap(); + params.serial_number = Some(serial_number.clone()); + params.key_usages = vec![rcgen::KeyUsagePurpose::DigitalSignature, rcgen::KeyUsagePurpose::KeyEncipherment]; + params.extended_key_usages = vec![rcgen::ExtendedKeyUsagePurpose::ServerAuth]; + params.distinguished_name.push(rcgen::DnType::CommonName, "server"); + + let signing_key = KeyPair::generate().unwrap(); + let cert = params.signed_by(&signing_key, issuer).unwrap(); + ServerCert { cert, signing_key, serial_number } + } + + fn generate_client_cert(cn: &str, issuer: &Issuer<'_, KeyPair>) -> CertifiedKey { + let mut params = CertificateParams::default(); + params.distinguished_name.push(rcgen::DnType::CommonName, cn); + let signing_key = KeyPair::generate().unwrap(); + let cert = params.signed_by(&signing_key, issuer).unwrap(); + CertifiedKey { cert, signing_key } + } + + struct TlsListener { + server_cert: ServerCert, + listener: TcpListener, + acceptor: TlsAcceptor, + } + + impl TlsListener { + async fn listen(roots: RootCertStore, server_cert: ServerCert) -> Self { + let verifier = WebPkiClientVerifier::builder(roots.into()).build().unwrap(); + let server_config = ServerConfig::builder() + .with_client_cert_verifier(verifier) + .with_single_cert( + vec![server_cert.cert.der().clone()], + PrivatePkcs8KeyDer::from(server_cert.signing_key.serialize_der()).into(), + ) + .unwrap(); + let listener = TcpListener::bind((HOSTNAME, 0)).await.unwrap(); + Self { server_cert, listener, acceptor: TlsAcceptor::from(Arc::new(server_config)) } + } + + async fn accept(&self) -> TlsStream { + let (stream, _addr) = self.listener.accept().await.unwrap(); + self.acceptor.accept(stream).await.unwrap().into() + } + + fn local_port(&self) -> u16 { + self.listener.local_addr().unwrap().port() + } + } + + async fn listen() -> (Ca, TlsListener) { + let ca = generate_ca_cert(); + let server_cert = generate_server_cert(&ca.issuer); + let mut roots = RootCertStore::empty(); + roots.add(ca.cert.der().clone()).unwrap(); + + let listener = TlsListener::listen(roots, server_cert).await; + (ca, listener) + } + + async fn hostname_verification(hostname_verification: bool, host: &str, revoke: bool) { + let (mut ca, listener) = listen().await; + + let client_cert = generate_client_cert("client1", &ca.issuer); + + if revoke { + ca.revoke(listener.server_cert.serial_number.clone()); + } + + let mut options = TlsOptions::new() + .with_pem_ca(&ca.pem()) + .unwrap() + .with_pem_identity(&client_cert.cert.pem(), &client_cert.signing_key.serialize_pem()) + .unwrap(); + + if !hostname_verification { + options = unsafe { options.with_no_hostname_verification() }; + } + + let client = options.into_client().unwrap(); + + let (_server_stream, _client_stream) = + join!(listener.accept(), async { client.connect(host, listener.local_port()).await.unwrap() }); + } + + #[asyncs::test] + async fn hostname_verification_ok() { + hostname_verification(true, HOSTNAME, false).await; + hostname_verification(false, HOSTNAME, false).await; + hostname_verification(false, MISTMATCH_HOSTNAME, false).await; + } + + #[asyncs::test] + #[should_panic(expected = "NotValidForName")] + async fn hostname_verification_failure() { + hostname_verification(true, MISTMATCH_HOSTNAME, false).await; + } + + #[asyncs::test] + #[should_panic(expected = "InvalidCertificate(Revoked)")] + async fn no_hostname_verification_with_crls() { + hostname_verification(false, MISTMATCH_HOSTNAME, true).await; + } + + async fn assert_client_name(listener: &TlsListener, client: &TlsClient, client_name: &str) { + let (server_stream, _client_stream) = + join!(listener.accept(), async { client.connect(HOSTNAME, listener.local_port()).await.unwrap() }); + + let (_, state) = server_stream.get_ref(); + let peer_cert = state.peer_certificates().unwrap(); + let cert = X509Certificate::from_der(peer_cert[0].as_ref()).unwrap().1; + let name = cert.subject().iter_common_name().next().unwrap(); + assert_eq!(name.as_str().unwrap(), client_name); + } + + #[asyncs::test] + async fn with_pem_identity() { + let (ca, listener) = listen().await; + + let client_cert = generate_client_cert("client1", &ca.issuer); + + let options = TlsOptions::new() + .with_pem_ca(&ca.pem()) + .unwrap() + .with_pem_identity(&client_cert.cert.pem(), &client_cert.signing_key.serialize_pem()) + .unwrap(); + + let client = options.into_client().unwrap(); + + assert_client_name(&listener, &client, "client1").await; + } + + #[asyncs::test] + #[should_panic(expected = "InvalidCertificate(Revoked)")] + async fn with_crls() { + let (mut ca, listener) = listen().await; + ca.revoke(listener.server_cert.serial_number.clone()); + + let client_cert = generate_client_cert("client1", &ca.issuer); + + let options = TlsOptions::new() + .with_pem_ca(&ca.pem()) + .unwrap() + .with_pem_identity(&client_cert.cert.pem(), &client_cert.signing_key.serialize_pem()) + .unwrap(); + + let client = options.into_client().unwrap(); + + assert_client_name(&listener, &client, "client1").await; + } + + #[asyncs::test] + async fn with_static_certs() { + let (ca, listener) = listen().await; + + let client_cert = generate_client_cert("client1", &ca.issuer); + + let options = TlsOptions::new().with_certs( + TlsCerts::builder() + .with_ca(TlsCa::from_pem(&ca.pem()).unwrap()) + .with_identity( + TlsIdentity::from_pem(&client_cert.cert.pem(), &client_cert.signing_key.serialize_pem()).unwrap(), + ) + .build() + .unwrap(), + ); + + let client = options.into_client().unwrap(); + + assert_client_name(&listener, &client, "client1").await; + } + + #[asyncs::test] + async fn with_dynamic_certs() { + let (ca, listener) = listen().await; + + let client_cert = generate_client_cert("client1", &ca.issuer); + + let dynamic_certs = TlsDynamicCerts::new( + TlsCerts::builder() + .with_ca(TlsCa::from_pem(&ca.pem()).unwrap()) + .with_identity( + TlsIdentity::from_pem(&client_cert.cert.pem(), &client_cert.signing_key.serialize_pem()).unwrap(), + ) + .build() + .unwrap(), + ); + + let options = TlsOptions::new().with_certs(dynamic_certs.clone()); + + let client = options.into_client().unwrap(); + + assert_client_name(&listener, &client, "client1").await; + + let client_cert = generate_client_cert("client2", &ca.issuer); + dynamic_certs.update( + TlsCerts::builder() + .with_ca(TlsCa::from_pem(&ca.pem()).unwrap()) + .with_identity( + TlsIdentity::from_pem(&client_cert.cert.pem(), &client_cert.signing_key.serialize_pem()).unwrap(), + ) + .build() + .unwrap(), + ); + + assert_client_name(&listener, &client, "client2").await; + } + + struct FileBasedDynamicCerts { + ca: Ca, + dir: TempDir, + certs: TlsDynamicCerts, + feedback: UnboundedReceiver<()>, + _task: TaskHandle<()>, + } + + struct EventSender { + sender: UnboundedSender, + } + + impl notify::EventHandler for EventSender { + fn handle_event(&mut self, event: Result) { + if let Ok(event) = event { + self.sender.unbounded_send(event).unwrap(); + } + } + } + + impl FileBasedDynamicCerts { + pub fn new(ca: Ca, client_name: &str) -> Self { + let dir = TempDir::new().unwrap(); + Self::generate_cert(&ca, dir.path(), client_name); + let (certs, feedback, _task) = Self::load_dynamic_certs(&ca, dir.path()); + Self { ca, dir, certs, feedback, _task } + } + + fn load_dynamic_certs(ca: &Ca, dir: &Path) -> (TlsDynamicCerts, UnboundedReceiver<()>, TaskHandle<()>) { + let cert_path = dir.join("cert.pem").to_path_buf(); + let key_path = dir.join("cert.key.pem").to_path_buf(); + + let mut cert_modified = std::fs::metadata(&cert_path).unwrap().modified().unwrap(); + let mut key_modified = std::fs::metadata(&key_path).unwrap().modified().unwrap(); + let client_cert = std::fs::read_to_string(&cert_path).unwrap(); + let client_key = std::fs::read_to_string(&key_path).unwrap(); + + let dynamic_certs = TlsDynamicCerts::new( + TlsCerts::builder() + .with_ca(TlsCa::from_pem(&ca.pem()).unwrap()) + .with_identity(TlsIdentity::from_pem(&client_cert, &client_key).unwrap()) + .build() + .unwrap(), + ); + let dynamic_certs_updator = dynamic_certs.clone(); + + let (feedback_sender, feedback_receiver) = mpsc::unbounded(); + let task = asyncs::spawn(async move { + let (tx, mut rx) = mpsc::unbounded(); + let mut watcher = notify::recommended_watcher(EventSender { sender: tx }).unwrap(); + watcher.watch(&cert_path, RecursiveMode::NonRecursive).unwrap(); + watcher.watch(&key_path, RecursiveMode::NonRecursive).unwrap(); + while rx.next().await.is_some() { + let updated_cert_modified = std::fs::metadata(&cert_path).unwrap().modified().unwrap(); + let updated_key_modified = std::fs::metadata(&key_path).unwrap().modified().unwrap(); + if updated_cert_modified <= cert_modified || updated_key_modified <= key_modified { + continue; + } + cert_modified = updated_cert_modified; + key_modified = updated_key_modified; + let client_cert = std::fs::read_to_string(&cert_path).unwrap(); + let client_key = std::fs::read_to_string(&key_path).unwrap(); + dynamic_certs_updator + .update_identity(Some(TlsIdentity::from_pem(&client_cert, &client_key).unwrap())); + feedback_sender.unbounded_send(()).unwrap(); + } + }) + .attach(); + (dynamic_certs, feedback_receiver, task) + } + + fn generate_cert(ca: &Ca, dir: &Path, name: &str) { + let client_cert = generate_client_cert(name, &ca.issuer); + let file = AtomicWriteFile::open(dir.join("cert.pem")).unwrap(); + write!(&file, "{}", client_cert.cert.pem()).unwrap(); + file.commit().unwrap(); + + let file = AtomicWriteFile::open(dir.join("cert.key.pem")).unwrap(); + write!(&file, "{}", client_cert.signing_key.serialize_pem()).unwrap(); + file.commit().unwrap(); + } + + pub async fn regenerate_cert(&mut self, client_name: &str) { + Self::generate_cert(&self.ca, self.dir.path(), client_name); + self.feedback.next().await; + } + } + + #[asyncs::test] + async fn with_file_based_dynamic_certs() { + let (ca, listener) = listen().await; + + let options = TlsOptions::new().with_pem_ca(&ca.pem()).unwrap(); + + let mut certs = FileBasedDynamicCerts::new(ca, "client1"); + + let options = options.with_certs(certs.certs.clone()); + + let client = options.into_client().unwrap(); + + assert_client_name(&listener, &client, "client1").await; + + certs.regenerate_cert("client2").await; + + assert_client_name(&listener, &client, "client2").await; + } +} diff --git a/src/tls/mod.rs b/src/tls/mod.rs new file mode 100644 index 0000000..bec84f7 --- /dev/null +++ b/src/tls/mod.rs @@ -0,0 +1,7 @@ +mod client; +mod options; +mod verifier; + +pub(crate) use self::client::*; +pub use self::options::*; +use self::verifier::*; diff --git a/src/tls/options.rs b/src/tls/options.rs new file mode 100644 index 0000000..c4554a9 --- /dev/null +++ b/src/tls/options.rs @@ -0,0 +1,327 @@ +use std::sync::{Arc, RwLock}; + +use derive_where::derive_where; +use ignore_result::Ignore; +use rustls::pki_types::{CertificateDer, CertificateRevocationListDer, PrivateKeyDer}; +use rustls::RootCertStore; + +use super::TlsClient; +use crate::client::Result; +use crate::Error; + +type PemItem = rustls_pemfile::Item; + +/// Ca certificates and crls to authenticate peer. +#[derive(Clone, Debug)] +pub struct TlsCa { + pub(super) roots: RootCertStore, + pub(super) crls: Vec>, +} + +impl TlsCa { + /// Constructs [TlsCa] from pem. + pub fn from_pem(pem: &str) -> Result { + let mut ca = Self { roots: RootCertStore::empty(), crls: Vec::new() }; + for r in rustls_pemfile::read_all(&mut pem.as_bytes()) { + match r { + Ok(PemItem::X509Certificate(cert)) => ca.roots.add(cert).ignore(), + Ok(PemItem::Crl(crl)) => ca.crls.push(crl), + Ok(_) => continue, + Err(err) => return Err(Error::with_other("fail to read ca", err)), + } + } + if ca.roots.is_empty() { + return Err(Error::BadArguments(&"no valid tls trust anchor in pem")); + } + Ok(ca) + } + + fn merge(&mut self, ca: TlsCa) { + self.roots.roots.extend(ca.roots.roots); + self.crls.extend(ca.crls); + } +} + +/// A CA signed certificate and its private key. +#[derive_where(Debug)] +pub struct TlsIdentity { + /// CA signed certificate. + pub(super) cert: Vec>, + + /// Key to certificate. + #[derive_where(skip)] + pub(super) key: PrivateKeyDer<'static>, +} + +impl TlsIdentity { + /// Constructs [TlsIdentity] from pem. + pub fn from_pem(cert: &str, key: &str) -> Result { + let r: Result, _> = rustls_pemfile::certs(&mut cert.as_bytes()).collect(); + let cert = match r { + Err(err) => return Err(Error::with_other("fail to read cert", err)), + Ok(cert) => cert, + }; + let key = match rustls_pemfile::private_key(&mut key.as_bytes()) { + Err(err) => return Err(Error::with_other("fail to read client private key", err)), + Ok(None) => return Err(Error::BadArguments(&"no client private key")), + Ok(Some(key)) => key, + }; + Ok(Self { cert, key }) + } +} + +impl Clone for TlsIdentity { + fn clone(&self) -> Self { + Self { cert: self.cert.clone(), key: self.key.clone_key() } + } +} + +/// Certificates used by client to authenticate with server. +#[derive(Clone, Debug)] +pub struct TlsCerts { + /// Ca to authenticate server. + pub(super) ca: TlsCa, + /// Optional client side identity for server to authenticate. + pub(super) identity: Option, +} + +impl TlsCerts { + /// Constructs a builder to build [TlsCerts]. + pub fn builder() -> TlsCertsBuilder { + TlsCertsBuilder::new() + } +} + +/// Builder to construct [TlsCerts]. +#[derive(Clone, Debug)] +pub struct TlsCertsBuilder { + ca: Option, + /// Optional client side identity. + identity: Option, +} + +impl TlsCertsBuilder { + /// Constructs an empty builder. + fn new() -> Self { + Self { ca: None, identity: None } + } + + /// Specifies ca certificates and also crls. + pub fn with_ca(mut self, ca: TlsCa) -> Self { + self.ca = Some(ca); + self + } + + /// Specifies client identity for server to authenticate. + pub fn with_identity(mut self, identity: TlsIdentity) -> Self { + self.identity = Some(identity); + self + } + + /// Builds [TlsCerts] and fails if no ca specified. + pub fn build(self) -> Result { + let ca = match self.ca { + None => return Err(Error::BadArguments(&"no tls ca")), + Some(ca) => ca, + }; + Ok(TlsCerts { ca, identity: self.identity }) + } +} + +/// Options to carry [TlsCerts]. +#[derive(Clone, Debug)] +pub struct TlsCertsOptions { + certs: TlsInnerCerts, +} + +#[derive(Clone, Debug)] +enum TlsInnerCerts { + Static(TlsCerts), + Dynamic(TlsDynamicCerts), +} + +impl From for TlsInnerCerts { + fn from(options: TlsCertsOptions) -> Self { + options.certs + } +} + +impl From for TlsCertsOptions { + fn from(certs: TlsInnerCerts) -> Self { + Self { certs } + } +} + +impl From for TlsCertsOptions { + fn from(certs: TlsCerts) -> Self { + TlsInnerCerts::Static(certs).into() + } +} + +impl From for TlsCertsOptions { + fn from(certs: TlsDynamicCerts) -> Self { + TlsInnerCerts::Dynamic(certs).into() + } +} + +/// Cell to keep up to date [TlsCerts]. +/// +/// [TlsDynamicCerts] by itself are concurrent safe in updating certs, but concurrency implies +/// uncertainty which means you won't known which one will last. +#[derive(Clone, Debug)] +pub struct TlsDynamicCerts { + certs: Arc)>>, +} + +impl TlsDynamicCerts { + /// Constructs [TlsDynamicCerts] with certs. + pub fn new(certs: TlsCerts) -> Self { + let certs = certs.into(); + Self { certs: Arc::new(RwLock::new((1, certs))) } + } + + /// Updates with newer certs. + pub fn update(&self, certs: TlsCerts) { + let certs = certs.into(); + let mut writer = self.certs.write().unwrap(); + writer.0 += 1; + let old = std::mem::replace(&mut writer.1, certs); + drop(writer); + drop(old); + } + + /// Updates with newer ca certificates. + pub fn update_ca(&self, ca: TlsCa) { + self.update_partially(|certs| certs.ca = ca.clone()) + } + + /// Updates with newer client tls identity. + pub fn update_identity(&self, identity: Option) { + self.update_partially(|certs| certs.identity = identity.clone()) + } + + fn update_versioned(&self, version: u64, certs: TlsCerts) -> bool { + let certs = certs.into(); + let mut writer = self.certs.write().unwrap(); + if writer.0 != version { + return false; + } + writer.0 += 1; + let old = std::mem::replace(&mut writer.1, certs); + drop(writer); + drop(old); + true + } + + fn update_partially(&self, update: impl Fn(&mut TlsCerts)) { + loop { + let (version, certs) = self.get_versioned(); + let mut certs = (*certs).clone(); + update(&mut certs); + if self.update_versioned(version, certs) { + break; + } + } + } + + pub(crate) fn get_versioned(&self) -> (u64, Arc) { + self.certs.read().unwrap().clone() + } + + pub(crate) fn get_updated(&self, version: u64) -> Option<(u64, Arc)> { + let locked = self.certs.read().unwrap(); + if version >= locked.0 { + return None; + } + Some(locked.clone()) + } +} + +/// Options for tls connection. +#[derive(Clone, Debug)] +pub struct TlsOptions { + ca: Option, + identity: Option, + certs: Option, + hostname_verification: bool, +} + +impl Default for TlsOptions { + /// Same as [Self::new]. + fn default() -> Self { + Self::new() + } +} + +impl TlsOptions { + /// Tls options with no ca certificates. + #[deprecated(since = "0.10.0", note = "use TlsOptions::new instead")] + pub fn no_ca() -> Self { + Self::new() + } + + /// Tls options with no ca certificates. + pub fn new() -> Self { + Self { ca: None, identity: None, certs: None, hostname_verification: true } + } + + /// Disables hostname verification in tls handshake. + /// + /// # Safety + /// This exposes risk to man-in-the-middle attacks. + pub unsafe fn with_no_hostname_verification(mut self) -> Self { + self.hostname_verification = false; + self + } + + /// Adds new ca certificates. + /// + /// It behaves different to [TlsOptions::with_pem_ca] in two ways: + /// 1. It is additive. + /// 2. It takes only certs into account. + #[deprecated(since = "0.10.0", note = "use TlsOptions::with_pem_ca instead")] + pub fn with_pem_ca_certs(mut self, certs: &str) -> Result { + let mut ca = TlsCa::from_pem(certs)?; + ca.crls.clear(); + match self.ca.as_mut() { + None => self.ca = Some(ca), + Some(existing_ca) => existing_ca.merge(ca), + }; + Ok(self) + } + + /// Specifies ca certificates and also crls. + /// + /// See also [TlsCa::from_pem]. + pub fn with_pem_ca(mut self, ca: &str) -> Result { + self.ca = Some(TlsCa::from_pem(ca)?); + Ok(self) + } + + /// Specifies client identity for server to authenticate. + /// + /// See also [TlsIdentity::from_pem]. + pub fn with_pem_identity(mut self, cert: &str, key: &str) -> Result { + self.identity = Some(TlsIdentity::from_pem(cert, key)?); + Ok(self) + } + + /// Specifies certificates to connection to server. This takes precedence over + /// [TlsOptions::with_pem_identity] and [TlsOptions::with_pem_ca]. + pub fn with_certs(mut self, certs: impl Into) -> Self { + self.certs = Some(certs.into()); + self + } + + pub(crate) fn into_client(self) -> Result { + let hostname_verification = self.hostname_verification; + match self.certs.map(TlsInnerCerts::from) { + None => { + let certs = TlsCertsBuilder { ca: self.ca, identity: self.identity }.build()?; + TlsClient::new_static(certs, hostname_verification) + }, + Some(TlsInnerCerts::Static(certs)) => TlsClient::new_static(certs, hostname_verification), + Some(TlsInnerCerts::Dynamic(certs)) => TlsClient::new_dynamic(certs, hostname_verification), + } + } +} diff --git a/src/tls/verifier.rs b/src/tls/verifier.rs new file mode 100644 index 0000000..6182069 --- /dev/null +++ b/src/tls/verifier.rs @@ -0,0 +1,152 @@ +use std::sync::Arc; + +use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; +use rustls::crypto::{CryptoProvider, WebPkiSupportedAlgorithms}; +use rustls::pki_types::{CertificateDer, CertificateRevocationListDer, ServerName, UnixTime}; +use rustls::{ + CertRevocationListError, + CertificateError, + DigitallySignedStruct, + Error as TlsError, + ExtendedKeyPurpose, + OtherError, + RootCertStore, + SignatureScheme, +}; +use webpki::{ + BorrowedCertRevocationList, + CertRevocationList, + EndEntityCert, + InvalidNameContext, + KeyUsage, + RevocationOptionsBuilder, +}; + +use crate::client::Result; + +// Rustls tends to make disable of hostname verification verbose since it exposes man-in-the-middle +// attacks. Though, there are still attempts to disable hostname verification in rustls, but no got +// merged until now. +// * Allow disabling Hostname Verification: https://github.com/rustls/rustls/issues/578 +// * Dangerous verifiers API proposal: https://github.com/rustls/rustls/pull/1197 +#[derive(Debug)] +pub(super) struct NoHostnameVerificationServerCertVerifier { + roots: RootCertStore, + crls: Vec>, + supported: WebPkiSupportedAlgorithms, +} + +impl NoHostnameVerificationServerCertVerifier { + pub unsafe fn new(roots: RootCertStore, crls: Vec>) -> Self { + let crls: Vec<_> = crls + .iter() + .map(|crl| BorrowedCertRevocationList::from_der(crl.as_ref()).unwrap().to_owned().unwrap()) + .map(CertRevocationList::Owned) + .collect(); + Self { roots, crls, supported: CryptoProvider::get_default().unwrap().signature_verification_algorithms } + } +} + +fn extended_key_purpose(values: impl Iterator) -> ExtendedKeyPurpose { + let values = values.collect::>(); + match &*values { + KeyUsage::CLIENT_AUTH_REPR => ExtendedKeyPurpose::ClientAuth, + KeyUsage::SERVER_AUTH_REPR => ExtendedKeyPurpose::ServerAuth, + _ => ExtendedKeyPurpose::Other(values), + } +} + +// Copied from https://github.com/rustls/rustls/blob/v/0.23.29/rustls/src/webpki/mod.rs#L59 +// LICENSE: https://github.com/rustls/rustls/blob/v/0.23.29/LICENSE (any of Apache 2.0, MIT and ISC) +fn pki_error(error: webpki::Error) -> TlsError { + use webpki::Error::*; + match error { + BadDer | BadDerTime | TrailingData(_) => CertificateError::BadEncoding.into(), + CertNotValidYet { time, not_before } => CertificateError::NotValidYetContext { time, not_before }.into(), + CertExpired { time, not_after } => CertificateError::ExpiredContext { time, not_after }.into(), + InvalidCertValidity => CertificateError::Expired.into(), + UnknownIssuer => CertificateError::UnknownIssuer.into(), + CertNotValidForName(InvalidNameContext { expected, presented }) => { + CertificateError::NotValidForNameContext { expected, presented }.into() + }, + CertRevoked => CertificateError::Revoked.into(), + UnknownRevocationStatus => CertificateError::UnknownRevocationStatus.into(), + CrlExpired { time, next_update } => CertificateError::ExpiredRevocationListContext { time, next_update }.into(), + IssuerNotCrlSigner => CertRevocationListError::IssuerInvalidForCrl.into(), + + InvalidSignatureForPublicKey => CertificateError::BadSignature.into(), + #[allow(deprecated)] + UnsupportedSignatureAlgorithm + | UnsupportedSignatureAlgorithmContext(_) + | UnsupportedSignatureAlgorithmForPublicKey => CertificateError::UnsupportedSignatureAlgorithm.into(), + + InvalidCrlSignatureForPublicKey => CertRevocationListError::BadSignature.into(), + #[allow(deprecated)] + UnsupportedCrlSignatureAlgorithm + | UnsupportedCrlSignatureAlgorithmContext(_) + | UnsupportedCrlSignatureAlgorithmForPublicKey => CertRevocationListError::UnsupportedSignatureAlgorithm.into(), + + #[allow(deprecated)] + RequiredEkuNotFound => CertificateError::InvalidPurpose.into(), + RequiredEkuNotFoundContext(webpki::RequiredEkuNotFoundContext { required, present }) => { + CertificateError::InvalidPurposeContext { + required: extended_key_purpose(required.oid_values()), + presented: present.into_iter().map(|eku| extended_key_purpose(eku.into_iter())).collect(), + } + .into() + }, + + _ => CertificateError::Other(OtherError(Arc::new(error))).into(), + } +} + +impl ServerCertVerifier for NoHostnameVerificationServerCertVerifier { + fn verify_server_cert( + &self, + end_entity: &CertificateDer<'_>, + intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp_response: &[u8], + now: UnixTime, + ) -> Result { + let cert = EndEntityCert::try_from(end_entity).map_err(pki_error)?; + let crls = self.crls.iter().collect::>(); + let revocation = match RevocationOptionsBuilder::new(&crls) { + Err(_) => None, + Ok(builder) => Some(builder.build()), + }; + cert.verify_for_usage( + self.supported.all, + &self.roots.roots, + intermediates, + now, + webpki::KeyUsage::server_auth(), + revocation, + None, + ) + .map_err(pki_error)?; + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls12_signature(message, cert, dss, &self.supported) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls12_signature(message, cert, dss, &self.supported) + } + + fn supported_verify_schemes(&self) -> Vec { + self.supported.supported_schemes() + } +} diff --git a/tests/zookeeper.rs b/tests/zookeeper.rs index 2725e7f..76867a9 100644 --- a/tests/zookeeper.rs +++ b/tests/zookeeper.rs @@ -262,7 +262,7 @@ impl Tls { fn options(&self) -> zk::TlsOptions { let mut options = zk::TlsOptions::default() - .with_pem_ca_certs(&self.ca_cert_pem) + .with_pem_ca(&self.ca_cert_pem) .unwrap() .with_pem_identity(&self.client_cert_pem, &self.client_cert_key) .unwrap(); @@ -273,7 +273,12 @@ impl Tls { } fn options_x(&self) -> zk::TlsOptions { - self.options().with_pem_identity(&self.client_x_cert_pem, &self.client_x_cert_key).unwrap() + let certs = zk::TlsCerts::builder() + .with_ca(zk::TlsCa::from_pem(&self.ca_cert_pem).unwrap()) + .with_identity(zk::TlsIdentity::from_pem(&self.client_x_cert_pem, &self.client_x_cert_key).unwrap()) + .build() + .unwrap(); + self.options().with_certs(zk::TlsDynamicCerts::new(certs)) } }