Skip to content

Commit 9946af2

Browse files
authored
Merge pull request #2 from NordSecurity/Add_sendmmsg
Add support *mmsg syscalls
2 parents 36dec54 + 201dd0b commit 9946af2

File tree

4 files changed

+230
-0
lines changed

4 files changed

+230
-0
lines changed

src/lib.rs

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,3 +697,93 @@ impl<'name, 'bufs, 'control> fmt::Debug for MsgHdrMut<'name, 'bufs, 'control> {
697697
"MsgHdrMut".fmt(fmt)
698698
}
699699
}
700+
701+
/// Wraps `mmsghdr` on Unix for a `sendmmsg(2)` system call.
702+
///
703+
/// Also see [`MsgHdr`] for the variant used by `sendmsg(2)`.
704+
#[repr(transparent)]
705+
#[cfg(any(target_os = "linux", target_os = "android",))]
706+
pub struct MMsgHdr<'addr, 'bufs, 'control> {
707+
inner: sys::mmsghdr,
708+
#[allow(clippy::type_complexity)]
709+
_lifetimes: PhantomData<(&'addr SockAddr, &'bufs IoSlice<'bufs>, &'control [u8])>,
710+
}
711+
712+
#[cfg(any(target_os = "linux", target_os = "android",))]
713+
impl<'addr, 'bufs, 'control> MMsgHdr<'addr, 'bufs, 'control> {
714+
/// Create a new `MMsgHdr` from `MsgHdr` and with the `msg_len` set to zero.
715+
pub fn new(msg: MsgHdr<'_, '_, '_>) -> Self {
716+
Self {
717+
inner: sys::mmsghdr {
718+
msg_hdr: msg.inner,
719+
msg_len: 0,
720+
},
721+
_lifetimes: PhantomData,
722+
}
723+
}
724+
725+
/// Number of bytes transmitted.
726+
pub fn transmitted_bytes(&self) -> u32 {
727+
self.inner.msg_len
728+
}
729+
}
730+
731+
#[cfg(any(target_os = "linux", target_os = "android",))]
732+
impl<'addr, 'bufs, 'control> fmt::Debug for MMsgHdr<'addr, 'bufs, 'control> {
733+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
734+
f.write_str(&format!("MMsgHdr({})", self.transmitted_bytes()))
735+
}
736+
}
737+
738+
/// Wraps `mmsghdr` on Unix for a `recvmmsg(2)` system call.
739+
///
740+
/// Also see [`MsgHdrMut`] for the variant used by `recvmsg(2)`.
741+
#[repr(transparent)]
742+
#[cfg(any(target_os = "linux", target_os = "android",))]
743+
pub struct MMsgHdrMut<'addr, 'bufs, 'control> {
744+
inner: sys::mmsghdr,
745+
#[allow(clippy::type_complexity)]
746+
_lifetimes: PhantomData<(
747+
&'addr mut SockAddr,
748+
&'bufs mut MaybeUninitSlice<'bufs>,
749+
&'control mut [u8],
750+
)>,
751+
}
752+
753+
#[cfg(any(target_os = "linux", target_os = "android",))]
754+
impl<'addr, 'bufs, 'control> MMsgHdrMut<'addr, 'bufs, 'control> {
755+
/// Create a new `MMsgHdrMut` from `MsgHdrMut` and with the `msg_len` set to zero.
756+
pub fn new(msg: MsgHdrMut<'_, '_, '_>) -> Self {
757+
Self {
758+
inner: sys::mmsghdr {
759+
msg_hdr: msg.inner,
760+
msg_len: 0,
761+
},
762+
_lifetimes: PhantomData,
763+
}
764+
}
765+
766+
/// Number of received bytes.
767+
pub fn recieved_bytes(&self) -> u32 {
768+
self.inner.msg_len
769+
}
770+
771+
/// Returns the flags of the message.
772+
pub fn flags(&self) -> RecvFlags {
773+
sys::msghdr_flags(&self.inner.msg_hdr)
774+
}
775+
776+
/// Gets the length of the control buffer.
777+
///
778+
/// Can be used to determine how much, if any, of the control buffer was filled by `recvmsg`.
779+
pub fn control_len(&self) -> usize {
780+
sys::msghdr_control_len(&self.inner.msg_hdr)
781+
}
782+
}
783+
784+
#[cfg(any(target_os = "linux", target_os = "android",))]
785+
impl<'addr, 'bufs, 'control> fmt::Debug for MMsgHdrMut<'addr, 'bufs, 'control> {
786+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
787+
f.write_str(&format!("MMsgHdrMut({})", self.recieved_bytes()))
788+
}
789+
}

src/socket.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ use crate::{Domain, Protocol, SockAddr, TcpKeepalive, Type};
2727
#[cfg(not(target_os = "redox"))]
2828
use crate::{MaybeUninitSlice, MsgHdr, RecvFlags};
2929

