Skip to content

Commit 15c13ad

Browse files
committed
Add support for *mmsg syscalls
1 parent 36dec54 commit 15c13ad

File tree

4 files changed

+244
-0
lines changed

4 files changed

+244
-0
lines changed

src/lib.rs

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,20 @@ macro_rules! man_links {
172172
};
173173
}
174174

175+
macro_rules! define_mmsg_if_supported {
176+
($($item:item)*) => {
177+
$(
178+
#[cfg(all(
179+
any(
180+
target_os = "linux",
181+
target_os = "android",
182+
)
183+
))]
184+
$item
185+
)*
186+
};
187+
}
188+
175189
mod sockaddr;
176190
mod socket;
177191
mod sockref;
@@ -697,3 +711,89 @@ impl<'name, 'bufs, 'control> fmt::Debug for MsgHdrMut<'name, 'bufs, 'control> {
697711
"MsgHdrMut".fmt(fmt)
698712
}
699713
}
714+
715+
define_mmsg_if_supported! {
716+
/// Wraps `mmsghdr` on Unix for a `sendmmsg(2)` system call.
717+
///
718+
/// Also see [`MsgHdr`] for the variant used by `sendmsg(2)`.
719+
#[repr(transparent)]
720+
pub struct MMsgHdr<'addr, 'bufs, 'control> {
721+
inner: sys::mmsghdr,
722+
#[allow(clippy::type_complexity)]
723+
_lifetimes: PhantomData<(&'addr SockAddr, &'bufs IoSlice<'bufs>, &'control [u8])>,
724+
}
725+
726+
impl<'addr, 'bufs, 'control> MMsgHdr<'addr, 'bufs, 'control> {
727+
/// Create a new `MMsgHdr` from `MsgHdr` and with the `msg_len` set to zero.
728+
pub fn new(msg: MsgHdr<'_, '_, '_>) -> Self {
729+
Self {
730+
inner: sys::mmsghdr {
731+
msg_hdr: msg.inner,
732+
msg_len: 0,
733+
},
734+
_lifetimes: PhantomData,
735+
}
736+
}
737+
738+
/// Number of bytes transmitted.
739+
pub fn transmitted_bytes(&self) -> u32 {
740+
self.inner.msg_len
741+
}
742+
}
743+
744+
impl<'addr, 'bufs, 'control> fmt::Debug for MMsgHdr<'addr, 'bufs, 'control> {
745+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
746+
f.write_str(&format!("MMsgHdr({})", self.transmitted_bytes()))
747+
}
748+
}
749+
750+
/// Wraps `mmsghdr` on Unix for a `recvmmsg(2)` system call.
751+
///
752+
/// Also see [`MsgHdrMut`] for the variant used by `recvmsg(2)`.
753+
#[repr(transparent)]
754+
pub struct MMsgHdrMut<'addr, 'bufs, 'control> {
755+
inner: sys::mmsghdr,
756+
#[allow(clippy::type_complexity)]
757+
_lifetimes: PhantomData<(
758+
&'addr mut SockAddr,
759+
&'bufs mut MaybeUninitSlice<'bufs>,
760+
&'control mut [u8],
761+
)>,
762+
}
763+
764+
impl<'addr, 'bufs, 'control> MMsgHdrMut<'addr, 'bufs, 'control> {
765+
/// Create a new `MMsgHdrMut` from `MsgHdrMut` and with the `msg_len` set to zero.
766+
pub fn new(msg: MsgHdrMut<'_, '_, '_>) -> Self {
767+
Self {
768+
inner: sys::mmsghdr {
769+
msg_hdr: msg.inner,
770+
msg_len: 0,
771+
},
772+
_lifetimes: PhantomData,
773+
}
774+
}
775+
776+
/// Number of received bytes.
777+
pub fn recieved_bytes(&self) -> u32 {
778+
self.inner.msg_len
779+
}
780+
781+
/// Returns the flags of the message.
782+
pub fn flags(&self) -> RecvFlags {
783+
sys::msghdr_flags(&self.inner.msg_hdr)
784+
}
785+
786+
/// Gets the length of the control buffer.
787+
///
788+
/// Can be used to determine how much, if any, of the control buffer was filled by `recvmsg`.
789+
pub fn control_len(&self) -> usize {
790+
sys::msghdr_control_len(&self.inner.msg_hdr)
791+
}
792+
}
793+
794+
impl<'addr, 'bufs, 'control> fmt::Debug for MMsgHdrMut<'addr, 'bufs, 'control> {
795+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
796+
f.write_str(&format!("MMsgHdrMut({})", self.recieved_bytes()))
797+
}
798+
}
799+
}

