Skip to content

Commit 5ee3675

Browse files
committed
Reload client certificates
1 parent 468b6e3 commit 5ee3675

File tree

5 files changed

+412
-39
lines changed

5 files changed

+412
-39
lines changed

Cargo.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ rust-version = "1.76"
1616

1717
[features]
1818
default = []
19-
tls = ["rustls", "rustls-pemfile", "webpki-roots", "futures-rustls"]
19+
tls = ["rustls", "rustls-pemfile", "webpki-roots", "futures-rustls", "async-fs"]
2020
sasl = ["sasl-gssapi", "sasl-digest-md5"]
2121
sasl-digest-md5 = ["rsasl/unstable_custom_mechanism", "md5", "linkme", "hex"]
2222
sasl-gssapi = ["rsasl/gssapi"]
@@ -48,6 +48,7 @@ md5 = { version = "0.7.0", optional = true }
4848
hex = { version = "0.4.3", optional = true }
4949
linkme = { version = "0.3", optional = true }
5050
async-io = "2.3.2"
51+
async-fs = { version = "2.1.2", optional = true }
5152
futures = "0.3.30"
5253
async-net = "2.0.0"
5354
futures-rustls = { version = "0.26.0", optional = true }
@@ -78,3 +79,8 @@ all-features = true
7879
[profile.dev]
7980
# Need this for linkme crate to work for spawns in macOS
8081
lto = "thin"
82+
83+
[[example]]
84+
name = "tls_file_based"
85+
path = "examples/tls_file_based.rs"
86+
required-features = ["tls", "tokio"]

examples/tls_file_based.rs

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
use std::env;
2+
use std::io::{self, Write};
3+
use std::path::PathBuf;
4+
use std::time::Duration;
5+
6+
use futures::executor::block_on;
7+
use zookeeper_client::Error::NodeExists;
8+
use zookeeper_client::{Acls, Client, CreateMode, TlsOptions};
9+
10+
fn main() -> Result<(), Box<dyn std::error::Error>> {
11+
block_on(run()).unwrap_or_else(|e| {
12+
eprintln!("Error: {}", e);
13+
std::process::exit(1);
14+
});
15+
Ok(())
16+
}
17+
18+
async fn run() -> Result<(), Box<dyn std::error::Error>> {
19+
let connect_string = env::var("ZK_CONNECT_STRING").unwrap_or_else(|_| "tcp+tls://localhost:2281".to_string());
20+
let ca_cert = PathBuf::from(env::var("ZK_CA_CERT").expect("ZK_CA_CERT environment variable is required"));
21+
let client_cert =
22+
PathBuf::from(env::var("ZK_CLIENT_CERT").expect("ZK_CLIENT_CERT environment variable is required"));
23+
let client_key = PathBuf::from(env::var("ZK_CLIENT_KEY").expect("ZK_CLIENT_KEY environment variable is required"));
24+
25+
println!("Connecting to ZooKeeper with file-based TLS...");
26+
println!("Server: {}", connect_string);
27+
println!("CA cert: {}", ca_cert.display());
28+
println!("Client cert: {}", client_cert.display());
29+
println!("Client key: {}", client_key.display());
30+
31+
let tls_options = TlsOptions::default()
32+
.with_pem_ca_certs(&ca_cert.to_str().unwrap())?
33+
.with_pem_identity_files(&client_cert, &client_key)
34+
.await?;
35+
36+
let tls_options = unsafe { tls_options.with_no_hostname_verification() };
37+
38+
println!("WARNING: Hostname verification disabled!");
39+
40+
let client = Client::connector()
41+
.connection_timeout(Duration::from_secs(10))
42+
.session_timeout(Duration::from_secs(30))
43+
.tls(tls_options)
44+
.secure_connect(&connect_string)
45+
.await?;
46+
47+
println!("Connected to ZooKeeper successfully!");
48+
49+
let path = "/tls_example";
50+
51+
loop {
52+
print!("\nOptions:\ne. Edit key\nq. Quit\nEnter choice (e/q): ");
53+
io::stdout().flush()?;
54+
55+
let mut input = String::new();
56+
io::stdin().read_line(&mut input)?;
57+
58+
match input.trim() {
59+
"e" => {
60+
print!("Enter new data for the key: ");
61+
io::stdout().flush()?;
62+
63+
let mut data_input = String::new();
64+
io::stdin().read_line(&mut data_input)?;
65+
let data = data_input.trim().as_bytes();
66+
67+
println!("Setting data at path: {}", path);
68+
match client.create(path, data, &CreateMode::Ephemeral.with_acls(Acls::anyone_all())).await {
69+
Ok(_) => println!("ZNode created successfully"),
70+
Err(NodeExists) => {
71+
println!("ZNode already exists, updating data...");
72+
client.set_data(path, data, None).await?;
73+
println!("ZNode data updated successfully");
74+
},
75+
Err(e) => {
76+
println!("Error creating/updating ZNode: {}", e);
77+
continue;
78+
},
79+
}
80+
81+
match client.get_data(path).await {
82+
Ok((data, _stat)) => {
83+
println!("Current data: {}", String::from_utf8_lossy(&data));
84+
},
85+
Err(e) => println!("Error reading data: {}", e),
86+
}
87+
},
88+
"q" => {
89+
println!("Cleaning up and exiting...");
90+
match client.delete(path, None).await {
91+
Ok(_) => println!("ZNode deleted successfully"),
92+
Err(_) => println!("ZNode may not exist or already deleted"),
93+
}
94+
break;
95+
},
96+
_ => {
97+
println!("Invalid choice. Please enter 'e' or 'q'.");
98+
},
99+
}
100+
}
101+
102+
println!("Example completed successfully!");
103+
Ok(())
104+
}

