Skip to content

Commit 79434f2

Browse files
authored
Migrate clients implementation from reqwest to hyper (#252)
* migrate from reqwest to hyper Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * updated test + some nits Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * reorder dependencies Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * more clean up. TODOs left: docstrings, debug! and info! lines where missing Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * one more formatting nit Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * minor http client refactor and update detector endpoint handling Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * update openai client Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * removed eventsource implementation, connected http_body_utils stream impl to eventsource_stream eventsource impl Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * simplified RequestLike, ResponseLike Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * use &self instead of self for HttpClient get/post Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * make sure error responses are handled properly and error handling improvements Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * doc strings Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * removed RequestLike and ResponseLike Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * [image tested successfully on cluster] marginal improvement to debug log statement Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * fix tokio spawn tracing (needs explicit instrumentation to avoid creating new trace) and add missing traceparent propagation from responses Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * more async tracing instrumentation fixes Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * more tracing logic fixes - traces are being split due to async instrumentation logic Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * added incoming response debug print - in process of debugging traceparent injection issue Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * attempt fix tranceparent injection issue - testing on cluster Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * more debug testing Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * nits & add RequestBody, ResponseBody Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * simplified client error handling and added missing handling for chat generation clients Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * simplify error handling Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * small nits Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * rebase formatting Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * fix client connect and request timeouts Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> * fix variable misuse and revert error message changes Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com> --------- Signed-off-by: Paul Scoropan <1paulscoropan@gmail.com>
1 parent a4964b1 commit 79434f2

26 files changed

+1654
-774
lines changed

Cargo.lock

Lines changed: 520 additions & 154 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,28 +20,36 @@ async-stream = "0.3.5"
2020
axum = { version = "0.7.5", features = ["json"] }
2121
axum-extra = "0.9.3"
2222
clap = { version = "4.5.15", features = ["derive", "env"] }
23+
eventsource = "0.5.0"
24+
eventsource-stream = "0.2.3"
2325
futures = "0.3.30"
26+
futures-core = "0.3.30"
27+
futures-timer = "3.0.3"
2428
ginepro = "0.8.1"
29+
http-body-util = "0.1.2"
2530
http-serde = "2.1.1"
2631
hyper = { version = "1.4.1", features = ["http1", "http2", "server"] }
32+
hyper-rustls = { version = "0.27.3", features = ["ring"]}
2733
hyper-util = { version = "0.1.7", features = ["server-auto", "server-graceful", "tokio"] }
34+
mime = "0.3.17"
2835
mio = "1.0.2"
2936
opentelemetry = { version = "0.24.0", features = ["trace", "metrics"] }
3037
opentelemetry-http = { version = "0.13.0", features = ["reqwest"] }
3138
opentelemetry-otlp = { version = "0.17.0", features = ["http-proto"] }
3239
opentelemetry_sdk = { version = "0.24.1", features = ["rt-tokio", "metrics"] }
40+
pin-project-lite = "0.2.15"
3341
prost = "0.13.1"
3442
reqwest = { version = "0.12.5", features = ["blocking", "rustls-tls", "json"] }
3543
reqwest-eventsource = "0.6.0"
36-
rustls = {version = "0.23.12", default-features = false, features = ["std"]}
44+
rustls = {version = "0.23.12", default-features = false, features = ["std", "ring"]}
3745
rustls-pemfile = "2.1.3"
3846
rustls-webpki = "0.102.6"
3947
serde = { version = "1.0.206", features = ["derive"] }
4048
serde_json = "1.0.124"
4149
serde_yml = "0.0.11"
4250
thiserror = "1.0.63"
4351
tokio = { version = "1.39.2", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "fs"] }
44-
tokio-rustls = { version = "0.26.0" }
52+
tokio-rustls = { version = "0.26.0", features = ["ring"]}
4553
tokio-stream = { version = "0.1.15", features = ["sync"] }
4654
tonic = { version = "0.12.1", features = ["tls", "tls-roots", "tls-webpki-roots"] }
4755
tower-http = { version = "0.5.2", features = ["trace"] }
@@ -51,6 +59,7 @@ tracing-opentelemetry = "0.25.0"
5159
tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] }
5260
url = "2.5.2"
5361
uuid = { version = "1.10.0", features = ["v4", "fast-rng"] }
62+
hyper-timeout = "0.5.2"
5463

