Skip to content

Commit 840f237

Browse files
committed
feat(net): helpers to handle control messages
1 parent 14697fa commit 840f237

File tree

4 files changed

+380
-0
lines changed

4 files changed

+380
-0
lines changed

compio-net/src/cmsg/mod.rs

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
use std::marker::PhantomData;
2+
3+
use compio_buf::{IoBuf, IoBufMut};
4+
5+
cfg_if::cfg_if! {
6+
if #[cfg(windows)] {
7+
#[path = "windows.rs"]
8+
mod sys;
9+
} else if #[cfg(unix)] {
10+
#[path = "unix.rs"]
11+
mod sys;
12+
}
13+
}
14+
15+
pub use sys::CMsgRef;
16+
17+
/// An iterator for control messages.
18+
pub struct CMsgIter<'a> {
19+
inner: sys::CMsgIter,
20+
_p: PhantomData<&'a ()>,
21+
}
22+
23+
impl<'a> CMsgIter<'a> {
24+
/// Create [`CMsgIter`] with the given buffer.
25+
///
26+
/// # Panics
27+
///
28+
/// This function will panic if the buffer is too short or not properly
29+
/// aligned.
30+
///
31+
/// # Safety
32+
///
33+
/// The buffer should contain valid control messages.
34+
pub unsafe fn new<B: IoBuf>(buffer: &'a B) -> Self {
35+
Self {
36+
inner: sys::CMsgIter::new(buffer.as_buf_ptr(), buffer.buf_len()),
37+
_p: PhantomData,
38+
}
39+
}
40+
}
41+
42+
impl<'a> Iterator for CMsgIter<'a> {
43+
type Item = CMsgRef<'a>;
44+
45+
fn next(&mut self) -> Option<Self::Item> {
46+
unsafe {
47+
let cmsg = self.inner.current();
48+
self.inner.next();
49+
cmsg
50+
}
51+
}
52+
}
53+
54+
/// Helper to construct control message.
55+
pub struct CMsgBuilder<B> {
56+
inner: sys::CMsgIter,
57+
buffer: B,
58+
len: usize,
59+
}
60+
61+
impl<B> CMsgBuilder<B> {
62+
/// Finishes building, returns the buffer and the length of the control
63+
/// message.
64+
pub fn build(self) -> (B, usize) {
65+
(self.buffer, self.len)
66+
}
67+
68+
/// Try to append a control message entry into the buffer. If the buffer
69+
/// does not have enough space or is not properly aligned with the value
70+
/// type, returns `None`.
71+
///
72+
/// # Safety
73+
///
74+
/// TODO: This function may be safe? Given that the buffer is zeroed,
75+
/// properly aligned and has enough space, safety conditions of all unsafe
76+
/// functions involved are satisfied, except for `CMSG_*`/`wsa_cmsg_*`, as
77+
/// their safety are not documented.
78+
pub unsafe fn try_push<T>(
79+
&mut self,
80+
level: sys::c_int,
81+
ty: sys::c_int,
82+
value: T,
83+
) -> Option<()> {
84+
if !self.inner.is_aligned::<T>() || !self.inner.is_space_enough::<T>() {
85+
return None;
86+
}
87+
88+
let mut cmsg = self.inner.current_mut()?;
89+
cmsg.set_level(level);
90+
cmsg.set_ty(ty);
91+
cmsg.set_data(value);
92+
93+
self.inner.next();
94+
self.len += sys::space_of::<T>();
95+
Some(())
96+
}
97+
}
98+
99+
impl<B: IoBufMut> CMsgBuilder<B> {
100+
/// Create [`CMsgBuilder`] with the given buffer. The buffer will be zeroed
101+
/// on creation.
102+
///
103+
/// # Panics
104+
///
105+
/// This function will panic if the buffer is too short or not properly
106+
/// aligned.
107+
pub fn new(mut buffer: B) -> Self {
108+
buffer.as_mut_slice().fill(std::mem::MaybeUninit::zeroed());
109+
Self {
110+
inner: sys::CMsgIter::new(buffer.as_buf_mut_ptr(), buffer.buf_len()),
111+
buffer,
112+
len: 0,
113+
}
114+
}
115+
}

