@@ -5,7 +5,7 @@ mod config;
5
5
mod mock;
6
6
7
7
#[ cfg( any( feature = "test_helpers" , test) ) ]
8
- pub use mock:: { NoRecvNetwork , UnboundedDuplexStream , MockNetwork } ;
8
+ pub use mock:: { MockNetwork , NoRecvNetwork , UnboundedDuplexStream } ;
9
9
10
10
use async_trait:: async_trait;
11
11
use quinn:: { Endpoint , RecvStream , SendStream } ;
@@ -27,6 +27,8 @@ const BYTES_PER_U64: usize = 8;
27
27
28
28
/// Error message emitted when reading a message length from the stream fails
29
29
const ERR_READ_MESSAGE_LENGTH : & str = "error reading message length from stream" ;
30
+ /// Error thrown when a stream finishes early
31
+ const ERR_STREAM_FINISHED_EARLY : & str = "stream finished early" ;
30
32
31
33
// ---------
32
34
// | Trait |
@@ -121,6 +123,60 @@ pub enum ReadWriteOrder {
121
123
WriteFirst ,
122
124
}
123
125
126
+ /// A wrapper around a raw [u8] buffer that tracks the cursor within the buffer
127
+ /// to allow partial fills across cancelled futures
128
+ ///
129
+ /// Similar to `tokio::io::ReadBuf` but takes ownership of the underlying buffer to
130
+ /// avoid coloring interfaces with lifetime parameters
131
+ #[ derive( Debug ) ]
132
+ struct BufferWithCursor {
133
+ /// The underlying buffer
134
+ buffer : Vec < u8 > ,
135
+ /// The current cursor position
136
+ cursor : usize ,
137
+ }
138
+
139
+ impl BufferWithCursor {
140
+ /// Create a new buffer with a cursor at the start of the buffer
141
+ pub fn new ( buf : Vec < u8 > ) -> Self {
142
+ assert_eq ! (
143
+ buf. len( ) ,
144
+ buf. capacity( ) ,
145
+ "buffer must be fully initialized"
146
+ ) ;
147
+
148
+ Self {
149
+ buffer : buf,
150
+ cursor : 0 ,
151
+ }
152
+ }
153
+
154
+ /// The number of bytes remaining in the buffer
155
+ pub fn remaining ( & self ) -> usize {
156
+ self . buffer . capacity ( ) - self . cursor
157
+ }
158
+
159
+ /// Whether the buffer is full
160
+ pub fn is_full ( & self ) -> bool {
161
+ self . remaining ( ) == 0
162
+ }
163
+
164
+ /// Get a mutable reference to the empty section of the underlying buffer
165
+ pub fn get_unfilled ( & mut self ) -> & mut [ u8 ] {
166
+ & mut self . buffer [ self . cursor ..]
167
+ }
168
+
169
+ /// Advance the cursor by `n` bytes
170
+ pub fn advance_cursor ( & mut self , n : usize ) {
171
+ self . cursor += n
172
+ }
173
+
174
+ /// Take ownership of the underlying buffer
175
+ pub fn into_vec ( self ) -> Vec < u8 > {
176
+ self . buffer
177
+ }
178
+ }
179
+
124
180
/// Implements an MpcNetwork on top of QUIC
125
181
#[ derive( Debug ) ]
126
182
pub struct QuicTwoPartyNet {
@@ -132,6 +188,19 @@ pub struct QuicTwoPartyNet {
132
188
local_addr : SocketAddr ,
133
189
/// Addresses of the counterparties in the MPC
134
190
peer_addr : SocketAddr ,
191
+ /// A buffered message length read from the stream
192
+ ///
193
+ /// In the case that the whole message is not available yet, `read_exact` may block
194
+ /// and the `read` method may be cancelled. We buffer the message length to avoid re-reading
195
+ /// the message length incorrectly from the stream. Essentially this field gives cancellation
196
+ /// safety to the `read` method.
197
+ buffered_message_length : Option < u64 > ,
198
+ /// A buffered partial message read from the stream
199
+ ///
200
+ /// This buffer exists to provide cancellation safety to a `read` future as the underlying `quinn`
201
+ /// stream is not cancellation safe, i.e. if a `ReadBuf` future is dropped, the buffer is dropped with
202
+ /// it and the message skipped
203
+ buffered_message : Option < BufferWithCursor > ,
135
204
/// The send side of the bidirectional stream
136
205
send_stream : Option < SendStream > ,
137
206
/// The receive side of the bidirectional stream
@@ -148,6 +217,8 @@ impl<'a> QuicTwoPartyNet {
148
217
local_addr,
149
218
peer_addr,
150
219
connected : false ,
220
+ buffered_message_length : None ,
221
+ buffered_message : None ,
151
222
send_stream : None ,
152
223
recv_stream : None ,
153
224
}
@@ -244,14 +315,7 @@ impl<'a> QuicTwoPartyNet {
244
315
245
316
/// Read a message length from the stream
246
317
async fn read_message_length ( & mut self ) -> Result < u64 , MpcNetworkError > {
247
- let mut read_buffer = vec ! [ 0u8 ; BYTES_PER_U64 ] ;
248
- self . recv_stream
249
- . as_mut ( )
250
- . unwrap ( )
251
- . read_exact ( & mut read_buffer)
252
- . await
253
- . map_err ( |e| MpcNetworkError :: RecvError ( e. to_string ( ) ) ) ?;
254
-
318
+ let read_buffer = self . read_bytes ( BYTES_PER_U64 ) . await ?;
255
319
Ok ( u64:: from_le_bytes ( read_buffer. try_into ( ) . map_err (
256
320
|_| MpcNetworkError :: SerializationError ( ERR_READ_MESSAGE_LENGTH . to_string ( ) ) ,
257
321
) ?) )
@@ -269,15 +333,30 @@ impl<'a> QuicTwoPartyNet {
269
333
270
334
/// Read exactly `n` bytes from the stream
271
335
async fn read_bytes ( & mut self , num_bytes : usize ) -> Result < Vec < u8 > , MpcNetworkError > {
272
- let mut read_buffer = vec ! [ 0u8 ; num_bytes] ;
273
- self . recv_stream
274
- . as_mut ( )
275
- . unwrap ( )
276
- . read_exact ( & mut read_buffer)
277
- . await
278
- . map_err ( |e| MpcNetworkError :: RecvError ( e. to_string ( ) ) ) ?;
336
+ // Allocate a buffer for the next message if one does not already exist
337
+ if self . buffered_message . is_none ( ) {
338
+ self . buffered_message = Some ( BufferWithCursor :: new ( vec ! [ 0u8 ; num_bytes] ) ) ;
339
+ }
340
+
341
+ // Read until the buffer is full
342
+ let read_buffer = self . buffered_message . as_mut ( ) . unwrap ( ) ;
343
+ while !read_buffer. is_full ( ) {
344
+ let bytes_read = self
345
+ . recv_stream
346
+ . as_mut ( )
347
+ . unwrap ( )
348
+ . read ( read_buffer. get_unfilled ( ) )
349
+ . await
350
+ . map_err ( |e| MpcNetworkError :: RecvError ( e. to_string ( ) ) ) ?
351
+ . ok_or ( MpcNetworkError :: RecvError (
352
+ ERR_STREAM_FINISHED_EARLY . to_string ( ) ,
353
+ ) ) ?;
354
+
355
+ read_buffer. advance_cursor ( bytes_read) ;
356
+ }
279
357
280
- Ok ( read_buffer. to_vec ( ) )
358
+ // Take ownership of the buffer, and reset the buffered message to `None`
359
+ Ok ( self . buffered_message . take ( ) . unwrap ( ) . into_vec ( ) )
281
360
}
282
361
}
283
362
@@ -298,10 +377,18 @@ impl MpcNetwork for QuicTwoPartyNet {
298
377
}
299
378
300
379
async fn receive_message ( & mut self ) -> Result < NetworkOutbound , MpcNetworkError > {
301
- // Read the message length from the buffer
302
- let len = self . read_message_length ( ) . await ?;
380
+ // Read the message length from the buffer if already read from the stream
381
+ if self . buffered_message_length . is_none ( ) {
382
+ self . buffered_message_length = Some ( self . read_message_length ( ) . await ?) ;
383
+ }
384
+
385
+ // Read the data from the stream
386
+ let len = self . buffered_message_length . unwrap ( ) ;
303
387
let bytes = self . read_bytes ( len as usize ) . await ?;
304
388
389
+ // Reset the message length buffer after the data has been pulled from the stream
390
+ self . buffered_message_length = None ;
391
+
305
392
// Deserialize the message
306
393
serde_json:: from_slice ( & bytes)
307
394
. map_err ( |err| MpcNetworkError :: SerializationError ( err. to_string ( ) ) )
0 commit comments