Skip to content

Commit 2e17c10

Browse files
committed
Reload client certificates
1 parent 468b6e3 commit 2e17c10

File tree

5 files changed

+415
-39
lines changed

5 files changed

+415
-39
lines changed

Cargo.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ rust-version = "1.76"
1616

1717
[features]
1818
default = []
19-
tls = ["rustls", "rustls-pemfile", "webpki-roots", "futures-rustls"]
19+
tls = ["rustls", "rustls-pemfile", "webpki-roots", "futures-rustls", "async-fs"]
2020
sasl = ["sasl-gssapi", "sasl-digest-md5"]
2121
sasl-digest-md5 = ["rsasl/unstable_custom_mechanism", "md5", "linkme", "hex"]
2222
sasl-gssapi = ["rsasl/gssapi"]
@@ -48,6 +48,7 @@ md5 = { version = "0.7.0", optional = true }
4848
hex = { version = "0.4.3", optional = true }
4949
linkme = { version = "0.3", optional = true }
5050
async-io = "2.3.2"
51+
async-fs = { version = "2.1.2", optional = true }
5152
futures = "0.3.30"
5253
async-net = "2.0.0"
5354
futures-rustls = { version = "0.26.0", optional = true }
@@ -67,6 +68,7 @@ tempfile = "3.6.0"
6768
rcgen = { version = "0.12.1", features = ["default", "x509-parser"] }
6869
serial_test = "3.0.0"
6970
asyncs = { version = "0.3.0", features = ["test"] }
71+
smol = "2.0.2"
7072
blocking = "1.6.0"
7173

7274
[package.metadata.cargo-all-features]
@@ -78,3 +80,8 @@ all-features = true
7880
[profile.dev]
7981
# Need this for linkme crate to work for spawns in macOS
8082
lto = "thin"
83+
84+
[[example]]
85+
name = "tls_file_based"
86+
path = "examples/tls_file_based.rs"
87+
required-features = ["tls", "smol"]

examples/tls_file_based.rs

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
use std::env;
2+
use std::io::{self, Write};
3+
use std::path::PathBuf;
4+
use std::time::Duration;
5+
6+
use zookeeper_client::Error::NodeExists;
7+
use zookeeper_client::{Acls, Client, CreateMode, TlsOptions};
8+
9+
fn main() -> Result<(), Box<dyn std::error::Error>> {
10+
env_logger::init();
11+
smol::block_on(run()).unwrap_or_else(|e| {
12+
eprintln!("Error: {}", e);
13+
std::process::exit(1);
14+
});
15+
Ok(())
16+
}
17+
18+
async fn run() -> Result<(), Box<dyn std::error::Error>> {
19+
let connect_string = env::var("ZK_CONNECT_STRING").unwrap_or_else(|_| "tcp+tls://localhost:2281".to_string());
20+
let ca_cert = PathBuf::from(env::var("ZK_CA_CERT").expect("ZK_CA_CERT environment variable is required"));
21+
let client_cert =
22+
PathBuf::from(env::var("ZK_CLIENT_CERT").expect("ZK_CLIENT_CERT environment variable is required"));
23+
let client_key = PathBuf::from(env::var("ZK_CLIENT_KEY").expect("ZK_CLIENT_KEY environment variable is required"));
24+
25+
println!("Connecting to ZooKeeper with file-based TLS...");
26+
println!("Server: {}", connect_string);
27+
println!("CA cert: {}", ca_cert.display());
28+
println!("Client cert: {}", client_cert.display());
29+
println!("Client key: {}", client_key.display());
30+
31+
let loaded_ca_cert = async_fs::read_to_string(&ca_cert).await?;
32+
let tls_options = TlsOptions::default()
33+
.with_pem_ca_certs(&loaded_ca_cert)?
34+
.with_pem_identity_files(&client_cert, &client_key)
35+
.await?;
36+
37+
let tls_options = unsafe { tls_options.with_no_hostname_verification() };
38+
39+
println!("WARNING: Hostname verification disabled!");
40+
41+
let client = Client::connector()
42+
.connection_timeout(Duration::from_secs(10))
43+
.session_timeout(Duration::from_secs(30))
44+
.tls(tls_options)
45+
.secure_connect(&connect_string)
46+
.await?;
47+
48+
println!("Connected to ZooKeeper successfully!");
49+
50+
let path = "/tls_example";
51+
52+
loop {
53+
print!("\nOptions:\ne. Edit key\nq. Quit\nEnter choice (e/q): ");
54+
io::stdout().flush()?;
55+
56+
let mut input = String::new();
57+
io::stdin().read_line(&mut input)?;
58+
59+
match input.trim() {
60+
"e" => {
61+
print!("Enter new data for the key: ");
62+
io::stdout().flush()?;
63+
64+
let mut data_input = String::new();
65+
io::stdin().read_line(&mut data_input)?;
66+
let data = data_input.trim().as_bytes();
67+
68+
println!("Setting data at path: {}", path);
69+
match client.create(path, data, &CreateMode::Ephemeral.with_acls(Acls::anyone_all())).await {
70+
Ok(_) => println!("ZNode created successfully"),
71+
Err(NodeExists) => {
72+
println!("ZNode already exists, updating data...");
73+
client.set_data(path, data, None).await?;
74+
println!("ZNode data updated successfully");
75+
},
76+
Err(e) => {
77+
println!("Error creating/updating ZNode: {}", e);
78+
continue;
79+
},
80+
}
81+
82+
match client.get_data(path).await {
83+
Ok((data, _stat)) => {
84+
println!("Current data: {}", String::from_utf8_lossy(&data));
85+
},
86+
Err(e) => println!("Error reading data: {}", e),
87+
}
88+
},
89+
"q" => {
90+
println!("Cleaning up and exiting...");
91+
match client.delete(path, None).await {
92+
Ok(_) => println!("ZNode deleted successfully"),
93+
Err(_) => println!("ZNode may not exist or already deleted"),
94+
}
95+
break;
96+
},
97+
_ => {
98+
println!("Invalid choice. Please enter 'e' or 'q'.");
99+
},
100+
}
101+
}
102+
103+
println!("Example completed successfully!");
104+
Ok(())
105+
}

