1
1
use crate :: error:: Error ;
2
- use crate :: net:: Socket ;
2
+ use crate :: net:: { Socket , SocketExt } ;
3
3
use bytes:: BytesMut ;
4
+ use futures_util:: Sink ;
4
5
use std:: ops:: ControlFlow ;
6
+ use std:: pin:: Pin ;
7
+ use std:: task:: { ready, Context , Poll } ;
5
8
use std:: { cmp, io} ;
6
9
7
10
use crate :: io:: { AsyncRead , AsyncReadExt , ProtocolDecode , ProtocolEncode } ;
@@ -13,6 +16,7 @@ pub struct BufferedSocket<S> {
13
16
socket : S ,
14
17
write_buf : WriteBuffer ,
15
18
read_buf : ReadBuffer ,
19
+ wants_bytes : usize ,
16
20
}
17
21
18
22
pub struct WriteBuffer {
@@ -42,6 +46,7 @@ impl<S: Socket> BufferedSocket<S> {
42
46
read : BytesMut :: new ( ) ,
43
47
available : BytesMut :: with_capacity ( DEFAULT_BUF_SIZE ) ,
44
48
} ,
49
+ wants_bytes : 0 ,
45
50
}
46
51
}
47
52
@@ -56,6 +61,25 @@ impl<S: Socket> BufferedSocket<S> {
56
61
. await
57
62
}
58
63
64
+ pub fn poll_try_read < F , R > (
65
+ & mut self ,
66
+ cx : & mut Context < ' _ > ,
67
+ mut try_read : F ,
68
+ ) -> Poll < Result < R , Error > >
69
+ where
70
+ F : FnMut ( & mut BytesMut ) -> Result < ControlFlow < R , usize > , Error > ,
71
+ {
72
+ loop {
73
+ // Read if we want bytes
74
+ ready ! ( self . poll_handle_read( cx) ?) ;
75
+
76
+ match try_read ( & mut self . read_buf . read ) ? {
77
+ ControlFlow :: Continue ( read_len) => self . wants_bytes = read_len,
78
+ ControlFlow :: Break ( ret) => return Poll :: Ready ( Ok ( ret) ) ,
79
+ } ;
80
+ }
81
+ }
82
+
59
83
/// Retryable read operation.
60
84
///
61
85
/// The callback should check the contents of the buffer passed to it and either:
@@ -125,8 +149,8 @@ impl<S: Socket> BufferedSocket<S> {
125
149
pub async fn flush ( & mut self ) -> io:: Result < ( ) > {
126
150
while !self . write_buf . is_empty ( ) {
127
151
let written = self . socket . write ( self . write_buf . get ( ) ) . await ?;
152
+ // Consume does the sanity check
128
153
self . write_buf . consume ( written) ;
129
- self . write_buf . sanity_check ( ) ;
130
154
}
131
155
132
156
self . socket . flush ( ) . await ?;
@@ -154,7 +178,37 @@ impl<S: Socket> BufferedSocket<S> {
154
178
socket : Box :: new ( self . socket ) ,
155
179
write_buf : self . write_buf ,
156
180
read_buf : self . read_buf ,
181
+ wants_bytes : self . wants_bytes ,
182
+ }
183
+ }
184
+
185
+ fn poll_handle_read ( & mut self , cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
186
+ // Because of how `BytesMut` works, we should only be shifting capacity back and forth
187
+ // between `read` and `available` unless we have to read an oversize message.
188
+ while self . read_buf . len ( ) < self . wants_bytes {
189
+ self . read_buf
190
+ . reserve ( self . wants_bytes - self . read_buf . len ( ) ) ;
191
+
192
+ let read = ready ! ( self . socket. poll_read( cx, & mut self . read_buf. available) ?) ;
193
+
194
+ if read == 0 {
195
+ return Poll :: Ready ( Err ( io:: Error :: new (
196
+ io:: ErrorKind :: UnexpectedEof ,
197
+ format ! (
198
+ "expected to read {} bytes, got {} bytes at EOF" ,
199
+ self . wants_bytes,
200
+ self . read_buf. len( )
201
+ ) ,
202
+ ) ) ) ;
203
+ }
204
+
205
+ // we've read at least enough for `wants_bytes`, so we don't want more.
206
+ self . wants_bytes = 0 ;
207
+
208
+ self . read_buf . advance ( read) ;
157
209
}
210
+
211
+ Poll :: Ready ( Ok ( ( ) ) )
158
212
}
159
213
}
160
214
@@ -326,4 +380,41 @@ impl ReadBuffer {
326
380
self . available = BytesMut :: with_capacity ( DEFAULT_BUF_SIZE ) ;
327
381
}
328
382
}
383
+
384
+ fn len ( & self ) -> usize {
385
+ self . read . len ( )
386
+ }
387
+ }
388
+
389
+ impl < S : Socket > Sink < & [ u8 ] > for BufferedSocket < S > {
390
+ type Error = crate :: Error ;
391
+
392
+ fn poll_ready ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < crate :: Result < ( ) > > {
393
+ if self . write_buf . bytes_written >= DEFAULT_BUF_SIZE {
394
+ self . poll_flush ( cx)
395
+ } else {
396
+ Poll :: Ready ( Ok ( ( ) ) )
397
+ }
398
+ }
399
+
400
+ fn start_send ( mut self : Pin < & mut Self > , item : & [ u8 ] ) -> crate :: Result < ( ) > {
401
+ self . write_buffer_mut ( ) . put_slice ( item) ;
402
+ Ok ( ( ) )
403
+ }
404
+
405
+ fn poll_flush ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < crate :: Result < ( ) > > {
406
+ let this = & mut * self ;
407
+
408
+ while !this. write_buf . is_empty ( ) {
409
+ let written = ready ! ( this. socket. poll_write( cx, this. write_buf. get( ) ) ?) ;
410
+ // Consume does the sanity check
411
+ this. write_buf . consume ( written) ;
412
+ }
413
+ this. socket . poll_flush ( cx) . map_err ( Into :: into)
414
+ }
415
+
416
+ fn poll_close ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < crate :: Result < ( ) > > {
417
+ ready ! ( self . as_mut( ) . poll_flush( cx) ) ?;
418
+ self . socket . poll_shutdown ( cx) . map_err ( Into :: into)
419
+ }
329
420
}
0 commit comments