diff --git a/Cargo.lock b/Cargo.lock index 602d5ff05e..d122bbf0a7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1450,7 +1450,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" dependencies = [ "cfg-if", - "windows-targets 0.52.6", + "windows-targets 0.48.5", ] [[package]] @@ -3754,9 +3754,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.26" +version = "0.23.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df51b5869f3a441595eac5e8ff14d486ff285f7b8c0df8770e49c3b56351f0f0" +checksum = "7160e3e10bf4535308537f3c4e1641468cd0e485175d6163087c0393c7d46643" dependencies = [ "aws-lc-rs", "log", @@ -3779,15 +3779,18 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" +checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" +dependencies = [ + "zeroize", +] [[package]] name = "rustls-webpki" -version = "0.103.1" +version = "0.103.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fef8b8769aaccf73098557a87cd1816b4f9c7c16811c9c77142aa695c16f2c03" +checksum = "e4a72fe2bcf7a6ac6fd7d0b9e5cb68aeb7d4c0a0271730218b3e92d43b4eb435" dependencies = [ "aws-lc-rs", "ring", diff --git a/linkerd/app/outbound/src/tls/logical/tests.rs b/linkerd/app/outbound/src/tls/logical/tests.rs index 8802841a0a..dc7a677b7b 100644 --- a/linkerd/app/outbound/src/tls/logical/tests.rs +++ b/linkerd/app/outbound/src/tls/logical/tests.rs @@ -11,13 +11,18 @@ use linkerd_proxy_client_policy::{self as client_policy, tls::sni}; use parking_lot::Mutex; use std::{ collections::HashMap, + marker::PhantomData, net::SocketAddr, sync::Arc, task::{Context, Poll}, time::Duration, }; use tokio::sync::watch; -use tokio_rustls::rustls::pki_types::DnsName; +use tokio_rustls::rustls::{ + internal::msgs::codec::{Codec, Reader}, + pki_types::DnsName, + InvalidMessage, +}; mod basic; @@ -170,44 +175,57 @@ fn sni_route(backend: client_policy::Backend, sni: sni::MatchSni) -> client_poli // generates a sample ClientHello TLS message for testing fn generate_client_hello(sni: &str) -> Vec { use tokio_rustls::rustls::{ - internal::msgs::{ - base::Payload, - codec::{Codec, Reader}, - enums::Compression, - handshake::{ - ClientExtension, ClientHelloPayload, HandshakeMessagePayload, HandshakePayload, - Random, ServerName, SessionId, - }, - message::{MessagePayload, PlainMessage}, - }, - CipherSuite, ContentType, HandshakeType, ProtocolVersion, + internal::msgs::{base::Payload, codec::Codec, message::PlainMessage}, + ContentType, ProtocolVersion, }; let sni = DnsName::try_from(sni.to_string()).unwrap(); let sni = trim_hostname_trailing_dot_for_sni(&sni); - let mut server_name_bytes = vec![]; - 0u8.encode(&mut server_name_bytes); // encode the type first - (sni.as_ref().len() as u16).encode(&mut server_name_bytes); // then the length as u16 - server_name_bytes.extend_from_slice(sni.as_ref().as_bytes()); // then the server name itself - - let server_name = - ServerName::read(&mut Reader::init(&server_name_bytes)).expect("Server name is valid"); - - let hs_payload = HandshakeMessagePayload { - typ: HandshakeType::ClientHello, - payload: HandshakePayload::ClientHello(ClientHelloPayload { - client_version: ProtocolVersion::TLSv1_2, - random: Random::from([0; 32]), - session_id: SessionId::read(&mut Reader::init(&[0])).unwrap(), - cipher_suites: vec![CipherSuite::TLS_NULL_WITH_NULL_NULL], - compression_methods: vec![Compression::Null], - extensions: vec![ClientExtension::ServerName(vec![server_name])], - }), - }; + // rustls has internal-only types that can encode a ClientHello, but they are mostly + // inaccessible and an unstable part of the public API anyway. Manually encode one here for + // testing only instead. + + let mut hs_payload_bytes = vec![]; + 1u8.encode(&mut hs_payload_bytes); // client hello ID + + let client_hello_body = { + let mut payload = LengthPayload::::empty(); + + payload.buf.extend_from_slice(&[0x03, 0x03]); // client version, TLSv1.2 + + payload.buf.extend_from_slice(&[0u8; 32]); // random + + 0u8.encode(&mut payload.buf); // session ID + + LengthPayload::::from_slice(&[0x00, 0x00] /* TLS_NULL_WITH_NULL_NULL */) + .encode(&mut payload.buf); - let mut hs_payload_bytes = Vec::default(); - MessagePayload::handshake(hs_payload).encode(&mut hs_payload_bytes); + LengthPayload::::from_slice(&[0x00] /* no compression */).encode(&mut payload.buf); + + let extensions = { + let mut payload = LengthPayload::::empty(); + 0u16.encode(&mut payload.buf); // server name extension ID + + let server_name_extension = { + let mut payload = LengthPayload::::empty(); + let server_name = { + let mut payload = LengthPayload::::empty(); + 0u8.encode(&mut payload.buf); // DNS hostname ID + LengthPayload::::from_slice(sni.as_ref().as_bytes()) + .encode(&mut payload.buf); + payload + }; + server_name.encode(&mut payload.buf); + payload + }; + server_name_extension.encode(&mut payload.buf); + payload + }; + extensions.encode(&mut payload.buf); + payload + }; + client_hello_body.encode(&mut hs_payload_bytes); let message = PlainMessage { typ: ContentType::Handshake, @@ -218,6 +236,65 @@ fn generate_client_hello(sni: &str) -> Vec { message.into_unencrypted_opaque().encode() } +#[derive(Debug)] +struct LengthPayload { + buf: Vec, + _boo: PhantomData T>, +} + +impl LengthPayload { + fn empty() -> Self { + Self { + buf: vec![], + _boo: PhantomData, + } + } + + fn from_slice(s: &[u8]) -> Self { + Self { + buf: s.to_vec(), + _boo: PhantomData, + } + } +} + +impl Codec<'_> for LengthPayload { + fn encode(&self, bytes: &mut Vec) { + (self.buf.len() as u8).encode(bytes); + bytes.extend_from_slice(&self.buf); + } + + fn read(_: &mut Reader<'_>) -> std::result::Result { + unimplemented!() + } +} + +impl Codec<'_> for LengthPayload { + fn encode(&self, bytes: &mut Vec) { + (self.buf.len() as u16).encode(bytes); + bytes.extend_from_slice(&self.buf); + } + + fn read(_: &mut Reader<'_>) -> std::result::Result { + unimplemented!() + } +} + +#[derive(Debug)] +struct U24; + +impl Codec<'_> for LengthPayload { + fn encode(&self, bytes: &mut Vec) { + let len = self.buf.len() as u32; + bytes.extend_from_slice(&len.to_be_bytes()[1..]); + bytes.extend_from_slice(&self.buf); + } + + fn read(_: &mut Reader<'_>) -> std::result::Result { + unimplemented!() + } +} + fn trim_hostname_trailing_dot_for_sni(dns_name: &DnsName<'_>) -> DnsName<'static> { let dns_name_str = dns_name.as_ref(); diff --git a/linkerd/meshtls/rustls/Cargo.toml b/linkerd/meshtls/rustls/Cargo.toml index dc811d120d..7dd84f4193 100644 --- a/linkerd/meshtls/rustls/Cargo.toml +++ b/linkerd/meshtls/rustls/Cargo.toml @@ -17,7 +17,7 @@ test-util = ["linkerd-tls-test-util"] futures = { version = "0.3", default-features = false } ring = { version = "0.17", features = ["std"] } rustls-pemfile = "2.2" -rustls-webpki = { version = "0.103.1", default-features = false, features = ["std"] } +rustls-webpki = { version = "0.103.3", default-features = false, features = ["std"] } thiserror = "2" tokio = { version = "1", features = ["macros", "rt", "sync"] } tokio-rustls = { workspace = true }