Skip to content

Commit e9df5ba

Browse files
committed
Change Socket::recv_vectored to accept uninitialised buffers
Also includes: * Socket::recv_vectored_with_flags * Socket::recv_from_vectored * Socket::recv_from_vectored_with_flags
1 parent b61deee commit e9df5ba

File tree

5 files changed

+199
-55
lines changed

5 files changed

+199
-55
lines changed

src/lib.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,10 @@
6666
// Disallow warnings in examples.
6767
#![doc(test(attr(deny(warnings))))]
6868

69+
use std::fmt;
70+
use std::mem::MaybeUninit;
6971
use std::net::SocketAddr;
72+
use std::ops::{Deref, DerefMut};
7073
use std::time::Duration;
7174

7275
/// Macro to implement `fmt::Debug` for a type, printing the constant names
@@ -277,6 +280,47 @@ impl RecvFlags {
277280
}
278281
}
279282

283+
/// A version of [`IoSliceMut`] that allows the buffer to be uninitialised.
284+
///
285+
/// [`IoSliceMut`]: std::io::IoSliceMut
286+
#[repr(transparent)]
287+
pub struct MaybeUninitSlice<'a>(sys::MaybeUninitSlice<'a>);
288+
289+
unsafe impl<'a> Send for MaybeUninitSlice<'a> {}
290+
291+
unsafe impl<'a> Sync for MaybeUninitSlice<'a> {}
292+
293+
impl<'a> fmt::Debug for MaybeUninitSlice<'a> {
294+
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
295+
fmt::Debug::fmt(self.0.as_slice(), fmt)
296+
}
297+
}
298+
299+
impl<'a> MaybeUninitSlice<'a> {
300+
/// Creates a new `MaybeUninitSlice` wrapping a byte slice.
301+
///
302+
/// # Panics
303+
///
304+
/// Panics on Windows if the slice is larger than 4GB.
305+
pub fn new(buf: &'a mut [MaybeUninit<u8>]) -> MaybeUninitSlice<'a> {
306+
MaybeUninitSlice(sys::MaybeUninitSlice::new(buf))
307+
}
308+
}
309+
310+
impl<'a> Deref for MaybeUninitSlice<'a> {
311+
type Target = [MaybeUninit<u8>];
312+
313+
fn deref(&self) -> &[MaybeUninit<u8>] {
314+
self.0.as_slice()
315+
}
316+
}
317+
318+
impl<'a> DerefMut for MaybeUninitSlice<'a> {
319+
fn deref_mut(&mut self) -> &mut [MaybeUninit<u8>] {
320+
self.0.as_mut_slice()
321+
}
322+
}
323+
280324
/// Configures a socket's TCP keepalive parameters.
281325
///
282326
/// See [`Socket::set_tcp_keepalive`].

src/socket.rs

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use std::time::Duration;
2121
use crate::sys::{self, c_int, getsockopt, setsockopt, Bool};
2222
#[cfg(not(target_os = "redox"))]
2323
use crate::RecvFlags;
24-
use crate::{Domain, Protocol, SockAddr, TcpKeepalive, Type};
24+
use crate::{Domain, MaybeUninitSlice, Protocol, SockAddr, TcpKeepalive, Type};
2525

2626
/// Owned wrapper around a system socket.
2727
///
@@ -339,19 +339,42 @@ impl Socket {
339339
///
340340
/// [`recv`]: Socket::recv
341341
/// [`connect`]: Socket::connect
342+
///
343+
/// # Safety
344+
///
345+
/// Normally casting a `IoSliceMut` to `MaybeUninitSlice` would be unsound,
346+
/// as that allows us to write uninitialised bytes to the buffer. However
347+
/// this implementation promises to not write uninitialised bytes to the
348+
/// `bufs` and passes it directly to `recvmsg(2)` system call. This promise
349+
/// ensures that this function can be called using `bufs` of type `&mut
350+
/// [IoSliceMut]`.
351+
///
352+
/// Note that the [`io::Read::read_vectored`] implementation calls this
353+
/// function with `buf`s of type `&mut [IoSliceMut]`, allowing initialised
354+
/// buffers to be used without using `unsafe`.
342355
#[cfg(not(target_os = "redox"))]
343-
pub fn recv_vectored(&self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<(usize, RecvFlags)> {
356+
pub fn recv_vectored(
357+
&self,
358+
bufs: &mut [MaybeUninitSlice<'_>],
359+
) -> io::Result<(usize, RecvFlags)> {
344360
self.recv_vectored_with_flags(bufs, 0)
345361
}
346362

347363
/// Identical to [`recv_vectored`] but allows for specification of arbitrary
348364
/// flags to the underlying `recvmsg`/`WSARecv` call.
349365
///
350366
/// [`recv_vectored`]: Socket::recv_vectored
367+
///
368+
/// # Safety
369+
///
370+
/// `recv_from_vectored` makes the same safety guarantees regarding `bufs`
371+
/// as [`recv_vectored`].
372+
///
373+
/// [`recv_vectored`]: Socket::recv_vectored
351374
#[cfg(not(target_os = "redox"))]
352375
pub fn recv_vectored_with_flags(
353376
&self,
354-
bufs: &mut [IoSliceMut<'_>],
377+
bufs: &mut [MaybeUninitSlice<'_>],
355378
flags: i32,
356379
) -> io::Result<(usize, RecvFlags)> {
357380
sys::recv_vectored(self.inner, bufs, flags)
@@ -404,10 +427,17 @@ impl Socket {
404427
/// [`recv_from`] this allows passing multiple buffers.
405428
///
406429
/// [`recv_from`]: Socket::recv_from
430+
///
431+
/// # Safety
432+
///
433+
/// `recv_from_vectored` makes the same safety guarantees regarding `bufs`
434+
/// as [`recv_vectored`].
435+
///
436+
/// [`recv_vectored`]: Socket::recv_vectored
407437
#[cfg(not(target_os = "redox"))]
408438
pub fn recv_from_vectored(
409439
&self,
410-
bufs: &mut [IoSliceMut<'_>],
440+
bufs: &mut [MaybeUninitSlice<'_>],
411441
) -> io::Result<(usize, RecvFlags, SockAddr)> {
412442
self.recv_from_vectored_with_flags(bufs, 0)
413443
}
@@ -416,10 +446,17 @@ impl Socket {
416446
/// arbitrary flags to the underlying `recvmsg`/`WSARecvFrom` call.
417447
///
418448
/// [`recv_from_vectored`]: Socket::recv_from_vectored
449+
///
450+
/// # Safety
451+
///
452+
/// `recv_from_vectored` makes the same safety guarantees regarding `bufs`
453+
/// as [`recv_vectored`].
454+
///
455+
/// [`recv_vectored`]: Socket::recv_vectored
419456
#[cfg(not(target_os = "redox"))]
420457
pub fn recv_from_vectored_with_flags(
421458
&self,
422-
bufs: &mut [IoSliceMut<'_>],
459+
bufs: &mut [MaybeUninitSlice<'_>],
423460
flags: i32,
424461
) -> io::Result<(usize, RecvFlags, SockAddr)> {
425462
sys::recv_from_vectored(self.inner, bufs, flags)
@@ -1292,6 +1329,11 @@ impl Read for Socket {
12921329

12931330
#[cfg(not(target_os = "redox"))]
12941331
fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
1332+
// Safety: both `IoSliceMut` and `MaybeUninitSlice` promise to have the
1333+
// same layout, that of `iovec`/`WSABUF`. Furthermore `recv_vectored`
1334+
// promises to not write unitialised bytes to the `bufs` and pass it
1335+
// directly to the `recvmsg` system call, so this is safe.
1336+
let bufs = unsafe { &mut *(bufs as *mut [IoSliceMut<'_>] as *mut [MaybeUninitSlice<'_>]) };
12951337
self.recv_vectored(bufs).map(|(n, _)| n)
12961338
}
12971339
}
@@ -1305,6 +1347,8 @@ impl<'a> Read for &'a Socket {
13051347

13061348
#[cfg(not(target_os = "redox"))]
13071349
fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
1350+
// Safety: see other `Read::read` impl.
1351+
let bufs = unsafe { &mut *(bufs as *mut [IoSliceMut<'_>] as *mut [MaybeUninitSlice<'_>]) };
13081352
self.recv_vectored(bufs).map(|(n, _)| n)
13091353
}
13101354
}

src/sys/unix.rs

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ use std::cmp::min;
1010
#[cfg(all(feature = "all", target_os = "linux"))]
1111
use std::ffi::{CStr, CString};
1212
#[cfg(not(target_os = "redox"))]
13-
use std::io::{IoSlice, IoSliceMut};
13+
use std::io::IoSlice;
14+
use std::marker::PhantomData;
1415
use std::mem::{self, size_of, MaybeUninit};
1516
use std::net::Shutdown;
1617
use std::net::{Ipv4Addr, Ipv6Addr};
@@ -21,10 +22,8 @@ use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd};
2122
use std::os::unix::net::{UnixDatagram, UnixListener, UnixStream};
2223
#[cfg(feature = "all")]
2324
use std::path::Path;
24-
#[cfg(all(feature = "all", target_os = "linux"))]
25-
use std::slice;
2625
use std::time::Duration;
27-
use std::{io, ptr};
26+
use std::{io, ptr, slice};
2827

2928
#[cfg(not(target_vendor = "apple"))]
3029
use libc::ssize_t;
@@ -308,6 +307,32 @@ impl std::fmt::Debug for RecvFlags {
308307
}
309308
}
310309

310+
#[repr(transparent)]
311+
pub struct MaybeUninitSlice<'a> {
312+
vec: libc::iovec,
313+
_lifetime: PhantomData<&'a mut [MaybeUninit<u8>]>,
314+
}
315+
316+
impl<'a> MaybeUninitSlice<'a> {
317+
pub(crate) fn new(buf: &'a mut [MaybeUninit<u8>]) -> MaybeUninitSlice<'a> {
318+
MaybeUninitSlice {
319+
vec: libc::iovec {
320+
iov_base: buf.as_mut_ptr().cast(),
321+
iov_len: buf.len(),
322+
},
323+
_lifetime: PhantomData,
324+
}
325+
}
326+
327+
pub(crate) fn as_slice(&self) -> &[MaybeUninit<u8>] {
328+
unsafe { slice::from_raw_parts(self.vec.iov_base.cast(), self.vec.iov_len) }
329+
}
330+
331+
pub(crate) fn as_mut_slice(&mut self) -> &mut [MaybeUninit<u8>] {
332+
unsafe { slice::from_raw_parts_mut(self.vec.iov_base.cast(), self.vec.iov_len) }
333+
}
334+
}
335+
311336
/// Unix only API.
312337
impl SockAddr {
313338
/// Constructs a `SockAddr` with the family `AF_UNIX` and the provided path.
@@ -453,7 +478,7 @@ pub(crate) fn recv_from(
453478
#[cfg(not(target_os = "redox"))]
454479
pub(crate) fn recv_vectored(
455480
fd: Socket,
456-
bufs: &mut [IoSliceMut<'_>],
481+
bufs: &mut [crate::MaybeUninitSlice<'_>],
457482
flags: c_int,
458483
) -> io::Result<(usize, RecvFlags)> {
459484
recvmsg(fd, ptr::null_mut(), bufs, flags).map(|(n, _, recv_flags)| (n, recv_flags))
@@ -462,7 +487,7 @@ pub(crate) fn recv_vectored(
462487
#[cfg(not(target_os = "redox"))]
463488
pub(crate) fn recv_from_vectored(
464489
fd: Socket,
465-
bufs: &mut [IoSliceMut<'_>],
490+
bufs: &mut [crate::MaybeUninitSlice<'_>],
466491
flags: c_int,
467492
) -> io::Result<(usize, RecvFlags, SockAddr)> {
468493
// Safety: `recvmsg` initialises the address storage and we set the length
@@ -484,7 +509,7 @@ pub(crate) fn recv_from_vectored(
484509
fn recvmsg(
485510
fd: Socket,
486511
msg_name: *mut sockaddr_storage,
487-
bufs: &mut [IoSliceMut<'_>],
512+
bufs: &mut [crate::MaybeUninitSlice<'_>],
488513
flags: c_int,
489514
) -> io::Result<(usize, libc::socklen_t, RecvFlags)> {
490515
let msg_namelen = if msg_name.is_null() {

src/sys/windows.rs

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,24 @@
77
// except according to those terms.
88

99
use std::cmp::min;
10-
use std::io::{self, IoSlice, IoSliceMut};
10+
use std::io::{self, IoSlice};
11+
use std::marker::PhantomData;
1112
use std::mem::{self, size_of, MaybeUninit};
1213
use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown};
1314
use std::os::windows::prelude::*;
14-
use std::ptr;
1515
use std::sync::Once;
1616
use std::time::Duration;
17+
use std::{ptr, slice};
1718

1819
use winapi::ctypes::c_long;
1920
use winapi::shared::in6addr::*;
2021
use winapi::shared::inaddr::*;
2122
use winapi::shared::minwindef::DWORD;
23+
use winapi::shared::minwindef::ULONG;
2224
use winapi::shared::mstcpip::{tcp_keepalive, SIO_KEEPALIVE_VALS};
2325
use winapi::shared::ntdef::HANDLE;
2426
use winapi::shared::ws2def;
27+
use winapi::shared::ws2def::WSABUF;
2528
use winapi::um::handleapi::SetHandleInformation;
2629
use winapi::um::processthreadsapi::GetCurrentProcessId;
2730
use winapi::um::winbase::{self, INFINITE};
@@ -144,6 +147,33 @@ impl std::fmt::Debug for RecvFlags {
144147
}
145148
}
146149

150+
#[repr(transparent)]
151+
pub struct MaybeUninitSlice<'a> {
152+
vec: WSABUF,
153+
_lifetime: PhantomData<&'a mut [MaybeUninit<u8>]>,
154+
}
155+
156+
impl<'a> MaybeUninitSlice<'a> {
157+
pub fn new(buf: &'a mut [MaybeUninit<u8>]) -> MaybeUninitSlice<'a> {
158+
assert!(buf.len() <= ULONG::MAX as usize);
159+
MaybeUninitSlice {
160+
vec: WSABUF {
161+
len: buf.len() as ULONG,
162+
buf: buf.as_mut_ptr().cast(),
163+
},
164+
_lifetime: PhantomData,
165+
}
166+
}
167+
168+
pub fn as_slice(&self) -> &[MaybeUninit<u8>] {
169+
unsafe { slice::from_raw_parts(self.vec.buf.cast(), self.vec.len as usize) }
170+
}
171+
172+
pub fn as_mut_slice(&mut self) -> &mut [MaybeUninit<u8>] {
173+
unsafe { slice::from_raw_parts_mut(self.vec.buf.cast(), self.vec.len as usize) }
174+
}
175+
}
176+
147177
fn init() {
148178
static INIT: Once = Once::new();
149179

@@ -293,7 +323,7 @@ pub(crate) fn recv(socket: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int) ->
293323

294324
pub(crate) fn recv_vectored(
295325
socket: Socket,
296-
bufs: &mut [IoSliceMut<'_>],
326+
bufs: &mut [crate::MaybeUninitSlice<'_>],
297327
flags: c_int,
298328
) -> io::Result<(usize, RecvFlags)> {
299329
let mut nread = 0;
@@ -354,7 +384,7 @@ pub(crate) fn recv_from(
354384

355385
pub(crate) fn recv_from_vectored(
356386
socket: Socket,
357-
bufs: &mut [IoSliceMut<'_>],
387+
bufs: &mut [crate::MaybeUninitSlice<'_>],
358388
flags: c_int,
359389
) -> io::Result<(usize, RecvFlags, SockAddr)> {
360390
// Safety: `recvfrom` initialises the `SockAddr` for us.

0 commit comments

Comments
 (0)