5564
[build-dependencies]
5665
tonic-build = "0.12.1"

src/clients.rs

Lines changed: 51 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,17 @@ use async_trait::async_trait;
2727
use axum::http::{Extensions, HeaderMap};
2828
use futures::Stream;
2929
use ginepro::LoadBalancedChannel;
30-
use tokio::{fs::File, io::AsyncReadExt};
30+
use hyper_timeout::TimeoutConnector;
31+
use hyper_util::rt::TokioExecutor;
3132
use tonic::{metadata::MetadataMap, Request};
32-
use tracing::{debug, instrument};
33+
use tracing::{debug, instrument, Span};
34+
use tracing_opentelemetry::OpenTelemetrySpanExt;
3335
use url::Url;
3436

3537
use crate::{
3638
config::{ServiceConfig, Tls},
3739
health::HealthCheckResult,
38-
tracing_utils::with_traceparent_header,
40+
utils::{tls, trace::with_traceparent_header},
3941
};
4042

4143
pub mod errors;
@@ -60,7 +62,7 @@ pub use generation::GenerationClient;
6062

6163
pub mod openai;
6264

63-
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(60);
65+
const DEFAULT_CONNECT_TIMEOUT_SEC: u64 = 60;
6466
const DEFAULT_REQUEST_TIMEOUT_SEC: u64 = 600;
6567

6668
pub type BoxStream<T> = Pin<Box<dyn Stream<Item = T> + Send>>;
@@ -198,7 +200,10 @@ impl ClientMap {
198200
}
199201