30+
#[cfg(all(feature = "all", any(target_os = "android", target_os = "linux",)))]
31+
use crate::{MMsgHdr, MMsgHdrMut};
32+
3033
/// Owned wrapper around a system socket.
3134
///
3235
/// This type simply wraps an instance of a file descriptor (`c_int`) on Unix
@@ -634,6 +637,18 @@ impl Socket {
634637
sys::recvmsg(self.as_raw(), msg, flags)
635638
}
636639

640+
/// Receive multiple messages on the socket using a single system call.
641+
#[doc = man_links!(unix: recvmmsg(2))]
642+
#[cfg(all(feature = "all", any(target_os = "android", target_os = "linux",)))]
643+
pub fn recvmmsg(
644+
&self,
645+
msgvec: &mut [MMsgHdrMut<'_, '_, '_>],
646+
flags: c_int,
647+
timeout: Option<Duration>,
648+
) -> io::Result<usize> {
649+
sys::recvmmsg(self.as_raw(), msgvec, flags, timeout)
650+
}
651+
637652
/// Sends data on the socket to a connected peer.
638653
///
639654
/// This is typically used on TCP sockets or datagram sockets which have
@@ -735,6 +750,13 @@ impl Socket {
735750
pub fn sendmsg(&self, msg: &MsgHdr<'_, '_, '_>, flags: sys::c_int) -> io::Result<usize> {
736751
sys::sendmsg(self.as_raw(), msg, flags)
737752
}
753+
754+
/// Send multiple messages on the socket using a single system call.
755+
#[doc = man_links!(unix: sendmmsg(2))]
756+
#[cfg(all(feature = "all", any(target_os = "linux", target_os = "android",)))]
757+
pub fn sendmmsg(&self, msgvec: &mut [MMsgHdr<'_, '_, '_>], flags: c_int) -> io::Result<usize> {
758+
sys::sendmmsg(self.as_raw(), msgvec, flags)
759+
}
738760
}
739761

740762
/// Set `SOCK_CLOEXEC` and `NO_HANDLE_INHERIT` on the `ty`pe on platforms that

src/sys/unix.rs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@ use crate::{Domain, Protocol, SockAddr, TcpKeepalive, Type};
8080
#[cfg(not(target_os = "redox"))]
8181
use crate::{MsgHdr, MsgHdrMut, RecvFlags};
8282

83+
#[cfg(all(feature = "all", any(target_os = "linux", target_os = "android",)))]
84+
use crate::{MMsgHdr, MMsgHdrMut};
85+
86+
// Used in `MMsgHdr`.
87+
#[cfg(any(target_os = "linux", target_os = "android",))]
88+
pub(crate) use libc::mmsghdr;
89+
8390
pub(crate) use libc::c_int;
8491

8592
// Used in `Domain`.
@@ -1076,6 +1083,35 @@ pub(crate) fn recvmsg(
10761083
syscall!(recvmsg(fd, &mut msg.inner, flags)).map(|n| n as usize)
10771084
}
10781085

1086+
#[cfg(all(feature = "all", any(target_os = "linux", target_os = "android",)))]
1087+
/// This emits all the messages in a single syscall
1088+
pub(crate) fn recvmmsg(
1089+
fd: Socket,
1090+
msgvec: &mut [MMsgHdrMut<'_, '_, '_>],
1091+
flags: c_int,
1092+
timeout: Option<Duration>,
1093+
) -> io::Result<usize> {
1094+
if cfg!(target_env = "musl") {
1095+
debug_assert!(flags >= 0, "socket flags must be non-negative");
1096+
}
1097+
1098+
let mut timeout = timeout.map(into_timespec);
1099+
let timeout_ptr = timeout
1100+
.as_mut()
1101+
.map(|t| t as *mut _)
1102+
.unwrap_or(ptr::null_mut());
1103+
1104+
syscall!(recvmmsg(
1105+
fd,
1106+
// SAFETY: `MMsgHdrMut` is `#[repr(transparent)]` and wraps a `libc::mmsghdr`
1107+
msgvec.as_mut_ptr() as *mut mmsghdr,
1108+
msgvec.len() as _,
1109+
flags as _,
1110+
timeout_ptr
1111+
))
1112+
.map(|n| n as usize)
1113+
}
1114+
10791115
pub(crate) fn send(fd: Socket, buf: &[u8], flags: c_int) -> io::Result<usize> {
10801116
syscall!(send(
10811117
fd,
@@ -1120,6 +1156,27 @@ pub(crate) fn sendmsg(fd: Socket, msg: &MsgHdr<'_, '_, '_>, flags: c_int) -> io:
11201156
syscall!(sendmsg(fd, &msg.inner, flags)).map(|n| n as usize)
11211157
}
11221158

1159+
#[cfg(all(feature = "all", any(target_os = "linux", target_os = "android",)))]
1160+
/// This transmits all the messages in a single syscall
1161+
pub(crate) fn sendmmsg(
1162+
fd: Socket,
1163+
msgvec: &mut [MMsgHdr<'_, '_, '_>],
1164+
flags: c_int,
1165+
) -> io::Result<usize> {
1166+
if cfg!(target_env = "musl") {
1167+
debug_assert!(flags >= 0, "socket flags must be non-negative");
1168+
}
1169+
1170+
syscall!(sendmmsg(
1171+
fd,
1172+
// SAFETY: `MMsgHdr` is `#[repr(transparent)]` and wraps a `libc::mmsghdr`
1173+
msgvec.as_mut_ptr() as *mut mmsghdr,
1174+
msgvec.len() as _,
1175+
flags as _
1176+
))
1177+
.map(|n| n as usize)
1178+
}
1179+
11231180
/// Wrapper around `getsockopt` to deal with platform specific timeouts.
11241181
pub(crate) fn timeout_opt(fd: Socket, opt: c_int, val: c_int) -> io::Result<Option<Duration>> {
11251182
unsafe { getsockopt(fd, opt, val).map(from_timeval) }
@@ -1161,6 +1218,26 @@ fn into_timeval(duration: Option<Duration>) -> libc::timeval {
11611218
}
11621219
}
11631220

1221+
#[cfg(all(feature = "all", any(target_os = "linux", target_os = "android",)))]
1222+
fn into_timespec(duration: Duration) -> libc::timespec {
1223+
// https://github.com/rust-lang/libc/issues/1848
1224+
#[cfg_attr(target_env = "musl", allow(deprecated))]
1225+
libc::timespec {
1226+
tv_sec: min(duration.as_secs(), libc::time_t::MAX as u64) as libc::time_t,
1227+
#[cfg(any(
1228+
all(target_arch = "x86_64", target_pointer_width = "32"),
1229+
target_pointer_width = "64"
1230+
))]
1231+
tv_nsec: duration.subsec_nanos() as i64,
1232+
1233+
#[cfg(not(any(
1234+
all(target_arch = "x86_64", target_pointer_width = "32"),
1235+
target_pointer_width = "64"
1236+
)))]
1237+
tv_nsec: duration.subsec_nanos().clamp(0, i32::MAX as u32) as i32,
1238+
}
1239+
}
1240+
11641241
#[cfg(all(
11651242
feature = "all",
11661243
not(any(target_os = "haiku", target_os = "openbsd", target_os = "vita"))

