Skip to content

Commit 4bcf3dc

Browse files
committed
Reload client certificates
This allows creating a client with certificate paths instead of a preloaded certificate. When created this way, on reconnection the client will check if the certificate files have been changed on disk and reload them if they have. This allows us to have auto-reloading of refreshed certificates client side.
1 parent 468b6e3 commit 4bcf3dc

File tree

5 files changed

+415
-39
lines changed

5 files changed

+415
-39
lines changed

Cargo.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ rust-version = "1.76"
1616

1717
[features]
1818
default = []
19-
tls = ["rustls", "rustls-pemfile", "webpki-roots", "futures-rustls"]
19+
tls = ["rustls", "rustls-pemfile", "webpki-roots", "futures-rustls", "async-fs"]
2020
sasl = ["sasl-gssapi", "sasl-digest-md5"]
2121
sasl-digest-md5 = ["rsasl/unstable_custom_mechanism", "md5", "linkme", "hex"]
2222
sasl-gssapi = ["rsasl/gssapi"]
@@ -48,6 +48,7 @@ md5 = { version = "0.7.0", optional = true }
4848
hex = { version = "0.4.3", optional = true }
4949
linkme = { version = "0.3", optional = true }
5050
async-io = "2.3.2"
51+
async-fs = { version = "2.1.2", optional = true }
5152
futures = "0.3.30"
5253
async-net = "2.0.0"
5354
futures-rustls = { version = "0.26.0", optional = true }
@@ -67,6 +68,7 @@ tempfile = "3.6.0"
6768
rcgen = { version = "0.12.1", features = ["default", "x509-parser"] }
6869
serial_test = "3.0.0"
6970
asyncs = { version = "0.3.0", features = ["test"] }
71+
smol = "2.0.2"
7072
blocking = "1.6.0"
7173

7274
[package.metadata.cargo-all-features]
@@ -78,3 +80,8 @@ all-features = true
7880
[profile.dev]
7981
# Need this for linkme crate to work for spawns in macOS
8082
lto = "thin"
83+
84+
[[example]]
85+
name = "tls_file_based"
86+
path = "examples/tls_file_based.rs"
87+
required-features = ["tls", "smol"]

examples/tls_file_based.rs

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
use std::env;
2+
use std::io::{self, Write};
3+
use std::path::PathBuf;
4+
use std::time::Duration;
5+
6+
use zookeeper_client::Error::NodeExists;
7+
use zookeeper_client::{Acls, Client, CreateMode, TlsOptions};
8+
9+
fn main() -> Result<(), Box<dyn std::error::Error>> {
10+
env_logger::init();
11+
smol::block_on(run()).unwrap_or_else(|e| {
12+
eprintln!("Error: {}", e);
13+
std::process::exit(1);
14+
});
15+
Ok(())
16+
}
17+
18+
async fn run() -> Result<(), Box<dyn std::error::Error>> {
19+
let connect_string = env::var("ZK_CONNECT_STRING").unwrap_or_else(|_| "tcp+tls://localhost:2281".to_string());
20+
let ca_cert = PathBuf::from(env::var("ZK_CA_CERT").expect("ZK_CA_CERT environment variable is required"));
21+
let client_cert =
22+
PathBuf::from(env::var("ZK_CLIENT_CERT").expect("ZK_CLIENT_CERT environment variable is required"));
23+
let client_key = PathBuf::from(env::var("ZK_CLIENT_KEY").expect("ZK_CLIENT_KEY environment variable is required"));
24+
25+
println!("Connecting to ZooKeeper with file-based TLS...");
26+
println!("Server: {}", connect_string);
27+
println!("CA cert: {}", ca_cert.display());
28+
println!("Client cert: {}", client_cert.display());
29+
println!("Client key: {}", client_key.display());
30+
31+
let loaded_ca_cert = async_fs::read_to_string(&ca_cert).await?;
32+
let tls_options = TlsOptions::default()
33+
.with_pem_ca_certs(&loaded_ca_cert)?
34+
.with_pem_identity_files(&client_cert, &client_key)
35+
.await?;
36+
37+
let tls_options = unsafe { tls_options.with_no_hostname_verification() };
38+
39+
println!("WARNING: Hostname verification disabled!");
40+
41+
let client = Client::connector()
42+
.connection_timeout(Duration::from_secs(10))
43+
.session_timeout(Duration::from_secs(30))
44+
.tls(tls_options)
45+
.secure_connect(&connect_string)
46+
.await?;
47+
48+
println!("Connected to ZooKeeper successfully!");
49+
50+
let path = "/tls_example";
51+
52+
loop {
53+
print!("\nOptions:\ne. Edit key\nq. Quit\nEnter choice (e/q): ");
54+
io::stdout().flush()?;
55+
56+
let mut input = String::new();
57+
io::stdin().read_line(&mut input)?;
58+
59+
match input.trim() {
60+
"e" => {
61+
print!("Enter new data for the key: ");
62+
io::stdout().flush()?;
63+
64+
let mut data_input = String::new();
65+
io::stdin().read_line(&mut data_input)?;
66+
let data = data_input.trim().as_bytes();
67+
68+
println!("Setting data at path: {}", path);
69+
match client.create(path, data, &CreateMode::Ephemeral.with_acls(Acls::anyone_all())).await {
70+
Ok(_) => println!("ZNode created successfully"),
71+
Err(NodeExists) => {
72+
println!("ZNode already exists, updating data...");
73+
client.set_data(path, data, None).await?;
74+
println!("ZNode data updated successfully");
75+
},
76+
Err(e) => {
77+
println!("Error creating/updating ZNode: {}", e);
78+
continue;
79+
},
80+
}
81+
82+
match client.get_data(path).await {
83+
Ok((data, _stat)) => {
84+
println!("Current data: {}", String::from_utf8_lossy(&data));
85+
},
86+
Err(e) => println!("Error reading data: {}", e),
87+
}
88+
},
89+
"q" => {
90+
println!("Cleaning up and exiting...");
91+
match client.delete(path, None).await {
92+
Ok(_) => println!("ZNode deleted successfully"),
93+
Err(_) => println!("ZNode may not exist or already deleted"),
94+
}
95+
break;
96+
},
97+
_ => {
98+
println!("Invalid choice. Please enter 'e' or 'q'.");
99+
},
100+
}
101+
}
102+
103+
println!("Example completed successfully!");
104+
Ok(())
105+
}

