Skip to content

Commit 97eb90a

Browse files
committed
feat(driver): add RecvMsg/SendMsg opcode
1 parent 1dd5683 commit 97eb90a

File tree

5 files changed

+621
-8
lines changed

5 files changed

+621
-8
lines changed

compio-driver/src/iocp/op.rs

Lines changed: 263 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ use std::{
1111
};
1212

1313
use aligned_array::{Aligned, A8};
14-
use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
14+
use compio_buf::{
15+
BufResult, IntoInner, IoBuf, IoBufMut, IoSlice, IoSliceMut, IoVectoredBuf, IoVectoredBufMut,
16+
};
1517
#[cfg(not(feature = "once_cell_try"))]
1618
use once_cell::sync::OnceCell as OnceLock;
1719
use socket2::SockAddr;
@@ -25,10 +27,11 @@ use windows_sys::{
2527
},
2628
Networking::WinSock::{
2729
closesocket, setsockopt, shutdown, socklen_t, WSAIoctl, WSARecv, WSARecvFrom, WSASend,
28-
WSASendTo, LPFN_ACCEPTEX, LPFN_CONNECTEX, LPFN_GETACCEPTEXSOCKADDRS, SD_BOTH,
29-
SD_RECEIVE, SD_SEND, SIO_GET_EXTENSION_FUNCTION_POINTER, SOCKADDR, SOCKADDR_STORAGE,
30-
SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, SO_UPDATE_CONNECT_CONTEXT, WSAID_ACCEPTEX,
31-
WSAID_CONNECTEX, WSAID_GETACCEPTEXSOCKADDRS,
30+
WSASendMsg, WSASendTo, LPFN_ACCEPTEX, LPFN_CONNECTEX, LPFN_GETACCEPTEXSOCKADDRS,
31+
LPFN_WSARECVMSG, SD_BOTH, SD_RECEIVE, SD_SEND, SIO_GET_EXTENSION_FUNCTION_POINTER,
32+
SOCKADDR, SOCKADDR_STORAGE, SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT,
33+
SO_UPDATE_CONNECT_CONTEXT, WSABUF, WSAID_ACCEPTEX, WSAID_CONNECTEX,
34+
WSAID_GETACCEPTEXSOCKADDRS, WSAID_WSARECVMSG, WSAMSG,
3235
},
3336
Storage::FileSystem::{FlushFileBuffers, ReadFile, WriteFile},
3437
System::{
@@ -774,6 +777,261 @@ impl<T: IoVectoredBuf, S: AsRawFd> OpCode for SendToVectored<T, S> {
774777
}
775778
}
776779