tests/socket.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,47 @@ fn sendmsg() {
769769
assert_eq!(received, DATA.len());
770770
}
771771

772+
#[test]
773+
#[cfg(all(feature = "all", any(target_os = "linux", target_os = "android")))]
774+
fn send_and_recv_batched_msgs() {
775+
let (socket_a, socket_b) = udp_pair_unconnected();
776+
777+
const DATA: &[u8] = b"Hello, World!";
778+
779+
let addr_b = socket_b.local_addr().unwrap();
780+
let mut batched_msgs = Vec::new();
781+
let mut recv_batched_msgs = Vec::new();
782+
for _ in 0..10 {
783+
let bufs = &[IoSlice::new(DATA)];
784+
batched_msgs.push(socket2::MMsgHdr::new(
785+
socket2::MsgHdr::new().with_addr(&addr_b).with_buffers(bufs),
786+
));
787+
788+
let mut buf = [MaybeUninit::new(0u8); DATA.len()];
789+
let recv_bufs = MaybeUninitSlice::new(buf.as_mut_slice());
790+
recv_batched_msgs.push(socket2::MMsgHdrMut::new(
791+
socket2::MsgHdrMut::new().with_buffers(&mut [recv_bufs]),
792+
));
793+
}
794+
795+
let sent = socket_a.sendmmsg(batched_msgs.as_mut_slice(), 0).unwrap();
796+
797+
let mut sent_data = 0;
798+
// Calculate transmitted length
799+
for msg in batched_msgs.iter().take(sent) {
800+
sent_data += msg.transmitted_bytes()
801+
}
802+
assert!(sent_data as usize == 10 * DATA.len());
803+
804+
let recvd = socket_b
805+
.recvmmsg(recv_batched_msgs.as_mut_slice(), 0, None)
806+
.unwrap();
807+
808+
assert!(recvd == sent);
809+
810+
assert!(recv_batched_msgs[0].recieved_bytes() == DATA.len().try_into().unwrap());
811+
}
812+
772813
#[test]
773814
#[cfg(not(any(target_os = "redox", target_os = "vita")))]
774815
fn recv_vectored_truncated() {

0 commit comments

Comments
 (0)