Skip to content

Commit b61deee

Browse files
committed
Change Socket::recv to accept [MaybeUninit<u8>]
Allow uninitialised buffers to be used. Also in the following functions: * Socket::recv_out_of_band * Socket::recv_with_flags * Socket::peek * Socket::recv_from * Socket::recv_from_with_flags * Socket::peek_from
1 parent a37d708 commit b61deee

File tree

4 files changed

+69
-14
lines changed

4 files changed

+69
-14
lines changed

src/socket.rs

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use std::fmt;
1010
use std::io::{self, Read, Write};
1111
#[cfg(not(target_os = "redox"))]
1212
use std::io::{IoSlice, IoSliceMut};
13+
use std::mem::MaybeUninit;
1314
use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown};
1415
#[cfg(unix)]
1516
use std::os::unix::io::{FromRawFd, IntoRawFd};
@@ -286,7 +287,20 @@ impl Socket {
286287
/// This method might fail if the socket is not connected.
287288
///
288289
/// [`connect`]: Socket::connect
289-
pub fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
290+
///
291+
/// # Safety
292+
///
293+
/// Normally casting a `&mut [u8]` to `&mut [MaybeUninit<u8>]` would be
294+
/// unsound, as that allows us to write uninitialised bytes to the buffer.
295+
/// However this implementation promises to not write uninitialised bytes to
296+
/// the `buf`fer and passes it directly to `recv(2)` system call. This
297+
/// promise ensures that this function can be called using a `buf`fer of
298+
/// type `&mut [u8]`.
299+
///
300+
/// Note that the [`io::Read::read`] implementation calls this function with
301+
/// a `buf`fer of type `&mut [u8]`, allowing initialised buffers to be used
302+
/// without using `unsafe`.
303+
pub fn recv(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
290304
self.recv_with_flags(buf, 0)
291305
}
292306

@@ -297,15 +311,19 @@ impl Socket {
297311
///
298312
/// [`recv`]: Socket::recv
299313
/// [`out_of_band_inline`]: Socket::out_of_band_inline
300-
pub fn recv_out_of_band(&self, buf: &mut [u8]) -> io::Result<usize> {
314+
pub fn recv_out_of_band(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
301315
self.recv_with_flags(buf, sys::MSG_OOB)
302316
}
303317

304318
/// Identical to [`recv`] but allows for specification of arbitrary flags to
305319
/// the underlying `recv` call.
306320
///
307321
/// [`recv`]: Socket::recv
308-
pub fn recv_with_flags(&self, buf: &mut [u8], flags: sys::c_int) -> io::Result<usize> {
322+
pub fn recv_with_flags(
323+
&self,
324+
buf: &mut [MaybeUninit<u8>],
325+
flags: sys::c_int,
326+
) -> io::Result<usize> {
309327
sys::recv(self.inner, buf, flags)
310328
}
311329

@@ -345,13 +363,27 @@ impl Socket {
345363
///
346364
/// Successive calls return the same data. This is accomplished by passing
347365
/// `MSG_PEEK` as a flag to the underlying `recv` system call.
348-
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
366+
///
367+
/// # Safety
368+
///
369+
/// `peek` makes the same safety guarantees regarding the `buf`fer as
370+
/// [`recv`].
371+
///
372+
/// [`recv`]: Socket::recv
373+
pub fn peek(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
349374
self.recv_with_flags(buf, sys::MSG_PEEK)
350375
}
351376

352377
/// Receives data from the socket. On success, returns the number of bytes
353378
/// read and the address from whence the data came.
354-
pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)> {
379+
///
380+
/// # Safety
381+
///
382+
/// `recv_from` makes the same safety guarantees regarding the `buf`fer as
383+
/// [`recv`].
384+
///
385+
/// [`recv`]: Socket::recv
386+
pub fn recv_from(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<(usize, SockAddr)> {
355387
self.recv_from_with_flags(buf, 0)
356388
}
357389

@@ -361,7 +393,7 @@ impl Socket {
361393
/// [`recv_from`]: Socket::recv_from
362394
pub fn recv_from_with_flags(
363395
&self,
364-
buf: &mut [u8],
396+
buf: &mut [MaybeUninit<u8>],
365397
flags: i32,
366398
) -> io::Result<(usize, SockAddr)> {
367399
sys::recv_from(self.inner, buf, flags)
@@ -400,7 +432,14 @@ impl Socket {
400432
///
401433
/// On success, returns the number of bytes peeked and the address from
402434
/// whence the data came.
403-
pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SockAddr)> {
435+
///
436+
/// # Safety
437+
///
438+
/// `peek_from` makes the same safety guarantees regarding the `buf`fer as
439+
/// [`recv`].
440+
///
441+
/// [`recv`]: Socket::recv
442+
pub fn peek_from(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<(usize, SockAddr)> {
404443
self.recv_from_with_flags(buf, sys::MSG_PEEK)
405444
}
406445

@@ -1245,6 +1284,9 @@ impl Socket {
12451284

12461285
impl Read for Socket {
12471286
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1287+
// Safety: the `recv` implementation promises not to write uninitialised
1288+
// bytes to the `buf`fer, so this casting is safe.
1289+
let buf = unsafe { &mut *(buf as *mut [u8] as *mut [MaybeUninit<u8>]) };
12481290
self.recv(buf)
12491291
}
12501292

@@ -1256,6 +1298,8 @@ impl Read for Socket {
12561298

12571299
impl<'a> Read for &'a Socket {
12581300
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1301+
// Safety: see other `Read::read` impl.
1302+
let buf = unsafe { &mut *(buf as *mut [u8] as *mut [MaybeUninit<u8>]) };
12591303
self.recv(buf)
12601304
}
12611305

src/sys/unix.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ pub(crate) fn shutdown(fd: Socket, how: Shutdown) -> io::Result<()> {
419419
syscall!(shutdown(fd, how)).map(|_| ())
420420
}
421421

422-
pub(crate) fn recv(fd: Socket, buf: &mut [u8], flags: c_int) -> io::Result<usize> {
422+
pub(crate) fn recv(fd: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int) -> io::Result<usize> {
423423
syscall!(recv(
424424
fd,
425425
buf.as_mut_ptr().cast(),
@@ -429,7 +429,11 @@ pub(crate) fn recv(fd: Socket, buf: &mut [u8], flags: c_int) -> io::Result<usize
429429
.map(|n| n as usize)
430430
}
431431

432-
pub(crate) fn recv_from(fd: Socket, buf: &mut [u8], flags: c_int) -> io::Result<(usize, SockAddr)> {
432+
pub(crate) fn recv_from(
433+
fd: Socket,
434+
buf: &mut [MaybeUninit<u8>],
435+
flags: c_int,
436+
) -> io::Result<(usize, SockAddr)> {
433437
// Safety: `recvfrom` initialises the `SockAddr` for us.
434438
unsafe {
435439
SockAddr::init(|addr, addrlen| {

src/sys/windows.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ pub(crate) fn shutdown(socket: Socket, how: Shutdown) -> io::Result<()> {
273273
syscall!(shutdown(socket, how), PartialEq::eq, sock::SOCKET_ERROR).map(|_| ())
274274
}
275275

276-
pub(crate) fn recv(socket: Socket, buf: &mut [u8], flags: c_int) -> io::Result<usize> {
276+
pub(crate) fn recv(socket: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int) -> io::Result<usize> {
277277
let res = syscall!(
278278
recv(
279279
socket,
@@ -325,7 +325,7 @@ pub(crate) fn recv_vectored(
325325

326326
pub(crate) fn recv_from(
327327
socket: Socket,
328-
buf: &mut [u8],
328+
buf: &mut [MaybeUninit<u8>],
329329
flags: c_int,
330330
) -> io::Result<(usize, SockAddr)> {
331331
// Safety: `recvfrom` initialises the `SockAddr` for us.

tests/socket.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use std::io::Read;
77
use std::io::Write;
88
#[cfg(not(target_os = "redox"))]
99
use std::io::{IoSlice, IoSliceMut};
10+
use std::mem::MaybeUninit;
1011
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
1112
#[cfg(unix)]
1213
use std::os::unix::io::AsRawFd;
@@ -359,14 +360,14 @@ fn out_of_band() {
359360
// this from happening we'll sleep to ensure the data is present.
360361
thread::sleep(Duration::from_millis(10));
361362

362-
let mut buf = [1; DATA.len() + 1];
363+
let mut buf = [MaybeUninit::new(1); DATA.len() + 1];
363364
let n = receiver.recv_out_of_band(&mut buf).unwrap();
364365
assert_eq!(n, FIRST.len());
365-
assert_eq!(&buf[..n], FIRST);
366+
assert_eq!(unsafe { assume_init(&buf[..n]) }, FIRST);
366367

367368
let n = receiver.recv(&mut buf).unwrap();
368369
assert_eq!(n, DATA.len());
369-
assert_eq!(&buf[..n], DATA);
370+
assert_eq!(unsafe { assume_init(&buf[..n]) }, DATA);
370371
}
371372

372373
#[test]
@@ -643,6 +644,12 @@ fn any_ipv4() -> SockAddr {
643644
SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0).into()
644645
}
645646

647+
/// Assume the `buf`fer to be initialised.
648+
// TODO: replace with `MaybeUninit::slice_assume_init_ref` once stable.
649+
unsafe fn assume_init(buf: &[MaybeUninit<u8>]) -> &[u8] {
650+
&*(buf as *const [MaybeUninit<u8>] as *const [u8])
651+
}
652+
646653
/// Macro to create a simple test to set and get a socket option.
647654
macro_rules! test {
648655
// Test using the `arg`ument as expected return value.

0 commit comments

Comments
 (0)