Skip to content

Reload client certificates #59

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ rust-version = "1.76"

[features]
default = []
tls = ["rustls", "rustls-pemfile", "webpki-roots", "futures-rustls"]
tls = ["rustls", "rustls-pemfile", "webpki-roots", "futures-rustls", "async-fs"]
sasl = ["sasl-gssapi", "sasl-digest-md5"]
sasl-digest-md5 = ["rsasl/unstable_custom_mechanism", "md5", "linkme", "hex"]
sasl-gssapi = ["rsasl/gssapi"]
Expand Down Expand Up @@ -48,6 +48,7 @@ md5 = { version = "0.7.0", optional = true }
hex = { version = "0.4.3", optional = true }
linkme = { version = "0.3", optional = true }
async-io = "2.3.2"
async-fs = { version = "2.1.2", optional = true }
futures = "0.3.30"
async-net = "2.0.0"
futures-rustls = { version = "0.26.0", optional = true }
Expand All @@ -67,6 +68,7 @@ tempfile = "3.6.0"
rcgen = { version = "0.12.1", features = ["default", "x509-parser"] }
serial_test = "3.0.0"
asyncs = { version = "0.3.0", features = ["test"] }
smol = "2.0.2"
blocking = "1.6.0"

[package.metadata.cargo-all-features]
Expand All @@ -78,3 +80,8 @@ all-features = true
[profile.dev]
# Need this for linkme crate to work for spawns in macOS
lto = "thin"

[[example]]
name = "tls_file_based"
path = "examples/tls_file_based.rs"
required-features = ["tls", "smol"]
105 changes: 105 additions & 0 deletions examples/tls_file_based.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
use std::env;
use std::io::{self, Write};
use std::path::PathBuf;
use std::time::Duration;

use zookeeper_client::Error::NodeExists;
use zookeeper_client::{Acls, Client, CreateMode, TlsOptions};

fn main() -> Result<(), Box<dyn std::error::Error>> {
env_logger::init();
smol::block_on(run()).unwrap_or_else(|e| {
eprintln!("Error: {}", e);
std::process::exit(1);

Check warning on line 13 in examples/tls_file_based.rs

View check run for this annotation

Codecov / codecov/patch

examples/tls_file_based.rs#L9-L13

Added lines #L9 - L13 were not covered by tests
});
Ok(())
}

Check warning on line 16 in examples/tls_file_based.rs

View check run for this annotation

Codecov / codecov/patch

examples/tls_file_based.rs#L15-L16

Added lines #L15 - L16 were not covered by tests