src/socket.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ use crate::{Domain, Protocol, SockAddr, TcpKeepalive, Type};
2727
#[cfg(not(target_os = "redox"))]
2828
use crate::{MaybeUninitSlice, MsgHdr, RecvFlags};
2929

30+
define_mmsg_if_supported! {
31+
use crate::{MMsgHdr, MMsgHdrMut};
32+
}
33+
3034
/// Owned wrapper around a system socket.
3135
///
3236
/// This type simply wraps an instance of a file descriptor (`c_int`) on Unix
@@ -634,6 +638,19 @@ impl Socket {
634638
sys::recvmsg(self.as_raw(), msg, flags)
635639
}
636640

641+
define_mmsg_if_supported! {
642+
/// Receive multiple messages on the socket using a single system call.
643+
#[doc = man_links!(unix: recvmmsg(2))]
644+
pub fn recvmmsg(
645+
&self,
646+
msgvec: &mut [MMsgHdrMut<'_, '_, '_>],
647+
flags: c_int,
648+
timeout: Option<Duration>,
649+
) -> io::Result<usize> {
650+
sys::recvmmsg(self.as_raw(), msgvec, flags, timeout)
651+
}
652+
}
653+
637654
/// Sends data on the socket to a connected peer.
638655
///
639656
/// This is typically used on TCP sockets or datagram sockets which have
@@ -735,6 +752,14 @@ impl Socket {
735752
pub fn sendmsg(&self, msg: &MsgHdr<'_, '_, '_>, flags: sys::c_int) -> io::Result<usize> {
736753
sys::sendmsg(self.as_raw(), msg, flags)
737754
}
755+
756+
define_mmsg_if_supported! {
757+
/// Send multiple messages on the socket using a single system call.
758+
#[doc = man_links!(unix: sendmmsg(2))]
759+
pub fn sendmmsg(&self, msgvec: &mut [MMsgHdr<'_, '_, '_>], flags: c_int) -> io::Result<usize> {
760+
sys::sendmmsg(self.as_raw(), msgvec, flags)
761+
}
762+
}
738763
}
739764

740765
/// Set `SOCK_CLOEXEC` and `NO_HANDLE_INHERIT` on the `ty`pe on platforms that

src/sys/unix.rs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@ use crate::{Domain, Protocol, SockAddr, TcpKeepalive, Type};
8080
#[cfg(not(target_os = "redox"))]
8181
use crate::{MsgHdr, MsgHdrMut, RecvFlags};
8282

83+
define_mmsg_if_supported! {
84+
use crate::{MMsgHdr, MMsgHdrMut};
85+
86+
// Used in `MMsgHdr`.
87+
pub(crate) use libc::mmsghdr;
88+
}
89+
8390
pub(crate) use libc::c_int;
8491

8592
// Used in `Domain`.
@@ -1076,6 +1083,35 @@ pub(crate) fn recvmsg(
10761083
syscall!(recvmsg(fd, &mut msg.inner, flags)).map(|n| n as usize)
10771084
}
10781085

1086+
define_mmsg_if_supported! {
1087+
pub(crate) fn recvmmsg(
1088+
fd: Socket,
1089+
msgvec: &mut [MMsgHdrMut<'_, '_, '_>],
1090+
flags: c_int,
1091+
timeout: Option<Duration>,
1092+
) -> io::Result<usize> {
1093+
if cfg!(target_env = "musl") {
1094+
debug_assert!(flags >= 0, "socket flags must be non-negative");
1095+
}
1096+
1097+
let mut timeout = timeout.map(into_timespec);
1098+
let timeout_ptr = timeout
1099+
.as_mut()
1100+
.map(|t| t as *mut _)
1101+
.unwrap_or(ptr::null_mut());
1102+
1103+
syscall!(recvmmsg(
1104+
fd,
1105+
// SAFETY: `MMsgHdrMut` is `#[repr(transparent)]` and wraps a `libc::mmsghdr`
1106+
msgvec.as_mut_ptr() as *mut mmsghdr,
1107+
msgvec.len() as _,
1108+
flags as _,
1109+
timeout_ptr
1110+
))
1111+
.map(|n| n as usize)
1112+
}
1113+
}
1114+
10791115
pub(crate) fn send(fd: Socket, buf: &[u8], flags: c_int) -> io::Result<usize> {
10801116
syscall!(send(
10811117
fd,
@@ -1120,6 +1156,27 @@ pub(crate) fn sendmsg(fd: Socket, msg: &MsgHdr<'_, '_, '_>, flags: c_int) -> io:
11201156
syscall!(sendmsg(fd, &msg.inner, flags)).map(|n| n as usize)
11211157
}
11221158

1159+
define_mmsg_if_supported! {
1160+
pub(crate) fn sendmmsg(
1161+
fd: Socket,
1162+
msgvec: &mut [MMsgHdr<'_, '_, '_>],
1163+
flags: c_int,
1164+
) -> io::Result<usize> {
1165+
if cfg!(target_env = "musl") {
1166+
debug_assert!(flags >= 0, "socket flags must be non-negative");
1167+
}
1168+
1169+
syscall!(sendmmsg(
1170+
fd,
1171+
// SAFETY: `MMsgHdr` is `#[repr(transparent)]` and wraps a `libc::mmsghdr`
1172+
msgvec.as_mut_ptr() as *mut mmsghdr,
1173+
msgvec.len() as _,
1174+
flags as _
1175+
))
1176+
.map(|n| n as usize)
1177+
}
1178+
}
1179+
11231180
/// Wrapper around `getsockopt` to deal with platform specific timeouts.
11241181
pub(crate) fn timeout_opt(fd: Socket, opt: c_int, val: c_int) -> io::Result<Option<Duration>> {
11251182
unsafe { getsockopt(fd, opt, val).map(from_timeval) }
@@ -1161,6 +1218,27 @@ fn into_timeval(duration: Option<Duration>) -> libc::timeval {
11611218
}
11621219
}
11631220

1221+
define_mmsg_if_supported! {
1222+
fn into_timespec(duration: Duration) -> libc::timespec {
1223+
// https://github.com/rust-lang/libc/issues/1848
1224+
#[cfg_attr(target_env = "musl", allow(deprecated))]
1225+
libc::timespec {
1226+
tv_sec: min(duration.as_secs(), libc::time_t::MAX as u64) as libc::time_t,
1227+
#[cfg(any(
1228+
all(target_arch = "x86_64", target_pointer_width = "32"),
1229+
target_pointer_width = "64"
1230+
))]
1231+
tv_nsec: duration.subsec_nanos() as i64,
1232+
1233+
#[cfg(not(any(
1234+
all(target_arch = "x86_64", target_pointer_width = "32"),
1235+
target_pointer_width = "64"
1236+
)))]
1237+
tv_nsec: duration.subsec_nanos().clamp(0, i32::MAX as u32) as i32,
1238+
}
1239+
}
1240+
}
1241+
11641242
#[cfg(all(
11651243
feature = "all",
11661244
not(any(target_os = "haiku", target_os = "openbsd", target_os = "vita"))

