diff --git a/src/network.rs b/src/network.rs index fd33147..1cbf405 100644 --- a/src/network.rs +++ b/src/network.rs @@ -5,7 +5,7 @@ mod config; mod mock; #[cfg(any(feature = "test_helpers", test))] -pub use mock::{NoRecvNetwork, UnboundedDuplexStream, MockNetwork}; +pub use mock::{MockNetwork, NoRecvNetwork, UnboundedDuplexStream}; use async_trait::async_trait; use quinn::{Endpoint, RecvStream, SendStream}; @@ -27,6 +27,8 @@ const BYTES_PER_U64: usize = 8; /// Error message emitted when reading a message length from the stream fails const ERR_READ_MESSAGE_LENGTH: &str = "error reading message length from stream"; +/// Error thrown when a stream finishes early +const ERR_STREAM_FINISHED_EARLY: &str = "stream finished early"; // --------- // | Trait | @@ -121,6 +123,62 @@ pub enum ReadWriteOrder { WriteFirst, } +/// A wrapper around a raw `&[u8]` buffer that tracks a cursor within the buffer +/// to allow partial fills across cancelled futures +/// +/// Similar to `tokio::io::ReadBuf` but takes ownership of the underlying buffer to +/// avoid coloring interfaces with lifetime parameters +/// +/// TODO: Replace this with `std::io::Cursor` once it is stabilized +#[derive(Debug)] +struct BufferWithCursor { + /// The underlying buffer + buffer: Vec, + /// The current cursor position + cursor: usize, +} + +impl BufferWithCursor { + /// Create a new buffer with a cursor at the start of the buffer + pub fn new(buf: Vec) -> Self { + assert_eq!( + buf.len(), + buf.capacity(), + "buffer must be fully initialized" + ); + + Self { + buffer: buf, + cursor: 0, + } + } + + /// The number of bytes remaining in the buffer + pub fn remaining(&self) -> usize { + self.buffer.capacity() - self.cursor + } + + /// Whether the buffer is full + pub fn is_full(&self) -> bool { + self.remaining() == 0 + } + + /// Get a mutable reference to the empty section of the underlying buffer + pub fn get_unfilled(&mut self) -> &mut [u8] { + &mut self.buffer[self.cursor..] + } + + /// Advance the cursor by `n` bytes + pub fn advance_cursor(&mut self, n: usize) { + self.cursor += n + } + + /// Take ownership of the underlying buffer + pub fn into_vec(self) -> Vec { + self.buffer + } +} + /// Implements an MpcNetwork on top of QUIC #[derive(Debug)] pub struct QuicTwoPartyNet { @@ -132,6 +190,19 @@ pub struct QuicTwoPartyNet { local_addr: SocketAddr, /// Addresses of the counterparties in the MPC peer_addr: SocketAddr, + /// A buffered message length read from the stream + /// + /// In the case that the whole message is not available yet, reads may block + /// and the `read_message` future may be cancelled by the executor. + /// We buffer the message length to avoid re-reading the message length incorrectly from + /// the stream + buffered_message_length: Option, + /// A buffered partial message read from the stream + /// + /// This buffer exists to provide cancellation safety to a `read` future as the underlying `quinn` + /// stream is not cancellation safe, i.e. if a `ReadBuf` future is dropped, the buffer is dropped with + /// it and the partially read data is skipped + buffered_message: Option, /// The send side of the bidirectional stream send_stream: Option, /// The receive side of the bidirectional stream @@ -148,6 +219,8 @@ impl<'a> QuicTwoPartyNet { local_addr, peer_addr, connected: false, + buffered_message_length: None, + buffered_message: None, send_stream: None, recv_stream: None, } @@ -244,14 +317,7 @@ impl<'a> QuicTwoPartyNet { /// Read a message length from the stream async fn read_message_length(&mut self) -> Result { - let mut read_buffer = vec![0u8; BYTES_PER_U64]; - self.recv_stream - .as_mut() - .unwrap() - .read_exact(&mut read_buffer) - .await - .map_err(|e| MpcNetworkError::RecvError(e.to_string()))?; - + let read_buffer = self.read_bytes(BYTES_PER_U64).await?; Ok(u64::from_le_bytes(read_buffer.try_into().map_err( |_| MpcNetworkError::SerializationError(ERR_READ_MESSAGE_LENGTH.to_string()), )?)) @@ -269,15 +335,30 @@ impl<'a> QuicTwoPartyNet { /// Read exactly `n` bytes from the stream async fn read_bytes(&mut self, num_bytes: usize) -> Result, MpcNetworkError> { - let mut read_buffer = vec![0u8; num_bytes]; - self.recv_stream - .as_mut() - .unwrap() - .read_exact(&mut read_buffer) - .await - .map_err(|e| MpcNetworkError::RecvError(e.to_string()))?; + // Allocate a buffer for the next message if one does not already exist + if self.buffered_message.is_none() { + self.buffered_message = Some(BufferWithCursor::new(vec![0u8; num_bytes])); + } + + // Read until the buffer is full + let read_buffer = self.buffered_message.as_mut().unwrap(); + while !read_buffer.is_full() { + let bytes_read = self + .recv_stream + .as_mut() + .unwrap() + .read(read_buffer.get_unfilled()) + .await + .map_err(|e| MpcNetworkError::RecvError(e.to_string()))? + .ok_or(MpcNetworkError::RecvError( + ERR_STREAM_FINISHED_EARLY.to_string(), + ))?; + + read_buffer.advance_cursor(bytes_read); + } - Ok(read_buffer.to_vec()) + // Take ownership of the buffer, and reset the buffered message to `None` + Ok(self.buffered_message.take().unwrap().into_vec()) } } @@ -298,10 +379,18 @@ impl MpcNetwork for QuicTwoPartyNet { } async fn receive_message(&mut self) -> Result { - // Read the message length from the buffer - let len = self.read_message_length().await?; + // Read the message length from the buffer if available + if self.buffered_message_length.is_none() { + self.buffered_message_length = Some(self.read_message_length().await?); + } + + // Read the data from the stream + let len = self.buffered_message_length.unwrap(); let bytes = self.read_bytes(len as usize).await?; + // Reset the message length buffer after the data has been pulled from the stream + self.buffered_message_length = None; + // Deserialize the message serde_json::from_slice(&bytes) .map_err(|err| MpcNetworkError::SerializationError(err.to_string()))