780+
static WSA_RECVMSG: OnceLock<LPFN_WSARECVMSG> = OnceLock::new();
781+
782+
struct RecvMsgHeader<S> {
783+
fd: SharedFd<S>,
784+
addr: SOCKADDR_STORAGE,
785+
addr_len: socklen_t,
786+
_p: PhantomPinned,
787+
}
788+
789+
impl<S> RecvMsgHeader<S> {
790+
fn new(fd: SharedFd<S>) -> Self {
791+
Self {
792+
fd,
793+
addr: unsafe { std::mem::zeroed() },
794+
addr_len: std::mem::size_of::<SOCKADDR_STORAGE>() as _,
795+
_p: PhantomPinned,
796+
}
797+
}
798+
799+
fn into_addr(self) -> (SOCKADDR_STORAGE, socklen_t) {
800+
(self.addr, self.addr_len)
801+
}
802+
}
803+
804+
impl<S: AsRawFd> RecvMsgHeader<S> {
805+
unsafe fn operate(
806+
&mut self,
807+
slices: &mut [IoSliceMut],
808+
control: IoSliceMut,
809+
optr: *mut OVERLAPPED,
810+
) -> Poll<io::Result<usize>> {
811+
let recvmsg_fn = WSA_RECVMSG
812+
.get_or_try_init(|| get_wsa_fn(self.fd.as_raw_fd(), WSAID_WSARECVMSG))?
813+
.ok_or_else(|| {
814+
io::Error::new(io::ErrorKind::Unsupported, "cannot retrieve WSARecvMsg")
815+
})?;
816+
817+
let mut msg = WSAMSG {
818+
name: &mut self.addr as *mut _ as _,
819+
namelen: self.addr_len,
820+
lpBuffers: slices.as_mut_ptr() as _,
821+
dwBufferCount: slices.len() as _,
822+
Control: std::mem::transmute::<IoSliceMut, WSABUF>(control),
823+
dwFlags: 0,
824+
};
825+
826+
let mut received = 0;
827+
let res = recvmsg_fn(
828+
self.fd.as_raw_fd() as _,
829+
&mut msg,
830+
&mut received,
831+
optr,
832+
None,
833+
);
834+
winsock_result(res, received)
835+
}
836+
}
837+
838+
/// Receive data and source address with ancillary data.
839+
pub struct RecvMsg<T: IoBufMut, C: IoBufMut, S> {
840+
header: RecvMsgHeader<S>,
841+
buffer: MsgBuf<T, C>,
842+
}
843+
844+
impl<T: IoBufMut, C: IoBufMut, S> RecvMsg<T, C, S> {
845+
/// Create [`RecvMsg`].
846+
pub fn new(fd: SharedFd<S>, buffer: MsgBuf<T, C>) -> Self {
847+
Self {
848+
header: RecvMsgHeader::new(fd),
849+
buffer,
850+
}
851+
}
852+
}
853+
854+
impl<T: IoBufMut, C: IoBufMut, S> IntoInner for RecvMsg<T, C, S> {
855+
type Inner = (MsgBuf<T, C>, SOCKADDR_STORAGE, socklen_t);
856+
857+
fn into_inner(self) -> Self::Inner {
858+
let (addr, addr_len) = self.header.into_addr();
859+
(self.buffer, addr, addr_len)
860+
}
861+
}
862+
863+
impl<T: IoBufMut, C: IoBufMut, S: AsRawFd> OpCode for RecvMsg<T, C, S> {
864+
unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
865+
let this = self.get_unchecked_mut();
866+
this.header.operate(
867+
&mut [this.buffer.inner.as_io_slice_mut()],
868+
this.buffer.control.as_io_slice_mut(),
869+
optr,
870+
)
871+
}
872+
873+
unsafe fn cancel(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> io::Result<()> {
874+
cancel(self.header.fd.as_raw_fd(), optr)
875+
}
876+
}
877+
878+
/// Receive data and source address with ancillary data into vectored buffer.
879+
pub struct RecvMsgVectored<T: IoVectoredBufMut, C: IoBufMut, S> {
880+
header: RecvMsgHeader<S>,
881+
buffer: MsgBuf<T, C>,
882+
}
883+
884+
impl<T: IoVectoredBufMut, C: IoBufMut, S> RecvMsgVectored<T, C, S> {
885+
/// Create [`RecvMsgVectored`].
886+
pub fn new(fd: SharedFd<S>, buffer: MsgBuf<T, C>) -> Self {
887+
Self {
888+
header: RecvMsgHeader::new(fd),
889+
buffer,
890+
}
891+
}
892+
}
893+
894+
impl<T: IoVectoredBufMut, C: IoBufMut, S> IntoInner for RecvMsgVectored<T, C, S> {
895+
type Inner = (MsgBuf<T, C>, SOCKADDR_STORAGE, socklen_t);
896+
897+
fn into_inner(self) -> Self::Inner {
898+
let (addr, addr_len) = self.header.into_addr();
899+
(self.buffer, addr, addr_len)
900+
}
901+
}
902+
903+
impl<T: IoVectoredBufMut, C: IoBufMut, S: AsRawFd> OpCode for RecvMsgVectored<T, C, S> {
904+
unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
905+
let this = self.get_unchecked_mut();
906+
this.header.operate(
907+
&mut this.buffer.inner.as_io_slices_mut(),
908+
this.buffer.control.as_io_slice_mut(),
909+
optr,
910+
)
911+
}
912+
913+
unsafe fn cancel(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> io::Result<()> {
914+
cancel(self.header.fd.as_raw_fd(), optr)
915+
}
916+
}
917+
918+
struct SendMsgHeader<S> {
919+
fd: SharedFd<S>,
920+
addr: SockAddr,
921+
_p: PhantomPinned,
922+
}
923+
924+
impl<S> SendMsgHeader<S> {
925+
pub fn new(fd: SharedFd<S>, addr: SockAddr) -> Self {
926+
Self {
927+
fd,
928+
addr,
929+
_p: PhantomPinned,
930+
}
931+
}
932+
}
933+
934+
impl<S: AsRawFd> SendMsgHeader<S> {
935+
unsafe fn operate(
936+
&mut self,
937+
slices: &mut [IoSlice],
938+
control: IoSlice,
939+
optr: *mut OVERLAPPED,
940+
) -> Poll<io::Result<usize>> {
941+
let msg = WSAMSG {
942+
name: self.addr.as_ptr() as _,
943+
namelen: self.addr.len(),
944+
lpBuffers: slices.as_ptr() as _,
945+
dwBufferCount: slices.len() as _,
946+
Control: std::mem::transmute::<IoSlice, WSABUF>(control),
947+
dwFlags: 0,
948+
};
949+
950+
let mut sent = 0;
951+
let res = WSASendMsg(self.fd.as_raw_fd() as _, &msg, 0, &mut sent, optr, None);
952+
winsock_result(res, sent)
953+
}
954+
}
955+
956+
/// Send data to specified address accompanied by ancillary data.
957+
pub struct SendMsg<T: IoBuf, C: IoBuf, S> {
958+
header: SendMsgHeader<S>,
959+
buffer: MsgBuf<T, C>,
960+
}
961+
962+
impl<T: IoBuf, C: IoBuf, S> SendMsg<T, C, S> {
963+
/// Create [`SendMsg`].
964+
pub fn new(fd: SharedFd<S>, buffer: MsgBuf<T, C>, addr: SockAddr) -> Self {
965+
Self {
966+
header: SendMsgHeader::new(fd, addr),
967+
buffer,
968+
}
969+
}
970+
}
971+
972+
impl<T: IoBuf, C: IoBuf, S> IntoInner for SendMsg<T, C, S> {
973+
type Inner = MsgBuf<T, C>;
974+
975+
fn into_inner(self) -> Self::Inner {
976+
self.buffer
977+
}
978+
}
979+
980+
impl<T: IoBuf, C: IoBuf, S: AsRawFd> OpCode for SendMsg<T, C, S> {
981+
unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
982+
let this = self.get_unchecked_mut();
983+
this.header.operate(
984+
&mut [this.buffer.inner.as_io_slice()],
985+
this.buffer.control.as_io_slice(),
986+
optr,
987+
)
988+
}
989+
990+
unsafe fn cancel(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> io::Result<()> {
991+
cancel(self.header.fd.as_raw_fd(), optr)
992+
}
993+
}
994+
995+
/// Send data to specified address accompanied by ancillary data from vectored
996+
/// buffer.
997+
pub struct SendMsgVectored<T: IoVectoredBuf, C: IoBuf, S> {
998+
header: SendMsgHeader<S>,
999+
buffer: MsgBuf<T, C>,
1000+
}
1001+
1002+
impl<T: IoVectoredBuf, C: IoBuf, S> SendMsgVectored<T, C, S> {
1003+
/// Create [`SendMsgVectored`].
1004+
pub fn new(fd: SharedFd<S>, buffer: MsgBuf<T, C>, addr: SockAddr) -> Self {
1005+
Self {
1006+
header: SendMsgHeader::new(fd, addr),
1007+
buffer,
1008+
}
1009+
}
1010+
}
1011+
1012+
impl<T: IoVectoredBuf, C: IoBuf, S> IntoInner for SendMsgVectored<T, C, S> {
1013+
type Inner = MsgBuf<T, C>;
1014+
1015+
fn into_inner(self) -> Self::Inner {
1016+
self.buffer
1017+
}
1018+
}
1019+
1020+
impl<T: IoVectoredBuf, C: IoBuf, S: AsRawFd> OpCode for SendMsgVectored<T, C, S> {
1021+
unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
1022+
let this = self.get_unchecked_mut();
1023+
this.header.operate(
1024+
&mut this.buffer.inner.as_io_slices(),
1025+
this.buffer.control.as_io_slice(),
1026+
optr,
1027+
)
1028+
}
1029+
1030+
unsafe fn cancel(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> io::Result<()> {
1031+
cancel(self.header.fd.as_raw_fd(), optr)
1032+
}
1033+
}
1034+
7771035
/// Connect a named pipe server.
7781036
pub struct ConnectNamedPipe<S> {
7791037
pub(crate) fd: SharedFd<S>,

compio-driver/src/iour/op.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,54 @@ impl<T: IoVectoredBuf, S> IntoInner for SendToVectored<T, S> {
556556
}
557557
}
558558

