Skip to content

Commit 2eed344

Browse files
authored
Merge pull request #59 from Fishrock123/async-h1-pooling
feat: h1 connection pooling
2 parents 06994bb + 93bb1db commit 2eed344

File tree

4 files changed

+256
-32
lines changed

4 files changed

+256
-32
lines changed

Cargo.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,25 @@ rustdoc-args = ["--cfg", "feature=\"docs\""]
2222
[features]
2323
default = ["h1_client"]
2424
docs = ["h1_client", "curl_client", "wasm_client", "hyper_client"]
25-
h1_client = ["async-h1", "async-std", "async-native-tls"]
26-
h1_client_rustls = ["async-h1", "async-std", "async-tls"]
25+
h1_client = ["async-h1", "async-std", "async-native-tls", "deadpool", "futures"]
26+
h1_client_rustls = ["async-h1", "async-std", "async-tls", "deadpool", "futures"]
2727
native_client = ["curl_client", "wasm_client"]
2828
curl_client = ["isahc", "async-std"]
2929
wasm_client = ["js-sys", "web-sys", "wasm-bindgen", "wasm-bindgen-futures", "futures"]
3030
hyper_client = ["hyper", "hyper-tls", "http-types/hyperium_http", "futures-util"]
3131

3232
[dependencies]
3333
async-trait = "0.1.37"
34+
dashmap = "4.0.2"
3435
http-types = "2.3.0"
3536
log = "0.4.7"
3637

3738
# h1_client
3839
async-h1 = { version = "2.0.0", optional = true }
3940
async-std = { version = "1.6.0", default-features = false, optional = true }
4041
async-native-tls = { version = "0.3.1", optional = true }
42+
deadpool = { version = "0.7.0", optional = true }
43+
futures = { version = "0.3.8", optional = true }
4144

4245
# h1_client_rustls
4346
async-tls = { version = "0.10.0", optional = true }

src/h1.rs renamed to src/h1/mod.rs

Lines changed: 93 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,44 @@
1-
//! http-client implementation for async-h1.
1+
//! http-client implementation for async-h1, with connecton pooling ("Keep-Alive").
22
3-
use super::{async_trait, Error, HttpClient, Request, Response};
3+
use std::fmt::Debug;
4+
use std::net::SocketAddr;
45

56
use async_h1::client;
7+
use async_std::net::TcpStream;
8+
use dashmap::DashMap;
9+
use deadpool::managed::Pool;
610
use http_types::StatusCode;
711

8-
/// Async-h1 based HTTP Client.
9-
#[derive(Debug)]
12+
#[cfg(not(feature = "h1_client_rustls"))]
13+
use async_native_tls::TlsStream;
14+
#[cfg(feature = "h1_client_rustls")]
15+
use async_tls::client::TlsStream;
16+
17+
use super::{async_trait, Error, HttpClient, Request, Response};
18+
19+
mod tcp;
20+
mod tls;
21+
22+
use tcp::{TcpConnWrapper, TcpConnection};
23+
use tls::{TlsConnWrapper, TlsConnection};
24+
25+
// This number is based on a few random benchmarks and see whatever gave decent perf vs resource use.
26+
const DEFAULT_MAX_CONCURRENT_CONNECTIONS: usize = 50;
27+
28+
type HttpPool = DashMap<SocketAddr, Pool<TcpStream, std::io::Error>>;
29+
type HttpsPool = DashMap<SocketAddr, Pool<TlsStream<TcpStream>, Error>>;
30+
31+
/// Async-h1 based HTTP Client, with connecton pooling ("Keep-Alive").
1032
pub struct H1Client {
11-
_priv: (),
33+
http_pools: HttpPool,
34+
https_pools: HttpsPool,
35+
max_concurrent_connections: usize,
36+
}
37+
38+
impl Debug for H1Client {
39+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40+
f.write_str("H1Client")
41+
}
1242
}
1343