compio-net/src/cmsg/unix.rs

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
use std::mem;
2+
3+
pub use libc::c_int;
4+
use libc::{cmsghdr, msghdr, CMSG_DATA, CMSG_FIRSTHDR, CMSG_LEN, CMSG_NXTHDR, CMSG_SPACE};
5+
6+
/// Reference to a control message.
7+
pub struct CMsgRef<'a>(&'a cmsghdr);
8+
9+
impl<'a> CMsgRef<'a> {
10+
/// Returns the level of the control message.
11+
pub fn level(&self) -> c_int {
12+
self.0.cmsg_level
13+
}
14+
15+
/// Returns the type of the control message.
16+
pub fn ty(&self) -> c_int {
17+
self.0.cmsg_type
18+
}
19+
20+
/// Returns the length of the control message.
21+
#[allow(clippy::len_without_is_empty)]
22+
pub fn len(&self) -> usize {
23+
self.0.cmsg_len as _
24+
}
25+
26+
/// Returns a reference to the data of the control message.
27+
///
28+
/// # Safety
29+
///
30+
/// The data part must be properly aligned and contains an initialized
31+
/// instance of `T`.
32+
pub unsafe fn data<T>(&self) -> &T {
33+
let data_ptr = CMSG_DATA(self.0);
34+
data_ptr.cast::<T>().as_ref().unwrap()
35+
}
36+
}
37+
38+
pub(crate) struct CMsgMut<'a>(&'a mut cmsghdr);
39+
40+
impl<'a> CMsgMut<'a> {
41+
pub(crate) fn set_level(&mut self, level: c_int) {
42+
self.0.cmsg_level = level;
43+
}
44+
45+
pub(crate) fn set_ty(&mut self, ty: c_int) {
46+
self.0.cmsg_type = ty;
47+
}
48+
49+
pub(crate) unsafe fn set_data<T>(&mut self, data: T) {
50+
self.0.cmsg_len = CMSG_LEN(mem::size_of::<T>() as _) as _;
51+
let data_ptr = CMSG_DATA(self.0);
52+
std::ptr::write(data_ptr.cast::<T>(), data);
53+
}
54+
}
55+
56+
pub(crate) struct CMsgIter {
57+
msg: msghdr,
58+
cmsg: *mut cmsghdr,
59+
}
60+
61+
impl CMsgIter {
62+
pub(crate) fn new(ptr: *const u8, len: usize) -> Self {
63+
assert!(len >= unsafe { CMSG_SPACE(0) as _ }, "buffer too short");
64+
assert!(ptr.cast::<cmsghdr>().is_aligned(), "misaligned buffer");
65+
66+
let mut msg: msghdr = unsafe { mem::zeroed() };
67+
msg.msg_control = ptr as _;
68+
msg.msg_controllen = len as _;
69+
// SAFETY: msg is initialized and valid
70+
let cmsg = unsafe { CMSG_FIRSTHDR(&msg) };
71+
Self { msg, cmsg }
72+
}
73+
74+
pub(crate) unsafe fn current<'a>(&self) -> Option<CMsgRef<'a>> {
75+
self.cmsg.as_ref().map(CMsgRef)
76+
}
77+
78+
pub(crate) unsafe fn next(&mut self) {
79+
if !self.cmsg.is_null() {
80+
self.cmsg = CMSG_NXTHDR(&self.msg, self.cmsg);
81+
}
82+
}
83+
84+
pub(crate) unsafe fn current_mut<'a>(&self) -> Option<CMsgMut<'a>> {
85+
self.cmsg.as_mut().map(CMsgMut)
86+
}
87+
88+
pub(crate) fn is_aligned<T>(&self) -> bool {
89+
self.msg.msg_control.cast::<T>().is_aligned()
90+
}
91+
92+
pub(crate) fn is_space_enough<T>(&self) -> bool {
93+
if !self.cmsg.is_null() {
94+
let space = unsafe { CMSG_SPACE(mem::size_of::<T>() as _) as usize };
95+
#[allow(clippy::unnecessary_cast)]
96+
let max = self.msg.msg_control as usize + self.msg.msg_controllen as usize;
97+
self.cmsg as usize + space <= max
98+
} else {
99+
false
100+
}
101+
}
102+
}
103+
104+
pub(crate) fn space_of<T>() -> usize {
105+
unsafe { CMSG_SPACE(mem::size_of::<T>() as _) as _ }
106+
}