src/session/connection.rs

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use std::io::{Error, ErrorKind, IoSlice, Result};
22
use std::pin::Pin;
3+
#[cfg(feature = "tls")]
4+
use std::sync::Arc;
35
use std::task::{Context, Poll};
46
use std::time::Duration;
57

@@ -15,18 +17,18 @@ use tracing::{debug, trace};
1517

1618
#[cfg(feature = "tls")]
1719
mod tls {
18-
pub use std::sync::Arc;
19-
2020
pub use futures_rustls::client::TlsStream;
2121
pub use futures_rustls::TlsConnector;
2222
pub use rustls::pki_types::ServerName;
23-
pub use rustls::ClientConfig;
2423
}
24+
2525
#[cfg(feature = "tls")]
2626
use tls::*;
2727

2828
use crate::deadline::Deadline;
2929
use crate::endpoint::{EndpointRef, IterableEndpoints};
30+
#[cfg(feature = "tls")]
31+
use crate::TlsOptions;
3032

3133
#[derive(Debug)]
3234
pub enum Connection {
@@ -170,15 +172,15 @@ impl Connection {
170172
#[derive(Clone)]
171173
pub struct Connector {
172174
#[cfg(feature = "tls")]
173-
tls: Option<TlsConnector>,
175+
tls_options: Option<TlsOptions>,
174176
timeout: Duration,
175177
}
176178

177179
impl Connector {
178180
#[cfg(feature = "tls")]
179181
#[allow(dead_code)]
180182
pub fn new() -> Self {
181-
Self { tls: None, timeout: Duration::from_secs(10) }
183+
Self { tls_options: None, timeout: Duration::from_secs(10) }
182184
}
183185

184186
#[cfg(not(feature = "tls"))]
@@ -187,14 +189,27 @@ impl Connector {
187189
}
188190

189191
#[cfg(feature = "tls")]
190-
pub fn with_tls(config: ClientConfig) -> Self {
191-
Self { tls: Some(TlsConnector::from(Arc::new(config))), timeout: Duration::from_secs(10) }
192+
pub fn with_tls_options(tls_options: TlsOptions) -> Self {
193+
Self { tls_options: Some(tls_options), timeout: Duration::from_secs(10) }
194+
}
195+
196+
#[cfg(feature = "tls")]
197+
async fn get_current_tls_connector(&self) -> Result<TlsConnector> {
198+
let Some(ref tls_opts) = self.tls_options else {
199+
return Err(Error::new(ErrorKind::InvalidInput, "no TLS configuration"));
200+
};
201+
let config = tls_opts
202+
.to_config()
203+
.await
204+
.map_err(|e| Error::new(ErrorKind::InvalidData, format!("TLS config creation failed: {}", e)))?;
205+
Ok(TlsConnector::from(Arc::new(config)))
192206
}
193207

194208
#[cfg(feature = "tls")]
195209
async fn connect_tls(&self, stream: TcpStream, host: &str) -> Result<Connection> {
210+
let tls_connector = self.get_current_tls_connector().await?;
196211
let domain = ServerName::try_from(host).unwrap().to_owned();
197-
let stream = self.tls.as_ref().unwrap().connect(domain, stream).await?;
212+
let stream = tls_connector.connect(domain, stream).await?;
198213
Ok(Connection::new_tls(stream))
199214
}
200215

@@ -209,7 +224,7 @@ impl Connector {
209224
pub async fn connect(&self, endpoint: EndpointRef<'_>, deadline: &mut Deadline) -> Result<Connection> {
210225
if endpoint.tls {
211226
#[cfg(feature = "tls")]
212-
if self.tls.is_none() {
227+
if self.tls_options.is_none() {
213228
return Err(Error::new(ErrorKind::Unsupported, "tls not supported"));
214229
}
215230
#[cfg(not(feature = "tls"))]
@@ -288,4 +303,12 @@ mod tests {
288303
let err = connector.connect(endpoint, &mut Deadline::never()).await.unwrap_err();
289304
assert_eq!(err.kind(), ErrorKind::Unsupported);
290305
}
306+
307+
#[cfg(feature = "tls")]
308+
#[test]
309+
fn test_with_tls_options() {
310+
let tls_options = crate::TlsOptions::default();
311+
let connector = Connector::with_tls_options(tls_options);
312+
assert!(connector.tls_options.is_some());
313+
}
291314
}

src/session/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ impl Builder {
130130
return Err(Error::BadArguments(&"connection timeout must not be negative"));
131131
}
132132
#[cfg(feature = "tls")]
133-
let connector = Connector::with_tls(self.tls.unwrap_or_default().into_config()?);
133+
let connector = Connector::with_tls_options(self.tls.unwrap_or_default());
134134
#[cfg(not(feature = "tls"))]
135135
let connector = Connector::new();
136136
let (state_sender, state_receiver) = asyncs::sync::watch::channel(SessionState::Disconnected);

0 commit comments

Comments
 (0)