From 201dd0becbbf8f9d4d3b33e94df02d6cd72a35ee Mon Sep 17 00:00:00 2001 From: Hasan Date: Mon, 19 May 2025 18:33:49 +0200 Subject: [PATCH] Add support for *mmsg syscalls --- src/lib.rs | 90 +++++++++++++++++++++++++++++++++++++++++++++++++ src/socket.rs | 22 ++++++++++++ src/sys/unix.rs | 77 ++++++++++++++++++++++++++++++++++++++++++ tests/socket.rs | 41 ++++++++++++++++++++++ 4 files changed, 230 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 4d39b05d..c286511f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -697,3 +697,93 @@ impl<'name, 'bufs, 'control> fmt::Debug for MsgHdrMut<'name, 'bufs, 'control> { "MsgHdrMut".fmt(fmt) } } + +/// Wraps `mmsghdr` on Unix for a `sendmmsg(2)` system call. +/// +/// Also see [`MsgHdr`] for the variant used by `sendmsg(2)`. +#[repr(transparent)] +#[cfg(any(target_os = "linux", target_os = "android",))] +pub struct MMsgHdr<'addr, 'bufs, 'control> { + inner: sys::mmsghdr, + #[allow(clippy::type_complexity)] + _lifetimes: PhantomData<(&'addr SockAddr, &'bufs IoSlice<'bufs>, &'control [u8])>, +} + +#[cfg(any(target_os = "linux", target_os = "android",))] +impl<'addr, 'bufs, 'control> MMsgHdr<'addr, 'bufs, 'control> { + /// Create a new `MMsgHdr` from `MsgHdr` and with the `msg_len` set to zero. + pub fn new(msg: MsgHdr<'_, '_, '_>) -> Self { + Self { + inner: sys::mmsghdr { + msg_hdr: msg.inner, + msg_len: 0, + }, + _lifetimes: PhantomData, + } + } + + /// Number of bytes transmitted. + pub fn transmitted_bytes(&self) -> u32 { + self.inner.msg_len + } +} + +#[cfg(any(target_os = "linux", target_os = "android",))] +impl<'addr, 'bufs, 'control> fmt::Debug for MMsgHdr<'addr, 'bufs, 'control> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&format!("MMsgHdr({})", self.transmitted_bytes())) + } +} + +/// Wraps `mmsghdr` on Unix for a `recvmmsg(2)` system call. +/// +/// Also see [`MsgHdrMut`] for the variant used by `recvmsg(2)`. +#[repr(transparent)] +#[cfg(any(target_os = "linux", target_os = "android",))] +pub struct MMsgHdrMut<'addr, 'bufs, 'control> { + inner: sys::mmsghdr, + #[allow(clippy::type_complexity)] + _lifetimes: PhantomData<( + &'addr mut SockAddr, + &'bufs mut MaybeUninitSlice<'bufs>, + &'control mut [u8], + )>, +} + +#[cfg(any(target_os = "linux", target_os = "android",))] +impl<'addr, 'bufs, 'control> MMsgHdrMut<'addr, 'bufs, 'control> { + /// Create a new `MMsgHdrMut` from `MsgHdrMut` and with the `msg_len` set to zero. + pub fn new(msg: MsgHdrMut<'_, '_, '_>) -> Self { + Self { + inner: sys::mmsghdr { + msg_hdr: msg.inner, + msg_len: 0, + }, + _lifetimes: PhantomData, + } + } + + /// Number of received bytes. + pub fn recieved_bytes(&self) -> u32 { + self.inner.msg_len + } + + /// Returns the flags of the message. + pub fn flags(&self) -> RecvFlags { + sys::msghdr_flags(&self.inner.msg_hdr) + } + + /// Gets the length of the control buffer. + /// + /// Can be used to determine how much, if any, of the control buffer was filled by `recvmsg`. + pub fn control_len(&self) -> usize { + sys::msghdr_control_len(&self.inner.msg_hdr) + } +} + +#[cfg(any(target_os = "linux", target_os = "android",))] +impl<'addr, 'bufs, 'control> fmt::Debug for MMsgHdrMut<'addr, 'bufs, 'control> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&format!("MMsgHdrMut({})", self.recieved_bytes())) + } +} diff --git a/src/socket.rs b/src/socket.rs index 06cd07f9..ef7cf167 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -27,6 +27,9 @@ use crate::{Domain, Protocol, SockAddr, TcpKeepalive, Type}; #[cfg(not(target_os = "redox"))] use crate::{MaybeUninitSlice, MsgHdr, RecvFlags}; +#[cfg(all(feature = "all", any(target_os = "android", target_os = "linux",)))] +use crate::{MMsgHdr, MMsgHdrMut}; + /// Owned wrapper around a system socket. /// /// This type simply wraps an instance of a file descriptor (`c_int`) on Unix @@ -634,6 +637,18 @@ impl Socket { sys::recvmsg(self.as_raw(), msg, flags) } + /// Receive multiple messages on the socket using a single system call. + #[doc = man_links!(unix: recvmmsg(2))] + #[cfg(all(feature = "all", any(target_os = "android", target_os = "linux",)))] + pub fn recvmmsg( + &self, + msgvec: &mut [MMsgHdrMut<'_, '_, '_>], + flags: c_int, + timeout: Option, + ) -> io::Result { + sys::recvmmsg(self.as_raw(), msgvec, flags, timeout) + } + /// Sends data on the socket to a connected peer. /// /// This is typically used on TCP sockets or datagram sockets which have @@ -735,6 +750,13 @@ impl Socket { pub fn sendmsg(&self, msg: &MsgHdr<'_, '_, '_>, flags: sys::c_int) -> io::Result { sys::sendmsg(self.as_raw(), msg, flags) } + + /// Send multiple messages on the socket using a single system call. + #[doc = man_links!(unix: sendmmsg(2))] + #[cfg(all(feature = "all", any(target_os = "linux", target_os = "android",)))] + pub fn sendmmsg(&self, msgvec: &mut [MMsgHdr<'_, '_, '_>], flags: c_int) -> io::Result { + sys::sendmmsg(self.as_raw(), msgvec, flags) + } } /// Set `SOCK_CLOEXEC` and `NO_HANDLE_INHERIT` on the `ty`pe on platforms that diff --git a/src/sys/unix.rs b/src/sys/unix.rs index 8e24f4e5..80da46ea 100644 --- a/src/sys/unix.rs +++ b/src/sys/unix.rs @@ -80,6 +80,13 @@ use crate::{Domain, Protocol, SockAddr, TcpKeepalive, Type}; #[cfg(not(target_os = "redox"))] use crate::{MsgHdr, MsgHdrMut, RecvFlags}; +#[cfg(all(feature = "all", any(target_os = "linux", target_os = "android",)))] +use crate::{MMsgHdr, MMsgHdrMut}; + +// Used in `MMsgHdr`. +#[cfg(any(target_os = "linux", target_os = "android",))] +pub(crate) use libc::mmsghdr; + pub(crate) use libc::c_int; // Used in `Domain`. @@ -1076,6 +1083,35 @@ pub(crate) fn recvmsg( syscall!(recvmsg(fd, &mut msg.inner, flags)).map(|n| n as usize) } +#[cfg(all(feature = "all", any(target_os = "linux", target_os = "android",)))] +/// This emits all the messages in a single syscall +pub(crate) fn recvmmsg( + fd: Socket, + msgvec: &mut [MMsgHdrMut<'_, '_, '_>], + flags: c_int, + timeout: Option, +) -> io::Result { + if cfg!(target_env = "musl") { + debug_assert!(flags >= 0, "socket flags must be non-negative"); + } + + let mut timeout = timeout.map(into_timespec); + let timeout_ptr = timeout + .as_mut() + .map(|t| t as *mut _) + .unwrap_or(ptr::null_mut()); + + syscall!(recvmmsg( + fd, + // SAFETY: `MMsgHdrMut` is `#[repr(transparent)]` and wraps a `libc::mmsghdr` + msgvec.as_mut_ptr() as *mut mmsghdr, + msgvec.len() as _, + flags as _, + timeout_ptr + )) + .map(|n| n as usize) +} + pub(crate) fn send(fd: Socket, buf: &[u8], flags: c_int) -> io::Result { syscall!(send( fd, @@ -1120,6 +1156,27 @@ pub(crate) fn sendmsg(fd: Socket, msg: &MsgHdr<'_, '_, '_>, flags: c_int) -> io: syscall!(sendmsg(fd, &msg.inner, flags)).map(|n| n as usize) } +#[cfg(all(feature = "all", any(target_os = "linux", target_os = "android",)))] +/// This transmits all the messages in a single syscall +pub(crate) fn sendmmsg( + fd: Socket, + msgvec: &mut [MMsgHdr<'_, '_, '_>], + flags: c_int, +) -> io::Result { + if cfg!(target_env = "musl") { + debug_assert!(flags >= 0, "socket flags must be non-negative"); + } + + syscall!(sendmmsg( + fd, + // SAFETY: `MMsgHdr` is `#[repr(transparent)]` and wraps a `libc::mmsghdr` + msgvec.as_mut_ptr() as *mut mmsghdr, + msgvec.len() as _, + flags as _ + )) + .map(|n| n as usize) +} + /// Wrapper around `getsockopt` to deal with platform specific timeouts. pub(crate) fn timeout_opt(fd: Socket, opt: c_int, val: c_int) -> io::Result> { unsafe { getsockopt(fd, opt, val).map(from_timeval) } @@ -1161,6 +1218,26 @@ fn into_timeval(duration: Option) -> libc::timeval { } } +#[cfg(all(feature = "all", any(target_os = "linux", target_os = "android",)))] +fn into_timespec(duration: Duration) -> libc::timespec { + // https://github.com/rust-lang/libc/issues/1848 + #[cfg_attr(target_env = "musl", allow(deprecated))] + libc::timespec { + tv_sec: min(duration.as_secs(), libc::time_t::MAX as u64) as libc::time_t, + #[cfg(any( + all(target_arch = "x86_64", target_pointer_width = "32"), + target_pointer_width = "64" + ))] + tv_nsec: duration.subsec_nanos() as i64, + + #[cfg(not(any( + all(target_arch = "x86_64", target_pointer_width = "32"), + target_pointer_width = "64" + )))] + tv_nsec: duration.subsec_nanos().clamp(0, i32::MAX as u32) as i32, + } +} + #[cfg(all( feature = "all", not(any(target_os = "haiku", target_os = "openbsd", target_os = "vita")) diff --git a/tests/socket.rs b/tests/socket.rs index a2dca668..40196d40 100644 --- a/tests/socket.rs +++ b/tests/socket.rs @@ -769,6 +769,47 @@ fn sendmsg() { assert_eq!(received, DATA.len()); } +#[test] +#[cfg(all(feature = "all", any(target_os = "linux", target_os = "android")))] +fn send_and_recv_batched_msgs() { + let (socket_a, socket_b) = udp_pair_unconnected(); + + const DATA: &[u8] = b"Hello, World!"; + + let addr_b = socket_b.local_addr().unwrap(); + let mut batched_msgs = Vec::new(); + let mut recv_batched_msgs = Vec::new(); + for _ in 0..10 { + let bufs = &[IoSlice::new(DATA)]; + batched_msgs.push(socket2::MMsgHdr::new( + socket2::MsgHdr::new().with_addr(&addr_b).with_buffers(bufs), + )); + + let mut buf = [MaybeUninit::new(0u8); DATA.len()]; + let recv_bufs = MaybeUninitSlice::new(buf.as_mut_slice()); + recv_batched_msgs.push(socket2::MMsgHdrMut::new( + socket2::MsgHdrMut::new().with_buffers(&mut [recv_bufs]), + )); + } + + let sent = socket_a.sendmmsg(batched_msgs.as_mut_slice(), 0).unwrap(); + + let mut sent_data = 0; + // Calculate transmitted length + for msg in batched_msgs.iter().take(sent) { + sent_data += msg.transmitted_bytes() + } + assert!(sent_data as usize == 10 * DATA.len()); + + let recvd = socket_b + .recvmmsg(recv_batched_msgs.as_mut_slice(), 0, None) + .unwrap(); + + assert!(recvd == sent); + + assert!(recv_batched_msgs[0].recieved_bytes() == DATA.len().try_into().unwrap()); +} + #[test] #[cfg(not(any(target_os = "redox", target_os = "vita")))] fn recv_vectored_truncated() {