Skip to content

Commit e19282a

Browse files
committed
network: Buffer partial reads for cancellation safety
1 parent aae4236 commit e19282a

File tree

1 file changed

+106
-19
lines changed

1 file changed

+106
-19
lines changed

src/network.rs

Lines changed: 106 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ mod config;
55
mod mock;
66

77
#[cfg(any(feature = "test_helpers", test))]
8-
pub use mock::{NoRecvNetwork, UnboundedDuplexStream, MockNetwork};
8+
pub use mock::{MockNetwork, NoRecvNetwork, UnboundedDuplexStream};
99

1010
use async_trait::async_trait;
1111
use quinn::{Endpoint, RecvStream, SendStream};
@@ -27,6 +27,8 @@ const BYTES_PER_U64: usize = 8;
2727

2828
/// Error message emitted when reading a message length from the stream fails
2929
const ERR_READ_MESSAGE_LENGTH: &str = "error reading message length from stream";
30+
/// Error thrown when a stream finishes early
31+
const ERR_STREAM_FINISHED_EARLY: &str = "stream finished early";
3032

3133
// ---------
3234
// | Trait |
@@ -121,6 +123,60 @@ pub enum ReadWriteOrder {
121123
WriteFirst,
122124
}
123125

126+
/// A wrapper around a raw [u8] buffer that tracks the cursor within the buffer
127+
/// to allow partial fills across cancelled futures
128+
///
129+
/// Similar to `tokio::io::ReadBuf` but takes ownership of the underlying buffer to
130+
/// avoid coloring interfaces with lifetime parameters
131+
#[derive(Debug)]
132+
struct BufferWithCursor {
133+
/// The underlying buffer
134+
buffer: Vec<u8>,
135+
/// The current cursor position
136+
cursor: usize,
137+
}
138+
139+
impl BufferWithCursor {
140+
/// Create a new buffer with a cursor at the start of the buffer
141+
pub fn new(buf: Vec<u8>) -> Self {
142+
assert_eq!(
143+
buf.len(),
144+
buf.capacity(),
145+
"buffer must be fully initialized"
146+
);
147+
148+
Self {
149+
buffer: buf,
150+
cursor: 0,
151+
}
152+
}
153+
154+
/// The number of bytes remaining in the buffer
155+
pub fn remaining(&self) -> usize {
156+
self.buffer.capacity() - self.cursor
157+
}
158+
159+
/// Whether the buffer is full
160+
pub fn is_full(&self) -> bool {
161+
self.remaining() == 0
162+
}
163+
164+
/// Get a mutable reference to the empty section of the underlying buffer
165+
pub fn get_unfilled(&mut self) -> &mut [u8] {
166+
&mut self.buffer[self.cursor..]
167+
}
168+
169+
/// Advance the cursor by `n` bytes
170+
pub fn advance_cursor(&mut self, n: usize) {
171+
self.cursor += n
172+
}
173+
174+
/// Take ownership of the underlying buffer
175+
pub fn into_vec(self) -> Vec<u8> {
176+
self.buffer
177+
}
178+
}
179+
124180
/// Implements an MpcNetwork on top of QUIC
125181
#[derive(Debug)]
126182
pub struct QuicTwoPartyNet {
@@ -132,6 +188,19 @@ pub struct QuicTwoPartyNet {
132188
local_addr: SocketAddr,
133189
/// Addresses of the counterparties in the MPC
134190
peer_addr: SocketAddr,
191+
/// A buffered message length read from the stream
192+
///
193+
/// In the case that the whole message is not available yet, `read_exact` may block
194+
/// and the `read` method may be cancelled. We buffer the message length to avoid re-reading
195+
/// the message length incorrectly from the stream. Essentially this field gives cancellation
196+
/// safety to the `read` method.
197+
buffered_message_length: Option<u64>,
198+
/// A buffered partial message read from the stream
199+
///
200+
/// This buffer exists to provide cancellation safety to a `read` future as the underlying `quinn`
201+
/// stream is not cancellation safe, i.e. if a `ReadBuf` future is dropped, the buffer is dropped with
202+
/// it and the message skipped
203+
buffered_message: Option<BufferWithCursor>,
135204
/// The send side of the bidirectional stream
136205
send_stream: Option<SendStream>,
137206
/// The receive side of the bidirectional stream
@@ -148,6 +217,8 @@ impl<'a> QuicTwoPartyNet {
148217
local_addr,
149218
peer_addr,
150219
connected: false,
220+
buffered_message_length: None,
221+
buffered_message: None,
151222
send_stream: None,
152223
recv_stream: None,
153224
}
@@ -244,14 +315,7 @@ impl<'a> QuicTwoPartyNet {
244315

245316
/// Read a message length from the stream
246317
async fn read_message_length(&mut self) -> Result<u64, MpcNetworkError> {
247-
let mut read_buffer = vec![0u8; BYTES_PER_U64];
248-
self.recv_stream
249-
.as_mut()
250-
.unwrap()
251-
.read_exact(&mut read_buffer)
252-
.await
253-
.map_err(|e| MpcNetworkError::RecvError(e.to_string()))?;
254-
318+
let read_buffer = self.read_bytes(BYTES_PER_U64).await?;
255319
Ok(u64::from_le_bytes(read_buffer.try_into().map_err(
256320
|_| MpcNetworkError::SerializationError(ERR_READ_MESSAGE_LENGTH.to_string()),
257321
)?))
@@ -269,15 +333,30 @@ impl<'a> QuicTwoPartyNet {
269333

270334
/// Read exactly `n` bytes from the stream
271335
async fn read_bytes(&mut self, num_bytes: usize) -> Result<Vec<u8>, MpcNetworkError> {
272-
let mut read_buffer = vec![0u8; num_bytes];
273-
self.recv_stream
274-
.as_mut()
275-
.unwrap()
276-
.read_exact(&mut read_buffer)
277-
.await
278-
.map_err(|e| MpcNetworkError::RecvError(e.to_string()))?;
336+
// Allocate a buffer for the next message if one does not already exist
337+
if self.buffered_message.is_none() {
338+
self.buffered_message = Some(BufferWithCursor::new(vec![0u8; num_bytes]));
339+
}
340+
341+
// Read until the buffer is full
342+
let read_buffer = self.buffered_message.as_mut().unwrap();
343+
while !read_buffer.is_full() {
344+
let bytes_read = self
345+
.recv_stream
346+
.as_mut()
347+
.unwrap()
348+
.read(read_buffer.get_unfilled())
349+
.await
350+
.map_err(|e| MpcNetworkError::RecvError(e.to_string()))?
351+
.ok_or(MpcNetworkError::RecvError(
352+
ERR_STREAM_FINISHED_EARLY.to_string(),
353+
))?;
354+
355+
read_buffer.advance_cursor(bytes_read);
356+
}
279357

280-
Ok(read_buffer.to_vec())
358+
// Take ownership of the buffer, and reset the buffered message to `None`
359+
Ok(self.buffered_message.take().unwrap().into_vec())
281360
}
282361
}
283362

@@ -298,10 +377,18 @@ impl MpcNetwork for QuicTwoPartyNet {
298377
}
299378

300379
async fn receive_message(&mut self) -> Result<NetworkOutbound, MpcNetworkError> {
301-
// Read the message length from the buffer
302-
let len = self.read_message_length().await?;
380+
// Read the message length from the buffer if already read from the stream
381+
if self.buffered_message_length.is_none() {
382+
self.buffered_message_length = Some(self.read_message_length().await?);
383+
}
384+
385+
// Read the data from the stream
386+
let len = self.buffered_message_length.unwrap();
303387
let bytes = self.read_bytes(len as usize).await?;
304388

389+
// Reset the message length buffer after the data has been pulled from the stream
390+
self.buffered_message_length = None;
391+
305392
// Deserialize the message
306393
serde_json::from_slice(&bytes)
307394
.map_err(|err| MpcNetworkError::SerializationError(err.to_string()))

0 commit comments

Comments
 (0)