src/session/connection.rs

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use std::io::{Error, ErrorKind, IoSlice, Result};
22
use std::pin::Pin;
3+
#[cfg(feature = "tls")]
4+
use std::sync::Arc;
35
use std::task::{Context, Poll};
46
use std::time::Duration;
57

@@ -15,18 +17,18 @@ use tracing::{debug, trace};
1517

1618
#[cfg(feature = "tls")]
1719
mod tls {
18-
pub use std::sync::Arc;
19-
2020
pub use futures_rustls::client::TlsStream;
2121
pub use futures_rustls::TlsConnector;
2222
pub use rustls::pki_types::ServerName;
23-
pub use rustls::ClientConfig;
2423
}
24+
2525
#[cfg(feature = "tls")]
2626
use tls::*;
2727

2828
use crate::deadline::Deadline;
2929
use crate::endpoint::{EndpointRef, IterableEndpoints};
30+
#[cfg(feature = "tls")]
31+
use crate::TlsOptions;
3032

3133
#[derive(Debug)]
3234
pub enum Connection {
@@ -170,15 +172,15 @@ impl Connection {
170172
#[derive(Clone)]
171173
pub struct Connector {
172174
#[cfg(feature = "tls")]
173-
tls: Option<TlsConnector>,
175+
tls_options: Option<TlsOptions>,
174176
timeout: Duration,
175177
}
176178

177179
impl Connector {
178180
#[cfg(feature = "tls")]
179181
#[allow(dead_code)]
180182
pub fn new() -> Self {
181-
Self { tls: None, timeout: Duration::from_secs(10) }
183+
Self { tls_options: None, timeout: Duration::from_secs(10) }
182184
}
183185

184186
#[cfg(not(feature = "tls"))]
@@ -187,14 +189,27 @@ impl Connector {
187189
}
188190

189191
#[cfg(feature = "tls")]
190-
pub fn with_tls(config: ClientConfig) -> Self {
191-
Self { tls: Some(TlsConnector::from(Arc::new(config))), timeout: Duration::from_secs(10) }
192+
pub fn with_tls_options(tls_options: TlsOptions) -> Self {
193+
Self { tls_options: Some(tls_options), timeout: Duration::from_secs(10) }
194+
}
195+
196+
#[cfg(feature = "tls")]
197+
async fn get_current_tls_connector(&self) -> Result<TlsConnector> {
198+
let Some(ref tls_opts) = self.tls_options else {
199+
return Err(Error::new(ErrorKind::InvalidInput, "no TLS configuration"));
200+
};
201+
let config = tls_opts
202+
.to_config()
203+
.await
204+
.map_err(|e| Error::new(ErrorKind::InvalidData, format!("TLS config creation failed: {}", e)))?;
205+
Ok(TlsConnector::from(Arc::new(config)))
192206
}
193207

194208
#[cfg(feature = "tls")]
195209
async fn connect_tls(&self, stream: TcpStream, host: &str) -> Result<Connection> {
210+
let tls_connector = self.get_current_tls_connector().await?;
196211
let domain = ServerName::try_from(host).unwrap().to_owned();
197-
let stream = self.tls.as_ref().unwrap().connect(domain, stream).await?;
212+
let stream = tls_connector.connect(domain, stream).await?;
198213
Ok(Connection::new_tls(stream))
199214
}
200215

@@ -209,7 +224,7 @@ impl Connector {
209224
pub async fn connect(&self, endpoint: EndpointRef<'_>, deadline: &mut Deadline) -> Result<Connection> {
210225
if endpoint.tls {
211226
#[cfg(feature = "tls")]
212-
if self.tls.is_none() {
227+
if self.tls_options.is_none() {
213228
return Err(Error::new(ErrorKind::Unsupported, "tls not supported"));
214229
}
215230
#[cfg(not(feature = "tls"))]
@@ -288,4 +303,12 @@ mod tests {
288303
let err = connector.connect(endpoint, &mut Deadline::never()).await.unwrap_err();
289304
assert_eq!(err.kind(), ErrorKind::Unsupported);
290305
}
306+
307+
#[cfg(feature = "tls")]
308+
#[test]
309+
fn test_with_tls_options() {
310+
let tls_options = crate::TlsOptions::default();
311+
let connector = Connector::with_tls_options(tls_options);
312+
assert!(connector.tls_options.is_some());
313+
}
291314
}

src/session/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ impl Builder {
130130
return Err(Error::BadArguments(&"connection timeout must not be negative"));
131131
}
132132
#[cfg(feature = "tls")]
133-
let connector = Connector::with_tls(self.tls.unwrap_or_default().into_config()?);
133+
let connector = Connector::with_tls_options(self.tls.unwrap_or_default());
134134
#[cfg(not(feature = "tls"))]
135135
let connector = Connector::new();
136136
let (state_sender, state_receiver) = asyncs::sync::watch::channel(SessionState::Disconnected);

0 commit comments

Comments
 (0)