From 4bcf3dc88f529e703b1fe918b10ffd7a77133478 Mon Sep 17 00:00:00 2001 From: Giorgos Georgiou Date: Thu, 26 Jun 2025 18:39:28 +0300 Subject: [PATCH] Reload client certificates This allows creating a client with certificate paths instead of a preloaded certificate. When created this way, on reconnection the client will check if the certificate files have been changed on disk and reload them if they have. This allows us to have auto-reloading of refreshed certificates client side. --- Cargo.toml | 9 +- examples/tls_file_based.rs | 105 +++++++++++++ src/session/connection.rs | 41 +++-- src/session/mod.rs | 2 +- src/tls.rs | 297 +++++++++++++++++++++++++++++++++---- 5 files changed, 415 insertions(+), 39 deletions(-) create mode 100644 examples/tls_file_based.rs diff --git a/Cargo.toml b/Cargo.toml index 1e1a58d..b8bdc20 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ rust-version = "1.76" [features] default = [] -tls = ["rustls", "rustls-pemfile", "webpki-roots", "futures-rustls"] +tls = ["rustls", "rustls-pemfile", "webpki-roots", "futures-rustls", "async-fs"] sasl = ["sasl-gssapi", "sasl-digest-md5"] sasl-digest-md5 = ["rsasl/unstable_custom_mechanism", "md5", "linkme", "hex"] sasl-gssapi = ["rsasl/gssapi"] @@ -48,6 +48,7 @@ md5 = { version = "0.7.0", optional = true } hex = { version = "0.4.3", optional = true } linkme = { version = "0.3", optional = true } async-io = "2.3.2" +async-fs = { version = "2.1.2", optional = true } futures = "0.3.30" async-net = "2.0.0" futures-rustls = { version = "0.26.0", optional = true } @@ -67,6 +68,7 @@ tempfile = "3.6.0" rcgen = { version = "0.12.1", features = ["default", "x509-parser"] } serial_test = "3.0.0" asyncs = { version = "0.3.0", features = ["test"] } +smol = "2.0.2" blocking = "1.6.0" [package.metadata.cargo-all-features] @@ -78,3 +80,8 @@ all-features = true [profile.dev] # Need this for linkme crate to work for spawns in macOS lto = "thin" + +[[example]] +name = "tls_file_based" +path = "examples/tls_file_based.rs" +required-features = ["tls", "smol"] diff --git a/examples/tls_file_based.rs b/examples/tls_file_based.rs new file mode 100644 index 0000000..7d5e03a --- /dev/null +++ b/examples/tls_file_based.rs @@ -0,0 +1,105 @@ +use std::env; +use std::io::{self, Write}; +use std::path::PathBuf; +use std::time::Duration; + +use zookeeper_client::Error::NodeExists; +use zookeeper_client::{Acls, Client, CreateMode, TlsOptions}; + +fn main() -> Result<(), Box> { + env_logger::init(); + smol::block_on(run()).unwrap_or_else(|e| { + eprintln!("Error: {}", e); + std::process::exit(1); + }); + Ok(()) +} + +async fn run() -> Result<(), Box> { + let connect_string = env::var("ZK_CONNECT_STRING").unwrap_or_else(|_| "tcp+tls://localhost:2281".to_string()); + let ca_cert = PathBuf::from(env::var("ZK_CA_CERT").expect("ZK_CA_CERT environment variable is required")); + let client_cert = + PathBuf::from(env::var("ZK_CLIENT_CERT").expect("ZK_CLIENT_CERT environment variable is required")); + let client_key = PathBuf::from(env::var("ZK_CLIENT_KEY").expect("ZK_CLIENT_KEY environment variable is required")); + + println!("Connecting to ZooKeeper with file-based TLS..."); + println!("Server: {}", connect_string); + println!("CA cert: {}", ca_cert.display()); + println!("Client cert: {}", client_cert.display()); + println!("Client key: {}", client_key.display()); + + let loaded_ca_cert = async_fs::read_to_string(&ca_cert).await?; + let tls_options = TlsOptions::default() + .with_pem_ca_certs(&loaded_ca_cert)? + .with_pem_identity_files(&client_cert, &client_key) + .await?; + + let tls_options = unsafe { tls_options.with_no_hostname_verification() }; + + println!("WARNING: Hostname verification disabled!"); + + let client = Client::connector() + .connection_timeout(Duration::from_secs(10)) + .session_timeout(Duration::from_secs(30)) + .tls(tls_options) + .secure_connect(&connect_string) + .await?; + + println!("Connected to ZooKeeper successfully!"); + + let path = "/tls_example"; + + loop { + print!("\nOptions:\ne. Edit key\nq. Quit\nEnter choice (e/q): "); + io::stdout().flush()?; + + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + + match input.trim() { + "e" => { + print!("Enter new data for the key: "); + io::stdout().flush()?; + + let mut data_input = String::new(); + io::stdin().read_line(&mut data_input)?; + let data = data_input.trim().as_bytes(); + + println!("Setting data at path: {}", path); + match client.create(path, data, &CreateMode::Ephemeral.with_acls(Acls::anyone_all())).await { + Ok(_) => println!("ZNode created successfully"), + Err(NodeExists) => { + println!("ZNode already exists, updating data..."); + client.set_data(path, data, None).await?; + println!("ZNode data updated successfully"); + }, + Err(e) => { + println!("Error creating/updating ZNode: {}", e); + continue; + }, + } + + match client.get_data(path).await { + Ok((data, _stat)) => { + println!("Current data: {}", String::from_utf8_lossy(&data)); + }, + Err(e) => println!("Error reading data: {}", e), + } + }, + "q" => { + println!("Cleaning up and exiting..."); + match client.delete(path, None).await { + Ok(_) => println!("ZNode deleted successfully"), + Err(_) => println!("ZNode may not exist or already deleted"), + } + break; + }, + _ => { + println!("Invalid choice. Please enter 'e' or 'q'."); + }, + } + } + + println!("Example completed successfully!"); + Ok(()) +} diff --git a/src/session/connection.rs b/src/session/connection.rs index 7f3dfc1..ac2de67 100644 --- a/src/session/connection.rs +++ b/src/session/connection.rs @@ -1,5 +1,7 @@ use std::io::{Error, ErrorKind, IoSlice, Result}; use std::pin::Pin; +#[cfg(feature = "tls")] +use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; @@ -15,18 +17,18 @@ use tracing::{debug, trace}; #[cfg(feature = "tls")] mod tls { - pub use std::sync::Arc; - pub use futures_rustls::client::TlsStream; pub use futures_rustls::TlsConnector; pub use rustls::pki_types::ServerName; - pub use rustls::ClientConfig; } + #[cfg(feature = "tls")] use tls::*; use crate::deadline::Deadline; use crate::endpoint::{EndpointRef, IterableEndpoints}; +#[cfg(feature = "tls")] +use crate::TlsOptions; #[derive(Debug)] pub enum Connection { @@ -170,7 +172,7 @@ impl Connection { #[derive(Clone)] pub struct Connector { #[cfg(feature = "tls")] - tls: Option, + tls_options: Option, timeout: Duration, } @@ -178,7 +180,7 @@ impl Connector { #[cfg(feature = "tls")] #[allow(dead_code)] pub fn new() -> Self { - Self { tls: None, timeout: Duration::from_secs(10) } + Self { tls_options: None, timeout: Duration::from_secs(10) } } #[cfg(not(feature = "tls"))] @@ -187,14 +189,27 @@ impl Connector { } #[cfg(feature = "tls")] - pub fn with_tls(config: ClientConfig) -> Self { - Self { tls: Some(TlsConnector::from(Arc::new(config))), timeout: Duration::from_secs(10) } + pub fn with_tls_options(tls_options: TlsOptions) -> Self { + Self { tls_options: Some(tls_options), timeout: Duration::from_secs(10) } + } + + #[cfg(feature = "tls")] + async fn get_current_tls_connector(&self) -> Result { + let Some(ref tls_opts) = self.tls_options else { + return Err(Error::new(ErrorKind::InvalidInput, "no TLS configuration")); + }; + let config = tls_opts + .to_config() + .await + .map_err(|e| Error::new(ErrorKind::InvalidData, format!("TLS config creation failed: {}", e)))?; + Ok(TlsConnector::from(Arc::new(config))) } #[cfg(feature = "tls")] async fn connect_tls(&self, stream: TcpStream, host: &str) -> Result { + let tls_connector = self.get_current_tls_connector().await?; let domain = ServerName::try_from(host).unwrap().to_owned(); - let stream = self.tls.as_ref().unwrap().connect(domain, stream).await?; + let stream = tls_connector.connect(domain, stream).await?; Ok(Connection::new_tls(stream)) } @@ -209,7 +224,7 @@ impl Connector { pub async fn connect(&self, endpoint: EndpointRef<'_>, deadline: &mut Deadline) -> Result { if endpoint.tls { #[cfg(feature = "tls")] - if self.tls.is_none() { + if self.tls_options.is_none() { return Err(Error::new(ErrorKind::Unsupported, "tls not supported")); } #[cfg(not(feature = "tls"))] @@ -288,4 +303,12 @@ mod tests { let err = connector.connect(endpoint, &mut Deadline::never()).await.unwrap_err(); assert_eq!(err.kind(), ErrorKind::Unsupported); } + + #[cfg(feature = "tls")] + #[test] + fn test_with_tls_options() { + let tls_options = crate::TlsOptions::default(); + let connector = Connector::with_tls_options(tls_options); + assert!(connector.tls_options.is_some()); + } } diff --git a/src/session/mod.rs b/src/session/mod.rs index c57bcd5..300172a 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -130,7 +130,7 @@ impl Builder { return Err(Error::BadArguments(&"connection timeout must not be negative")); } #[cfg(feature = "tls")] - let connector = Connector::with_tls(self.tls.unwrap_or_default().into_config()?); + let connector = Connector::with_tls_options(self.tls.unwrap_or_default()); #[cfg(not(feature = "tls"))] let connector = Connector::new(); let (state_sender, state_receiver) = asyncs::sync::watch::channel(SessionState::Disconnected); diff --git a/src/tls.rs b/src/tls.rs index 7b68925..bbb51e9 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -1,5 +1,8 @@ +use std::path::{Path, PathBuf}; use std::sync::Arc; +use std::time::SystemTime; +use futures::lock::Mutex; use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; use rustls::crypto::{CryptoProvider, WebPkiSupportedAlgorithms}; use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime}; @@ -9,24 +12,107 @@ use rustls::{ClientConfig, DigitallySignedStruct, Error as TlsError, RootCertSto use crate::client::Result; use crate::Error; -/// Options for tls connection. #[derive(Debug)] -pub struct TlsOptions { - identity: Option<(Vec>, PrivateKeyDer<'static>)>, - ca_certs: RootCertStore, - hostname_verification: bool, +struct FileProvider { + certs: Vec>, + key: PrivateKeyDer<'static>, + cert_path: PathBuf, + key_path: PathBuf, + cert_modified: SystemTime, + key_modified: SystemTime, +} + +impl FileProvider { + async fn new(cert_path: PathBuf, key_path: PathBuf) -> Result { + let (certs, key) = load_certificates_from_files(&cert_path, &key_path).await?; + let (cert_modified, key_modified) = get_file_timestamps(&cert_path, &key_path).await?; + Ok(Self { certs, key, cert_path, key_path, cert_modified, key_modified }) + } + + async fn update_and_fetch(&mut self) -> Result<(Vec>, PrivateKeyDer<'static>)> { + let (cert_modified, key_modified) = get_file_timestamps(&self.cert_path, &self.key_path).await?; + let cert_changed = cert_modified > self.cert_modified; + let key_changed = key_modified > self.key_modified; + // Refresh if both files were modified, as we want to make sure that we don't pick up a new cert/key with + // an old key/cert. + if cert_changed && key_changed { + tracing::debug!("Reloading client certificates"); + match load_certificates_from_files(&self.cert_path, &self.key_path).await { + Err(e) => tracing::warn!("Failed to reload certificates, keeping existing ones: {}", e), + Ok((certs, key)) => { + tracing::info!("Reloaded client certificates"); + println!("Reloaded client certificates"); + self.cert_modified = cert_modified; + self.key_modified = key_modified; + self.certs = certs; + self.key = key; + }, + } + } + Ok((self.certs.clone(), self.key.clone_key())) + } +} + +async fn load_certificates_from_files( + cert_path: &Path, + key_path: &Path, +) -> Result<(Vec>, PrivateKeyDer<'static>)> { + let cert_content = async_fs::read_to_string(cert_path) + .await + .map_err(|e| Error::with_other("Failed to read certificate file", e))?; + let key_content = + async_fs::read_to_string(key_path).await.map_err(|e| Error::with_other("Failed to read key file", e))?; + parse_pem_identity(&cert_content, &key_content) +} + +async fn get_file_timestamps(cert_path: &Path, key_path: &Path) -> Result<(SystemTime, SystemTime)> { + let cert_metadata = async_fs::metadata(cert_path) + .await + .map_err(|e| Error::with_other("Failed to get certificate file metadata", e))?; + let key_metadata = + async_fs::metadata(key_path).await.map_err(|e| Error::with_other("Failed to get key file metadata", e))?; + + let cert_modified = cert_metadata.modified().unwrap_or(SystemTime::UNIX_EPOCH); + let key_modified = key_metadata.modified().unwrap_or(SystemTime::UNIX_EPOCH); + Ok((cert_modified, key_modified)) +} + +#[derive(Debug)] +enum IdentityProvider { + Static { certs: Vec>, key: PrivateKeyDer<'static> }, + FileBased { provider: Arc> }, } -impl Clone for TlsOptions { +impl IdentityProvider { + pub async fn check_and_reload_certificates( + &self, + ) -> Result<(Vec>, PrivateKeyDer<'static>)> { + match self { + IdentityProvider::Static { certs, key } => Ok((certs.clone(), key.clone_key())), + IdentityProvider::FileBased { provider } => provider.lock().await.update_and_fetch().await, + } + } +} + +impl Clone for IdentityProvider { fn clone(&self) -> Self { - Self { - identity: self.identity.as_ref().map(|id| (id.0.clone(), id.1.clone_key())), - ca_certs: self.ca_certs.clone(), - hostname_verification: self.hostname_verification, + match self { + IdentityProvider::Static { certs, key } => { + IdentityProvider::Static { certs: certs.clone(), key: key.clone_key() } + }, + provider @ IdentityProvider::FileBased { .. } => provider.clone(), } } } +/// Options for tls connection. +#[derive(Debug, Clone)] +pub struct TlsOptions { + identity_provider: Option, + ca_certs: RootCertStore, + hostname_verification: bool, +} + impl Default for TlsOptions { /// Tls options with well-known ca roots. fn default() -> Self { @@ -105,11 +191,29 @@ impl ServerCertVerifier for TlsServerCertVerifier { } } +/// Helper function to parse certificate and key content from strings +fn parse_pem_identity( + cert_content: &str, + key_content: &str, +) -> Result<(Vec>, PrivateKeyDer<'static>)> { + let r: std::result::Result, _> = rustls_pemfile::certs(&mut cert_content.as_bytes()).collect(); + let certs = match r { + Err(err) => return Err(Error::with_other("fail to read cert", err)), + Ok(certs) => certs, + }; + let key = match rustls_pemfile::private_key(&mut key_content.as_bytes()) { + Err(err) => return Err(Error::with_other("fail to read client private key", err)), + Ok(None) => return Err(Error::BadArguments(&"no client private key")), + Ok(Some(key)) => key, + }; + Ok((certs, key)) +} + impl TlsOptions { /// Tls options with no ca certificates. Use [TlsOptions::default] if well-known ca roots is /// desirable. pub fn no_ca() -> Self { - Self { ca_certs: RootCertStore::empty(), identity: None, hostname_verification: true } + Self { ca_certs: RootCertStore::empty(), identity_provider: None, hostname_verification: true } } /// Disables hostname verification in tls handshake. @@ -137,30 +241,33 @@ impl TlsOptions { /// Specifies client identity for server to authenticate. pub fn with_pem_identity(mut self, cert: &str, key: &str) -> Result { - let r: std::result::Result, _> = rustls_pemfile::certs(&mut cert.as_bytes()).collect(); - let certs = match r { - Err(err) => return Err(Error::with_other("fail to read cert", err)), - Ok(certs) => certs, - }; - let key = match rustls_pemfile::private_key(&mut key.as_bytes()) { - Err(err) => return Err(Error::with_other("fail to read client private key", err)), - Ok(None) => return Err(Error::BadArguments(&"no client private key")), - Ok(Some(key)) => key, - }; - self.identity = Some((certs, key)); + let (certs, key) = parse_pem_identity(cert, key)?; + self.identity_provider = Some(IdentityProvider::Static { certs, key }); Ok(self) } - fn take_roots(&mut self) -> RootCertStore { - std::mem::replace(&mut self.ca_certs, RootCertStore::empty()) + /// Specifies client identity from file paths with automatic reloading on file changes when + /// reconnections take place. + pub async fn with_pem_identity_files( + mut self, + cert_path: impl Into, + key_path: impl Into, + ) -> Result { + let cert_path = cert_path.into(); + let key_path = key_path.into(); + + let file_provider = FileProvider::new(cert_path, key_path).await?; + self.identity_provider = Some(IdentityProvider::FileBased { provider: Arc::new(Mutex::new(file_provider)) }); + + Ok(self) } - pub(crate) fn into_config(mut self) -> Result { - // This has to be called before server cert verifier to install default crypto provider. + pub(crate) async fn to_config(&self) -> Result { let builder = ClientConfig::builder(); - let verifier = TlsServerCertVerifier::new(self.take_roots(), self.hostname_verification); + let verifier = TlsServerCertVerifier::new(self.ca_certs.clone(), self.hostname_verification); let builder = builder.dangerous().with_custom_certificate_verifier(Arc::new(verifier)); - if let Some((client_cert, client_key)) = self.identity.take() { + if let Some(identity_provider) = &self.identity_provider { + let (client_cert, client_key) = identity_provider.check_and_reload_certificates().await?; match builder.with_client_auth_cert(client_cert, client_key) { Ok(config) => Ok(config), Err(err) => Err(Error::with_other("invalid client private key", err)), @@ -170,3 +277,137 @@ impl TlsOptions { } } } + +#[cfg(test)] +mod tests { + use std::time::Duration; + use std::{fs, thread}; + + use rcgen::{Certificate, CertificateParams}; + use tempfile::TempDir; + + use super::*; + + fn generate_test_cert_and_key() -> (String, String) { + let mut params = CertificateParams::new(vec!["localhost".to_string()]); + params.alg = &rcgen::PKCS_ECDSA_P256_SHA256; + + let cert = Certificate::from_params(params).unwrap(); + let cert_pem = cert.serialize_pem().unwrap(); + let key_pem = cert.serialize_private_key_pem(); + + (cert_pem, key_pem) + } + + #[asyncs::test] + async fn test_with_pem_identity_files() { + let temp_dir = TempDir::new().unwrap(); + let cert_path = temp_dir.path().join("test.crt"); + let key_path = temp_dir.path().join("test.key"); + + // Generate valid test certificates + let (cert_pem, key_pem) = generate_test_cert_and_key(); + fs::write(&cert_path, &cert_pem).unwrap(); + fs::write(&key_path, &key_pem).unwrap(); + + // Test loading certificates from files + let tls_options = TlsOptions::default().with_pem_identity_files(&cert_path, &key_path).await.unwrap(); + + // Verify that identity was loaded + assert!(tls_options.identity_provider.is_some()); + } + + #[asyncs::test] + async fn test_with_pem_identity_files_missing_cert() { + let temp_dir = TempDir::new().unwrap(); + let cert_path = temp_dir.path().join("missing.crt"); + let key_path = temp_dir.path().join("test.key"); + + let (_, key_pem) = generate_test_cert_and_key(); + fs::write(&key_path, &key_pem).unwrap(); + + // Should fail when certificate file is missing + let result = TlsOptions::default().with_pem_identity_files(&cert_path, &key_path).await; + + assert!(result.is_err()); + } + + #[asyncs::test] + async fn test_with_pem_identity_files_missing_key() { + let temp_dir = TempDir::new().unwrap(); + let cert_path = temp_dir.path().join("test.crt"); + let key_path = temp_dir.path().join("missing.key"); + + let (cert_pem, _) = generate_test_cert_and_key(); + fs::write(&cert_path, &cert_pem).unwrap(); + + // Should fail when key file is missing + let result = TlsOptions::default().with_pem_identity_files(&cert_path, &key_path).await; + + assert!(result.is_err()); + } + + #[asyncs::test] + async fn test_check_and_reload_certificates_no_changes() { + let temp_dir = TempDir::new().unwrap(); + let cert_path = temp_dir.path().join("test.crt"); + let key_path = temp_dir.path().join("test.key"); + + let (cert_pem, key_pem) = generate_test_cert_and_key(); + fs::write(&cert_path, &cert_pem).unwrap(); + fs::write(&key_path, &key_pem).unwrap(); + + let tls_options = TlsOptions::default().with_pem_identity_files(&cert_path, &key_path).await.unwrap(); + + let (cert_1, key_1) = + tls_options.identity_provider.as_ref().unwrap().check_and_reload_certificates().await.unwrap(); + let (cert_2, key_2) = + tls_options.identity_provider.as_ref().unwrap().check_and_reload_certificates().await.unwrap(); + assert_eq!(cert_1, cert_2); + assert_eq!(key_1, key_2); + } + + #[asyncs::test] + async fn test_check_and_reload_certificates_key_changes() { + let temp_dir = TempDir::new().unwrap(); + let cert_path = temp_dir.path().join("test.crt"); + let key_path = temp_dir.path().join("test.key"); + + let (cert_pem, key_pem) = generate_test_cert_and_key(); + fs::write(&cert_path, &cert_pem).unwrap(); + fs::write(&key_path, &key_pem).unwrap(); + + let tls_options = TlsOptions::default().with_pem_identity_files(&cert_path, &key_path).await.unwrap(); + let (cert_1, key_1) = tls_options.identity_provider.unwrap().check_and_reload_certificates().await.unwrap(); + + // Sleep to ensure different modification time + thread::sleep(Duration::from_millis(50)); + + // Update the key file with new content (must update both cert and key for valid pair) + let (new_cert_pem, new_key_pem) = generate_test_cert_and_key(); + fs::write(&cert_path, &new_cert_pem).unwrap(); + fs::write(&key_path, &new_key_pem).unwrap(); + + let tls_options = TlsOptions::default().with_pem_identity_files(&cert_path, &key_path).await.unwrap(); + let (cert_2, key_2) = tls_options.identity_provider.unwrap().check_and_reload_certificates().await.unwrap(); + assert!(cert_1 != cert_2); + assert!(key_1 != key_2); + } + + #[asyncs::test] + async fn test_into_config_with_file_based_certs() { + let temp_dir = TempDir::new().unwrap(); + let cert_path = temp_dir.path().join("test.crt"); + let key_path = temp_dir.path().join("test.key"); + + let (cert_pem, key_pem) = generate_test_cert_and_key(); + fs::write(&cert_path, &cert_pem).unwrap(); + fs::write(&key_path, &key_pem).unwrap(); + + let tls_options = TlsOptions::default().with_pem_identity_files(&cert_path, &key_path).await.unwrap(); + + // Should be able to create a valid ClientConfig + let config = tls_options.to_config().await; + assert!(config.is_ok()); + } +}