@@ -9,6 +9,7 @@ use linkerd_app_core::{
9
9
use linkerd_app_test:: { AsyncReadExt , AsyncWriteExt } ;
10
10
use linkerd_proxy_client_policy:: { self as client_policy, tls:: sni} ;
11
11
use parking_lot:: Mutex ;
12
+ use std:: marker:: PhantomData ;
12
13
use std:: {
13
14
collections:: HashMap ,
14
15
net:: SocketAddr ,
@@ -17,7 +18,9 @@ use std::{
17
18
time:: Duration ,
18
19
} ;
19
20
use tokio:: sync:: watch;
21
+ use tokio_rustls:: rustls:: internal:: msgs:: codec:: { Codec , Reader } ;
20
22
use tokio_rustls:: rustls:: pki_types:: DnsName ;
23
+ use tokio_rustls:: rustls:: InvalidMessage ;
21
24
22
25
mod basic;
23
26
@@ -170,44 +173,57 @@ fn sni_route(backend: client_policy::Backend, sni: sni::MatchSni) -> client_poli
170
173
// generates a sample ClientHello TLS message for testing
171
174
fn generate_client_hello ( sni : & str ) -> Vec < u8 > {
172
175
use tokio_rustls:: rustls:: {
173
- internal:: msgs:: {
174
- base:: Payload ,
175
- codec:: { Codec , Reader } ,
176
- enums:: Compression ,
177
- handshake:: {
178
- ClientExtension , ClientHelloPayload , HandshakeMessagePayload , HandshakePayload ,
179
- Random , ServerName , SessionId ,
180
- } ,
181
- message:: { MessagePayload , PlainMessage } ,
182
- } ,
183
- CipherSuite , ContentType , HandshakeType , ProtocolVersion ,
176
+ internal:: msgs:: { base:: Payload , codec:: Codec , message:: PlainMessage } ,
177
+ ContentType , ProtocolVersion ,
184
178
} ;
185
179
186
180
let sni = DnsName :: try_from ( sni. to_string ( ) ) . unwrap ( ) ;
187
181
let sni = trim_hostname_trailing_dot_for_sni ( & sni) ;
188
182
189
- let mut server_name_bytes = vec ! [ ] ;
190
- 0u8 . encode ( & mut server_name_bytes) ; // encode the type first
191
- ( sni. as_ref ( ) . len ( ) as u16 ) . encode ( & mut server_name_bytes) ; // then the length as u16
192
- server_name_bytes. extend_from_slice ( sni. as_ref ( ) . as_bytes ( ) ) ; // then the server name itself
193
-
194
- let server_name =
195
- ServerName :: read ( & mut Reader :: init ( & server_name_bytes) ) . expect ( "Server name is valid" ) ;
196
-
197
- let hs_payload = HandshakeMessagePayload {
198
- typ : HandshakeType :: ClientHello ,
199
- payload : HandshakePayload :: ClientHello ( ClientHelloPayload {
200
- client_version : ProtocolVersion :: TLSv1_2 ,
201
- random : Random :: from ( [ 0 ; 32 ] ) ,
202
- session_id : SessionId :: read ( & mut Reader :: init ( & [ 0 ] ) ) . unwrap ( ) ,
203
- cipher_suites : vec ! [ CipherSuite :: TLS_NULL_WITH_NULL_NULL ] ,
204
- compression_methods : vec ! [ Compression :: Null ] ,
205
- extensions : vec ! [ ClientExtension :: ServerName ( vec![ server_name] ) ] ,
206
- } ) ,
207
- } ;
183
+ // rustls has internal-only types that can encode a ClientHello, but they are mostly
184
+ // inaccessible and an unstable part of the public API anyway. Manually encode one here for
185
+ // testing only instead.
186
+
187
+ let mut hs_payload_bytes = vec ! [ ] ;
188
+ 1u8 . encode ( & mut hs_payload_bytes) ; // client hello ID
189
+
190
+ let mut client_hello_body = {
191
+ let mut payload = LengthPayload :: < U24 > :: empty ( ) ;
192
+
193
+ payload. buf . extend_from_slice ( & [ 0x03 , 0x03 ] ) ; // client version, TLSv1.2
194
+
195
+ payload. buf . extend_from_slice ( & [ 0u8 ; 32 ] ) ; // random
196
+
197
+ 0u8 . encode ( & mut payload. buf ) ; // session ID
198
+
199
+ LengthPayload :: < u16 > :: from_slice ( & [ 0x00 , 0x00 ] /* TLS_NULL_WITH_NULL_NULL */ )
200
+ . encode ( & mut payload. buf ) ;
201
+
202
+ LengthPayload :: < u8 > :: from_slice ( & [ 0x00 ] /* no compression */ ) . encode ( & mut payload. buf ) ;
208
203
209
- let mut hs_payload_bytes = Vec :: default ( ) ;
210
- MessagePayload :: handshake ( hs_payload) . encode ( & mut hs_payload_bytes) ;
204
+ let mut extensions = {
205
+ let mut payload = LengthPayload :: < u16 > :: empty ( ) ;
206
+ 0u16 . encode ( & mut payload. buf ) ; // server name extension ID
207
+
208
+ let server_name_extension = {
209
+ let mut payload = LengthPayload :: < u16 > :: empty ( ) ;
210
+ let server_name = {
211
+ let mut payload = LengthPayload :: < u16 > :: empty ( ) ;
212
+ 0u8 . encode ( & mut payload. buf ) ; // DNS hostname ID
213
+ LengthPayload :: < u16 > :: from_slice ( sni. as_ref ( ) . as_bytes ( ) )
214
+ . encode ( & mut payload. buf ) ;
215
+ payload
216
+ } ;
217
+ server_name. encode ( & mut payload. buf ) ;
218
+ payload
219
+ } ;
220
+ server_name_extension. encode ( & mut payload. buf ) ;
221
+ payload
222
+ } ;
223
+ extensions. encode ( & mut payload. buf ) ;
224
+ payload
225
+ } ;
226
+ client_hello_body. encode ( & mut hs_payload_bytes) ;
211
227
212
228
let message = PlainMessage {
213
229
typ : ContentType :: Handshake ,
@@ -218,6 +234,65 @@ fn generate_client_hello(sni: &str) -> Vec<u8> {
218
234
message. into_unencrypted_opaque ( ) . encode ( )
219
235
}
220
236
237
+ #[ derive( Debug ) ]
238
+ struct LengthPayload < T > {
239
+ buf : Vec < u8 > ,
240
+ _boo : PhantomData < fn ( ) -> T > ,
241
+ }
242
+
243
+ impl < T > LengthPayload < T > {
244
+ fn empty ( ) -> Self {
245
+ Self {
246
+ buf : vec ! [ ] ,
247
+ _boo : PhantomData ,
248
+ }
249
+ }
250
+
251
+ fn from_slice ( s : & [ u8 ] ) -> Self {
252
+ Self {
253
+ buf : s. to_vec ( ) ,
254
+ _boo : PhantomData ,
255
+ }
256
+ }
257
+ }
258
+
259
+ impl Codec < ' _ > for LengthPayload < u8 > {
260
+ fn encode ( & self , bytes : & mut Vec < u8 > ) {
261
+ ( self . buf . len ( ) as u8 ) . encode ( bytes) ;
262
+ bytes. extend_from_slice ( & self . buf ) ;
263
+ }
264
+
265
+ fn read ( _: & mut Reader < ' _ > ) -> std:: result:: Result < Self , InvalidMessage > {
266
+ unimplemented ! ( )
267
+ }
268
+ }
269
+
270
+ impl Codec < ' _ > for LengthPayload < u16 > {
271
+ fn encode ( & self , bytes : & mut Vec < u8 > ) {
272
+ ( self . buf . len ( ) as u16 ) . encode ( bytes) ;
273
+ bytes. extend_from_slice ( & self . buf ) ;
274
+ }
275
+
276
+ fn read ( _: & mut Reader < ' _ > ) -> std:: result:: Result < Self , InvalidMessage > {
277
+ unimplemented ! ( )
278
+ }
279
+ }
280
+
281
+ #[ derive( Debug ) ]
282
+ struct U24 ;
283
+
284
+ impl Codec < ' _ > for LengthPayload < U24 > {
285
+ fn encode ( & self , bytes : & mut Vec < u8 > ) {
286
+ let len = self . buf . len ( ) as u32 ;
287
+ bytes. extend_from_slice ( & len. to_be_bytes ( ) [ 1 ..] ) ;
288
+ bytes. extend_from_slice ( & self . buf ) ;
289
+ }
290
+
291
+ fn read ( _: & mut Reader < ' _ > ) -> std:: result:: Result < Self , InvalidMessage > {
292
+ unimplemented ! ( )
293
+ }
294
+ }
295
+
221
296
fn trim_hostname_trailing_dot_for_sni ( dns_name : & DnsName < ' _ > ) -> DnsName < ' static > {
222
297
let dns_name_str = dns_name. as_ref ( ) ;
223
298
0 commit comments