Skip to content

Commit bcda5d9

Browse files
authored
Explicit API key option on client (#699)
1 parent 8da219d commit bcda5d9

File tree

12 files changed

+115
-50
lines changed

12 files changed

+115
-50
lines changed

client/src/lib.rs

Lines changed: 99 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ use tonic::{
6767
body::BoxBody,
6868
client::GrpcService,
6969
codegen::InterceptedService,
70-
metadata::{MetadataKey, MetadataValue},
70+
metadata::{MetadataKey, MetadataMap, MetadataValue},
7171
service::Interceptor,
7272
transport::{Certificate, Channel, Endpoint, Identity},
7373
Code, Status,
@@ -133,6 +133,15 @@ pub struct ClientOptions {
133133
/// If set (which it is by default), HTTP2 gRPC keep alive will be enabled.
134134
#[builder(default = "Some(ClientKeepAliveConfig::default())")]
135135
pub keep_alive: Option<ClientKeepAliveConfig>,
136+
137+
/// HTTP headers to include on every RPC call.
138+
#[builder(default)]
139+
pub headers: Option<HashMap<String, String>>,
140+
141+
/// API key which is set as the "Authorization" header with "Bearer " prepended. This will only
142+
/// be applied if the headers don't already have an "Authorization" header.
143+
#[builder(default)]
144+
pub api_key: Option<String>,
136145
}
137146

138147
/// Configuration options for TLS
@@ -279,7 +288,7 @@ pub enum ClientInitError {
279288
pub struct ConfiguredClient<C> {
280289
client: C,
281290
options: Arc<ClientOptions>,
282-
headers: Arc<RwLock<HashMap<String, String>>>,
291+
headers: Arc<RwLock<ClientHeaders>>,
283292
/// Capabilities as read from the `get_system_info` RPC call made on client connection
284293
capabilities: Option<get_system_info_response::Capabilities>,
285294
workers: Arc<SlotManager>,
@@ -288,8 +297,12 @@ pub struct ConfiguredClient<C> {
288297
impl<C> ConfiguredClient<C> {
289298
/// Set HTTP request headers overwriting previous headers
290299
pub fn set_headers(&self, headers: HashMap<String, String>) {
291-
let mut guard = self.headers.write();
292-
*guard = headers;
300+
self.headers.write().user_headers = headers;
301+
}
302+
303+
/// Set API key, overwriting previous
304+
pub fn set_api_key(&self, api_key: Option<String>) {
305+
self.headers.write().api_key = api_key;
293306
}
294307

295308
/// Returns the options the client is configured with
@@ -309,6 +322,34 @@ impl<C> ConfiguredClient<C> {
309322
}
310323
}
311324

325+
#[derive(Debug)]
326+
struct ClientHeaders {
327+
user_headers: HashMap<String, String>,
328+
api_key: Option<String>,
329+
}
330+
331+
impl ClientHeaders {
332+
fn apply_to_metadata(&self, metadata: &mut MetadataMap) {
333+
for (key, val) in self.user_headers.iter() {
334+
// Only if not already present
335+
if !metadata.contains_key(key) {
336+
// Ignore invalid keys/values
337+
if let (Ok(key), Ok(val)) = (MetadataKey::from_str(key), val.parse()) {
338+
metadata.insert(key, val);
339+
}
340+
}
341+
}
342+
if let Some(api_key) = &self.api_key {
343+
// Only if not already present
344+
if !metadata.contains_key("authorization") {
345+
if let Ok(val) = format!("Bearer {}", api_key).parse() {
346+
metadata.insert("authorization", val);
347+
}
348+
}
349+
}
350+
}
351+
}
352+
312353
// The configured client is effectively a "smart" (dumb) pointer
313354
impl<C> Deref for ConfiguredClient<C> {
314355
type Target = C;
@@ -331,12 +372,8 @@ impl ClientOptions {
331372
&self,
332373
namespace: impl Into<String>,
333374
metrics_meter: Option<TemporalMeter>,
334-
headers: Option<Arc<RwLock<HashMap<String, String>>>>,
335375
) -> Result<RetryClient<Client>, ClientInitError> {
336-
let client = self
337-
.connect_no_namespace(metrics_meter, headers)
338-
.await?
339-
.into_inner();
376+
let client = self.connect_no_namespace(metrics_meter).await?.into_inner();
340377
let client = Client::new(client, namespace.into());
341378
let retry_client = RetryClient::new(client, self.retry_config.clone());
342379
Ok(retry_client)
@@ -349,7 +386,6 @@ impl ClientOptions {
349386
pub async fn connect_no_namespace(
350387
&self,
351388
metrics_meter: Option<TemporalMeter>,
352-
headers: Option<Arc<RwLock<HashMap<String, String>>>>,
353389
) -> Result<RetryClient<ConfiguredClient<TemporalServiceClientWithMetrics>>, ClientInitError>
354390
{
355391
let channel = Channel::from_shared(self.target_url.to_string())?;
@@ -374,7 +410,10 @@ impl ClientOptions {
374410
metrics: metrics_meter.clone().map(MetricsContext::new),
375411
})
376412
.service(channel);
377-
let headers = headers.unwrap_or_default();
413+
let headers = Arc::new(RwLock::new(ClientHeaders {
414+
user_headers: self.headers.clone().unwrap_or_default(),
415+
api_key: self.api_key.clone(),
416+
}));
378417
let interceptor = ServiceCallInterceptor {
379418
opts: self.clone(),
380419
headers: headers.clone(),
@@ -442,7 +481,7 @@ impl ClientOptions {
442481
pub struct ServiceCallInterceptor {
443482
opts: ClientOptions,
444483
/// Only accessed as a reader
445-
headers: Arc<RwLock<HashMap<String, String>>>,
484+
headers: Arc<RwLock<ClientHeaders>>,
446485
}
447486

448487
impl Interceptor for ServiceCallInterceptor {
@@ -468,16 +507,7 @@ impl Interceptor for ServiceCallInterceptor {
468507
.unwrap_or_else(|_| MetadataValue::from_static("")),
469508
);
470509
}
471-
let headers = &*self.headers.read();
472-
for (k, v) in headers {
473-
if metadata.contains_key(k) {
474-
// Don't overwrite per-request specified headers
475-
continue;
476-
}
477-
if let (Ok(k), Ok(v)) = (MetadataKey::from_str(k), v.parse()) {
478-
metadata.insert(k, v);
479-
}
480-
}
510+
self.headers.read().apply_to_metadata(metadata);
481511
if !metadata.contains_key("grpc-timeout") {
482512
request.set_timeout(OTHER_CALL_TIMEOUT);
483513
}
@@ -1559,7 +1589,7 @@ mod tests {
15591589
use super::*;
15601590

15611591
#[test]
1562-
fn respects_per_call_headers() {
1592+
fn applies_headers() {
15631593
let opts = ClientOptionsBuilder::default()
15641594
.identity("enchicat".to_string())
15651595
.target_url(Url::parse("https://smolkitty").unwrap())
@@ -1568,16 +1598,55 @@ mod tests {
15681598
.build()
15691599
.unwrap();
15701600

1571-
let mut static_headers = HashMap::new();
1572-
static_headers.insert("enchi".to_string(), "kitty".to_string());
1573-
let mut iceptor = ServiceCallInterceptor {
1601+
// Initial header set
1602+
let headers = Arc::new(RwLock::new(ClientHeaders {
1603+
user_headers: HashMap::new(),
1604+
api_key: Some("my-api-key".to_owned()),
1605+
}));
1606+
headers
1607+
.clone()
1608+
.write()
1609+
.user_headers
1610+
.insert("my-meta-key".to_owned(), "my-meta-val".to_owned());
1611+
let mut interceptor = ServiceCallInterceptor {
15741612
opts,
1575-
headers: Arc::new(RwLock::new(static_headers)),
1613+
headers: headers.clone(),
15761614
};
1615+
1616+
// Confirm on metadata
1617+
let req = interceptor.call(tonic::Request::new(())).unwrap();
1618+
assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val");
1619+
assert_eq!(
1620+
req.metadata().get("authorization").unwrap(),
1621+
"Bearer my-api-key"
1622+
);
1623+
1624+
// Overwrite at request time
15771625
let mut req = tonic::Request::new(());
1578-
req.metadata_mut().insert("enchi", "cat".parse().unwrap());
1579-
let next_req = iceptor.call(req).unwrap();
1580-
assert_eq!(next_req.metadata().get("enchi").unwrap(), "cat");
1626+
req.metadata_mut()
1627+
.insert("my-meta-key", "my-meta-val2".parse().unwrap());
1628+
req.metadata_mut()
1629+
.insert("authorization", "my-api-key2".parse().unwrap());
1630+
let req = interceptor.call(req).unwrap();
1631+
assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val2");
1632+
assert_eq!(req.metadata().get("authorization").unwrap(), "my-api-key2");
1633+
1634+
// Overwrite auth on header
1635+
headers
1636+
.clone()
1637+
.write()
1638+
.user_headers
1639+
.insert("authorization".to_owned(), "my-api-key3".to_owned());
1640+
let req = interceptor.call(tonic::Request::new(())).unwrap();
1641+
assert_eq!(req.metadata().get("my-meta-key").unwrap(), "my-meta-val");
1642+
assert_eq!(req.metadata().get("authorization").unwrap(), "my-api-key3");
1643+
1644+
// Remove headers and auth and confirm gone
1645+
headers.clone().write().user_headers.clear();
1646+
headers.clone().write().api_key.take();
1647+
let req = interceptor.call(tonic::Request::new(())).unwrap();
1648+
assert!(!req.metadata().contains_key("my-meta-key"));
1649+
assert!(!req.metadata().contains_key("authorization"));
15811650
}
15821651

15831652
#[test]

client/src/raw.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,7 @@ mod tests {
10151015
#[allow(dead_code)]
10161016
async fn raw_client_retry_compiles() {
10171017
let opts = ClientOptionsBuilder::default().build().unwrap();
1018-
let raw_client = opts.connect_no_namespace(None, None).await.unwrap();
1018+
let raw_client = opts.connect_no_namespace(None).await.unwrap();
10191019
let mut retry_client = RetryClient::new(raw_client, opts.retry_config);
10201020

10211021
let list_ns_req = ListNamespacesRequest::default();

core/src/ephemeral_server/mod.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -307,11 +307,7 @@ impl EphemeralServer {
307307
.build()?;
308308
for _ in 0..50 {
309309
sleep(Duration::from_millis(100)).await;
310-
if client_options
311-
.connect_no_namespace(None, None)
312-
.await
313-
.is_ok()
314-
{
310+
if client_options.connect_no_namespace(None).await.is_ok() {
315311
return success;
316312
}
317313
}

sdk/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
1818
//! let server_options = sdk_client_options(Url::from_str("http://localhost:7233")?).build()?;
1919
//!
20-
//! let client = server_options.connect("default", None, None).await?;
20+
//! let client = server_options.connect("default", None).await?;
2121
//!
2222
//! let telemetry_options = TelemetryOptionsBuilder::default().build()?;
2323
//! let runtime = CoreRuntime::new_assume_tokio(telemetry_options)?;

test-utils/src/histfetch.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use temporal_sdk_core_test_utils::get_integ_server_options;
1111
#[tokio::main]
1212
async fn main() -> Result<(), anyhow::Error> {
1313
let gw_opts = get_integ_server_options();
14-
let client = gw_opts.connect("default", None, None).await?;
14+
let client = gw_opts.connect("default", None).await?;
1515
let wf_id = std::env::args()
1616
.nth(1)
1717
.expect("must provide workflow id as only argument");

test-utils/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ impl CoreWfStarter {
378378
.expect("Worker config must be valid");
379379
let client = Arc::new(
380380
get_integ_server_options()
381-
.connect(cfg.namespace.clone(), None, None)
381+
.connect(cfg.namespace.clone(), None)
382382
.await
383383
.expect("Must connect"),
384384
);

tests/integ_tests/client_tests.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ async fn can_use_retry_client() {
1717
#[tokio::test]
1818
async fn can_use_retry_raw_client() {
1919
let opts = get_integ_server_options();
20-
let raw_client = opts.connect_no_namespace(None, None).await.unwrap();
20+
let raw_client = opts.connect_no_namespace(None).await.unwrap();
2121
let mut retry_client = RetryClient::new(raw_client, opts.retry_config);
2222
retry_client
2323
.describe_namespace(DescribeNamespaceRequest {
@@ -31,6 +31,6 @@ async fn can_use_retry_raw_client() {
3131
#[tokio::test]
3232
async fn calls_get_system_info() {
3333
let opts = get_integ_server_options();
34-
let raw_client = opts.connect_no_namespace(None, None).await.unwrap();
34+
let raw_client = opts.connect_no_namespace(None).await.unwrap();
3535
assert!(raw_client.get_client().capabilities().is_some());
3636
}

tests/integ_tests/ephemeral_server_tests.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ async fn assert_ephemeral_server(server: &EphemeralServer) {
124124
.client_version("0.1.0".to_string())
125125
.build()
126126
.unwrap()
127-
.connect_no_namespace(None, None)
127+
.connect_no_namespace(None)
128128
.await
129129
.unwrap();
130130
let resp = client

tests/integ_tests/metrics_tests.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ async fn prometheus_metrics_exported() {
7272
let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap();
7373
let opts = get_integ_server_options();
7474
let mut raw_client = opts
75-
.connect_no_namespace(rt.telemetry().get_temporal_metric_meter(), None)
75+
.connect_no_namespace(rt.telemetry().get_temporal_metric_meter())
7676
.await
7777
.unwrap();
7878
assert!(raw_client.get_client().capabilities().is_some());
@@ -125,7 +125,7 @@ async fn one_slot_worker_reports_available_slot() {
125125

126126
let client = Arc::new(
127127
get_integ_server_options()
128-
.connect(worker_cfg.namespace.clone(), None, None)
128+
.connect(worker_cfg.namespace.clone(), None)
129129
.await
130130
.expect("Must connect"),
131131
);
@@ -453,7 +453,7 @@ fn runtime_new() {
453453
let opts = get_integ_server_options();
454454
handle.block_on(async {
455455
let mut raw_client = opts
456-
.connect_no_namespace(rt.telemetry().get_temporal_metric_meter(), None)
456+
.connect_no_namespace(rt.telemetry().get_temporal_metric_meter())
457457
.await
458458
.unwrap();
459459
assert!(raw_client.get_client().capabilities().is_some());

tests/integ_tests/visibility_tests.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ async fn client_list_open_closed_workflow_executions() {
9191
async fn client_create_namespace() {
9292
let client = Arc::new(
9393
get_integ_server_options()
94-
.connect(NAMESPACE.to_owned(), None, None)
94+
.connect(NAMESPACE.to_owned(), None)
9595
.await
9696
.expect("Must connect"),
9797
);
@@ -138,7 +138,7 @@ async fn client_create_namespace() {
138138
async fn client_describe_namespace() {
139139
let client = Arc::new(
140140
get_integ_server_options()
141-
.connect(NAMESPACE.to_owned(), None, None)
141+
.connect(NAMESPACE.to_owned(), None)
142142
.await
143143
.expect("Must connect"),
144144
);

0 commit comments

Comments
 (0)