200202
#[instrument(skip_all, fields(hostname = service_config.hostname))]
201-
pub async fn create_http_client(default_port: u16, service_config: &ServiceConfig) -> HttpClient {
203+
pub async fn create_http_client(
204+
default_port: u16,
205+
service_config: &ServiceConfig,
206+
) -> Result<HttpClient, Error> {
202207
let port = service_config.port.unwrap_or(default_port);
203208
let protocol = match service_config.tls {
204209
Some(_) => "https",
@@ -210,53 +215,36 @@ pub async fn create_http_client(default_port: u16, service_config: &ServiceConfi
210215
.set_port(Some(port))
211216
.unwrap_or_else(|_| panic!("error setting port: {}", port));
212217
debug!(%base_url, "creating HTTP client");
218+
219+
let connect_timeout = Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SEC);
213220
let request_timeout = Duration::from_secs(
214221
service_config
215222
.request_timeout
216223
.unwrap_or(DEFAULT_REQUEST_TIMEOUT_SEC),
217224
);
218-
let mut builder = reqwest::ClientBuilder::new()
219-
.connect_timeout(DEFAULT_CONNECT_TIMEOUT)
220-
.timeout(request_timeout);
221-
if let Some(Tls::Config(tls_config)) = &service_config.tls {
222-
let mut cert_buf = Vec::new();
223-
let cert_path = tls_config.cert_path.as_ref().unwrap().as_path();
224-
File::open(cert_path)
225-
.await
226-
.unwrap_or_else(|error| panic!("error reading cert from {cert_path:?}: {error}"))
227-
.read_to_end(&mut cert_buf)
228-
.await
229-
.unwrap();
230-
231-
if let Some(key_path) = &tls_config.key_path {
232-
File::open(key_path)
233-
.await
234-
.unwrap_or_else(|error| panic!("error reading key from {key_path:?}: {error}"))
235-
.read_to_end(&mut cert_buf)
236-
.await
237-
.unwrap();
238-
}
239-
let identity = reqwest::Identity::from_pem(&cert_buf)
240-
.unwrap_or_else(|error| panic!("error parsing bundled client certificate: {error}"));
241-
242-
builder = builder.use_rustls_tls().identity(identity);
243-
builder = builder.danger_accept_invalid_certs(tls_config.insecure.unwrap_or(false));
244225

245-
if let Some(client_ca_cert_path) = &tls_config.client_ca_cert_path {
246-
let ca_cert = tokio::fs::read(client_ca_cert_path)
226+
let https_conn_builder = match &service_config.tls {
227+
Some(Tls::Config(tls)) => hyper_rustls::HttpsConnectorBuilder::new().with_tls_config(
228+
tls::build_client_config(tls)
247229
.await
248-
.unwrap_or_else(|error| {
249-
panic!("error reading cert from {client_ca_cert_path:?}: {error}")
250-
});
251-
let cacert = reqwest::Certificate::from_pem(&ca_cert)
252-
.unwrap_or_else(|error| panic!("error parsing ca cert: {error}"));
253-
builder = builder.add_root_certificate(cacert)
254-
}
255-
}
256-
let client = builder
257-
.build()
258-
.unwrap_or_else(|error| panic!("error creating http client: {error}"));
259-
HttpClient::new(base_url, client)
230+
.map_err(|e| e.into_client_error())?,
231+
),
232+
Some(_) => panic!("unexpected unresolved TLS in client builder"),
233+
None => hyper_rustls::HttpsConnectorBuilder::new()
234+
.with_tls_config(tls::build_insecure_client_config()),
235+
};
236+
let https_conn = https_conn_builder
237+
.https_or_http()
238+
.enable_http1()
239+
.enable_http2()
240+
.build();
241+
242+
let mut timeout_conn = TimeoutConnector::new(https_conn);
243+
timeout_conn.set_connect_timeout(Some(connect_timeout));
244+
245+
let client =
246+
hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build(timeout_conn);
247+
Ok(HttpClient::new(base_url, request_timeout, client))
260248
}
261249

262250
#[instrument(skip_all, fields(hostname = service_config.hostname))]
@@ -273,13 +261,14 @@ pub async fn create_grpc_client<C>(
273261
let mut base_url = Url::parse(&format!("{}://{}", protocol, &service_config.hostname)).unwrap();
274262
base_url.set_port(Some(port)).unwrap();
275263
debug!(%base_url, "creating gRPC client");
264+
let connect_timeout = Duration::from_secs(DEFAULT_REQUEST_TIMEOUT_SEC);
276265
let request_timeout = Duration::from_secs(
277266
service_config
278267
.request_timeout
279268
.unwrap_or(DEFAULT_REQUEST_TIMEOUT_SEC),
280269
);
281270
let mut builder = LoadBalancedChannel::builder((service_config.hostname.clone(), port))
282-
.connect_timeout(DEFAULT_CONNECT_TIMEOUT)
271+
.connect_timeout(connect_timeout)
283272
.timeout(request_timeout);
284273

285274
let client_tls_config = if let Some(Tls::Config(tls_config)) = &service_config.tls {
@@ -349,31 +338,39 @@ pub fn is_valid_hostname(hostname: &str) -> bool {
349338
/// Turns a gRPC client request body of type `T` and header map into a `tonic::Request<T>`.
350339
/// Will also inject the current `traceparent` header into the request based on the current span.
351340
fn grpc_request_with_headers<T>(request: T, headers: HeaderMap) -> Request<T> {
352-
let headers = with_traceparent_header(headers);
341+
let ctx = Span::current().context();
342+
let headers = with_traceparent_header(&ctx, headers);
353343
let metadata = MetadataMap::from_headers(headers);
354344
Request::from_parts(metadata, Extensions::new(), request)
355345
}
356346

357347
#[cfg(test)]
358348
mod tests {
359349
use errors::grpc_to_http_code;
350+
use http_body_util::BodyExt;
360351
use hyper::{http, StatusCode};
361-
use reqwest::Response;
362352

363353
use super::*;
364354
use crate::{
355+
clients::http::Response,
365356
health::{HealthCheckResult, HealthStatus},
366357
pb::grpc::health::v1::{health_check_response::ServingStatus, HealthCheckResponse},
367358
};
368359

369-
async fn mock_http_response(
370-
status: StatusCode,
371-
body: &str,
372-
) -> Result<Response, reqwest::Error> {
373-
Ok(reqwest::Response::from(
360+
async fn mock_http_response(status: StatusCode, body: &str) -> Result<Response, Error> {
361+
Ok(Response(
374362
http::Response::builder()
375363
.status(status)
376-
.body(body.to_string())
364+
.body(
365+
body.to_string()
366+
.map_err(|e| {
367+
panic!(
368+
"infallible error parsing string body in test response: {}",
369+
e
370+
)
371+
})
372+
.boxed(),
373+
)
377374
.unwrap(),
378375
))
379376
}

src/clients/chunker.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ use crate::{
3939
caikit_data_model::nlp::{ChunkerTokenizationStreamResult, Token, TokenizationResults},
4040
grpc::health::v1::{health_client::HealthClient, HealthCheckRequest},
4141
},
42-
tracing_utils::trace_context_from_grpc_response,
42+
utils::trace::trace_context_from_grpc_response,
4343
};
4444

4545
const DEFAULT_PORT: u16 = 8085;

src/clients/detector.rs

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,15 @@ use std::fmt::Debug;
1919

2020
use axum::http::HeaderMap;
2121
use hyper::StatusCode;
22-
use reqwest::Response;
23-
use serde::{Deserialize, Serialize};
24-
use tracing::info;
22+
use serde::Deserialize;
23+
use tracing::instrument;
2524
use url::Url;
2625

26+
use super::{
27+
http::{HttpClientExt, RequestBody, ResponseBody},
28+
Error,
29+
};
30+
2731
pub mod text_contents;
2832
pub use text_contents::*;
2933
pub mod text_chat;
@@ -33,9 +37,6 @@ pub use text_context_doc::*;
3337
pub mod text_generation;
3438
pub use text_generation::*;
3539

36-
use super::{Error, HttpClient};
37-
use crate::tracing_utils::{trace_context_from_http_response, with_traceparent_header};
38-
3940
const DEFAULT_PORT: u16 = 8080;
4041
const DETECTOR_ID_HEADER_NAME: &str = "detector-id";
4142

@@ -54,24 +55,55 @@ impl From<DetectorError> for Error {
5455
}
5556
}
5657

57-
/// Make a POST request for an HTTP detector client and return the response.
58-
/// Also injects the `traceparent` header from the current span and traces the response.
59-
pub async fn post_with_headers<T: Debug + Serialize>(
60-
client: HttpClient,
61-
url: Url,
62-
request: T,
63-
headers: HeaderMap,
64-
model_id: &str,
65-
) -> Result<Response, Error> {
66-
let headers = with_traceparent_header(headers);
67-
info!(?url, ?headers, ?request, "sending client request");
68-
let response = client
69-
.post(url)
70-
.headers(headers)
71-
.header(DETECTOR_ID_HEADER_NAME, model_id)
72-
.json(&request)
73-
.send()
74-
.await?;
75-
trace_context_from_http_response(&response);
76-
Ok(response)
58+
/// This trait should be implemented by all detectors.
59+
/// If the detector has an HTTP client (currently all detector clients are HTTP) this trait will
60+
/// implicitly extend the client with an HTTP detector specific post function.
61+
pub trait DetectorClient {}
62+
63+
/// Provides a helper extension for HTTP detector clients.
64+
pub trait DetectorClientExt: HttpClientExt {
65+
/// Wraps the post function with extra detector functionality
66+
/// (detector id header injection & error handling)
67+
async fn post_to_detector<U: ResponseBody>(
68+
&self,
69+
model_id: &str,
70+
url: Url,
71+
headers: HeaderMap,
72+
request: impl RequestBody,
73+
) -> Result<U, Error>;
74+
75+
/// Wraps call to inner HTTP client endpoint function.
76+
fn endpoint(&self, path: &str) -> Url;
77+
}
78+
79+
impl<C: DetectorClient + HttpClientExt> DetectorClientExt for C {
80+
#[instrument(skip_all, fields(model_id, url))]
81+
async fn post_to_detector<U: ResponseBody>(
82+
&self,
83+
model_id: &str,
84+
url: Url,
85+
headers: HeaderMap,
86+
request: impl RequestBody,
87+
) -> Result<U, Error> {
88+
let mut headers = headers;
89+
headers.append(DETECTOR_ID_HEADER_NAME, model_id.parse().unwrap());
90+
let response = self.inner().post(url, headers, request).await?;
91+
92+
let status = response.status();
93+
match status {
94+
StatusCode::OK => Ok(response.json().await?),
95+
_ => Err(response
96+
.json::<DetectorError>()
97+
.await
98+
.unwrap_or(DetectorError {
99+
code: status.as_u16(),
100+
message: "".into(),
101+
})
102+
.into()),
103+
}
104+
}
105+
106+
fn endpoint(&self, path: &str) -> Url {
107+
self.inner().endpoint(path)
108+
}
77109
}

0 commit comments

Comments
 (0)