Skip to content

Commit 0345a68

Browse files
committed
feat(net): use WSAEventSelect on Windows for PollFd
1 parent 7a0a6d0 commit 0345a68

File tree

5 files changed

+193
-44
lines changed

5 files changed

+193
-44
lines changed

compio-net/src/poll_fd.rs

Lines changed: 181 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,38 @@
1-
#[cfg(unix)]
2-
use std::os::fd::FromRawFd;
31
#[cfg(windows)]
4-
use std::os::windows::io::{AsRawSocket, FromRawSocket, RawSocket};
2+
use std::os::windows::io::{AsRawSocket, RawSocket};
53
use std::{io, ops::Deref};
64

7-
use compio_buf::{BufResult, IntoInner};
8-
#[cfg(unix)]
9-
use compio_driver::op::{Interest, PollOnce};
5+
use compio_buf::IntoInner;
106
use compio_driver::{AsRawFd, RawFd, SharedFd, ToSharedFd};
7+
#[cfg(windows)]
8+
use windows_sys::Win32::Networking::WinSock::{FD_ACCEPT, FD_CONNECT, FD_READ, FD_WRITE};
9+
#[cfg(unix)]
10+
use {
11+
compio_buf::BufResult,
12+
compio_driver::op::{Interest, PollOnce},
13+
};
1114

1215
/// A wrapper for socket, providing functionalities to wait for readiness.
1316
#[derive(Debug)]
1417
pub struct PollFd<T: AsRawFd> {
1518
inner: SharedFd<T>,
19+
#[cfg(windows)]
20+
event: sys::WSAEvent,
1621
}
1722

1823
impl<T: AsRawFd> PollFd<T> {
1924
/// Create [`PollFd`] without attaching the source. Ready-based sources need
2025
/// not to be attached.
21-
pub fn new(source: T) -> Self {
22-
Self {
23-
inner: SharedFd::new(source),
24-
}
26+
pub fn new(source: T) -> io::Result<Self> {
27+
Self::from_shared_fd(SharedFd::new(source))
2528
}
2629

27-
pub(crate) fn from_shared_fd(inner: SharedFd<T>) -> Self {
28-
Self { inner }
30+
pub(crate) fn from_shared_fd(inner: SharedFd<T>) -> io::Result<Self> {
31+
Ok(Self {
32+
inner,
33+
#[cfg(windows)]
34+
event: sys::WSAEvent::new()?,
35+
})
2936
}
3037
}
3138

@@ -59,6 +66,30 @@ impl<T: AsRawFd + 'static> PollFd<T> {
5966
}
6067
}
6168

69+
#[cfg(windows)]
70+
impl<T: AsRawFd + 'static> PollFd<T> {
71+
/// Wait for accept readiness, before calling `accept`, or after `accept`
72+
/// returns `WouldBlock`.
73+
pub async fn accept_ready(&self) -> io::Result<()> {
74+
self.event.wait(self.to_shared_fd(), FD_ACCEPT).await
75+
}
76+
77+
/// Wait for connect readiness.
78+
pub async fn connect_ready(&self) -> io::Result<()> {
79+
self.event.wait(self.to_shared_fd(), FD_CONNECT).await
80+
}
81+
82+
/// Wait for read readiness.
83+
pub async fn read_ready(&self) -> io::Result<()> {
84+
self.event.wait(self.to_shared_fd(), FD_READ).await
85+
}
86+
87+
/// Wait for write readiness.
88+
pub async fn write_ready(&self) -> io::Result<()> {
89+
self.event.wait(self.to_shared_fd(), FD_WRITE).await
90+
}
91+
}
92+
6293
impl<T: AsRawFd> IntoInner for PollFd<T> {
6394
type Inner = SharedFd<T>;
6495

@@ -73,14 +104,6 @@ impl<T: AsRawFd> ToSharedFd<T> for PollFd<T> {
73104
}
74105
}
75106