1444
impl Default for H1Client {
@@ -20,13 +50,28 @@ impl Default for H1Client {
2050
impl H1Client {
2151
/// Create a new instance.
2252
pub fn new() -> Self {
23-
Self { _priv: () }
53+
Self {
54+
http_pools: DashMap::new(),
55+
https_pools: DashMap::new(),
56+
max_concurrent_connections: DEFAULT_MAX_CONCURRENT_CONNECTIONS,
57+
}
58+
}
59+
60+
/// Create a new instance.
61+
pub fn with_max_connections(max: usize) -> Self {
62+
Self {
63+
http_pools: DashMap::new(),
64+
https_pools: DashMap::new(),
65+
max_concurrent_connections: max,
66+
}
2467
}
2568
}
2669

2770
#[async_trait]
2871
impl HttpClient for H1Client {
2972
async fn send(&self, mut req: Request) -> Result<Response, Error> {
73+
req.insert_header("Connection", "keep-alive");
74+
3075
// Insert host
3176
let host = req
3277
.url()
@@ -57,40 +102,58 @@ impl HttpClient for H1Client {
57102

58103
match scheme {
59104
"http" => {
60-
let stream = async_std::net::TcpStream::connect(addr).await?;
105+
let pool_ref = if let Some(pool_ref) = self.http_pools.get(&addr) {
106+
pool_ref
107+
} else {
108+
let manager = TcpConnection::new(addr);
109+
let pool = Pool::<TcpStream, std::io::Error>::new(
110+
manager,
111+
self.max_concurrent_connections,
112+
);
113+
self.http_pools.insert(addr, pool);
114+
self.http_pools.get(&addr).unwrap()
115+
};
116+
117+
// Deadlocks are prevented by cloning an inner pool Arc and dropping the original locking reference before we await.
118+
let pool = pool_ref.clone();
119+
std::mem::drop(pool_ref);
120+
121+
let stream = pool.get().await?;
61122
req.set_peer_addr(stream.peer_addr().ok());
62123
req.set_local_addr(stream.local_addr().ok());
63-
client::connect(stream, req).await
124+
client::connect(TcpConnWrapper::new(stream), req).await
64125
}
65126
"https" => {
66-
let raw_stream = async_std::net::TcpStream::connect(addr).await?;
67-
req.set_peer_addr(raw_stream.peer_addr().ok());
68-
req.set_local_addr(raw_stream.local_addr().ok());
69-
let tls_stream = add_tls(host, raw_stream).await?;
70-
client::connect(tls_stream, req).await
127+
let pool_ref = if let Some(pool_ref) = self.https_pools.get(&addr) {
128+
pool_ref
129+
} else {
130+
let manager = TlsConnection::new(host.clone(), addr);
131+
let pool = Pool::<TlsStream<TcpStream>, Error>::new(
132+
manager,
133+
self.max_concurrent_connections,
134+
);
135+
self.https_pools.insert(addr, pool);
136+
self.https_pools.get(&addr).unwrap()
137+
};
138+
139+
// Deadlocks are prevented by cloning an inner pool Arc and dropping the original locking reference before we await.
140+
let pool = pool_ref.clone();
141+
std::mem::drop(pool_ref);
142+
143+
let stream = pool
144+
.get()
145+
.await
146+
.map_err(|e| Error::from_str(400, e.to_string()))?;
147+
req.set_peer_addr(stream.get_ref().peer_addr().ok());
148+
req.set_local_addr(stream.get_ref().local_addr().ok());
149+
150+
client::connect(TlsConnWrapper::new(stream), req).await
71151
}
72152
_ => unreachable!(),
73153
}
74154
}
75155
}
76156

77-
#[cfg(not(feature = "h1_client_rustls"))]
78-
async fn add_tls(
79-
host: String,
80-
stream: async_std::net::TcpStream,
81-
) -> Result<async_native_tls::TlsStream<async_std::net::TcpStream>, async_native_tls::Error> {
82-
async_native_tls::connect(host, stream).await
83-
}
84-
85-
#[cfg(feature = "h1_client_rustls")]
86-
async fn add_tls(
87-
host: String,
88-
stream: async_std::net::TcpStream,
89-
) -> std::io::Result<async_tls::client::TlsStream<async_std::net::TcpStream>> {
90-
let connector = async_tls::TlsConnector::default();
91-
connector.connect(host, stream).await
92-
}
93-
94157
#[cfg(test)]
95158
mod tests {
96159
use super::*;

src/h1/tcp.rs

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
use std::fmt::Debug;
2+
use std::net::SocketAddr;
3+
use std::pin::Pin;
4+
5+
use async_std::net::TcpStream;
6+
use async_trait::async_trait;
7+
use deadpool::managed::{Manager, Object, RecycleResult};
8+
use futures::io::{AsyncRead, AsyncWrite};
9+
use futures::task::{Context, Poll};
10+
11+
#[derive(Clone, Debug)]
12+
pub(crate) struct TcpConnection {
13+
addr: SocketAddr,
14+
}
15+
impl TcpConnection {
16+
pub(crate) fn new(addr: SocketAddr) -> Self {
17+
Self { addr }
18+
}
19+
}
20+
21+
pub(crate) struct TcpConnWrapper {
22+
conn: Object<TcpStream, std::io::Error>,
23+
}
24+
impl TcpConnWrapper {
25+
pub(crate) fn new(conn: Object<TcpStream, std::io::Error>) -> Self {
26+
Self { conn }
27+
}
28+
}
29+
30+
impl AsyncRead for TcpConnWrapper {
31+
fn poll_read(
32+
mut self: Pin<&mut Self>,
33+
cx: &mut Context<'_>,
34+
buf: &mut [u8],
35+
) -> Poll<Result<usize, std::io::Error>> {
36+
Pin::new(&mut *self.conn).poll_read(cx, buf)
37+
}
38+
}
39+
40+
impl AsyncWrite for TcpConnWrapper {
41+
fn poll_write(
42+
mut self: Pin<&mut Self>,
43+
cx: &mut Context<'_>,
44+
buf: &[u8],
45+
) -> Poll<std::io::Result<usize>> {
46+
Pin::new(&mut *self.conn).poll_write(cx, buf)
47+
}
48+
49+
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
50+
Pin::new(&mut *self.conn).poll_flush(cx)
51+
}
52+
53+
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
54+
Pin::new(&mut *self.conn).poll_close(cx)
55+
}
56+
}
57+
58+
#[async_trait]
59+
impl Manager<TcpStream, std::io::Error> for TcpConnection {
60+
async fn create(&self) -> Result<TcpStream, std::io::Error> {
61+
Ok(TcpStream::connect(self.addr).await?)
62+
}
63+
64+
async fn recycle(&self, _conn: &mut TcpStream) -> RecycleResult<std::io::Error> {
65+
Ok(())
66+
}
67+
}

src/h1/tls.rs

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
use std::fmt::Debug;
2+
use std::net::SocketAddr;
3+
use std::pin::Pin;
4+
5+
use async_std::net::TcpStream;
6+
use async_trait::async_trait;
7+
use deadpool::managed::{Manager, Object, RecycleResult};
8+
use futures::io::{AsyncRead, AsyncWrite};
9+
use futures::task::{Context, Poll};
10+
11+
#[cfg(not(feature = "h1_client_rustls"))]
12+
use async_native_tls::TlsStream;
13+
#[cfg(feature = "h1_client_rustls")]
14+
use async_tls::client::TlsStream;
15+
16+
use crate::Error;
17+
18+
#[derive(Clone, Debug)]
19+
pub(crate) struct TlsConnection {
20+
host: String,
21+
addr: SocketAddr,
22+
}
23+
impl TlsConnection {
24+
pub(crate) fn new(host: String, addr: SocketAddr) -> Self {
25+
Self { host, addr }
26+
}
27+
}
28+
29+
pub(crate) struct TlsConnWrapper {
30+
conn: Object<TlsStream<TcpStream>, Error>,
31+
}
32+
impl TlsConnWrapper {
33+
pub(crate) fn new(conn: Object<TlsStream<TcpStream>, Error>) -> Self {
34+
Self { conn }
35+
}
36+
}
37+
38+
impl AsyncRead for TlsConnWrapper {
39+
fn poll_read(
40+
mut self: Pin<&mut Self>,
41+
cx: &mut Context<'_>,
42+
buf: &mut [u8],
43+
) -> Poll<Result<usize, std::io::Error>> {
44+
Pin::new(&mut *self.conn).poll_read(cx, buf)
45+
}
46+
}
47+
48+
impl AsyncWrite for TlsConnWrapper {
49+
fn poll_write(
50+
mut self: Pin<&mut Self>,
51+
cx: &mut Context<'_>,
52+
buf: &[u8],
53+
) -> Poll<std::io::Result<usize>> {
54+
Pin::new(&mut *self.conn).poll_write(cx, buf)
55+
}
56+
57+
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
58+
Pin::new(&mut *self.conn).poll_flush(cx)
59+
}
60+
61+
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
62+
Pin::new(&mut *self.conn).poll_close(cx)
63+
}
64+
}
65+
66+
#[async_trait]
67+
impl Manager<TlsStream<TcpStream>, Error> for TlsConnection {
68+
async fn create(&self) -> Result<TlsStream<TcpStream>, Error> {
69+
let raw_stream = async_std::net::TcpStream::connect(self.addr).await?;
70+
let tls_stream = add_tls(&self.host, raw_stream).await?;
71+
Ok(tls_stream)
72+
}
73+
74+
async fn recycle(&self, _conn: &mut TlsStream<TcpStream>) -> RecycleResult<Error> {
75+
Ok(())
76+
}
77+
}
78+
79+
#[cfg(not(feature = "h1_client_rustls"))]
80+
async fn add_tls(
81+
host: &str,
82+
stream: TcpStream,
83+
) -> Result<async_native_tls::TlsStream<TcpStream>, async_native_tls::Error> {
84+
async_native_tls::connect(host, stream).await
85+
}
86+
87+
#[cfg(feature = "h1_client_rustls")]
88+
async fn add_tls(host: &str, stream: TcpStream) -> Result<TlsStream<TcpStream>, std::io::Error> {
89+
let connector = async_tls::TlsConnector::default();
90+
connector.connect(host, stream).await
91+
}

0 commit comments

Comments
 (0)