async fn run() -> Result<(), Box<dyn std::error::Error>> {
let connect_string = env::var("ZK_CONNECT_STRING").unwrap_or_else(|_| "tcp+tls://localhost:2281".to_string());
let ca_cert = PathBuf::from(env::var("ZK_CA_CERT").expect("ZK_CA_CERT environment variable is required"));
let client_cert =
PathBuf::from(env::var("ZK_CLIENT_CERT").expect("ZK_CLIENT_CERT environment variable is required"));
let client_key = PathBuf::from(env::var("ZK_CLIENT_KEY").expect("ZK_CLIENT_KEY environment variable is required"));

Check warning on line 23 in examples/tls_file_based.rs

View check run for this annotation

Codecov / codecov/patch

examples/tls_file_based.rs#L18-L23

Added lines #L18 - L23 were not covered by tests

println!("Connecting to ZooKeeper with file-based TLS...");
println!("Server: {}", connect_string);
println!("CA cert: {}", ca_cert.display());
println!("Client cert: {}", client_cert.display());
println!("Client key: {}", client_key.display());

Check warning on line 29 in examples/tls_file_based.rs

View check run for this annotation

Codecov / codecov/patch

examples/tls_file_based.rs#L25-L29

Added lines #L25 - L29 were not covered by tests

let loaded_ca_cert = async_fs::read_to_string(&ca_cert).await?;
let tls_options = TlsOptions::default()
.with_pem_ca_certs(&loaded_ca_cert)?
.with_pem_identity_files(&client_cert, &client_key)
.await?;

Check warning on line 35 in examples/tls_file_based.rs

View check run for this annotation

Codecov / codecov/patch

examples/tls_file_based.rs#L31-L35

Added lines #L31 - L35 were not covered by tests

let tls_options = unsafe { tls_options.with_no_hostname_verification() };

Check warning on line 37 in examples/tls_file_based.rs

View check run for this annotation

Codecov / codecov/patch

examples/tls_file_based.rs#L37

Added line #L37 was not covered by tests

println!("WARNING: Hostname verification disabled!");

Check warning on line 39 in examples/tls_file_based.rs

View check run for this annotation

Codecov / codecov/patch

examples/tls_file_based.rs#L39

Added line #L39 was not covered by tests

let client = Client::connector()
.connection_timeout(Duration::from_secs(10))
.session_timeout(Duration::from_secs(30))
.tls(tls_options)
.secure_connect(&connect_string)
.await?;

Check warning on line 46 in examples/tls_file_based.rs

View check run for this annotation

Codecov / codecov/patch

examples/tls_file_based.rs#L41-L46

Added lines #L41 - L46 were not covered by tests

println!("Connected to ZooKeeper successfully!");

Check warning on line 48 in examples/tls_file_based.rs

View check run for this annotation

Codecov / codecov/patch

examples/tls_file_based.rs#L48

Added line #L48 was not covered by tests

let path = "/tls_example";

Check warning on line 50 in examples/tls_file_based.rs

View check run for this annotation

Codecov / codecov/patch

examples/tls_file_based.rs#L50

Added line #L50 was not covered by tests

loop {
print!("\nOptions:\ne. Edit key\nq. Quit\nEnter choice (e/q): ");
io::stdout().flush()?;

Check warning on line 54 in examples/tls_file_based.rs

View check run for this annotation

Codecov / codecov/patch

examples/tls_file_based.rs#L53-L54

Added lines #L53 - L54 were not covered by tests

let mut input = String::new();
io::stdin().read_line(&mut input)?;

Check warning on line 57 in examples/tls_file_based.rs

View check run for this annotation

Codecov / codecov/patch

examples/tls_file_based.rs#L56-L57

Added lines #L56 - L57 were not covered by tests

match input.trim() {
"e" => {
print!("Enter new data for the key: ");
io::stdout().flush()?;

Check warning on line 62 in examples/tls_file_based.rs

View check run for this annotation

Codecov / codecov/patch

examples/tls_file_based.rs#L59-L62

Added lines #L59 - L62 were not covered by tests

let mut data_input = String::new();
io::stdin().read_line(&mut data_input)?;
let data = data_input.trim().as_bytes();

Check warning on line 66 in examples/tls_file_based.rs

View check run for this annotation

Codecov / codecov/patch

examples/tls_file_based.rs#L64-L66

Added lines #L64 - L66 were not covered by tests

println!("Setting data at path: {}", path);
match client.create(path, data, &CreateMode::Ephemeral.with_acls(Acls::anyone_all())).await {
Ok(_) => println!("ZNode created successfully"),

Check warning on line 70 in examples/tls_file_based.rs

View check run for this annotation

Codecov / codecov/patch

examples/tls_file_based.rs#L68-L70

Added lines #L68 - L70 were not covered by tests
Err(NodeExists) => {
println!("ZNode already exists, updating data...");
client.set_data(path, data, None).await?;
println!("ZNode data updated successfully");

Check warning on line 74 in examples/tls_file_based.rs

View check run for this annotation

Codecov / codecov/patch

examples/tls_file_based.rs#L72-L74

Added lines #L72 - L74 were not covered by tests
},
Err(e) => {
println!("Error creating/updating ZNode: {}", e);
continue;

Check warning on line 78 in examples/tls_file_based.rs

View check run for this annotation

Codecov / codecov/patch

examples/tls_file_based.rs#L76-L78

Added lines #L76 - L78 were not covered by tests
},
}

match client.get_data(path).await {
Ok((data, _stat)) => {
println!("Current data: {}", String::from_utf8_lossy(&data));
},
Err(e) => println!("Error reading data: {}", e),

Check warning on line 86 in examples/tls_file_based.rs

View check run for this annotation

Codecov / codecov/patch

examples/tls_file_based.rs#L82-L86

Added lines #L82 - L86 were not covered by tests
}
},
"q" => {
println!("Cleaning up and exiting...");
match client.delete(path, None).await {
Ok(_) => println!("ZNode deleted successfully"),
Err(_) => println!("ZNode may not exist or already deleted"),

Check warning on line 93 in examples/tls_file_based.rs

View check run for this annotation

Codecov / codecov/patch

examples/tls_file_based.rs#L89-L93

Added lines #L89 - L93 were not covered by tests
}
break;

Check warning on line 95 in examples/tls_file_based.rs

View check run for this annotation

Codecov / codecov/patch

examples/tls_file_based.rs#L95

Added line #L95 was not covered by tests
},
_ => {
println!("Invalid choice. Please enter 'e' or 'q'.");
},

Check warning on line 99 in examples/tls_file_based.rs

View check run for this annotation

Codecov / codecov/patch

examples/tls_file_based.rs#L97-L99

Added lines #L97 - L99 were not covered by tests
}
}

println!("Example completed successfully!");
Ok(())
}

