@@ -5,7 +5,7 @@ mod config;
5
5
mod mock;
6
6
7
7
#[ cfg( any( feature = "test_helpers" , test) ) ]
8
- pub use mock:: { NoRecvNetwork , UnboundedDuplexStream , MockNetwork } ;
8
+ pub use mock:: { MockNetwork , NoRecvNetwork , UnboundedDuplexStream } ;
9
9
10
10
use async_trait:: async_trait;
11
11
use quinn:: { Endpoint , RecvStream , SendStream } ;
@@ -27,6 +27,8 @@ const BYTES_PER_U64: usize = 8;
27
27
28
28
/// Error message emitted when reading a message length from the stream fails
29
29
const ERR_READ_MESSAGE_LENGTH : & str = "error reading message length from stream" ;
30
+ /// Error thrown when a stream finishes early
31
+ const ERR_STREAM_FINISHED_EARLY : & str = "stream finished early" ;
30
32
31
33
// ---------
32
34
// | Trait |
@@ -121,6 +123,62 @@ pub enum ReadWriteOrder {
121
123
WriteFirst ,
122
124
}
123
125
126
+ /// A wrapper around a raw `&[u8]` buffer that tracks a cursor within the buffer
127
+ /// to allow partial fills across cancelled futures
128
+ ///
129
+ /// Similar to `tokio::io::ReadBuf` but takes ownership of the underlying buffer to
130
+ /// avoid coloring interfaces with lifetime parameters
131
+ ///
132
+ /// TODO: Replace this with `std::io::Cursor` once it is stabilized
133
+ #[ derive( Debug ) ]
134
+ struct BufferWithCursor {
135
+ /// The underlying buffer
136
+ buffer : Vec < u8 > ,
137
+ /// The current cursor position
138
+ cursor : usize ,
139
+ }
140
+
141
+ impl BufferWithCursor {
142
+ /// Create a new buffer with a cursor at the start of the buffer
143
+ pub fn new ( buf : Vec < u8 > ) -> Self {
144
+ assert_eq ! (
145
+ buf. len( ) ,
146
+ buf. capacity( ) ,
147
+ "buffer must be fully initialized"
148
+ ) ;
149
+
150
+ Self {
151
+ buffer : buf,
152
+ cursor : 0 ,
153
+ }
154
+ }
155
+
156
+ /// The number of bytes remaining in the buffer
157
+ pub fn remaining ( & self ) -> usize {
158
+ self . buffer . capacity ( ) - self . cursor
159
+ }
160
+
161
+ /// Whether the buffer is full
162
+ pub fn is_full ( & self ) -> bool {
163
+ self . remaining ( ) == 0
164
+ }
165
+
166
+ /// Get a mutable reference to the empty section of the underlying buffer
167
+ pub fn get_unfilled ( & mut self ) -> & mut [ u8 ] {
168
+ & mut self . buffer [ self . cursor ..]
169
+ }
170
+
171
+ /// Advance the cursor by `n` bytes
172
+ pub fn advance_cursor ( & mut self , n : usize ) {
173
+ self . cursor += n
174
+ }
175
+
176
+ /// Take ownership of the underlying buffer
177
+ pub fn into_vec ( self ) -> Vec < u8 > {
178
+ self . buffer
179
+ }
180
+ }
181
+
124
182
/// Implements an MpcNetwork on top of QUIC
125
183
#[ derive( Debug ) ]
126
184
pub struct QuicTwoPartyNet {
@@ -132,6 +190,19 @@ pub struct QuicTwoPartyNet {
132
190
local_addr : SocketAddr ,
133
191
/// Addresses of the counterparties in the MPC
134
192
peer_addr : SocketAddr ,
193
+ /// A buffered message length read from the stream
194
+ ///
195
+ /// In the case that the whole message is not available yet, reads may block
196
+ /// and the `read_message` future may be cancelled by the executor.
197
+ /// We buffer the message length to avoid re-reading the message length incorrectly from
198
+ /// the stream
199
+ buffered_message_length : Option < u64 > ,
200
+ /// A buffered partial message read from the stream
201
+ ///
202
+ /// This buffer exists to provide cancellation safety to a `read` future as the underlying `quinn`
203
+ /// stream is not cancellation safe, i.e. if a `ReadBuf` future is dropped, the buffer is dropped with
204
+ /// it and the partially read data is skipped
205
+ buffered_message : Option < BufferWithCursor > ,
135
206
/// The send side of the bidirectional stream
136
207
send_stream : Option < SendStream > ,
137
208
/// The receive side of the bidirectional stream
@@ -148,6 +219,8 @@ impl<'a> QuicTwoPartyNet {
148
219
local_addr,
149
220
peer_addr,
150
221
connected : false ,
222
+ buffered_message_length : None ,
223
+ buffered_message : None ,
151
224
send_stream : None ,
152
225
recv_stream : None ,
153
226
}
@@ -244,14 +317,7 @@ impl<'a> QuicTwoPartyNet {
244
317
245
318
/// Read a message length from the stream
246
319
async fn read_message_length ( & mut self ) -> Result < u64 , MpcNetworkError > {
247
- let mut read_buffer = vec ! [ 0u8 ; BYTES_PER_U64 ] ;
248
- self . recv_stream
249
- . as_mut ( )
250
- . unwrap ( )
251
- . read_exact ( & mut read_buffer)
252
- . await
253
- . map_err ( |e| MpcNetworkError :: RecvError ( e. to_string ( ) ) ) ?;
254
-
320
+ let read_buffer = self . read_bytes ( BYTES_PER_U64 ) . await ?;
255
321
Ok ( u64:: from_le_bytes ( read_buffer. try_into ( ) . map_err (
256
322
|_| MpcNetworkError :: SerializationError ( ERR_READ_MESSAGE_LENGTH . to_string ( ) ) ,
257
323
) ?) )
@@ -269,15 +335,30 @@ impl<'a> QuicTwoPartyNet {
269
335
270
336
/// Read exactly `n` bytes from the stream
271
337
async fn read_bytes ( & mut self , num_bytes : usize ) -> Result < Vec < u8 > , MpcNetworkError > {
272
- let mut read_buffer = vec ! [ 0u8 ; num_bytes] ;
273
- self . recv_stream
274
- . as_mut ( )
275
- . unwrap ( )
276
- . read_exact ( & mut read_buffer)
277
- . await
278
- . map_err ( |e| MpcNetworkError :: RecvError ( e. to_string ( ) ) ) ?;
338
+ // Allocate a buffer for the next message if one does not already exist
339
+ if self . buffered_message . is_none ( ) {
340
+ self . buffered_message = Some ( BufferWithCursor :: new ( vec ! [ 0u8 ; num_bytes] ) ) ;
341
+ }
342
+
343
+ // Read until the buffer is full
344
+ let read_buffer = self . buffered_message . as_mut ( ) . unwrap ( ) ;
345
+ while !read_buffer. is_full ( ) {
346
+ let bytes_read = self
347
+ . recv_stream
348
+ . as_mut ( )
349
+ . unwrap ( )
350
+ . read ( read_buffer. get_unfilled ( ) )
351
+ . await
352
+ . map_err ( |e| MpcNetworkError :: RecvError ( e. to_string ( ) ) ) ?
353
+ . ok_or ( MpcNetworkError :: RecvError (
354
+ ERR_STREAM_FINISHED_EARLY . to_string ( ) ,
355
+ ) ) ?;
356
+
357
+ read_buffer. advance_cursor ( bytes_read) ;
358
+ }
279
359
280
- Ok ( read_buffer. to_vec ( ) )
360
+ // Take ownership of the buffer, and reset the buffered message to `None`
361
+ Ok ( self . buffered_message . take ( ) . unwrap ( ) . into_vec ( ) )
281
362
}
282
363
}
283
364
@@ -298,10 +379,18 @@ impl MpcNetwork for QuicTwoPartyNet {
298
379
}
299
380
300
381
async fn receive_message ( & mut self ) -> Result < NetworkOutbound , MpcNetworkError > {
301
- // Read the message length from the buffer
302
- let len = self . read_message_length ( ) . await ?;
382
+ // Read the message length from the buffer if available
383
+ if self . buffered_message_length . is_none ( ) {
384
+ self . buffered_message_length = Some ( self . read_message_length ( ) . await ?) ;
385
+ }
386
+
387
+ // Read the data from the stream
388
+ let len = self . buffered_message_length . unwrap ( ) ;
303
389
let bytes = self . read_bytes ( len as usize ) . await ?;
304
390
391
+ // Reset the message length buffer after the data has been pulled from the stream
392
+ self . buffered_message_length = None ;
393
+
305
394
// Deserialize the message
306
395
serde_json:: from_slice ( & bytes)
307
396
. map_err ( |err| MpcNetworkError :: SerializationError ( err. to_string ( ) ) )
0 commit comments