Skip to content

Improve TLS certificate loading, handling and validation #478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion lib/src/component/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,12 @@ impl From<error::Error> for types::Error {
| Error::InvalidAlpnRepsonse { .. }
| Error::DeviceDetectionError(_)
| Error::Again
| Error::SharedMemory => types::Error::GenericError,
| Error::SharedMemory
| Error::TlsNoCertsAdded
| Error::TlsNoCAAvailable
| Error::TlsNoValidCACerts
| Error::TlsInvalidHost
| Error::TlsCertificateValidationFailed => types::Error::GenericError,
}
}
}
17 changes: 16 additions & 1 deletion lib/src/config/backends/client_cert_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ pub enum ClientCertError {
InvalidToml,
#[error("No certificates found in client cert definition")]
NoCertsFound,
#[error("Invalid certificate data: {0}")]
InvalidCertificateData(String),
#[error("Expected a string value for key {0}, got something else")]
InvalidTomlData(&'static str),
}
Expand All @@ -42,14 +44,27 @@ impl ClientCertInfo {

for item in cert_info.into_iter().chain(key_info) {
match item {
rustls_pemfile::Item::X509Certificate(x) => certificates.push(Certificate(x)),
rustls_pemfile::Item::X509Certificate(x) => {
// Basic validation of certificate data
if x.is_empty() {
return Err(ClientCertError::InvalidCertificateData(
"Empty certificate data".to_string(),
));
}
certificates.push(Certificate(x))
}
rustls_pemfile::Item::RSAKey(x) => keys.push(PrivateKey(x)),
rustls_pemfile::Item::PKCS8Key(x) => keys.push(PrivateKey(x)),
rustls_pemfile::Item::ECKey(x) => keys.push(PrivateKey(x)),
_ => {}
}
}

// Ensure certificates were found
if certificates.is_empty() {
return Err(ClientCertError::NoCertsFound);
}

let key = if keys.is_empty() {
return Err(ClientCertError::NoKeysFound);
} else if keys.len() > 1 {
Expand Down
20 changes: 20 additions & 0 deletions lib/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,21 @@ pub enum Error {
#[error("Could not load native certificates: {0}")]
BadCerts(std::io::Error),

#[error("No valid CA certificates could be added")]
TlsNoCertsAdded,

#[error("No CA certificates available")]
TlsNoCAAvailable,

#[error("No valid CA certificates found in provided certificate bundle")]
TlsNoValidCACerts,

#[error("Invalid or missing host for TLS connection")]
TlsInvalidHost,

#[error("TLS certificate validation failed")]
TlsCertificateValidationFailed,

#[error("Could not generate new backend name from '{0}'")]
BackendNameRegistryError(String),

Expand Down Expand Up @@ -197,6 +212,11 @@ impl Error {
Error::AbiVersionMismatch
| Error::BackendUrl(_)
| Error::BadCerts(_)
| Error::TlsNoCertsAdded
| Error::TlsNoCAAvailable
| Error::TlsNoValidCACerts
| Error::TlsInvalidHost
| Error::TlsCertificateValidationFailed
| Error::DownstreamRequestError(_)
| Error::DownstreamRespSending
| Error::FastlyConfig(_)
Expand Down
66 changes: 57 additions & 9 deletions lib/src/upstream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,24 @@ impl TlsConfig {
let mut roots = rustls::RootCertStore::empty();
match rustls_native_certs::load_native_certs() {
Ok(certs) => {
let mut cert_added = false;
for cert in certs {
if let Err(e) = roots.add(&rustls::Certificate(cert.0)) {
warn!("failed to load certificate: {e}");
match roots.add(&rustls::Certificate(cert.0)) {
Ok(_) => cert_added = true,
Err(e) => {
// Log but continue trying other certs
warn!("failed to load certificate: {e}");
}
}
}
if !cert_added {
return Err(Error::TlsNoCertsAdded);
}
}
Err(err) => return Err(Error::BadCerts(err)),
}
if roots.is_empty() {
warn!("no CA certificates available");
return Err(Error::TlsNoCAAvailable);
}

let partial_config = rustls::ClientConfig::builder().with_safe_defaults();
Expand Down Expand Up @@ -143,6 +151,11 @@ impl hyper::service::Service<Uri> for BackendConnector {
ignored
);
}
if added == 0 && !self.backend.ca_certs.is_empty() {
return Box::pin(std::future::ready(Err(
Box::new(Error::TlsNoValidCACerts).into()
)));
}
let config = if self.backend.ca_certs.is_empty() {
config
.partial_config
Expand All @@ -153,7 +166,12 @@ impl hyper::service::Service<Uri> for BackendConnector {
};

Box::pin(async move {
let tcp = connect_fut.await.map_err(Box::new)?;
let tcp = connect_fut.await.map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::Other,
format!("TCP connection error: {}", e),
)
})?;

let remote_addr = tcp.peer_addr()?;
let metadata = ConnMetadata {
Expand All @@ -163,7 +181,15 @@ impl hyper::service::Service<Uri> for BackendConnector {

let conn = if backend.uri.scheme_str() == Some("https") {
let mut config = if let Some(certed_key) = &backend.client_cert {
config.with_client_auth_cert(certed_key.certs(), certed_key.key())?
config
.with_client_auth_cert(certed_key.certs(), certed_key.key())
.map_err(|_| {
Error::InvalidClientCert(
crate::config::ClientCertError::InvalidCertificateData(
"Client certificate validation failed".to_string(),
),
)
})?
} else {
config.with_no_client_auth()
};
Expand All @@ -177,10 +203,32 @@ impl hyper::service::Service<Uri> for BackendConnector {
.cert_host
.as_deref()
.or_else(|| backend.uri.host())
.unwrap_or_default();
let dnsname = ServerName::try_from(cert_host).map_err(Box::new)?;

let tls = connector.connect(dnsname, tcp).await.map_err(Box::new)?;
.ok_or(Error::TlsInvalidHost)?;

let dnsname = ServerName::try_from(cert_host).map_err(|_| {
let err_msg = format!("Invalid DNS name: {}", cert_host);
tracing::error!("{}", err_msg);
Error::TlsInvalidHost
})?;

// Connect with proper validation
let tls = connector
.connect(dnsname, tcp)
.await
.inspect_err(|e| {
// Log detailed error information for certificate issues
tracing::error!("TLS certificate validation failed: {}", e);
})
.map_err(|e| {
if e.to_string().contains("certificate validation failed") {
Error::TlsCertificateValidationFailed
} else {
Error::IoError(std::io::Error::new(
std::io::ErrorKind::Other,
format!("TLS connection error: {}", e),
))
}
})?;

if backend.grpc {
let (_, tls_state) = tls.get_ref();
Expand Down
Loading