@@ -4,6 +4,8 @@ use std::os::unix::io::{AsRawFd, RawFd};
4
4
#[ cfg( windows) ]
5
5
use std:: os:: windows:: io:: { AsRawSocket , RawSocket } ;
6
6
use std:: pin:: Pin ;
7
+ #[ cfg( feature = "early-data" ) ]
8
+ use std:: task:: Waker ;
7
9
use std:: task:: { Context , Poll } ;
8
10
9
11
use rustls:: ClientConnection ;
@@ -20,7 +22,7 @@ pub struct TlsStream<IO> {
20
22
pub ( crate ) state : TlsState ,
21
23
22
24
#[ cfg( feature = "early-data" ) ]
23
- pub ( crate ) early_waker : Option < std :: task :: Waker > ,
25
+ pub ( crate ) early_waker : Option < Waker > ,
24
26
}
25
27
26
28
impl < IO > TlsStream < IO > {
@@ -152,78 +154,70 @@ where
152
154
let mut stream =
153
155
Stream :: new ( & mut this. io , & mut this. session ) . set_eof ( !this. state . readable ( ) ) ;
154
156
155
- #[ allow( clippy:: match_single_binding) ]
156
- match this. state {
157
- #[ cfg( feature = "early-data" ) ]
158
- TlsState :: EarlyData ( ref mut pos, ref mut data) => {
159
- use std:: io:: Write ;
160
-
161
- // write early data
162
- if let Some ( mut early_data) = stream. session . early_data ( ) {
163
- let len = match early_data. write ( buf) {
164
- Ok ( n) => n,
165
- Err ( err) => return Poll :: Ready ( Err ( err) ) ,
166
- } ;
167
- if len != 0 {
168
- data. extend_from_slice ( & buf[ ..len] ) ;
169
- return Poll :: Ready ( Ok ( len) ) ;
170
- }
171
- }
172
-
173
- // complete handshake
174
- while stream. session . is_handshaking ( ) {
175
- ready ! ( stream. handshake( cx) ) ?;
176
- }
177
-
178
- // write early data (fallback)
179
- if !stream. session . is_early_data_accepted ( ) {
180
- while * pos < data. len ( ) {
181
- let len = ready ! ( stream. as_mut_pin( ) . poll_write( cx, & data[ * pos..] ) ) ?;
182
- * pos += len;
183
- }
184
- }
185
-
186
- // end
187
- this. state = TlsState :: Stream ;
188
-
189
- if let Some ( waker) = this. early_waker . take ( ) {
190
- waker. wake ( ) ;
191
- }
192
-
193
- stream. as_mut_pin ( ) . poll_write ( cx, buf)
157
+ #[ cfg( feature = "early-data" ) ]
158
+ {
159
+ let bufs = [ io:: IoSlice :: new ( buf) ] ;
160
+ let written = ready ! ( poll_handle_early_data(
161
+ & mut this. state,
162
+ & mut stream,
163
+ & mut this. early_waker,
164
+ cx,
165
+ & bufs
166
+ ) ) ?;
167
+ if written != 0 {
168
+ return Poll :: Ready ( Ok ( written) ) ;
194
169
}
195
- _ => stream. as_mut_pin ( ) . poll_write ( cx, buf) ,
196
170
}
171
+
172
+ stream. as_mut_pin ( ) . poll_write ( cx, buf)
197
173
}
198
174
199
- fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
175
+ /// Note: that it does not guarantee the final data to be sent.
176
+ /// To be cautious, you must manually call `flush`.
177
+ fn poll_write_vectored (
178
+ self : Pin < & mut Self > ,
179
+ cx : & mut Context < ' _ > ,
180
+ bufs : & [ io:: IoSlice < ' _ > ] ,
181
+ ) -> Poll < io:: Result < usize > > {
200
182
let this = self . get_mut ( ) ;
201
183
let mut stream =
202
184
Stream :: new ( & mut this. io , & mut this. session ) . set_eof ( !this. state . readable ( ) ) ;
203
185
204
186
#[ cfg( feature = "early-data" ) ]
205
187
{
206
- if let TlsState :: EarlyData ( ref mut pos, ref mut data) = this. state {
207
- // complete handshake
208
- while stream. session . is_handshaking ( ) {
209
- ready ! ( stream. handshake( cx) ) ?;
210
- }
188
+ let written = ready ! ( poll_handle_early_data(
189
+ & mut this. state,
190
+ & mut stream,
191
+ & mut this. early_waker,
192
+ cx,
193
+ bufs
194
+ ) ) ?;
195
+ if written != 0 {
196
+ return Poll :: Ready ( Ok ( written) ) ;
197
+ }
198
+ }
211
199
212
- // write early data (fallback)
213
- if !stream. session . is_early_data_accepted ( ) {
214
- while * pos < data. len ( ) {
215
- let len = ready ! ( stream. as_mut_pin( ) . poll_write( cx, & data[ * pos..] ) ) ?;
216
- * pos += len;
217
- }
218
- }
200
+ stream. as_mut_pin ( ) . poll_write_vectored ( cx, bufs)
201
+ }
219
202
220
- this. state = TlsState :: Stream ;
203
+ #[ inline]
204
+ fn is_write_vectored ( & self ) -> bool {
205
+ true
206
+ }
221
207
222
- if let Some ( waker) = this. early_waker . take ( ) {
223
- waker. wake ( ) ;
224
- }
225
- }
226
- }
208
+ fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
209
+ let this = self . get_mut ( ) ;
210
+ let mut stream =
211
+ Stream :: new ( & mut this. io , & mut this. session ) . set_eof ( !this. state . readable ( ) ) ;
212
+
213
+ #[ cfg( feature = "early-data" ) ]
214
+ ready ! ( poll_handle_early_data(
215
+ & mut this. state,
216
+ & mut stream,
217
+ & mut this. early_waker,
218
+ cx,
219
+ & [ ]
220
+ ) ) ?;
227
221
228
222
stream. as_mut_pin ( ) . poll_flush ( cx)
229
223
}
@@ -248,3 +242,69 @@ where
248
242
stream. as_mut_pin ( ) . poll_shutdown ( cx)
249
243
}
250
244
}
245
+
246
+ #[ cfg( feature = "early-data" ) ]
247
+ fn poll_handle_early_data < IO > (
248
+ state : & mut TlsState ,
249
+ stream : & mut Stream < IO , ClientConnection > ,
250
+ early_waker : & mut Option < Waker > ,
251
+ cx : & mut Context < ' _ > ,
252
+ bufs : & [ io:: IoSlice < ' _ > ] ,
253
+ ) -> Poll < io:: Result < usize > >
254
+ where
255
+ IO : AsyncRead + AsyncWrite + Unpin ,
256
+ {
257
+ if let TlsState :: EarlyData ( pos, data) = state {
258
+ use std:: io:: Write ;
259
+
260
+ // write early data
261
+ if let Some ( mut early_data) = stream. session . early_data ( ) {
262
+ let mut written = 0 ;
263
+
264
+ for buf in bufs {
265
+ if buf. is_empty ( ) {
266
+ continue ;
267
+ }
268
+
269
+ let len = match early_data. write ( buf) {
270
+ Ok ( 0 ) => break ,
271
+ Ok ( n) => n,
272
+ Err ( err) => return Poll :: Ready ( Err ( err) ) ,
273
+ } ;
274
+
275
+ written += len;
276
+ data. extend_from_slice ( & buf[ ..len] ) ;
277
+
278
+ if len < buf. len ( ) {
279
+ break ;
280
+ }
281
+ }
282
+
283
+ if written != 0 {
284
+ return Poll :: Ready ( Ok ( written) ) ;
285
+ }
286
+ }
287
+
288
+ // complete handshake
289
+ while stream. session . is_handshaking ( ) {
290
+ ready ! ( stream. handshake( cx) ) ?;
291
+ }
292
+
293
+ // write early data (fallback)
294
+ if !stream. session . is_early_data_accepted ( ) {
295
+ while * pos < data. len ( ) {
296
+ let len = ready ! ( stream. as_mut_pin( ) . poll_write( cx, & data[ * pos..] ) ) ?;
297
+ * pos += len;
298
+ }
299
+ }
300
+
301
+ // end
302
+ * state = TlsState :: Stream ;
303
+
304
+ if let Some ( waker) = early_waker. take ( ) {
305
+ waker. wake ( ) ;
306
+ }
307
+ }
308
+
309
+ Poll :: Ready ( Ok ( 0 ) )
310
+ }
0 commit comments