@@ -67,7 +67,7 @@ use tonic::{
67
67
body:: BoxBody ,
68
68
client:: GrpcService ,
69
69
codegen:: InterceptedService ,
70
- metadata:: { MetadataKey , MetadataValue } ,
70
+ metadata:: { MetadataKey , MetadataMap , MetadataValue } ,
71
71
service:: Interceptor ,
72
72
transport:: { Certificate , Channel , Endpoint , Identity } ,
73
73
Code , Status ,
@@ -133,6 +133,15 @@ pub struct ClientOptions {
133
133
/// If set (which it is by default), HTTP2 gRPC keep alive will be enabled.
134
134
#[ builder( default = "Some(ClientKeepAliveConfig::default())" ) ]
135
135
pub keep_alive : Option < ClientKeepAliveConfig > ,
136
+
137
+ /// HTTP headers to include on every RPC call.
138
+ #[ builder( default ) ]
139
+ pub headers : Option < HashMap < String , String > > ,
140
+
141
+ /// API key which is set as the "Authorization" header with "Bearer " prepended. This will only
142
+ /// be applied if the headers don't already have an "Authorization" header.
143
+ #[ builder( default ) ]
144
+ pub api_key : Option < String > ,
136
145
}
137
146
138
147
/// Configuration options for TLS
@@ -279,7 +288,7 @@ pub enum ClientInitError {
279
288
pub struct ConfiguredClient < C > {
280
289
client : C ,
281
290
options : Arc < ClientOptions > ,
282
- headers : Arc < RwLock < HashMap < String , String > > > ,
291
+ headers : Arc < RwLock < ClientHeaders > > ,
283
292
/// Capabilities as read from the `get_system_info` RPC call made on client connection
284
293
capabilities : Option < get_system_info_response:: Capabilities > ,
285
294
workers : Arc < SlotManager > ,
@@ -288,8 +297,12 @@ pub struct ConfiguredClient<C> {
288
297
impl < C > ConfiguredClient < C > {
289
298
/// Set HTTP request headers overwriting previous headers
290
299
pub fn set_headers ( & self , headers : HashMap < String , String > ) {
291
- let mut guard = self . headers . write ( ) ;
292
- * guard = headers;
300
+ self . headers . write ( ) . user_headers = headers;
301
+ }
302
+
303
+ /// Set API key, overwriting previous
304
+ pub fn set_api_key ( & self , api_key : Option < String > ) {
305
+ self . headers . write ( ) . api_key = api_key;
293
306
}
294
307
295
308
/// Returns the options the client is configured with
@@ -309,6 +322,34 @@ impl<C> ConfiguredClient<C> {
309
322
}
310
323
}
311
324
325
+ #[ derive( Debug ) ]
326
+ struct ClientHeaders {
327
+ user_headers : HashMap < String , String > ,
328
+ api_key : Option < String > ,
329
+ }
330
+
331
+ impl ClientHeaders {
332
+ fn apply_to_metadata ( & self , metadata : & mut MetadataMap ) {
333
+ for ( key, val) in self . user_headers . iter ( ) {
334
+ // Only if not already present
335
+ if !metadata. contains_key ( key) {
336
+ // Ignore invalid keys/values
337
+ if let ( Ok ( key) , Ok ( val) ) = ( MetadataKey :: from_str ( key) , val. parse ( ) ) {
338
+ metadata. insert ( key, val) ;
339
+ }
340
+ }
341
+ }
342
+ if let Some ( api_key) = & self . api_key {
343
+ // Only if not already present
344
+ if !metadata. contains_key ( "authorization" ) {
345
+ if let Ok ( val) = format ! ( "Bearer {}" , api_key) . parse ( ) {
346
+ metadata. insert ( "authorization" , val) ;
347
+ }
348
+ }
349
+ }
350
+ }
351
+ }
352
+
312
353
// The configured client is effectively a "smart" (dumb) pointer
313
354
impl < C > Deref for ConfiguredClient < C > {
314
355
type Target = C ;
@@ -331,12 +372,8 @@ impl ClientOptions {
331
372
& self ,
332
373
namespace : impl Into < String > ,
333
374
metrics_meter : Option < TemporalMeter > ,
334
- headers : Option < Arc < RwLock < HashMap < String , String > > > > ,
335
375
) -> Result < RetryClient < Client > , ClientInitError > {
336
- let client = self
337
- . connect_no_namespace ( metrics_meter, headers)
338
- . await ?
339
- . into_inner ( ) ;
376
+ let client = self . connect_no_namespace ( metrics_meter) . await ?. into_inner ( ) ;
340
377
let client = Client :: new ( client, namespace. into ( ) ) ;
341
378
let retry_client = RetryClient :: new ( client, self . retry_config . clone ( ) ) ;
342
379
Ok ( retry_client)
@@ -349,7 +386,6 @@ impl ClientOptions {
349
386
pub async fn connect_no_namespace (
350
387
& self ,
351
388
metrics_meter : Option < TemporalMeter > ,
352
- headers : Option < Arc < RwLock < HashMap < String , String > > > > ,
353
389
) -> Result < RetryClient < ConfiguredClient < TemporalServiceClientWithMetrics > > , ClientInitError >
354
390
{
355
391
let channel = Channel :: from_shared ( self . target_url . to_string ( ) ) ?;
@@ -374,7 +410,10 @@ impl ClientOptions {
374
410
metrics : metrics_meter. clone ( ) . map ( MetricsContext :: new) ,
375
411
} )
376
412
. service ( channel) ;
377
- let headers = headers. unwrap_or_default ( ) ;
413
+ let headers = Arc :: new ( RwLock :: new ( ClientHeaders {
414
+ user_headers : self . headers . clone ( ) . unwrap_or_default ( ) ,
415
+ api_key : self . api_key . clone ( ) ,
416
+ } ) ) ;
378
417
let interceptor = ServiceCallInterceptor {
379
418
opts : self . clone ( ) ,
380
419
headers : headers. clone ( ) ,
@@ -442,7 +481,7 @@ impl ClientOptions {
442
481
pub struct ServiceCallInterceptor {
443
482
opts : ClientOptions ,
444
483
/// Only accessed as a reader
445
- headers : Arc < RwLock < HashMap < String , String > > > ,
484
+ headers : Arc < RwLock < ClientHeaders > > ,
446
485
}
447
486
448
487
impl Interceptor for ServiceCallInterceptor {
@@ -468,16 +507,7 @@ impl Interceptor for ServiceCallInterceptor {
468
507
. unwrap_or_else ( |_| MetadataValue :: from_static ( "" ) ) ,
469
508
) ;
470
509
}
471
- let headers = & * self . headers . read ( ) ;
472
- for ( k, v) in headers {
473
- if metadata. contains_key ( k) {
474
- // Don't overwrite per-request specified headers
475
- continue ;
476
- }
477
- if let ( Ok ( k) , Ok ( v) ) = ( MetadataKey :: from_str ( k) , v. parse ( ) ) {
478
- metadata. insert ( k, v) ;
479
- }
480
- }
510
+ self . headers . read ( ) . apply_to_metadata ( metadata) ;
481
511
if !metadata. contains_key ( "grpc-timeout" ) {
482
512
request. set_timeout ( OTHER_CALL_TIMEOUT ) ;
483
513
}
@@ -1559,7 +1589,7 @@ mod tests {
1559
1589
use super :: * ;
1560
1590
1561
1591
#[ test]
1562
- fn respects_per_call_headers ( ) {
1592
+ fn applies_headers ( ) {
1563
1593
let opts = ClientOptionsBuilder :: default ( )
1564
1594
. identity ( "enchicat" . to_string ( ) )
1565
1595
. target_url ( Url :: parse ( "https://smolkitty" ) . unwrap ( ) )
@@ -1568,16 +1598,55 @@ mod tests {
1568
1598
. build ( )
1569
1599
. unwrap ( ) ;
1570
1600
1571
- let mut static_headers = HashMap :: new ( ) ;
1572
- static_headers. insert ( "enchi" . to_string ( ) , "kitty" . to_string ( ) ) ;
1573
- let mut iceptor = ServiceCallInterceptor {
1601
+ // Initial header set
1602
+ let headers = Arc :: new ( RwLock :: new ( ClientHeaders {
1603
+ user_headers : HashMap :: new ( ) ,
1604
+ api_key : Some ( "my-api-key" . to_owned ( ) ) ,
1605
+ } ) ) ;
1606
+ headers
1607
+ . clone ( )
1608
+ . write ( )
1609
+ . user_headers
1610
+ . insert ( "my-meta-key" . to_owned ( ) , "my-meta-val" . to_owned ( ) ) ;
1611
+ let mut interceptor = ServiceCallInterceptor {
1574
1612
opts,
1575
- headers : Arc :: new ( RwLock :: new ( static_headers ) ) ,
1613
+ headers : headers . clone ( ) ,
1576
1614
} ;
1615
+
1616
+ // Confirm on metadata
1617
+ let req = interceptor. call ( tonic:: Request :: new ( ( ) ) ) . unwrap ( ) ;
1618
+ assert_eq ! ( req. metadata( ) . get( "my-meta-key" ) . unwrap( ) , "my-meta-val" ) ;
1619
+ assert_eq ! (
1620
+ req. metadata( ) . get( "authorization" ) . unwrap( ) ,
1621
+ "Bearer my-api-key"
1622
+ ) ;
1623
+
1624
+ // Overwrite at request time
1577
1625
let mut req = tonic:: Request :: new ( ( ) ) ;
1578
- req. metadata_mut ( ) . insert ( "enchi" , "cat" . parse ( ) . unwrap ( ) ) ;
1579
- let next_req = iceptor. call ( req) . unwrap ( ) ;
1580
- assert_eq ! ( next_req. metadata( ) . get( "enchi" ) . unwrap( ) , "cat" ) ;
1626
+ req. metadata_mut ( )
1627
+ . insert ( "my-meta-key" , "my-meta-val2" . parse ( ) . unwrap ( ) ) ;
1628
+ req. metadata_mut ( )
1629
+ . insert ( "authorization" , "my-api-key2" . parse ( ) . unwrap ( ) ) ;
1630
+ let req = interceptor. call ( req) . unwrap ( ) ;
1631
+ assert_eq ! ( req. metadata( ) . get( "my-meta-key" ) . unwrap( ) , "my-meta-val2" ) ;
1632
+ assert_eq ! ( req. metadata( ) . get( "authorization" ) . unwrap( ) , "my-api-key2" ) ;
1633
+
1634
+ // Overwrite auth on header
1635
+ headers
1636
+ . clone ( )
1637
+ . write ( )
1638
+ . user_headers
1639
+ . insert ( "authorization" . to_owned ( ) , "my-api-key3" . to_owned ( ) ) ;
1640
+ let req = interceptor. call ( tonic:: Request :: new ( ( ) ) ) . unwrap ( ) ;
1641
+ assert_eq ! ( req. metadata( ) . get( "my-meta-key" ) . unwrap( ) , "my-meta-val" ) ;
1642
+ assert_eq ! ( req. metadata( ) . get( "authorization" ) . unwrap( ) , "my-api-key3" ) ;
1643
+
1644
+ // Remove headers and auth and confirm gone
1645
+ headers. clone ( ) . write ( ) . user_headers . clear ( ) ;
1646
+ headers. clone ( ) . write ( ) . api_key . take ( ) ;
1647
+ let req = interceptor. call ( tonic:: Request :: new ( ( ) ) ) . unwrap ( ) ;
1648
+ assert ! ( !req. metadata( ) . contains_key( "my-meta-key" ) ) ;
1649
+ assert ! ( !req. metadata( ) . contains_key( "authorization" ) ) ;
1581
1650
}
1582
1651
1583
1652
#[ test]
0 commit comments