1
- use std:: { error:: Error , sync:: Arc } ;
1
+ use std:: { error:: Error , future :: Future , pin :: Pin , sync:: Arc } ;
2
2
3
3
use anyhow:: Context ;
4
+ use bytes:: Bytes ;
4
5
use http:: { header:: HOST , Request } ;
5
- use http_body_util:: BodyExt ;
6
+ use http_body_util:: { combinators:: BoxBody , BodyExt } ;
7
+ use hyper_util:: rt:: TokioExecutor ;
6
8
use spin_factor_outbound_networking:: {
7
9
config:: { allowed_hosts:: OutboundAllowedHosts , blocked_networks:: BlockedNetworks } ,
8
10
ComponentTlsClientConfigs , TlsClientConfig ,
@@ -259,7 +261,7 @@ async fn send_request_handler(
259
261
_ => ErrorCode :: ConnectionRefused ,
260
262
} ) ?;
261
263
262
- let ( mut sender, worker) = if use_tls {
264
+ let ( mut sender, worker, is_http2 ) = if use_tls {
263
265
#[ cfg( any( target_arch = "riscv64" , target_arch = "s390x" ) ) ]
264
266
{
265
267
return Err ( ErrorCode :: InternalError ( Some (
@@ -270,7 +272,11 @@ async fn send_request_handler(
270
272
#[ cfg( not( any( target_arch = "riscv64" , target_arch = "s390x" ) ) ) ]
271
273
{
272
274
use rustls:: pki_types:: ServerName ;
273
- let connector = tokio_rustls:: TlsConnector :: from ( tls_client_config. inner ( ) ) ;
275
+
276
+ let mut tls_client_config = ( * tls_client_config) . clone ( ) ;
277
+ tls_client_config. alpn_protocols = vec ! [ b"h2" . to_vec( ) , b"http/1.1" . to_vec( ) ] ;
278
+
279
+ let connector = tokio_rustls:: TlsConnector :: from ( Arc :: new ( tls_client_config) ) ;
274
280
let mut parts = authority_str. split ( ':' ) ;
275
281
let host = parts. next ( ) . unwrap_or ( & authority_str) ;
276
282
let domain = ServerName :: try_from ( host)
@@ -283,15 +289,30 @@ async fn send_request_handler(
283
289
tracing:: warn!( "tls protocol error: {e:?}" ) ;
284
290
ErrorCode :: TlsProtocolError
285
291
} ) ?;
292
+
293
+ let is_http2 = stream. get_ref ( ) . 1 . alpn_protocol ( ) == Some ( b"h2" ) ;
294
+
286
295
let stream = TokioIo :: new ( stream) ;
287
296
288
- let ( sender, conn) = timeout (
289
- connect_timeout,
290
- hyper:: client:: conn:: http1:: handshake ( stream) ,
291
- )
292
- . await
293
- . map_err ( |_| ErrorCode :: ConnectionTimeout ) ?
294
- . map_err ( hyper_request_error) ?;
297
+ let ( sender, conn) = if is_http2 {
298
+ timeout (
299
+ connect_timeout,
300
+ hyper:: client:: conn:: http2:: handshake ( TokioExecutor :: default ( ) , stream) ,
301
+ )
302
+ . await
303
+ . map_err ( |_| ErrorCode :: ConnectionTimeout ) ?
304
+ . map_err ( hyper_request_error)
305
+ . map ( |( sender, conn) | ( HttpSender :: Http2 ( sender) , HttpConn :: Http2 ( conn) ) ) ?
306
+ } else {
307
+ timeout (
308
+ connect_timeout,
309
+ hyper:: client:: conn:: http1:: handshake ( stream) ,
310
+ )
311
+ . await
312
+ . map_err ( |_| ErrorCode :: ConnectionTimeout ) ?
313
+ . map_err ( hyper_request_error)
314
+ . map ( |( sender, conn) | ( HttpSender :: Http1 ( sender) , HttpConn :: Http1 ( conn) ) ) ?
315
+ } ;
295
316
296
317
let worker = wasmtime_wasi:: runtime:: spawn ( async move {
297
318
match conn. await {
@@ -302,18 +323,37 @@ async fn send_request_handler(
302
323
}
303
324
} ) ;
304
325
305
- ( sender, worker)
326
+ ( sender, worker, is_http2 )
306
327
}
307
328
} else {
308
329
let tcp_stream = TokioIo :: new ( tcp_stream) ;
309
- let ( sender, conn) = timeout (
310
- connect_timeout,
311
- // TODO: we should plumb the builder through the http context, and use it here
312
- hyper:: client:: conn:: http1:: handshake ( tcp_stream) ,
313
- )
314
- . await
315
- . map_err ( |_| ErrorCode :: ConnectionTimeout ) ?
316
- . map_err ( hyper_request_error) ?;
330
+
331
+ let is_http2 = std:: env:: var_os ( "SPIN_OUTBOUND_H2C_PRIOR_KNOWLEDGE" ) . is_some_and ( |v| {
332
+ request
333
+ . uri ( )
334
+ . authority ( )
335
+ . is_some_and ( |authority| authority. as_str ( ) == v)
336
+ } ) ;
337
+
338
+ let ( sender, conn) = if is_http2 {
339
+ timeout (
340
+ connect_timeout,
341
+ hyper:: client:: conn:: http2:: handshake ( TokioExecutor :: default ( ) , tcp_stream) ,
342
+ )
343
+ . await
344
+ . map_err ( |_| ErrorCode :: ConnectionTimeout ) ?
345
+ . map_err ( hyper_request_error)
346
+ . map ( |( sender, conn) | ( HttpSender :: Http2 ( sender) , HttpConn :: Http2 ( conn) ) ) ?
347
+ } else {
348
+ timeout (
349
+ connect_timeout,
350
+ hyper:: client:: conn:: http1:: handshake ( tcp_stream) ,
351
+ )
352
+ . await
353
+ . map_err ( |_| ErrorCode :: ConnectionTimeout ) ?
354
+ . map_err ( hyper_request_error)
355
+ . map ( |( sender, conn) | ( HttpSender :: Http1 ( sender) , HttpConn :: Http1 ( conn) ) ) ?
356
+ } ;
317
357
318
358
let worker = wasmtime_wasi:: runtime:: spawn ( async move {
319
359
match conn. await {
@@ -323,22 +363,24 @@ async fn send_request_handler(
323
363
}
324
364
} ) ;
325
365
326
- ( sender, worker)
366
+ ( sender, worker, is_http2 )
327
367
} ;
328
368
329
- // at this point, the request contains the scheme and the authority, but
330
- // the http packet should only include those if addressing a proxy, so
331
- // remove them here, since SendRequest::send_request does not do it for us
332
- * request. uri_mut ( ) = http:: Uri :: builder ( )
333
- . path_and_query (
334
- request
335
- . uri ( )
336
- . path_and_query ( )
337
- . map ( |p| p. as_str ( ) )
338
- . unwrap_or ( "/" ) ,
339
- )
340
- . build ( )
341
- . expect ( "comes from valid request" ) ;
369
+ if !is_http2 {
370
+ // at this point, the request contains the scheme and the authority, but
371
+ // the http packet should only include those if addressing a proxy, so
372
+ // remove them here, since SendRequest::send_request does not do it for us
373
+ * request. uri_mut ( ) = http:: Uri :: builder ( )
374
+ . path_and_query (
375
+ request
376
+ . uri ( )
377
+ . path_and_query ( )
378
+ . map ( |p| p. as_str ( ) )
379
+ . unwrap_or ( "/" ) ,
380
+ )
381
+ . build ( )
382
+ . expect ( "comes from valid request" ) ;
383
+ }
342
384
343
385
let resp = timeout ( first_byte_timeout, sender. send_request ( request) )
344
386
. await
@@ -355,6 +397,43 @@ async fn send_request_handler(
355
397
} )
356
398
}
357
399
400
+ enum HttpSender {
401
+ Http1 ( hyper:: client:: conn:: http1:: SendRequest < BoxBody < Bytes , ErrorCode > > ) ,
402
+ Http2 ( hyper:: client:: conn:: http2:: SendRequest < BoxBody < Bytes , ErrorCode > > ) ,
403
+ }
404
+
405
+ #[ allow( clippy:: large_enum_variant) ]
406
+ enum HttpConn < T : hyper:: rt:: Read + hyper:: rt:: Write + Unpin + Send + ' static > {
407
+ Http1 ( hyper:: client:: conn:: http1:: Connection < T , BoxBody < Bytes , ErrorCode > > ) ,
408
+ Http2 ( hyper:: client:: conn:: http2:: Connection < T , BoxBody < Bytes , ErrorCode > , TokioExecutor > ) ,
409
+ }
410
+
411
+ impl < T : hyper:: rt:: Read + hyper:: rt:: Write + Unpin + Send > Future for HttpConn < T > {
412
+ type Output = Result < ( ) , hyper:: Error > ;
413
+
414
+ fn poll (
415
+ self : Pin < & mut Self > ,
416
+ cx : & mut std:: task:: Context < ' _ > ,
417
+ ) -> std:: task:: Poll < Self :: Output > {
418
+ match self . get_mut ( ) {
419
+ HttpConn :: Http1 ( conn) => Pin :: new ( conn) . poll ( cx) ,
420
+ HttpConn :: Http2 ( conn) => Pin :: new ( conn) . poll ( cx) ,
421
+ }
422
+ }
423
+ }
424
+
425
+ impl HttpSender {
426
+ async fn send_request (
427
+ & mut self ,
428
+ request : http:: Request < BoxBody < Bytes , ErrorCode > > ,
429
+ ) -> Result < http:: Response < hyper:: body:: Incoming > , hyper:: Error > {
430
+ match self {
431
+ HttpSender :: Http1 ( sender) => sender. send_request ( request) . await ,
432
+ HttpSender :: Http2 ( sender) => sender. send_request ( request) . await ,
433
+ }
434
+ }
435
+ }
436
+
358
437
/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
359
438
fn hyper_request_error ( err : hyper:: Error ) -> ErrorCode {
360
439
// If there's a source, we might be able to extract a wasi-http error from it.
0 commit comments