Skip to content

Commit 409e74e

Browse files
authored
HTTP CONNECT proxy support (#714)
Fixes #309
1 parent 00b5507 commit 409e74e

File tree

3 files changed

+99
-1
lines changed

3 files changed

+99
-1
lines changed

client/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@ telemetry = ["dep:opentelemetry"]
1717
anyhow = "1.0"
1818
async-trait = "0.1"
1919
backoff = "0.4"
20+
base64 = "0.21.7"
2021
derive_builder = { workspace = true }
2122
derive_more = "0.99"
2223
futures = "0.3"
2324
futures-retry = "0.6.0"
2425
http = "0.2"
26+
hyper = { version = "0.14.28" }
2527
once_cell = { workspace = true }
2628
opentelemetry = { workspace = true, features = ["metrics"], optional = true }
2729
parking_lot = "0.12"

client/src/lib.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
extern crate tracing;
99

1010
mod metrics;
11+
mod proxy;
1112
mod raw;
1213
mod retry;
1314
mod worker_registry;
1415
mod workflow_handle;
1516

17+
pub use crate::proxy::HttpConnectProxyOptions;
1618
pub use crate::retry::{CallType, RetryClient, RETRYABLE_ERROR_CODES};
1719
pub use raw::{HealthService, OperatorService, TestService, WorkflowService};
1820
pub use temporal_sdk_core_protos::temporal::api::{
@@ -142,6 +144,10 @@ pub struct ClientOptions {
142144
/// be applied if the headers don't already have an "Authorization" header.
143145
#[builder(default)]
144146
pub api_key: Option<String>,
147+
148+
/// HTTP CONNECT proxy to use for this client.
149+
#[builder(default)]
150+
pub http_connect_proxy: Option<HttpConnectProxyOptions>,
145151
}
146152

147153
/// Configuration options for TLS
@@ -403,7 +409,12 @@ impl ClientOptions {
403409
} else {
404410
channel
405411
};
406-
let channel = channel.connect().await?;
412+
// If there is a proxy, we have to connect that way
413+
let channel = if let Some(proxy) = self.http_connect_proxy.as_ref() {
414+
proxy.connect_endpoint(&channel).await?
415+
} else {
416+
channel.connect().await?
417+
};
407418
let service = ServiceBuilder::new()
408419
.layer_fn(move |channel| GrpcMetricSvc {
409420
inner: channel,

client/src/proxy.rs

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
use base64::prelude::*;
2+
use hyper::header;
3+
use std::future::Future;
4+
use std::pin::Pin;
5+
use std::task::Context;
6+
use std::task::Poll;
7+
use tokio::net::TcpStream;
8+
use tonic::transport::Channel;
9+
use tonic::transport::Endpoint;
10+
use tower::{service_fn, Service};
11+
12+
/// Options for HTTP CONNECT proxy.
13+
#[derive(Clone, Debug)]
14+
pub struct HttpConnectProxyOptions {
15+
/// The host:port to proxy through.
16+
pub target_addr: String,
17+
/// Optional HTTP basic auth for the proxy as user/pass tuple.
18+
pub basic_auth: Option<(String, String)>,
19+
}
20+
21+
impl HttpConnectProxyOptions {
22+
/// Create a channel from the given endpoint that uses the HTTP CONNECT proxy.
23+
pub async fn connect_endpoint(
24+
&self,
25+
endpoint: &Endpoint,
26+
) -> Result<Channel, tonic::transport::Error> {
27+
let proxy_options = self.clone();
28+
let svc_fn = service_fn(move |uri: tonic::transport::Uri| {
29+
let proxy_options = proxy_options.clone();
30+
async move { proxy_options.connect(uri).await }
31+
});
32+
endpoint.connect_with_connector(svc_fn).await
33+
}
34+
35+
async fn connect(
36+
&self,
37+
uri: tonic::transport::Uri,
38+
) -> anyhow::Result<hyper::upgrade::Upgraded> {
39+
debug!("Connecting to {} via proxy at {}", uri, self.target_addr);
40+
// Create CONNECT request
41+
let mut req_build = hyper::Request::builder().method("CONNECT").uri(uri);
42+
if let Some((user, pass)) = &self.basic_auth {
43+
let creds = BASE64_STANDARD.encode(format!("{}:{}", user, pass));
44+
req_build = req_build.header(header::PROXY_AUTHORIZATION, format!("Basic {}", creds));
45+
}
46+
let req = req_build.body(hyper::Body::empty())?;
47+
48+
// We have to create a client with a specific connector because Hyper is
49+
// not letting us change the HTTP/2 authority
50+
let client =
51+
hyper::Client::builder().build(OverrideAddrConnector(self.target_addr.clone()));
52+
53+
// Send request
54+
let res = client.request(req).await?;
55+
if res.status().is_success() {
56+
Ok(hyper::upgrade::on(res).await?)
57+
} else {
58+
Err(anyhow::anyhow!(
59+
"CONNECT call failed with status: {}",
60+
res.status()
61+
))
62+
}
63+
}
64+
}
65+
66+
#[derive(Clone)]
67+
struct OverrideAddrConnector(String);
68+
69+
impl Service<hyper::Uri> for OverrideAddrConnector {
70+
type Response = TcpStream;
71+
72+
type Error = anyhow::Error;
73+
74+
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
75+
76+
fn poll_ready(&mut self, _ctx: &mut Context<'_>) -> Poll<anyhow::Result<()>> {
77+
Poll::Ready(Ok(()))
78+
}
79+
80+
fn call(&mut self, _uri: hyper::Uri) -> Self::Future {
81+
let target_addr = self.0.clone();
82+
let fut = async move { Ok(TcpStream::connect(target_addr).await?) };
83+
Box::pin(fut)
84+
}
85+
}

0 commit comments

Comments
 (0)