Skip to content

Commit 5570090

Browse files
committed
fix(rustls): Remove dependency on most rustls internal types
We only used these types for generating a ClientHello message for testing. Instead, we can manually encode a sample message based on the TLS spec. Signed-off-by: Scott Fleener <scott@buoyant.io>
1 parent a91d9db commit 5570090

File tree

1 file changed

+107
-32
lines changed
  • linkerd/app/outbound/src/tls/logical

1 file changed

+107
-32
lines changed

linkerd/app/outbound/src/tls/logical/tests.rs

Lines changed: 107 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use linkerd_app_core::{
99
use linkerd_app_test::{AsyncReadExt, AsyncWriteExt};
1010
use linkerd_proxy_client_policy::{self as client_policy, tls::sni};
1111
use parking_lot::Mutex;
12+
use std::marker::PhantomData;
1213
use std::{
1314
collections::HashMap,
1415
net::SocketAddr,
@@ -17,7 +18,9 @@ use std::{
1718
time::Duration,
1819
};
1920
use tokio::sync::watch;
21+
use tokio_rustls::rustls::internal::msgs::codec::{Codec, Reader};
2022
use tokio_rustls::rustls::pki_types::DnsName;
23+
use tokio_rustls::rustls::InvalidMessage;
2124

2225
mod basic;
2326

@@ -170,44 +173,57 @@ fn sni_route(backend: client_policy::Backend, sni: sni::MatchSni) -> client_poli
170173
// generates a sample ClientHello TLS message for testing
171174
fn generate_client_hello(sni: &str) -> Vec<u8> {
172175
use tokio_rustls::rustls::{
173-
internal::msgs::{
174-
base::Payload,
175-
codec::{Codec, Reader},
176-
enums::Compression,
177-
handshake::{
178-
ClientExtension, ClientHelloPayload, HandshakeMessagePayload, HandshakePayload,
179-
Random, ServerName, SessionId,
180-
},
181-
message::{MessagePayload, PlainMessage},
182-
},
183-
CipherSuite, ContentType, HandshakeType, ProtocolVersion,
176+
internal::msgs::{base::Payload, codec::Codec, message::PlainMessage},
177+
ContentType, ProtocolVersion,
184178
};
185179

186180
let sni = DnsName::try_from(sni.to_string()).unwrap();
187181
let sni = trim_hostname_trailing_dot_for_sni(&sni);
188182

189-
let mut server_name_bytes = vec![];
190-
0u8.encode(&mut server_name_bytes); // encode the type first
191-
(sni.as_ref().len() as u16).encode(&mut server_name_bytes); // then the length as u16
192-
server_name_bytes.extend_from_slice(sni.as_ref().as_bytes()); // then the server name itself
193-
194-
let server_name =
195-
ServerName::read(&mut Reader::init(&server_name_bytes)).expect("Server name is valid");
196-
197-
let hs_payload = HandshakeMessagePayload {
198-
typ: HandshakeType::ClientHello,
199-
payload: HandshakePayload::ClientHello(ClientHelloPayload {
200-
client_version: ProtocolVersion::TLSv1_2,
201-
random: Random::from([0; 32]),
202-
session_id: SessionId::read(&mut Reader::init(&[0])).unwrap(),
203-
cipher_suites: vec![CipherSuite::TLS_NULL_WITH_NULL_NULL],
204-
compression_methods: vec![Compression::Null],
205-
extensions: vec![ClientExtension::ServerName(vec![server_name])],
206-
}),
207-
};
183+
// rustls has internal-only types that can encode a ClientHello, but they are mostly
184+
// inaccessible and an unstable part of the public API anyway. Manually encode one here for
185+
// testing only instead.
186+
187+
let mut hs_payload_bytes = vec![];
188+
1u8.encode(&mut hs_payload_bytes); // client hello ID
189+
190+
let mut client_hello_body = {
191+
let mut payload = LengthPayload::<U24>::empty();
192+
193+
payload.buf.extend_from_slice(&[0x03, 0x03]); // client version, TLSv1.2
194+
195+
payload.buf.extend_from_slice(&[0u8; 32]); // random
196+
197+
0u8.encode(&mut payload.buf); // session ID
198+
199+
LengthPayload::<u16>::from_slice(&[0x00, 0x00] /* TLS_NULL_WITH_NULL_NULL */)
200+
.encode(&mut payload.buf);
201+
202+
LengthPayload::<u8>::from_slice(&[0x00] /* no compression */).encode(&mut payload.buf);
208203

209-
let mut hs_payload_bytes = Vec::default();
210-
MessagePayload::handshake(hs_payload).encode(&mut hs_payload_bytes);
204+
let mut extensions = {
205+
let mut payload = LengthPayload::<u16>::empty();
206+
0u16.encode(&mut payload.buf); // server name extension ID
207+
208+
let server_name_extension = {
209+
let mut payload = LengthPayload::<u16>::empty();
210+
let server_name = {
211+
let mut payload = LengthPayload::<u16>::empty();
212+
0u8.encode(&mut payload.buf); // DNS hostname ID
213+
LengthPayload::<u16>::from_slice(sni.as_ref().as_bytes())
214+
.encode(&mut payload.buf);
215+
payload
216+
};
217+
server_name.encode(&mut payload.buf);
218+
payload
219+
};
220+
server_name_extension.encode(&mut payload.buf);
221+
payload
222+
};
223+
extensions.encode(&mut payload.buf);
224+
payload
225+
};
226+
client_hello_body.encode(&mut hs_payload_bytes);
211227

212228
let message = PlainMessage {
213229
typ: ContentType::Handshake,
@@ -218,6 +234,65 @@ fn generate_client_hello(sni: &str) -> Vec<u8> {
218234
message.into_unencrypted_opaque().encode()
219235
}
220236

237+
#[derive(Debug)]
238+
struct LengthPayload<T> {
239+
buf: Vec<u8>,
240+
_boo: PhantomData<fn() -> T>,
241+
}
242+
243+
impl<T> LengthPayload<T> {
244+
fn empty() -> Self {
245+
Self {
246+
buf: vec![],
247+
_boo: PhantomData,
248+
}
249+
}
250+
251+
fn from_slice(s: &[u8]) -> Self {
252+
Self {
253+
buf: s.to_vec(),
254+
_boo: PhantomData,
255+
}
256+
}
257+
}
258+
259+
impl Codec<'_> for LengthPayload<u8> {
260+
fn encode(&self, bytes: &mut Vec<u8>) {
261+
(self.buf.len() as u8).encode(bytes);
262+
bytes.extend_from_slice(&self.buf);
263+
}
264+
265+
fn read(_: &mut Reader<'_>) -> std::result::Result<Self, InvalidMessage> {
266+
unimplemented!()
267+
}
268+
}
269+
270+
impl Codec<'_> for LengthPayload<u16> {
271+
fn encode(&self, bytes: &mut Vec<u8>) {
272+
(self.buf.len() as u16).encode(bytes);
273+
bytes.extend_from_slice(&self.buf);
274+
}
275+
276+
fn read(_: &mut Reader<'_>) -> std::result::Result<Self, InvalidMessage> {
277+
unimplemented!()
278+
}
279+
}
280+
281+
#[derive(Debug)]
282+
struct U24;
283+
284+
impl Codec<'_> for LengthPayload<U24> {
285+
fn encode(&self, bytes: &mut Vec<u8>) {
286+
let len = self.buf.len() as u32;
287+
bytes.extend_from_slice(&len.to_be_bytes()[1..]);
288+
bytes.extend_from_slice(&self.buf);
289+
}
290+
291+
fn read(_: &mut Reader<'_>) -> std::result::Result<Self, InvalidMessage> {
292+
unimplemented!()
293+
}
294+
}
295+
221296
fn trim_hostname_trailing_dot_for_sni(dns_name: &DnsName<'_>) -> DnsName<'static> {
222297
let dns_name_str = dns_name.as_ref();
223298

0 commit comments

Comments
 (0)