76-
impl<T: AsRawFd> Clone for PollFd<T> {
77-
fn clone(&self) -> Self {
78-
Self {
79-
inner: self.inner.clone(),
80-
}
81-
}
82-
}
83-
84107
impl<T: AsRawFd> AsRawFd for PollFd<T> {
85108
fn as_raw_fd(&self) -> RawFd {
86109
self.inner.as_raw_fd()
@@ -94,24 +117,150 @@ impl<T: AsRawFd + AsRawSocket> AsRawSocket for PollFd<T> {
94117
}
95118
}
96119

97-
#[cfg(unix)]
98-
impl<T: AsRawFd + FromRawFd> FromRawFd for PollFd<T> {
99-
unsafe fn from_raw_fd(fd: RawFd) -> Self {
100-
Self::new(FromRawFd::from_raw_fd(fd))
120+
impl<T: AsRawFd> Deref for PollFd<T> {
121+
type Target = T;
122+
123+
fn deref(&self) -> &Self::Target {
124+
&self.inner
101125
}
102126
}
103127

104128
#[cfg(windows)]
105-
impl<T: AsRawFd + FromRawSocket> FromRawSocket for PollFd<T> {
106-
unsafe fn from_raw_socket(sock: RawSocket) -> Self {
107-
Self::new(FromRawSocket::from_raw_socket(sock))
129+
mod sys {
130+
use std::{
131+
io,
132+
os::windows::io::{AsRawHandle, FromRawHandle, OwnedHandle},
133+
pin::Pin,
134+
ptr::null,
135+
sync::atomic::{AtomicI32, AtomicUsize, Ordering},
136+
task::Poll,
137+
};
138+
139+
use compio_buf::{BufResult, IntoInner};
140+
use compio_driver::{syscall, AsRawFd, OpCode, OpType, SharedFd};
141+
use windows_sys::Win32::{
142+
Networking::WinSock::{WSAEnumNetworkEvents, WSAEventSelect, WSANETWORKEVENTS},
143+
System::{Threading::CreateEventW, IO::OVERLAPPED},
144+
};
145+
146+
const EVENT_COUNT: usize = 5;
147+
148+
#[derive(Debug)]
149+
pub struct WSAEvent {
150+
ev_object: SharedFd<OwnedHandle>,
151+
ev_record: [AtomicUsize; EVENT_COUNT],
152+
events: AtomicI32,
108153
}
109-
}
110154

111-
impl<T: AsRawFd> Deref for PollFd<T> {
112-
type Target = T;
155+
impl WSAEvent {
156+
pub fn new() -> io::Result<Self> {
157+
Ok(Self {
158+
ev_object: SharedFd::new(unsafe {
159+
OwnedHandle::from_raw_handle(syscall!(
160+
HANDLE,
161+
CreateEventW(null(), 1, 0, null())
162+
)? as _)
163+
}),
164+
ev_record: Default::default(),
165+
events: AtomicI32::new(0),
166+
})
167+
}
113168

114-
fn deref(&self) -> &Self::Target {
115-
&self.inner
169+
pub async fn wait<T: AsRawFd + 'static>(
170+
&self,
171+
socket: SharedFd<T>,
172+
event: u32,
173+
) -> io::Result<()> {
174+
struct EventGuard<'a> {
175+
wsa_event: &'a WSAEvent,
176+
event: i32,
177+
}
178+
179+
impl Drop for EventGuard<'_> {
180+
fn drop(&mut self) {
181+
let index = (self.event.ilog2() - 1) as usize;
182+
if self.wsa_event.ev_record[index].fetch_sub(1, Ordering::Relaxed) == 1 {
183+
self.wsa_event
184+
.events
185+
.fetch_add(!self.event, Ordering::Relaxed);
186+
}
187+
}
188+
}
189+
190+
let event = event as i32;
191+
let ev_object = self.ev_object.clone();
192+
193+
let index = (event.ilog2() - 1) as usize;
194+
let events = if self.ev_record[index].fetch_add(1, Ordering::Relaxed) == 0 {
195+
self.events.fetch_or(event, Ordering::Relaxed) | event
196+
} else {
197+
self.events.load(Ordering::Relaxed)
198+
};
199+
syscall!(
200+
SOCKET,
201+
WSAEventSelect(
202+
socket.as_raw_fd() as _,
203+
ev_object.as_raw_handle() as _,
204+
events
205+
)
206+
)?;
207+
let _guard = EventGuard {
208+
wsa_event: self,
209+
event,
210+
};
211+
let op = WaitWSAEvent::new(socket, ev_object, index + 1);
212+
let BufResult(res, _) = compio_runtime::submit(op).await;
213+
res?;
214+
Ok(())
215+
}
216+
}
217+
218+
struct WaitWSAEvent<T> {
219+
socket: SharedFd<T>,
220+
ev_object: SharedFd<OwnedHandle>,
221+
index: usize,
222+
}
223+
224+
impl<T> WaitWSAEvent<T> {
225+
pub fn new(socket: SharedFd<T>, ev_object: SharedFd<OwnedHandle>, index: usize) -> Self {
226+
Self {
227+
socket,
228+
ev_object,
229+
index,
230+
}
231+
}
232+
}
233+
234+
impl<T> IntoInner for WaitWSAEvent<T> {
235+
type Inner = SharedFd<OwnedHandle>;
236+
237+
fn into_inner(self) -> Self::Inner {
238+
self.ev_object
239+
}
240+
}
241+
242+
impl<T: AsRawFd> OpCode for WaitWSAEvent<T> {
243+
fn op_type(&self) -> OpType {
244+
OpType::Event(self.ev_object.as_raw_fd())
245+
}
246+
247+
unsafe fn operate(self: Pin<&mut Self>, _optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
248+
let mut events: WSANETWORKEVENTS = unsafe { std::mem::zeroed() };
249+
events.lNetworkEvents = 10;
250+
syscall!(
251+
SOCKET,
252+
WSAEnumNetworkEvents(
253+
self.socket.as_raw_fd() as _,
254+
self.ev_object.as_raw_handle() as _,
255+
&mut events
256+
)
257+
)?;
258+
let res = events.iErrorCode[self.index + 1];
259+
if res == 0 {
260+
Poll::Ready(Ok(0))
261+
} else {
262+
Poll::Ready(Err(io::Error::from_raw_os_error(res)))
263+
}
264+
}
116265
}
117266
}

compio-net/src/socket.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ impl Socket {
3636
self.socket.local_addr()
3737
}
3838

39-
pub fn to_poll_fd(&self) -> PollFd<Socket2> {
39+
pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
4040
PollFd::from_shared_fd(self.to_shared_fd())
4141
}
4242

43-
pub fn into_poll_fd(self) -> PollFd<Socket2> {
43+
pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
4444
PollFd::from_shared_fd(self.socket.into_inner())
4545
}
4646

compio-net/src/tcp.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,12 +207,12 @@ impl TcpStream {
207207
}
208208

209209
/// Create [`PollFd`] from inner socket.
210-
pub fn to_poll_fd(&self) -> PollFd<Socket2> {
210+
pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
211211
self.inner.to_poll_fd()
212212
}
213213

214214
/// Create [`PollFd`] from inner socket.
215-
pub fn into_poll_fd(self) -> PollFd<Socket2> {
215+
pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
216216
self.inner.into_poll_fd()
217217
}
218218
}

compio-net/src/unix.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,12 +189,12 @@ impl UnixStream {
189189
}
190190

191191
/// Create [`PollFd`] from inner socket.
192-
pub fn to_poll_fd(&self) -> PollFd<Socket2> {
192+
pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
193193
self.inner.to_poll_fd()
194194
}
195195

196196
/// Create [`PollFd`] from inner socket.
197-
pub fn into_poll_fd(self) -> PollFd<Socket2> {
197+
pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
198198
self.inner.into_poll_fd()
199199
}
200200
}

compio-net/tests/poll.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,21 @@ async fn poll_connect() {
2626
.unwrap();
2727
listener.listen(4).unwrap();
2828
let addr = listener.local_addr().unwrap();
29-
let listener = PollFd::new(listener);
29+
let listener = PollFd::new(listener).unwrap();
3030
let accept_task = async {
3131
loop {
3232
listener.accept_ready().await.unwrap();
3333
match listener.accept() {
3434
Ok(res) => break res,
3535
Err(e) if e.kind() == io::ErrorKind::WouldBlock => continue,
36-
Err(e) => Err(e).unwrap(),
36+
Err(e) => panic!("{:?}", e),
3737
}
3838
}
3939
};
4040

4141
let client = Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP)).unwrap();
4242
client.set_nonblocking(true).unwrap();
43-
let client = PollFd::new(client);
43+
let client = PollFd::new(client).unwrap();
4444
let res = client.connect(&addr);
4545
let tx = if let Err(e) = res {
4646
assert!(is_would_block(&e));
@@ -53,14 +53,14 @@ async fn poll_connect() {
5353
};
5454

5555
tx.set_nonblocking(true).unwrap();
56-
let tx = PollFd::new(tx);
56+
let tx = PollFd::new(tx).unwrap();
5757

5858
let send_task = async {
5959
loop {
6060
match tx.send(b"Hello world!") {
6161
Ok(res) => break res,
6262
Err(e) if is_would_block(&e) => {}
63-
Err(e) => Err(e).unwrap(),
63+
Err(e) => panic!("{:?}", e),
6464
}
6565
tx.write_ready().await.unwrap();
6666
}
@@ -75,7 +75,7 @@ async fn poll_connect() {
7575
break res;
7676
}
7777
Err(e) if is_would_block(&e) => {}
78-
Err(e) => Err(e).unwrap(),
78+
Err(e) => panic!("{:?}", e),
7979
}
8080
client.read_ready().await.unwrap();
8181
}

0 commit comments

Comments
 (0)