src/session/connection.rs

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use std::io::{Error, ErrorKind, IoSlice, Result};
22
use std::pin::Pin;
3+
#[cfg(feature = "tls")]
4+
use std::sync::Arc;
35
use std::task::{Context, Poll};
46
use std::time::Duration;
57

@@ -15,18 +17,18 @@ use tracing::{debug, trace};
1517

1618
#[cfg(feature = "tls")]
1719
mod tls {
18-
pub use std::sync::Arc;
19-
2020
pub use futures_rustls::client::TlsStream;
2121
pub use futures_rustls::TlsConnector;
2222
pub use rustls::pki_types::ServerName;
23-
pub use rustls::ClientConfig;
2423
}
24+
2525
#[cfg(feature = "tls")]
2626
use tls::*;
2727

2828
use crate::deadline::Deadline;
2929
use crate::endpoint::{EndpointRef, IterableEndpoints};
30+
#[cfg(feature = "tls")]
31+
use crate::TlsOptions;
3032

3133
#[derive(Debug)]
3234
pub enum Connection {
@@ -170,15 +172,15 @@ impl Connection {
170172
#[derive(Clone)]
171173
pub struct Connector {
172174
#[cfg(feature = "tls")]
173-
tls: Option<TlsConnector>,
175+
tls_options: Option<TlsOptions>,
174176
timeout: Duration,
175177
}
176178

177179
impl Connector {
178180
#[cfg(feature = "tls")]
179181
#[allow(dead_code)]
180182
pub fn new() -> Self {
181-
Self { tls: None, timeout: Duration::from_secs(10) }
183+
Self { tls_options: None, timeout: Duration::from_secs(10) }
182184
}
183185

184186
#[cfg(not(feature = "tls"))]
@@ -187,14 +189,27 @@ impl Connector {
187189
}
188190

189191
#[cfg(feature = "tls")]
190-
pub fn with_tls(config: ClientConfig) -> Self {
191-
Self { tls: Some(TlsConnector::from(Arc::new(config))), timeout: Duration::from_secs(10) }
192+
pub fn with_tls_options(tls_options: TlsOptions) -> Self {
193+
Self { tls_options: Some(tls_options), timeout: Duration::from_secs(10) }
194+
}
195+
196+
#[cfg(feature = "tls")]
197+
async fn get_current_tls_connector(&self) -> Result<TlsConnector> {
198+
let Some(ref tls_opts) = self.tls_options else {
199+
return Err(Error::new(ErrorKind::InvalidInput, "no TLS configuration"));
200+
};
201+
let config = tls_opts
202+
.to_config()
203+
.await
204+
.map_err(|e| Error::new(ErrorKind::InvalidData, format!("TLS config creation failed: {}", e)))?;
205+
Ok(TlsConnector::from(Arc::new(config)))
192206
}
193207

194208
#[cfg(feature = "tls")]
195209
async fn connect_tls(&self, stream: TcpStream, host: &str) -> Result<Connection> {
210+
let tls_connector = self.get_current_tls_connector().await?;
196211
let domain = ServerName::try_from(host).unwrap().to_owned();
197-
let stream = self.tls.as_ref().unwrap().connect(domain, stream).await?;
212+
let stream = tls_connector.connect(domain, stream).await?;
198213
Ok(Connection::new_tls(stream))
199214
}
200215

@@ -209,7 +224,7 @@ impl Connector {
209224
pub async fn connect(&self, endpoint: EndpointRef<'_>, deadline: &mut Deadline) -> Result<Connection> {
210225
if endpoint.tls {
211226
#[cfg(feature = "tls")]
212-
if self.tls.is_none() {
227+
if self.tls_options.is_none() {
213228
return Err(Error::new(ErrorKind::Unsupported, "tls not supported"));
214229
}
215230
#[cfg(not(feature = "tls"))]
@@ -288,4 +303,12 @@ mod tests {
288303
let err = connector.connect(endpoint, &mut Deadline::never()).await.unwrap_err();
289304
assert_eq!(err.kind(), ErrorKind::Unsupported);
290305
}
306+
307+
#[cfg(feature = "tls")]
308+
#[test]
309+
fn test_with_tls_options() {
310+
let tls_options = crate::TlsOptions::default();
311+
let connector = Connector::with_tls_options(tls_options);
312+
assert!(connector.tls_options.is_some());
313+
}
291314
}

src/session/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ impl Builder {
130130
return Err(Error::BadArguments(&"connection timeout must not be negative"));
131131
}
132132
#[cfg(feature = "tls")]
133-
let connector = Connector::with_tls(self.tls.unwrap_or_default().into_config()?);
133+
let connector = Connector::with_tls_options(self.tls.unwrap_or_default());
134134
#[cfg(not(feature = "tls"))]
135135
let connector = Connector::new();
136136
let (state_sender, state_receiver) = asyncs::sync::watch::channel(SessionState::Disconnected);

0 commit comments

Comments
 (0)