tests/socket.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,47 @@ fn sendmsg() {
769769
assert_eq!(received, DATA.len());
770770
}
771771

772+
#[test]
773+
#[cfg(any(target_os = "linux", target_os = "android"))]
774+
fn send_and_recv_batched_msgs() {
775+
let (socket_a, socket_b) = udp_pair_unconnected();
776+
777+
const DATA: &[u8] = b"Hello, World!";
778+
779+
let addr_b = socket_b.local_addr().unwrap();
780+
let mut batched_msgs = Vec::new();
781+
let mut recv_batched_msgs = Vec::new();
782+
for _ in 0..10 {
783+
let bufs = &[IoSlice::new(DATA)];
784+
batched_msgs.push(socket2::MMsgHdr::new(
785+
socket2::MsgHdr::new().with_addr(&addr_b).with_buffers(bufs),
786+
));
787+
788+
let mut buf = [MaybeUninit::new(0u8); 20];
789+
let recv_bufs = MaybeUninitSlice::new(buf.as_mut_slice());
790+
recv_batched_msgs.push(socket2::MMsgHdrMut::new(
791+
socket2::MsgHdrMut::new().with_buffers(&mut [recv_bufs]),
792+
));
793+
}
794+
795+
let sent = socket_a.sendmmsg(batched_msgs.as_mut_slice(), 0).unwrap();
796+
797+
let mut sent_data = 0;
798+
// Calculate transmitted length
799+
for msg in batched_msgs.iter().take(sent) {
800+
sent_data += msg.transmitted_bytes()
801+
}
802+
assert!(sent_data as usize == 10 * DATA.len());
803+
804+
let recvd = socket_b
805+
.recvmmsg(recv_batched_msgs.as_mut_slice(), 0, None)
806+
.unwrap();
807+
808+
assert!(recvd == sent);
809+
810+
assert!(recv_batched_msgs[0].recieved_bytes() == DATA.len().try_into().unwrap());
811+
}
812+
772813
#[test]
773814
#[cfg(not(any(target_os = "redox", target_os = "vita")))]
774815
fn recv_vectored_truncated() {

0 commit comments

Comments
 (0)