559+
impl<S: AsRawFd> RecvMsgHeader<S> {
560+
pub fn create_entry(&mut self) -> OpEntry {
561+
opcode::RecvMsg::new(Fd(self.fd.as_raw_fd()), &mut self.msg)
562+
.build()
563+
.into()
564+
}
565+
}
566+
567+
impl<T: IoBufMut, C: IoBufMut, S: AsRawFd> OpCode for RecvMsg<T, C, S> {
568+
fn create_entry(self: Pin<&mut Self>) -> OpEntry {
569+
let this = unsafe { self.get_unchecked_mut() };
570+
this.set_msg();
571+
this.header.create_entry()
572+
}
573+
}
574+
575+
impl<T: IoVectoredBufMut, C: IoBufMut, S: AsRawFd> OpCode for RecvMsgVectored<T, C, S> {
576+
fn create_entry(self: Pin<&mut Self>) -> OpEntry {
577+
let this = unsafe { self.get_unchecked_mut() };
578+
this.set_msg();
579+
this.header.create_entry()
580+
}
581+
}
582+
583+
impl<S: AsRawFd> SendMsgHeader<S> {
584+
pub fn create_entry(&mut self) -> OpEntry {
585+
opcode::SendMsg::new(Fd(self.fd.as_raw_fd()), &self.msg)
586+
.build()
587+
.into()
588+
}
589+
}
590+
591+
impl<T: IoBuf, C: IoBuf, S: AsRawFd> OpCode for SendMsg<T, C, S> {
592+
fn create_entry(self: Pin<&mut Self>) -> OpEntry {
593+
let this = unsafe { self.get_unchecked_mut() };
594+
this.set_msg();
595+
this.header.create_entry()
596+
}
597+
}
598+
599+
impl<T: IoVectoredBuf, C: IoBuf, S: AsRawFd> OpCode for SendMsgVectored<T, C, S> {
600+
fn create_entry(self: Pin<&mut Self>) -> OpEntry {
601+
let this = unsafe { self.get_unchecked_mut() };
602+
this.set_msg();
603+
this.header.create_entry()
604+
}
605+
}
606+
559607
impl<S: AsRawFd> OpCode for PollOnce<S> {
560608
fn create_entry(self: Pin<&mut Self>) -> OpEntry {
561609
let flags = match self.interest {

compio-driver/src/op.rs

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ use socket2::SockAddr;
1111
#[cfg(windows)]
1212
pub use crate::sys::op::ConnectNamedPipe;
1313
pub use crate::sys::op::{
14-
Accept, Recv, RecvFrom, RecvFromVectored, RecvVectored, Send, SendTo, SendToVectored,
15-
SendVectored,
14+
Accept, Recv, RecvFrom, RecvFromVectored, RecvMsg, RecvMsgVectored, RecvVectored, Send,
15+
SendMsg, SendMsgVectored, SendTo, SendToVectored, SendVectored,
1616
};
1717
#[cfg(unix)]
1818
pub use crate::sys::op::{
@@ -72,6 +72,36 @@ impl<T> RecvResultExt for BufResult<usize, (T, sockaddr_storage, socklen_t)> {
7272
}
7373
}
7474

75+
// FIXME: Using this struct instead of a simple tuple because we can implement
76+
// neither `BufResultExt` on `BufResult<(usize, O), (T, C)>` nor `SetBufInit` on
77+
// `(T, C)`. But it's not elegant. `.map_advanced` call happens in `compio-net`
78+
// so we must expose this struct. There should be better ways to do this.
79+
/// Helper struct for [`RecvMsg`], [`SendMsg`], and vectored variants.
80+
pub struct MsgBuf<T, C> {
81+
/// The buffer for message
82+
pub inner: T,
83+
/// The buffer for ancillary data
84+
pub control: C,
85+
}
86+
87+
impl<T, C> MsgBuf<T, C> {
88+
/// Create [`MsgBuf`].
89+
pub fn new(inner: T, control: C) -> Self {
90+
Self { inner, control }
91+
}
92+
93+
/// Unpack to tuple.
94+
pub fn into_tuple(self) -> (T, C) {
95+
(self.inner, self.control)
96+
}
97+
}
98+
99+
impl<T: SetBufInit, C> SetBufInit for MsgBuf<T, C> {
100+
unsafe fn set_buf_init(&mut self, len: usize) {
101+
self.inner.set_buf_init(len);
102+
}
103+
}
104+
75105
/// Spawn a blocking function in the thread pool.
76106
pub struct Asyncify<F, D> {
77107
pub(crate) f: Option<F>,

0 commit comments

Comments
 (0)