diff --git a/Cargo.toml b/Cargo.toml index 1e1a58d..b8bdc20 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ rust-version = "1.76" [features] default = [] -tls = ["rustls", "rustls-pemfile", "webpki-roots", "futures-rustls"] +tls = ["rustls", "rustls-pemfile", "webpki-roots", "futures-rustls", "async-fs"] sasl = ["sasl-gssapi", "sasl-digest-md5"] sasl-digest-md5 = ["rsasl/unstable_custom_mechanism", "md5", "linkme", "hex"] sasl-gssapi = ["rsasl/gssapi"] @@ -48,6 +48,7 @@ md5 = { version = "0.7.0", optional = true } hex = { version = "0.4.3", optional = true } linkme = { version = "0.3", optional = true } async-io = "2.3.2" +async-fs = { version = "2.1.2", optional = true } futures = "0.3.30" async-net = "2.0.0" futures-rustls = { version = "0.26.0", optional = true } @@ -67,6 +68,7 @@ tempfile = "3.6.0" rcgen = { version = "0.12.1", features = ["default", "x509-parser"] } serial_test = "3.0.0" asyncs = { version = "0.3.0", features = ["test"] } +smol = "2.0.2" blocking = "1.6.0" [package.metadata.cargo-all-features] @@ -78,3 +80,8 @@ all-features = true [profile.dev] # Need this for linkme crate to work for spawns in macOS lto = "thin" + +[[example]] +name = "tls_file_based" +path = "examples/tls_file_based.rs" +required-features = ["tls", "smol"] diff --git a/examples/tls_file_based.rs b/examples/tls_file_based.rs new file mode 100644 index 0000000..7d5e03a --- /dev/null +++ b/examples/tls_file_based.rs @@ -0,0 +1,105 @@ +use std::env; +use std::io::{self, Write}; +use std::path::PathBuf; +use std::time::Duration; + +use zookeeper_client::Error::NodeExists; +use zookeeper_client::{Acls, Client, CreateMode, TlsOptions}; + +fn main() -> Result<(), Box> { + env_logger::init(); + smol::block_on(run()).unwrap_or_else(|e| { + eprintln!("Error: {}", e); + std::process::exit(1); + }); + Ok(()) +} + +async fn run() -> Result<(), Box> { + let connect_string = env::var("ZK_CONNECT_STRING").unwrap_or_else(|_| "tcp+tls://localhost:2281".to_string()); + let ca_cert = PathBuf::from(env::var("ZK_CA_CERT").expect("ZK_CA_CERT environment variable is required")); + let client_cert = + PathBuf::from(env::var("ZK_CLIENT_CERT").expect("ZK_CLIENT_CERT environment variable is required")); + let client_key = PathBuf::from(env::var("ZK_CLIENT_KEY").expect("ZK_CLIENT_KEY environment variable is required")); + + println!("Connecting to ZooKeeper with file-based TLS..."); + println!("Server: {}", connect_string); + println!("CA cert: {}", ca_cert.display()); + println!("Client cert: {}", client_cert.display()); + println!("Client key: {}", client_key.display()); + + let loaded_ca_cert = async_fs::read_to_string(&ca_cert).await?; + let tls_options = TlsOptions::default() + .with_pem_ca_certs(&loaded_ca_cert)? + .with_pem_identity_files(&client_cert, &client_key) + .await?; + + let tls_options = unsafe { tls_options.with_no_hostname_verification() }; + + println!("WARNING: Hostname verification disabled!"); + + let client = Client::connector() + .connection_timeout(Duration::from_secs(10)) + .session_timeout(Duration::from_secs(30)) + .tls(tls_options) + .secure_connect(&connect_string) + .await?; + + println!("Connected to ZooKeeper successfully!"); + + let path = "/tls_example"; + + loop { + print!("\nOptions:\ne. Edit key\nq. Quit\nEnter choice (e/q): "); + io::stdout().flush()?; + + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + + match input.trim() { + "e" => { + print!("Enter new data for the key: "); + io::stdout().flush()?; + + let mut data_input = String::new(); + io::stdin().read_line(&mut data_input)?; + let data = data_input.trim().as_bytes(); + + println!("Setting data at path: {}", path); + match client.create(path, data, &CreateMode::Ephemeral.with_acls(Acls::anyone_all())).await { + Ok(_) => println!("ZNode created successfully"), + Err(NodeExists) => { + println!("ZNode already exists, updating data..."); + client.set_data(path, data, None).await?; + println!("ZNode data updated successfully"); + }, + Err(e) => { + println!("Error creating/updating ZNode: {}", e); + continue; + }, + } + + match client.get_data(path).await { + Ok((data, _stat)) => { + println!("Current data: {}", String::from_utf8_lossy(&data)); + }, + Err(e) => println!("Error reading data: {}", e), + } + }, + "q" => { + println!("Cleaning up and exiting..."); + match client.delete(path, None).await { + Ok(_) => println!("ZNode deleted successfully"), + Err(_) => println!("ZNode may not exist or already deleted"), + } + break; + }, + _ => { + println!("Invalid choice. Please enter 'e' or 'q'."); + }, + } + } + + println!("Example completed successfully!"); + Ok(()) +} diff --git a/src/session/connection.rs b/src/session/connection.rs index 7f3dfc1..ac2de67 100644 --- a/src/session/connection.rs +++ b/src/session/connection.rs @@ -1,5 +1,7 @@ use std::io::{Error, ErrorKind, IoSlice, Result}; use std::pin::Pin; +#[cfg(feature = "tls")] +use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; @@ -15,18 +17,18 @@ use tracing::{debug, trace}; #[cfg(feature = "tls")] mod tls { - pub use std::sync::Arc; - pub use futures_rustls::client::TlsStream; pub use futures_rustls::TlsConnector; pub use rustls::pki_types::ServerName; - pub use rustls::ClientConfig; } + #[cfg(feature = "tls")] use tls::*; use crate::deadline::Deadline; use crate::endpoint::{EndpointRef, IterableEndpoints}; +#[cfg(feature = "tls")] +use crate::TlsOptions; #[derive(Debug)] pub enum Connection { @@ -170,7 +172,7 @@ impl Connection { #[derive(Clone)] pub struct Connector { #[cfg(feature = "tls")] - tls: Option, + tls_options: Option, timeout: Duration, } @@ -178,7 +180,7 @@ impl Connector { #[cfg(feature = "tls")] #[allow(dead_code)] pub fn new() -> Self { - Self { tls: None, timeout: Duration::from_secs(10) } + Self { tls_options: None, timeout: Duration::from_secs(10) } } #[cfg(not(feature = "tls"))] @@ -187,14 +189,27 @@ impl Connector { } #[cfg(feature = "tls")] - pub fn with_tls(config: ClientConfig) -> Self { - Self { tls: Some(TlsConnector::from(Arc::new(config))), timeout: Duration::from_secs(10) } + pub fn with_tls_options(tls_options: TlsOptions) -> Self { + Self { tls_options: Some(tls_options), timeout: Duration::from_secs(10) } + } + + #[cfg(feature = "tls")] + async fn get_current_tls_connector(&self) -> Result { + let Some(ref tls_opts) = self.tls_options else { + return Err(Error::new(ErrorKind::InvalidInput, "no TLS configuration")); + }; + let config = tls_opts + .to_config() + .await + .map_err(|e| Error::new(ErrorKind::InvalidData, format!("TLS config creation failed: {}", e)))?; + Ok(TlsConnector::from(Arc::new(config))) } #[cfg(feature = "tls")] async fn connect_tls(&self, stream: TcpStream, host: &str) -> Result { + let tls_connector = self.get_current_tls_connector().await?; let domain = ServerName::try_from(host).unwrap().to_owned(); - let stream = self.tls.as_ref().unwrap().connect(domain, stream).await?; + let stream = tls_connector.connect(domain, stream).await?; Ok(Connection::new_tls(stream)) } @@ -209,7 +224,7 @@ impl Connector { pub async fn connect(&self, endpoint: EndpointRef<'_>, deadline: &mut Deadline) -> Result { if endpoint.tls { #[cfg(feature = "tls")] - if self.tls.is_none() { + if self.tls_options.is_none() { return Err(Error::new(ErrorKind::Unsupported, "tls not supported")); } #[cfg(not(feature = "tls"))] @@ -288,4 +303,12 @@ mod tests { let err = connector.connect(endpoint, &mut Deadline::never()).await.unwrap_err(); assert_eq!(err.kind(), ErrorKind::Unsupported); } + + #[cfg(feature = "tls")] + #[test] + fn test_with_tls_options() { + let tls_options = crate::TlsOptions::default(); + let connector = Connector::with_tls_options(tls_options); + assert!(connector.tls_options.is_some()); + } } diff --git a/src/session/mod.rs b/src/session/mod.rs index c57bcd5..300172a 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -130,7 +130,7 @@ impl Builder { return Err(Error::BadArguments(&"connection timeout must not be negative")); } #[cfg(feature = "tls")] - let connector = Connector::with_tls(self.tls.unwrap_or_default().into_config()?); + let connector = Connector::with_tls_options(self.tls.unwrap_or_default()); #[cfg(not(feature = "tls"))] let connector = Connector::new(); let (state_sender, state_receiver) = asyncs::sync::watch::channel(SessionState::Disconnected); diff --git a/src/tls.rs b/src/tls.rs index 7b68925..bbb51e9 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -1,5 +1,8 @@ +use std::path::{Path, PathBuf}; use std::sync::Arc; +use std::time::SystemTime; +use futures::lock::Mutex; use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; use rustls::crypto::{CryptoProvider, WebPkiSupportedAlgorithms}; use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime}; @@ -9,24 +12,107 @@ use rustls::{ClientConfig, DigitallySignedStruct, Error as TlsError, RootCertSto use crate::client::Result; use crate::Error; -/// Options for tls connection. #[derive(Debug)] -pub struct TlsOptions { - identity: Option<(Vec>, PrivateKeyDer<'static>)>, - ca_certs: RootCertStore, - hostname_verification: bool, +struct FileProvider { + certs: Vec>, + key: PrivateKeyDer<'static>, + cert_path: PathBuf, + key_path: PathBuf, + cert_modified: SystemTime, + key_modified: SystemTime, +} + +impl FileProvider { + async fn new(cert_path: PathBuf, key_path: PathBuf) -> Result { + let (certs, key) = load_certificates_from_files(&cert_path, &key_path).await?; + let (cert_modified, key_modified) = get_file_timestamps(&cert_path, &key_path).await?; + Ok(Self { certs, key, cert_path, key_path, cert_modified, key_modified }) + } + + async fn update_and_fetch(&mut self) -> Result<(Vec>, PrivateKeyDer<'static>)> { + let (cert_modified, key_modified) = get_file_timestamps(&self.cert_path, &self.key_path).await?; + let cert_changed = cert_modified > self.cert_modified; + let key_changed = key_modified > self.key_modified; + // Refresh if both files were modified, as we want to make sure that we don't pick up a new cert/key with + // an old key/cert. + if cert_changed && key_changed { + tracing::debug!("Reloading client certificates"); + match load_certificates_from_files(&self.cert_path, &self.key_path).await { + Err(e) => tracing::warn!("Failed to reload certificates, keeping existing ones: {}", e), + Ok((certs, key)) => { + tracing::info!("Reloaded client certificates"); + println!("Reloaded client certificates"); + self.cert_modified = cert_modified; + self.key_modified = key_modified; + self.certs = certs; + self.key = key; + }, + } + } + Ok((self.certs.clone(), self.key.clone_key())) + } +} + +async fn load_certificates_from_files( + cert_path: &Path, + key_path: &Path, +) -> Result<(Vec>, PrivateKeyDer<'static>)> { + let cert_content = async_fs::read_to_string(cert_path) + .await + .map_err(|e| Error::with_other("Failed to read certificate file", e))?; + let key_content = + async_fs::read_to_string(key_path).await.map_err(|e| Error::with_other("Failed to read key file", e))?; + parse_pem_identity(&cert_content, &key_content) +} + +async fn get_file_timestamps(cert_path: &Path, key_path: &Path) -> Result<(SystemTime, SystemTime)> { + let cert_metadata = async_fs::metadata(cert_path) + .await + .map_err(|e| Error::with_other("Failed to get certificate file metadata", e))?; + let key_metadata = + async_fs::metadata(key_path).await.map_err(|e| Error::with_other("Failed to get key file metadata", e))?; + + let cert_modified = cert_metadata.modified().unwrap_or(SystemTime::UNIX_EPOCH); + let key_modified = key_metadata.modified().unwrap_or(SystemTime::UNIX_EPOCH); + Ok((cert_modified, key_modified)) +} + +#[derive(Debug)] +enum IdentityProvider { + Static { certs: Vec>, key: PrivateKeyDer<'static> }, + FileBased { provider: Arc> }, } -impl Clone for TlsOptions { +impl IdentityProvider { + pub async fn check_and_reload_certificates( + &self, + ) -> Result<(Vec>, PrivateKeyDer<'static>)> { + match self { + IdentityProvider::Static { certs, key } => Ok((certs.clone(), key.clone_key())), + IdentityProvider::FileBased { provider } => provider.lock().await.update_and_fetch().await, + } + } +} + +impl Clone for IdentityProvider { fn clone(&self) -> Self { - Self { - identity: self.identity.as_ref().map(|id| (id.0.clone(), id.1.clone_key())), - ca_certs: self.ca_certs.clone(), - hostname_verification: self.hostname_verification, + match self { + IdentityProvider::Static { certs, key } => { + IdentityProvider::Static { certs: certs.clone(), key: key.clone_key() } + }, + provider @ IdentityProvider::FileBased { .. } => provider.clone(), } } } +/// Options for tls connection. +#[derive(Debug, Clone)] +pub struct TlsOptions { + identity_provider: Option, + ca_certs: RootCertStore, + hostname_verification: bool, +} + impl Default for TlsOptions { /// Tls options with well-known ca roots. fn default() -> Self { @@ -105,11 +191,29 @@ impl ServerCertVerifier for TlsServerCertVerifier { } } +/// Helper function to parse certificate and key content from strings +fn parse_pem_identity( + cert_content: &str, + key_content: &str, +) -> Result<(Vec>, PrivateKeyDer<'static>)> { + let r: std::result::Result, _> = rustls_pemfile::certs(&mut cert_content.as_bytes()).collect(); + let certs = match r { + Err(err) => return Err(Error::with_other("fail to read cert", err)), + Ok(certs) => certs, + }; + let key = match rustls_pemfile::private_key(&mut key_content.as_bytes()) { + Err(err) => return Err(Error::with_other("fail to read client private key", err)), + Ok(None) => return Err(Error::BadArguments(&"no client private key")), + Ok(Some(key)) => key, + }; + Ok((certs, key)) +} + impl TlsOptions { /// Tls options with no ca certificates. Use [TlsOptions::default] if well-known ca roots is /// desirable. pub fn no_ca() -> Self { - Self { ca_certs: RootCertStore::empty(), identity: None, hostname_verification: true } + Self { ca_certs: RootCertStore::empty(), identity_provider: None, hostname_verification: true } } /// Disables hostname verification in tls handshake. @@ -137,30 +241,33 @@ impl TlsOptions { /// Specifies client identity for server to authenticate. pub fn with_pem_identity(mut self, cert: &str, key: &str) -> Result { - let r: std::result::Result, _> = rustls_pemfile::certs(&mut cert.as_bytes()).collect(); - let certs = match r { - Err(err) => return Err(Error::with_other("fail to read cert", err)), - Ok(certs) => certs, - }; - let key = match rustls_pemfile::private_key(&mut key.as_bytes()) { - Err(err) => return Err(Error::with_other("fail to read client private key", err)), - Ok(None) => return Err(Error::BadArguments(&"no client private key")), - Ok(Some(key)) => key, - }; - self.identity = Some((certs, key)); + let (certs, key) = parse_pem_identity(cert, key)?; + self.identity_provider = Some(IdentityProvider::Static { certs, key }); Ok(self) } - fn take_roots(&mut self) -> RootCertStore { - std::mem::replace(&mut self.ca_certs, RootCertStore::empty()) + /// Specifies client identity from file paths with automatic reloading on file changes when + /// reconnections take place. + pub async fn with_pem_identity_files( + mut self, + cert_path: impl Into, + key_path: impl Into, + ) -> Result { + let cert_path = cert_path.into(); + let key_path = key_path.into(); + + let file_provider = FileProvider::new(cert_path, key_path).await?; + self.identity_provider = Some(IdentityProvider::FileBased { provider: Arc::new(Mutex::new(file_provider)) }); + + Ok(self) } - pub(crate) fn into_config(mut self) -> Result { - // This has to be called before server cert verifier to install default crypto provider. + pub(crate) async fn to_config(&self) -> Result { let builder = ClientConfig::builder(); - let verifier = TlsServerCertVerifier::new(self.take_roots(), self.hostname_verification); + let verifier = TlsServerCertVerifier::new(self.ca_certs.clone(), self.hostname_verification); let builder = builder.dangerous().with_custom_certificate_verifier(Arc::new(verifier)); - if let Some((client_cert, client_key)) = self.identity.take() { + if let Some(identity_provider) = &self.identity_provider { + let (client_cert, client_key) = identity_provider.check_and_reload_certificates().await?; match builder.with_client_auth_cert(client_cert, client_key) { Ok(config) => Ok(config), Err(err) => Err(Error::with_other("invalid client private key", err)), @@ -170,3 +277,137 @@ impl TlsOptions { } } } + +#[cfg(test)] +mod tests { + use std::time::Duration; + use std::{fs, thread}; + + use rcgen::{Certificate, CertificateParams}; + use tempfile::TempDir; + + use super::*; + + fn generate_test_cert_and_key() -> (String, String) { + let mut params = CertificateParams::new(vec!["localhost".to_string()]); + params.alg = &rcgen::PKCS_ECDSA_P256_SHA256; + + let cert = Certificate::from_params(params).unwrap(); + let cert_pem = cert.serialize_pem().unwrap(); + let key_pem = cert.serialize_private_key_pem(); + + (cert_pem, key_pem) + } + + #[asyncs::test] + async fn test_with_pem_identity_files() { + let temp_dir = TempDir::new().unwrap(); + let cert_path = temp_dir.path().join("test.crt"); + let key_path = temp_dir.path().join("test.key"); + + // Generate valid test certificates + let (cert_pem, key_pem) = generate_test_cert_and_key(); + fs::write(&cert_path, &cert_pem).unwrap(); + fs::write(&key_path, &key_pem).unwrap(); + + // Test loading certificates from files + let tls_options = TlsOptions::default().with_pem_identity_files(&cert_path, &key_path).await.unwrap(); + + // Verify that identity was loaded + assert!(tls_options.identity_provider.is_some()); + } + + #[asyncs::test] + async fn test_with_pem_identity_files_missing_cert() { + let temp_dir = TempDir::new().unwrap(); + let cert_path = temp_dir.path().join("missing.crt"); + let key_path = temp_dir.path().join("test.key"); + + let (_, key_pem) = generate_test_cert_and_key(); + fs::write(&key_path, &key_pem).unwrap(); + + // Should fail when certificate file is missing + let result = TlsOptions::default().with_pem_identity_files(&cert_path, &key_path).await; + + assert!(result.is_err()); + } + + #[asyncs::test] + async fn test_with_pem_identity_files_missing_key() { + let temp_dir = TempDir::new().unwrap(); + let cert_path = temp_dir.path().join("test.crt"); + let key_path = temp_dir.path().join("missing.key"); + + let (cert_pem, _) = generate_test_cert_and_key(); + fs::write(&cert_path, &cert_pem).unwrap(); + + // Should fail when key file is missing + let result = TlsOptions::default().with_pem_identity_files(&cert_path, &key_path).await; + + assert!(result.is_err()); + } + + #[asyncs::test] + async fn test_check_and_reload_certificates_no_changes() { + let temp_dir = TempDir::new().unwrap(); + let cert_path = temp_dir.path().join("test.crt"); + let key_path = temp_dir.path().join("test.key"); + + let (cert_pem, key_pem) = generate_test_cert_and_key(); + fs::write(&cert_path, &cert_pem).unwrap(); + fs::write(&key_path, &key_pem).unwrap(); + + let tls_options = TlsOptions::default().with_pem_identity_files(&cert_path, &key_path).await.unwrap(); + + let (cert_1, key_1) = + tls_options.identity_provider.as_ref().unwrap().check_and_reload_certificates().await.unwrap(); + let (cert_2, key_2) = + tls_options.identity_provider.as_ref().unwrap().check_and_reload_certificates().await.unwrap(); + assert_eq!(cert_1, cert_2); + assert_eq!(key_1, key_2); + } + + #[asyncs::test] + async fn test_check_and_reload_certificates_key_changes() { + let temp_dir = TempDir::new().unwrap(); + let cert_path = temp_dir.path().join("test.crt"); + let key_path = temp_dir.path().join("test.key"); + + let (cert_pem, key_pem) = generate_test_cert_and_key(); + fs::write(&cert_path, &cert_pem).unwrap(); + fs::write(&key_path, &key_pem).unwrap(); + + let tls_options = TlsOptions::default().with_pem_identity_files(&cert_path, &key_path).await.unwrap(); + let (cert_1, key_1) = tls_options.identity_provider.unwrap().check_and_reload_certificates().await.unwrap(); + + // Sleep to ensure different modification time + thread::sleep(Duration::from_millis(50)); + + // Update the key file with new content (must update both cert and key for valid pair) + let (new_cert_pem, new_key_pem) = generate_test_cert_and_key(); + fs::write(&cert_path, &new_cert_pem).unwrap(); + fs::write(&key_path, &new_key_pem).unwrap(); + + let tls_options = TlsOptions::default().with_pem_identity_files(&cert_path, &key_path).await.unwrap(); + let (cert_2, key_2) = tls_options.identity_provider.unwrap().check_and_reload_certificates().await.unwrap(); + assert!(cert_1 != cert_2); + assert!(key_1 != key_2); + } + + #[asyncs::test] + async fn test_into_config_with_file_based_certs() { + let temp_dir = TempDir::new().unwrap(); + let cert_path = temp_dir.path().join("test.crt"); + let key_path = temp_dir.path().join("test.key"); + + let (cert_pem, key_pem) = generate_test_cert_and_key(); + fs::write(&cert_path, &cert_pem).unwrap(); + fs::write(&key_path, &key_pem).unwrap(); + + let tls_options = TlsOptions::default().with_pem_identity_files(&cert_path, &key_path).await.unwrap(); + + // Should be able to create a valid ClientConfig + let config = tls_options.to_config().await; + assert!(config.is_ok()); + } +}