Skip to content

Commit 5ae4c2c

Browse files
authored
Merge pull request #214 from gkumbhat/add_header_passthrough
✨ Add header passthrough for NLP and detector clients
2 parents f26d490 + bfdc403 commit 5ae4c2c

File tree

10 files changed

+318
-62
lines changed

10 files changed

+318
-62
lines changed

config/config.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,8 @@ tls:
5353
detector_bundle_no_ca:
5454
cert_path: /path/to/client-bundle.pem
5555
insecure: true
56+
57+
# Following section can be used to configure the allowed headers that orchestrator will pass to
58+
# NLP provider and detectors. Note that, this section takes header keys, not values.
59+
# passthrough_headers:
60+
# - header-key

src/clients/detector.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
use std::collections::HashMap;
1919

20-
use hyper::StatusCode;
20+
use hyper::{HeaderMap, StatusCode};
2121
use serde::{Deserialize, Serialize};
2222

2323
use super::{create_http_clients, Error, HttpClient};
@@ -75,11 +75,13 @@ impl DetectorClient {
7575
&self,
7676
model_id: &str,
7777
request: ContentAnalysisRequest,
78+
headers: HeaderMap,
7879
) -> Result<Vec<Vec<ContentAnalysisResponse>>, Error> {
7980
let client = self.client(model_id)?;
8081
let url = client.base_url().as_str();
8182
let response = client
8283
.post(url)
84+
.headers(headers)
8385
.header(DETECTOR_ID_HEADER_NAME, model_id)
8486
.json(&request)
8587
.send()
@@ -104,11 +106,13 @@ impl DetectorClient {
104106
&self,
105107
model_id: &str,
106108
request: GenerationDetectionRequest,
109+
headers: HeaderMap,
107110
) -> Result<Vec<DetectionResult>, Error> {
108111
let client = self.client(model_id)?;
109112
let url = client.base_url().as_str();
110113
let response = client
111114
.post(url)
115+
.headers(headers)
112116
.header(DETECTOR_ID_HEADER_NAME, model_id)
113117
.json(&request)
114118
.send()
@@ -133,11 +137,13 @@ impl DetectorClient {
133137
&self,
134138
model_id: &str,
135139
request: ContextDocsDetectionRequest,
140+
headers: HeaderMap,
136141
) -> Result<Vec<DetectionResult>, Error> {
137142
let client = self.client(model_id)?;
138143
let url = client.base_url().as_str();
139144
let response = client
140145
.post(url)
146+
.headers(headers)
141147
.header(DETECTOR_ID_HEADER_NAME, model_id)
142148
.json(&request)
143149
.send()

src/clients/generation.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
use std::collections::HashMap;
1919

2020
use futures::{StreamExt, TryStreamExt};
21+
use hyper::HeaderMap;
2122
use tracing::debug;
2223

2324
use super::{BoxStream, Error, NlpClient, TgisClient};
@@ -85,6 +86,7 @@ impl GenerationClient {
8586
&self,
8687
model_id: String,
8788
text: String,
89+
headers: HeaderMap,
8890
) -> Result<(u32, Vec<String>), Error> {
8991
match &self.0 {
9092
Some(GenerationClientInner::Tgis(client)) => {
@@ -96,15 +98,17 @@ impl GenerationClient {
9698
truncate_input_tokens: 0,
9799
};
98100
debug!(%model_id, provider = "tgis", ?request, "sending tokenize request");
99-
let mut response = client.tokenize(request).await?;
101+
let mut response = client.tokenize(request, headers).await?;
100102
debug!(%model_id, provider = "tgis", ?response, "received tokenize response");
101103
let response = response.responses.swap_remove(0);
102104
Ok((response.token_count, response.tokens))
103105
}
104106
Some(GenerationClientInner::Nlp(client)) => {
105107
let request = TokenizationTaskRequest { text };
106108
debug!(%model_id, provider = "nlp", ?request, "sending tokenize request");
107-
let response = client.tokenization_task_predict(&model_id, request).await?;
109+
let response = client
110+
.tokenization_task_predict(&model_id, request, headers)
111+
.await?;
108112
debug!(%model_id, provider = "nlp", ?response, "received tokenize response");
109113
let tokens = response
110114
.results
@@ -122,6 +126,7 @@ impl GenerationClient {
122126
model_id: String,
123127
text: String,
124128
params: Option<GuardrailsTextGenerationParameters>,
129+
headers: HeaderMap,
125130
) -> Result<ClassifiedGeneratedTextResult, Error> {
126131
match &self.0 {
127132
Some(GenerationClientInner::Tgis(client)) => {
@@ -133,7 +138,7 @@ impl GenerationClient {
133138
params,
134139
};
135140
debug!(%model_id, provider = "tgis", ?request, "sending generate request");
136-
let response = client.generate(request).await?;
141+
let response = client.generate(request, headers).await?;
137142
debug!(%model_id, provider = "tgis", ?response, "received generate response");
138143
Ok(response.into())
139144
}
@@ -171,7 +176,7 @@ impl GenerationClient {
171176
};
172177
debug!(%model_id, provider = "nlp", ?request, "sending generate request");
173178
let response = client
174-
.text_generation_task_predict(&model_id, request)
179+
.text_generation_task_predict(&model_id, request, headers)
175180
.await?;
176181
debug!(%model_id, provider = "nlp", ?response, "received generate response");
177182
Ok(response.into())
@@ -185,6 +190,7 @@ impl GenerationClient {
185190
model_id: String,
186191
text: String,
187192
params: Option<GuardrailsTextGenerationParameters>,
193+
headers: HeaderMap,
188194
) -> Result<BoxStream<Result<ClassifiedGeneratedTextStreamResult, Error>>, Error> {
189195
match &self.0 {
190196
Some(GenerationClientInner::Tgis(client)) => {
@@ -197,7 +203,7 @@ impl GenerationClient {
197203
};
198204
debug!(%model_id, provider = "tgis", ?request, "sending generate_stream request");
199205
let response_stream = client
200-
.generate_stream(request)
206+
.generate_stream(request, headers)
201207
.await?
202208
.map_ok(Into::into)
203209
.boxed();
@@ -237,7 +243,7 @@ impl GenerationClient {
237243
};
238244
debug!(%model_id, provider = "nlp", ?request, "sending generate_stream request");
239245
let response_stream = client
240-
.server_streaming_text_generation_task_predict(&model_id, request)
246+
.server_streaming_text_generation_task_predict(&model_id, request, headers)
241247
.await?
242248
.map_ok(Into::into)
243249
.boxed();

src/clients/nlp.rs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717

1818
use std::collections::HashMap;
1919

20+
use axum::http::{Extensions, HeaderMap};
2021
use futures::{StreamExt, TryStreamExt};
2122
use ginepro::LoadBalancedChannel;
22-
use tonic::Request;
23+
use tonic::{metadata::MetadataMap, Request};
2324

2425
use super::{create_grpc_clients, BoxStream, Error};
2526
use crate::{
@@ -94,8 +95,9 @@ impl NlpClient {
9495
&self,
9596
model_id: &str,
9697
request: TokenizationTaskRequest,
98+
headers: HeaderMap,
9799
) -> Result<TokenizationResults, Error> {
98-
let request = request_with_model_id(request, model_id);
100+
let request = request_with_model_id(request, model_id, headers);
99101
Ok(self
100102
.client(model_id)?
101103
.tokenization_task_predict(request)
@@ -107,8 +109,9 @@ impl NlpClient {
107109
&self,
108110
model_id: &str,
109111
request: TokenClassificationTaskRequest,
112+
headers: HeaderMap,
110113
) -> Result<TokenClassificationResults, Error> {
111-
let request = request_with_model_id(request, model_id);
114+
let request = request_with_model_id(request, model_id, headers);
112115
Ok(self
113116
.client(model_id)?
114117
.token_classification_task_predict(request)
@@ -120,8 +123,9 @@ impl NlpClient {
120123
&self,
121124
model_id: &str,
122125
request: TextGenerationTaskRequest,
126+
headers: HeaderMap,
123127
) -> Result<GeneratedTextResult, Error> {
124-
let request = request_with_model_id(request, model_id);
128+
let request = request_with_model_id(request, model_id, headers);
125129
Ok(self
126130
.client(model_id)?
127131
.text_generation_task_predict(request)
@@ -133,8 +137,9 @@ impl NlpClient {
133137
&self,
134138
model_id: &str,
135139
request: ServerStreamingTextGenerationTaskRequest,
140+
headers: HeaderMap,
136141
) -> Result<BoxStream<Result<GeneratedTextStreamResult, Error>>, Error> {
137-
let request = request_with_model_id(request, model_id);
142+
let request = request_with_model_id(request, model_id, headers);
138143
let response_stream = self
139144
.client(model_id)?
140145
.server_streaming_text_generation_task_predict(request)
@@ -146,8 +151,9 @@ impl NlpClient {
146151
}
147152
}
148153

149-
fn request_with_model_id<T>(request: T, model_id: &str) -> Request<T> {
150-
let mut request = Request::new(request);
154+
fn request_with_model_id<T>(request: T, model_id: &str, headers: HeaderMap) -> Request<T> {
155+
let metadata = MetadataMap::from_headers(headers);
156+
let mut request = Request::from_parts(metadata, Extensions::new(), request);
151157
request
152158
.metadata_mut()
153159
.insert(MODEL_ID_HEADER_NAME, model_id.parse().unwrap());

src/clients/tgis.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
limitations under the License.
1515
1616
*/
17-
1817
use std::collections::HashMap;
1918

19+
use axum::http::HeaderMap;
2020
use futures::{StreamExt, TryStreamExt};
2121
use ginepro::LoadBalancedChannel;
2222
use tonic::Code;
@@ -99,6 +99,7 @@ impl TgisClient {
9999
pub async fn generate(
100100
&self,
101101
request: BatchedGenerationRequest,
102+
_headers: HeaderMap,
102103
) -> Result<BatchedGenerationResponse, Error> {
103104
let model_id = request.model_id.as_str();
104105
Ok(self.client(model_id)?.generate(request).await?.into_inner())
@@ -107,6 +108,7 @@ impl TgisClient {
107108
pub async fn generate_stream(
108109
&self,
109110
request: SingleGenerationRequest,
111+
_headers: HeaderMap,
110112
) -> Result<BoxStream<Result<GenerationResponse, Error>>, Error> {
111113
let model_id = request.model_id.as_str();
112114
let response_stream = self
@@ -122,6 +124,7 @@ impl TgisClient {
122124
pub async fn tokenize(
123125
&self,
124126
request: BatchedTokenizeRequest,
127+
_headers: HeaderMap,
125128
) -> Result<BatchedTokenizeResponse, Error> {
126129
let model_id = request.model_id.as_str();
127130
Ok(self.client(model_id)?.tokenize(request).await?.into_inner())

src/config.rs

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,18 @@
1616
*/
1717

1818
use std::{
19-
collections::HashMap,
19+
collections::{HashMap, HashSet},
2020
path::{Path, PathBuf},
2121
};
2222

2323
use serde::Deserialize;
24-
use tracing::{debug, error, warn};
24+
use tracing::{debug, error, info, warn};
2525

2626
use crate::clients::chunker::DEFAULT_MODEL_ID;
2727

28+
// Placeholder to add default allowed headers
29+
const DEFAULT_ALLOWED_HEADERS: &[&str] = &[];
30+
2831
#[derive(Debug, thiserror::Error)]
2932
pub enum Error {
3033
#[error("failed to read config from `{path}`: {error}")]
@@ -143,6 +146,9 @@ pub struct OrchestratorConfig {
143146
/// Map of TLS connections, allowing reuse across services
144147
/// that may require the same TLS information
145148
pub tls: Option<HashMap<String, TlsConfig>>,
149+
// List of header keys allowed to be passed to downstream servers
150+
#[serde(default)]
151+
pub passthrough_headers: HashSet<String>,
146152
}
147153

148154
impl OrchestratorConfig {
@@ -166,6 +172,27 @@ impl OrchestratorConfig {
166172
warn!("no chunker configs provided");
167173
}
168174

175+
if config.passthrough_headers.is_empty() {
176+
info!("No allowed headers specified");
177+
}
178+
179+
// Add default headers to allowed_headers list
180+
debug!(
181+
"Adding default headers: [{}]. ",
182+
DEFAULT_ALLOWED_HEADERS.join(", ")
183+
);
184+
185+
// Lowercase all header for case-insensitive comparison
186+
config.passthrough_headers = config
187+
.passthrough_headers
188+
.into_iter()
189+
.map(|h| h.to_lowercase())
190+
.collect::<HashSet<String>>();
191+
192+
config
193+
.passthrough_headers
194+
.extend(DEFAULT_ALLOWED_HEADERS.iter().map(|h| h.to_lowercase()));
195+
169196
config.apply_named_tls_configs()?;
170197
config.validate()?;
171198

@@ -521,4 +548,65 @@ tls:
521548
.expect_err("Config should not have been validated");
522549
assert!(matches!(error, Error::DetectorChunkerNotFound { .. }))
523550
}
551+
552+
#[test]
553+
fn test_passthrough_headers_empty_config() -> Result<(), Error> {
554+
let s = r#"
555+
generation:
556+
provider: tgis
557+
service:
558+
hostname: localhost
559+
port: 8000
560+
chunkers:
561+
sentence-en:
562+
type: sentence
563+
service:
564+
hostname: localhost
565+
port: 9000
566+
detectors:
567+
hap:
568+
service:
569+
hostname: localhost
570+
port: 9000
571+
tls: detector
572+
chunker_id: sentence-fr
573+
default_threshold: 0.5
574+
"#;
575+
let config: OrchestratorConfig = serde_yml::from_str(s).unwrap();
576+
assert!(config.passthrough_headers.is_empty());
577+
Ok(())
578+
}
579+
#[test]
580+
fn test_allowed_headers_passthrough_in_config() -> Result<(), Error> {
581+
let s = r#"
582+
generation:
583+
provider: tgis
584+
service:
585+
hostname: localhost
586+
port: 8000
587+
chunkers:
588+
sentence-en:
589+
type: sentence
590+
service:
591+
hostname: localhost
592+
port: 9000
593+
detectors:
594+
hap:
595+
service:
596+
hostname: localhost
597+
port: 9000
598+
tls: detector
599+
chunker_id: sentence-fr
600+
default_threshold: 0.5
601+
602+
passthrough_headers:
603+
- test-header
604+
"#;
605+
let config: OrchestratorConfig = serde_yml::from_str(s).unwrap();
606+
assert_eq!(
607+
config.passthrough_headers,
608+
HashSet::from(["test-header".to_string()])
609+
);
610+
Ok(())
611+
}
524612
}

0 commit comments

Comments
 (0)