Skip to content

Commit a55a636

Browse files
committed
network: Buffer partial reads for cancellation safety
1 parent aae4236 commit a55a636

File tree

1 file changed

+108
-19
lines changed

1 file changed

+108
-19
lines changed

src/network.rs

Lines changed: 108 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ mod config;
55
mod mock;
66

77
#[cfg(any(feature = "test_helpers", test))]
8-
pub use mock::{NoRecvNetwork, UnboundedDuplexStream, MockNetwork};
8+
pub use mock::{MockNetwork, NoRecvNetwork, UnboundedDuplexStream};
99

1010
use async_trait::async_trait;
1111
use quinn::{Endpoint, RecvStream, SendStream};
@@ -27,6 +27,8 @@ const BYTES_PER_U64: usize = 8;
2727

2828
/// Error message emitted when reading a message length from the stream fails
2929
const ERR_READ_MESSAGE_LENGTH: &str = "error reading message length from stream";
30+
/// Error thrown when a stream finishes early
31+
const ERR_STREAM_FINISHED_EARLY: &str = "stream finished early";
3032

3133
// ---------
3234
// | Trait |
@@ -121,6 +123,62 @@ pub enum ReadWriteOrder {
121123
WriteFirst,
122124
}
123125

126+
/// A wrapper around a raw `&[u8]` buffer that tracks a cursor within the buffer
127+
/// to allow partial fills across cancelled futures
128+
///
129+
/// Similar to `tokio::io::ReadBuf` but takes ownership of the underlying buffer to
130+
/// avoid coloring interfaces with lifetime parameters
131+
///
132+
/// TODO: Replace this with `std::io::Cursor` once it is stabilized
133+
#[derive(Debug)]
134+
struct BufferWithCursor {
135+
/// The underlying buffer
136+
buffer: Vec<u8>,
137+
/// The current cursor position
138+
cursor: usize,
139+
}
140+
141+
impl BufferWithCursor {
142+
/// Create a new buffer with a cursor at the start of the buffer
143+
pub fn new(buf: Vec<u8>) -> Self {
144+
assert_eq!(
145+
buf.len(),
146+
buf.capacity(),
147+
"buffer must be fully initialized"
148+
);
149+
150+
Self {
151+
buffer: buf,
152+
cursor: 0,
153+
}
154+
}
155+
156+
/// The number of bytes remaining in the buffer
157+
pub fn remaining(&self) -> usize {
158+
self.buffer.capacity() - self.cursor
159+
}
160+
161+
/// Whether the buffer is full
162+
pub fn is_full(&self) -> bool {
163+
self.remaining() == 0
164+
}
165+
166+
/// Get a mutable reference to the empty section of the underlying buffer
167+
pub fn get_unfilled(&mut self) -> &mut [u8] {
168+
&mut self.buffer[self.cursor..]
169+
}
170+
171+
/// Advance the cursor by `n` bytes
172+
pub fn advance_cursor(&mut self, n: usize) {
173+
self.cursor += n
174+
}
175+
176+
/// Take ownership of the underlying buffer
177+
pub fn into_vec(self) -> Vec<u8> {
178+
self.buffer
179+
}
180+
}
181+
124182
/// Implements an MpcNetwork on top of QUIC
125183
#[derive(Debug)]
126184
pub struct QuicTwoPartyNet {
@@ -132,6 +190,19 @@ pub struct QuicTwoPartyNet {
132190
local_addr: SocketAddr,
133191
/// Addresses of the counterparties in the MPC
134192
peer_addr: SocketAddr,
193+
/// A buffered message length read from the stream
194+
///
195+
/// In the case that the whole message is not available yet, reads may block
196+
/// and the `read_message` future may be cancelled by the executor.
197+
/// We buffer the message length to avoid re-reading the message length incorrectly from
198+
/// the stream
199+
buffered_message_length: Option<u64>,
200+
/// A buffered partial message read from the stream
201+
///
202+
/// This buffer exists to provide cancellation safety to a `read` future as the underlying `quinn`
203+
/// stream is not cancellation safe, i.e. if a `ReadBuf` future is dropped, the buffer is dropped with
204+
/// it and the partially read data is skipped
205+
buffered_message: Option<BufferWithCursor>,
135206
/// The send side of the bidirectional stream
136207
send_stream: Option<SendStream>,
137208
/// The receive side of the bidirectional stream
@@ -148,6 +219,8 @@ impl<'a> QuicTwoPartyNet {
148219
local_addr,
149220
peer_addr,
150221
connected: false,
222+
buffered_message_length: None,
223+
buffered_message: None,
151224
send_stream: None,
152225
recv_stream: None,
153226
}
@@ -244,14 +317,7 @@ impl<'a> QuicTwoPartyNet {
244317

245318
/// Read a message length from the stream
246319
async fn read_message_length(&mut self) -> Result<u64, MpcNetworkError> {
247-
let mut read_buffer = vec![0u8; BYTES_PER_U64];
248-
self.recv_stream
249-
.as_mut()
250-
.unwrap()
251-
.read_exact(&mut read_buffer)
252-
.await
253-
.map_err(|e| MpcNetworkError::RecvError(e.to_string()))?;
254-
320+
let read_buffer = self.read_bytes(BYTES_PER_U64).await?;
255321
Ok(u64::from_le_bytes(read_buffer.try_into().map_err(
256322
|_| MpcNetworkError::SerializationError(ERR_READ_MESSAGE_LENGTH.to_string()),
257323
)?))
@@ -269,15 +335,30 @@ impl<'a> QuicTwoPartyNet {
269335

270336
/// Read exactly `n` bytes from the stream
271337
async fn read_bytes(&mut self, num_bytes: usize) -> Result<Vec<u8>, MpcNetworkError> {
272-
let mut read_buffer = vec![0u8; num_bytes];
273-
self.recv_stream
274-
.as_mut()
275-
.unwrap()
276-
.read_exact(&mut read_buffer)
277-
.await
278-
.map_err(|e| MpcNetworkError::RecvError(e.to_string()))?;
338+
// Allocate a buffer for the next message if one does not already exist
339+
if self.buffered_message.is_none() {
340+
self.buffered_message = Some(BufferWithCursor::new(vec![0u8; num_bytes]));
341+
}
342+
343+
// Read until the buffer is full
344+
let read_buffer = self.buffered_message.as_mut().unwrap();
345+
while !read_buffer.is_full() {
346+
let bytes_read = self
347+
.recv_stream
348+
.as_mut()
349+
.unwrap()
350+
.read(read_buffer.get_unfilled())
351+
.await
352+
.map_err(|e| MpcNetworkError::RecvError(e.to_string()))?
353+
.ok_or(MpcNetworkError::RecvError(
354+
ERR_STREAM_FINISHED_EARLY.to_string(),
355+
))?;
356+
357+
read_buffer.advance_cursor(bytes_read);
358+
}
279359

280-
Ok(read_buffer.to_vec())
360+
// Take ownership of the buffer, and reset the buffered message to `None`
361+
Ok(self.buffered_message.take().unwrap().into_vec())
281362
}
282363
}
283364

@@ -298,10 +379,18 @@ impl MpcNetwork for QuicTwoPartyNet {
298379
}
299380

300381
async fn receive_message(&mut self) -> Result<NetworkOutbound, MpcNetworkError> {
301-
// Read the message length from the buffer
302-
let len = self.read_message_length().await?;
382+
// Read the message length from the buffer if available
383+
if self.buffered_message_length.is_none() {
384+
self.buffered_message_length = Some(self.read_message_length().await?);
385+
}
386+
387+
// Read the data from the stream
388+
let len = self.buffered_message_length.unwrap();
303389
let bytes = self.read_bytes(len as usize).await?;
304390

391+
// Reset the message length buffer after the data has been pulled from the stream
392+
self.buffered_message_length = None;
393+
305394
// Deserialize the message
306395
serde_json::from_slice(&bytes)
307396
.map_err(|err| MpcNetworkError::SerializationError(err.to_string()))

0 commit comments

Comments
 (0)