Skip to content

network: Buffer partial reads for cancellation safety #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 1, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 108 additions & 19 deletions src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ mod config;
mod mock;

#[cfg(any(feature = "test_helpers", test))]
pub use mock::{NoRecvNetwork, UnboundedDuplexStream, MockNetwork};
pub use mock::{MockNetwork, NoRecvNetwork, UnboundedDuplexStream};

use async_trait::async_trait;
use quinn::{Endpoint, RecvStream, SendStream};
Expand All @@ -27,6 +27,8 @@ const BYTES_PER_U64: usize = 8;

/// Error message emitted when reading a message length from the stream fails
const ERR_READ_MESSAGE_LENGTH: &str = "error reading message length from stream";
/// Error thrown when a stream finishes early
const ERR_STREAM_FINISHED_EARLY: &str = "stream finished early";

// ---------
// | Trait |
Expand Down Expand Up @@ -121,6 +123,62 @@ pub enum ReadWriteOrder {
WriteFirst,
}

/// A wrapper around a raw `&[u8]` buffer that tracks a cursor within the buffer
/// to allow partial fills across cancelled futures
///
/// Similar to `tokio::io::ReadBuf` but takes ownership of the underlying buffer to
/// avoid coloring interfaces with lifetime parameters
///
/// TODO: Replace this with `std::io::Cursor` once it is stabilized
#[derive(Debug)]
struct BufferWithCursor {
/// The underlying buffer
buffer: Vec<u8>,
/// The current cursor position
cursor: usize,
}

impl BufferWithCursor {
/// Create a new buffer with a cursor at the start of the buffer
pub fn new(buf: Vec<u8>) -> Self {
assert_eq!(
buf.len(),
buf.capacity(),
"buffer must be fully initialized"
);

Self {
buffer: buf,
cursor: 0,
}
}

/// The number of bytes remaining in the buffer
pub fn remaining(&self) -> usize {
self.buffer.capacity() - self.cursor
}

/// Whether the buffer is full
pub fn is_full(&self) -> bool {
self.remaining() == 0
}

/// Get a mutable reference to the empty section of the underlying buffer
pub fn get_unfilled(&mut self) -> &mut [u8] {
&mut self.buffer[self.cursor..]
}

/// Advance the cursor by `n` bytes
pub fn advance_cursor(&mut self, n: usize) {
self.cursor += n
}

/// Take ownership of the underlying buffer
pub fn into_vec(self) -> Vec<u8> {
self.buffer
}
}

/// Implements an MpcNetwork on top of QUIC
#[derive(Debug)]
pub struct QuicTwoPartyNet {
Expand All @@ -132,6 +190,19 @@ pub struct QuicTwoPartyNet {
local_addr: SocketAddr,
/// Addresses of the counterparties in the MPC
peer_addr: SocketAddr,
/// A buffered message length read from the stream
///
/// In the case that the whole message is not available yet, reads may block
/// and the `read_message` future may be cancelled by the executor.
/// We buffer the message length to avoid re-reading the message length incorrectly from
/// the stream
buffered_message_length: Option<u64>,
/// A buffered partial message read from the stream
///
/// This buffer exists to provide cancellation safety to a `read` future as the underlying `quinn`
/// stream is not cancellation safe, i.e. if a `ReadBuf` future is dropped, the buffer is dropped with
/// it and the partially read data is skipped
buffered_message: Option<BufferWithCursor>,
/// The send side of the bidirectional stream
send_stream: Option<SendStream>,
/// The receive side of the bidirectional stream
Expand All @@ -148,6 +219,8 @@ impl<'a> QuicTwoPartyNet {
local_addr,
peer_addr,
connected: false,
buffered_message_length: None,
buffered_message: None,
send_stream: None,
recv_stream: None,
}
Expand Down Expand Up @@ -244,14 +317,7 @@ impl<'a> QuicTwoPartyNet {

/// Read a message length from the stream
async fn read_message_length(&mut self) -> Result<u64, MpcNetworkError> {
let mut read_buffer = vec![0u8; BYTES_PER_U64];
self.recv_stream
.as_mut()
.unwrap()
.read_exact(&mut read_buffer)
.await
.map_err(|e| MpcNetworkError::RecvError(e.to_string()))?;

let read_buffer = self.read_bytes(BYTES_PER_U64).await?;
Ok(u64::from_le_bytes(read_buffer.try_into().map_err(
|_| MpcNetworkError::SerializationError(ERR_READ_MESSAGE_LENGTH.to_string()),
)?))
Expand All @@ -269,15 +335,30 @@ impl<'a> QuicTwoPartyNet {

/// Read exactly `n` bytes from the stream
async fn read_bytes(&mut self, num_bytes: usize) -> Result<Vec<u8>, MpcNetworkError> {
let mut read_buffer = vec![0u8; num_bytes];
self.recv_stream
.as_mut()
.unwrap()
.read_exact(&mut read_buffer)
.await
.map_err(|e| MpcNetworkError::RecvError(e.to_string()))?;
// Allocate a buffer for the next message if one does not already exist
if self.buffered_message.is_none() {
self.buffered_message = Some(BufferWithCursor::new(vec![0u8; num_bytes]));
}

// Read until the buffer is full
let read_buffer = self.buffered_message.as_mut().unwrap();
while !read_buffer.is_full() {
let bytes_read = self
.recv_stream
.as_mut()
.unwrap()
.read(read_buffer.get_unfilled())
.await
.map_err(|e| MpcNetworkError::RecvError(e.to_string()))?
.ok_or(MpcNetworkError::RecvError(
ERR_STREAM_FINISHED_EARLY.to_string(),
))?;

read_buffer.advance_cursor(bytes_read);
}

Ok(read_buffer.to_vec())
// Take ownership of the buffer, and reset the buffered message to `None`
Ok(self.buffered_message.take().unwrap().into_vec())
}
}

Expand All @@ -298,10 +379,18 @@ impl MpcNetwork for QuicTwoPartyNet {
}

async fn receive_message(&mut self) -> Result<NetworkOutbound, MpcNetworkError> {
// Read the message length from the buffer
let len = self.read_message_length().await?;
// Read the message length from the buffer if available
if self.buffered_message_length.is_none() {
self.buffered_message_length = Some(self.read_message_length().await?);
}

// Read the data from the stream
let len = self.buffered_message_length.unwrap();
let bytes = self.read_bytes(len as usize).await?;

// Reset the message length buffer after the data has been pulled from the stream
self.buffered_message_length = None;

// Deserialize the message
serde_json::from_slice(&bytes)
.map_err(|err| MpcNetworkError::SerializationError(err.to_string()))
Expand Down