compio-net/src/cmsg/windows.rs

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
use std::{mem, ptr::null_mut};
2+
3+
pub use i32 as c_int;
4+
use windows_sys::Win32::Networking::WinSock::{CMSGHDR, WSABUF, WSAMSG};
5+
6+
// Macros from https://github.com/microsoft/win32metadata/blob/main/generation/WinSDK/RecompiledIdlHeaders/shared/ws2def.h
7+
#[inline]
8+
const fn wsa_cmsghdr_align(length: usize) -> usize {
9+
(length + mem::align_of::<CMSGHDR>() - 1) & !(mem::align_of::<CMSGHDR>() - 1)
10+
}
11+
12+
// WSA_CMSGDATA_ALIGN(sizeof(CMSGHDR))
13+
const WSA_CMSGDATA_OFFSET: usize =
14+
(mem::size_of::<CMSGHDR>() + mem::align_of::<usize>() - 1) & !(mem::align_of::<usize>() - 1);
15+
16+
#[inline]
17+
unsafe fn wsa_cmsg_firsthdr(msg: *const WSAMSG) -> *mut CMSGHDR {
18+
if (*msg).Control.len as usize >= mem::size_of::<CMSGHDR>() {
19+
(*msg).Control.buf as _
20+
} else {
21+
null_mut()
22+
}
23+
}
24+
25+
#[inline]
26+
unsafe fn wsa_cmsg_nxthdr(msg: *const WSAMSG, cmsg: *const CMSGHDR) -> *mut CMSGHDR {
27+
if cmsg.is_null() {
28+
wsa_cmsg_firsthdr(msg)
29+
} else {
30+
let next = cmsg as usize + wsa_cmsghdr_align((*cmsg).cmsg_len);
31+
if next + mem::size_of::<CMSGHDR>()
32+
> (*msg).Control.buf as usize + (*msg).Control.len as usize
33+
{
34+
null_mut()
35+
} else {
36+
next as _
37+
}
38+
}
39+
}
40+
41+
#[inline]
42+
unsafe fn wsa_cmsg_data(cmsg: *const CMSGHDR) -> *mut u8 {
43+
(cmsg as usize + WSA_CMSGDATA_OFFSET) as _
44+
}
45+
46+
#[inline]
47+
const fn wsa_cmsg_space(length: usize) -> usize {
48+
WSA_CMSGDATA_OFFSET + wsa_cmsghdr_align(length)
49+
}
50+
51+
#[inline]
52+
const fn wsa_cmsg_len(length: usize) -> usize {
53+
WSA_CMSGDATA_OFFSET + length
54+
}
55+
56+
/// Reference to a control message.
57+
pub struct CMsgRef<'a>(&'a CMSGHDR);
58+
59+
impl<'a> CMsgRef<'a> {
60+
/// Returns the level of the control message.
61+
pub fn level(&self) -> i32 {
62+
self.0.cmsg_level
63+
}
64+
65+
/// Returns the type of the control message.
66+
pub fn ty(&self) -> i32 {
67+
self.0.cmsg_type
68+
}
69+
70+
/// Returns the length of the control message.
71+
#[allow(clippy::len_without_is_empty)]
72+
pub fn len(&self) -> usize {
73+
self.0.cmsg_len
74+
}
75+
76+
/// Returns a reference to the data of the control message.
77+
///
78+
/// # Safety
79+
///
80+
/// The data part must be properly aligned and contains an initialized
81+
/// instance of `T`.
82+
pub unsafe fn data<T>(&self) -> &T {
83+
let data_ptr = wsa_cmsg_data(self.0);
84+
data_ptr.cast::<T>().as_ref().unwrap()
85+
}
86+
}
87+
88+
pub(crate) struct CMsgMut<'a>(&'a mut CMSGHDR);
89+
90+
impl<'a> CMsgMut<'a> {
91+
pub(crate) fn set_level(&mut self, level: i32) {
92+
self.0.cmsg_level = level;
93+
}
94+
95+
pub(crate) fn set_ty(&mut self, ty: i32) {
96+
self.0.cmsg_type = ty;
97+
}
98+
99+
pub(crate) unsafe fn set_data<T>(&mut self, data: T) {
100+
self.0.cmsg_len = wsa_cmsg_len(mem::size_of::<T>() as _) as _;
101+
let data_ptr = wsa_cmsg_data(self.0);
102+
std::ptr::write(data_ptr.cast::<T>(), data);
103+
}
104+
}
105+
106+
pub(crate) struct CMsgIter {
107+
msg: WSAMSG,
108+
cmsg: *mut CMSGHDR,
109+
}
110+
111+
impl CMsgIter {
112+
pub(crate) fn new(ptr: *const u8, len: usize) -> Self {
113+
assert!(len >= wsa_cmsg_space(0) as _, "buffer too short");
114+
assert!(ptr.cast::<CMSGHDR>().is_aligned(), "misaligned buffer");
115+
116+
let mut msg: WSAMSG = unsafe { mem::zeroed() };
117+
msg.Control = WSABUF {
118+
len: len as _,
119+
buf: ptr as _,
120+
};
121+
// SAFETY: msg is initialized and valid
122+
let cmsg = unsafe { wsa_cmsg_firsthdr(&msg) };
123+
Self { msg, cmsg }
124+
}
125+
126+
pub(crate) unsafe fn current<'a>(&self) -> Option<CMsgRef<'a>> {
127+
self.cmsg.as_ref().map(CMsgRef)
128+
}
129+
130+
pub(crate) unsafe fn next(&mut self) {
131+
if !self.cmsg.is_null() {
132+
self.cmsg = wsa_cmsg_nxthdr(&self.msg, self.cmsg);
133+
}
134+
}
135+
136+
pub(crate) unsafe fn current_mut<'a>(&self) -> Option<CMsgMut<'a>> {
137+
self.cmsg.as_mut().map(CMsgMut)
138+
}
139+
140+
pub(crate) fn is_aligned<T>(&self) -> bool {
141+
self.msg.Control.buf.cast::<T>().is_aligned()
142+
}
143+
144+
pub(crate) fn is_space_enough<T>(&self) -> bool {
145+
if !self.cmsg.is_null() {
146+
let space = wsa_cmsg_space(mem::size_of::<T>() as _);
147+
let max = self.msg.Control.buf as usize + self.msg.Control.len as usize;
148+
self.cmsg as usize + space <= max
149+
} else {
150+
false
151+
}
152+
}
153+
}
154+
155+
pub(crate) fn space_of<T>() -> usize {
156+
wsa_cmsg_space(mem::size_of::<T>() as _)
157+
}

compio-net/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
66
#![warn(missing_docs)]
77

8+
mod cmsg;
89
mod poll_fd;
910
mod resolve;
1011
mod socket;
@@ -13,6 +14,7 @@ mod tcp;
1314
mod udp;
1415
mod unix;
1516

17+
pub use cmsg::*;
1618
pub use poll_fd::*;
1719
pub use resolve::ToSocketAddrsAsync;
1820
pub(crate) use resolve::{each_addr, first_addr_buf};

0 commit comments

Comments
 (0)