Check warning on line 105 in examples/tls_file_based.rs

View check run for this annotation

Codecov / codecov/patch

examples/tls_file_based.rs#L103-L105

Added lines #L103 - L105 were not covered by tests
41 changes: 32 additions & 9 deletions src/session/connection.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::io::{Error, ErrorKind, IoSlice, Result};
use std::pin::Pin;
#[cfg(feature = "tls")]
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;

Expand All @@ -15,18 +17,18 @@

#[cfg(feature = "tls")]
mod tls {
pub use std::sync::Arc;

pub use futures_rustls::client::TlsStream;
pub use futures_rustls::TlsConnector;
pub use rustls::pki_types::ServerName;
pub use rustls::ClientConfig;
}

#[cfg(feature = "tls")]
use tls::*;

use crate::deadline::Deadline;
use crate::endpoint::{EndpointRef, IterableEndpoints};
#[cfg(feature = "tls")]
use crate::TlsOptions;

#[derive(Debug)]
pub enum Connection {
Expand Down Expand Up @@ -170,15 +172,15 @@
#[derive(Clone)]
pub struct Connector {
#[cfg(feature = "tls")]
tls: Option<TlsConnector>,
tls_options: Option<TlsOptions>,
timeout: Duration,
}

impl Connector {
#[cfg(feature = "tls")]
#[allow(dead_code)]
pub fn new() -> Self {
Self { tls: None, timeout: Duration::from_secs(10) }
Self { tls_options: None, timeout: Duration::from_secs(10) }
}

#[cfg(not(feature = "tls"))]
Expand All @@ -187,14 +189,27 @@
}

#[cfg(feature = "tls")]
pub fn with_tls(config: ClientConfig) -> Self {
Self { tls: Some(TlsConnector::from(Arc::new(config))), timeout: Duration::from_secs(10) }
pub fn with_tls_options(tls_options: TlsOptions) -> Self {
Self { tls_options: Some(tls_options), timeout: Duration::from_secs(10) }
}

#[cfg(feature = "tls")]
async fn get_current_tls_connector(&self) -> Result<TlsConnector> {
let Some(ref tls_opts) = self.tls_options else {
return Err(Error::new(ErrorKind::InvalidInput, "no TLS configuration"));

Check warning on line 199 in src/session/connection.rs

View check run for this annotation

Codecov / codecov/patch

src/session/connection.rs#L199

Added line #L199 was not covered by tests
};
let config = tls_opts
.to_config()
.await
.map_err(|e| Error::new(ErrorKind::InvalidData, format!("TLS config creation failed: {}", e)))?;
Ok(TlsConnector::from(Arc::new(config)))
}

#[cfg(feature = "tls")]
async fn connect_tls(&self, stream: TcpStream, host: &str) -> Result<Connection> {
let tls_connector = self.get_current_tls_connector().await?;
let domain = ServerName::try_from(host).unwrap().to_owned();
let stream = self.tls.as_ref().unwrap().connect(domain, stream).await?;
let stream = tls_connector.connect(domain, stream).await?;
Ok(Connection::new_tls(stream))
}

Expand All @@ -209,7 +224,7 @@
pub async fn connect(&self, endpoint: EndpointRef<'_>, deadline: &mut Deadline) -> Result<Connection> {
if endpoint.tls {
#[cfg(feature = "tls")]
if self.tls.is_none() {
if self.tls_options.is_none() {
return Err(Error::new(ErrorKind::Unsupported, "tls not supported"));
}
#[cfg(not(feature = "tls"))]
Expand Down Expand Up @@ -288,4 +303,12 @@
let err = connector.connect(endpoint, &mut Deadline::never()).await.unwrap_err();
assert_eq!(err.kind(), ErrorKind::Unsupported);
}

#[cfg(feature = "tls")]
#[test]
fn test_with_tls_options() {
let tls_options = crate::TlsOptions::default();
let connector = Connector::with_tls_options(tls_options);
assert!(connector.tls_options.is_some());
}
}
2 changes: 1 addition & 1 deletion src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ impl Builder {
return Err(Error::BadArguments(&"connection timeout must not be negative"));
}
#[cfg(feature = "tls")]
let connector = Connector::with_tls(self.tls.unwrap_or_default().into_config()?);
let connector = Connector::with_tls_options(self.tls.unwrap_or_default());
#[cfg(not(feature = "tls"))]
let connector = Connector::new();
let (state_sender, state_receiver) = asyncs::sync::watch::channel(SessionState::Disconnected);
Expand Down
Loading
Loading