@@ -10,23 +10,15 @@ use bytes::buf::BufMut;
10
10
use futures:: io:: BufReader ;
11
11
use futures:: prelude:: * ;
12
12
use futures_lite:: AsyncReadExt ;
13
+ #[ cfg( feature = "tls" ) ]
14
+ pub use futures_rustls:: client:: TlsStream ;
13
15
use ignore_result:: Ignore ;
14
16
use tracing:: { debug, trace} ;
15
17
16
- #[ cfg( feature = "tls" ) ]
17
- mod tls {
18
- pub use std:: sync:: Arc ;
19
-
20
- pub use futures_rustls:: client:: TlsStream ;
21
- pub use futures_rustls:: TlsConnector ;
22
- pub use rustls:: pki_types:: ServerName ;
23
- pub use rustls:: ClientConfig ;
24
- }
25
- #[ cfg( feature = "tls" ) ]
26
- use tls:: * ;
27
-
28
18
use crate :: deadline:: Deadline ;
29
19
use crate :: endpoint:: { EndpointRef , IterableEndpoints } ;
20
+ #[ cfg( feature = "tls" ) ]
21
+ use crate :: tls:: TlsClient ;
30
22
31
23
#[ derive( Debug ) ]
32
24
pub enum Connection {
@@ -170,7 +162,7 @@ impl Connection {
170
162
#[ derive( Clone ) ]
171
163
pub struct Connector {
172
164
#[ cfg( feature = "tls" ) ]
173
- tls : Option < TlsConnector > ,
165
+ tls : Option < TlsClient > ,
174
166
timeout : Duration ,
175
167
}
176
168
@@ -186,15 +178,8 @@ impl Connector {
186
178
}
187
179
188
180
#[ cfg( feature = "tls" ) ]
189
- pub fn with_tls ( config : ClientConfig ) -> Self {
190
- Self { tls : Some ( TlsConnector :: from ( Arc :: new ( config) ) ) , timeout : Duration :: from_secs ( 10 ) }
191
- }
192
-
193
- #[ cfg( feature = "tls" ) ]
194
- async fn connect_tls ( & self , stream : TcpStream , host : & str ) -> Result < Connection > {
195
- let domain = ServerName :: try_from ( host) . unwrap ( ) . to_owned ( ) ;
196
- let stream = self . tls . as_ref ( ) . unwrap ( ) . connect ( domain, stream) . await ?;
197
- Ok ( Connection :: new_tls ( stream) )
181
+ pub fn with_tls ( client : TlsClient ) -> Self {
182
+ Self { tls : Some ( client) , timeout : Duration :: from_secs ( 10 ) }
198
183
}
199
184
200
185
pub fn timeout ( & self ) -> Duration {
@@ -205,34 +190,25 @@ impl Connector {
205
190
self . timeout = timeout;
206
191
}
207
192
208
- pub async fn connect ( & self , endpoint : EndpointRef < ' _ > , deadline : & mut Deadline ) -> Result < Connection > {
193
+ async fn connect_endpoint ( & self , endpoint : EndpointRef < ' _ > ) -> Result < Connection > {
209
194
if endpoint. tls {
210
195
#[ cfg( feature = "tls" ) ]
211
- if self . tls . is_none ( ) {
212
- return Err ( Error :: new ( ErrorKind :: Unsupported , "tls not configured" ) ) ;
213
- }
196
+ return match self . tls . as_ref ( ) {
197
+ None => return Err ( Error :: new ( ErrorKind :: Unsupported , "tls not configured" ) ) ,
198
+ Some ( client) => client. connect ( endpoint. host , endpoint. port ) . await . map ( Connection :: new_tls) ,
199
+ } ;
214
200
#[ cfg( not( feature = "tls" ) ) ]
215
201
return Err ( Error :: new ( ErrorKind :: Unsupported , "tls not supported" ) ) ;
216
202
}
203
+ TcpStream :: connect ( ( endpoint. host , endpoint. port ) ) . await . map ( Connection :: new_raw)
204
+ }
205
+
206
+ pub async fn connect ( & self , endpoint : EndpointRef < ' _ > , deadline : & mut Deadline ) -> Result < Connection > {
217
207
select ! {
208
+ biased;
209
+ r = self . connect_endpoint( endpoint) => r,
218
210
_ = unsafe { Pin :: new_unchecked( deadline) } => Err ( Error :: new( ErrorKind :: TimedOut , "deadline exceed" ) ) ,
219
211
_ = Timer :: after( self . timeout) => Err ( Error :: new( ErrorKind :: TimedOut , format!( "connection timeout{:?} exceed" , self . timeout) ) ) ,
220
- r = TcpStream :: connect( ( endpoint. host, endpoint. port) ) => {
221
- match r {
222
- Err ( err) => Err ( err) ,
223
- Ok ( sock) => {
224
- let connection = if endpoint. tls {
225
- #[ cfg( not( feature = "tls" ) ) ]
226
- unreachable!( "tls not supported" ) ;
227
- #[ cfg( feature = "tls" ) ]
228
- self . connect_tls( sock, endpoint. host) . await ?
229
- } else {
230
- Connection :: new_raw( sock)
231
- } ;
232
- Ok ( connection)
233
- } ,
234
- }
235
- } ,
236
212
}
237
213
}
238
214
0 commit comments