Skip to content

Commit f7e5923

Browse files
committed
fix(net,windows): use select correctly
1 parent 1840254 commit f7e5923

File tree

1 file changed

+30
-18
lines changed

1 file changed

+30
-18
lines changed

compio-net/src/poll_fd/windows.rs

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ use std::{
1111
use compio_buf::{BufResult, IntoInner};
1212
use compio_driver::{syscall, AsRawFd, OpCode, OpType, RawFd, SharedFd, ToSharedFd};
1313
use windows_sys::Win32::{
14+
Foundation::ERROR_IO_PENDING,
1415
Networking::WinSock::{
15-
WSAEnumNetworkEvents, WSAEventSelect, FD_ACCEPT, FD_CONNECT, FD_READ, FD_WRITE,
16-
WSANETWORKEVENTS,
16+
WSAEnumNetworkEvents, WSAEventSelect, FD_ACCEPT, FD_CONNECT, FD_MAX_EVENTS, FD_READ,
17+
FD_WRITE, WSANETWORKEVENTS,
1718
},
1819
System::{Threading::CreateEventW, IO::OVERLAPPED},
1920
};
@@ -85,12 +86,10 @@ impl<T: AsRawFd> Deref for PollFd<T> {
8586
}
8687
}
8788

88-
const EVENT_COUNT: usize = 5;
89-
9089
#[derive(Debug)]
9190
pub struct WSAEvent {
9291
ev_object: SharedFd<OwnedHandle>,
93-
ev_record: [AtomicUsize; EVENT_COUNT],
92+
ev_record: [AtomicUsize; FD_MAX_EVENTS as usize],
9493
events: AtomicI32,
9594
}
9695

@@ -109,7 +108,7 @@ impl WSAEvent {
109108

110109
pub async fn wait<T: AsRawFd + 'static>(
111110
&self,
112-
socket: SharedFd<T>,
111+
mut socket: SharedFd<T>,
113112
event: u32,
114113
) -> io::Result<()> {
115114
struct EventGuard<'a> {
@@ -119,7 +118,7 @@ impl WSAEvent {
119118

120119
impl Drop for EventGuard<'_> {
121120
fn drop(&mut self) {
122-
let index = (self.event.ilog2() - 1) as usize;
121+
let index = self.event.ilog2() as usize;
123122
if self.wsa_event.ev_record[index].fetch_sub(1, Ordering::Relaxed) == 1 {
124123
self.wsa_event
125124
.events
@@ -129,9 +128,9 @@ impl WSAEvent {
129128
}
130129

131130
let event = event as i32;
132-
let ev_object = self.ev_object.clone();
131+
let mut ev_object = self.ev_object.clone();
133132

134-
let index = (event.ilog2() - 1) as usize;
133+
let index = event.ilog2() as usize;
135134
let events = if self.ev_record[index].fetch_add(1, Ordering::Relaxed) == 0 {
136135
self.events.fetch_or(event, Ordering::Relaxed) | event
137136
} else {
@@ -149,25 +148,35 @@ impl WSAEvent {
149148
wsa_event: self,
150149
event,
151150
};
152-
let op = WaitWSAEvent::new(socket, ev_object, index + 1);
153-
let BufResult(res, _) = compio_runtime::submit(op).await;
154-
res?;
155-
Ok(())
151+
loop {
152+
let op = WaitWSAEvent::new(socket, ev_object, event);
153+
let BufResult(res, op) = compio_runtime::submit(op).await;
154+
WaitWSAEvent {
155+
socket,
156+
ev_object,
157+
..
158+
} = op;
159+
match res {
160+
Ok(_) => break Ok(()),
161+
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
162+
Err(e) => break Err(e),
163+
}
164+
}
156165
}
157166
}
158167

159168
struct WaitWSAEvent<T> {
160169
socket: SharedFd<T>,
161170
ev_object: SharedFd<OwnedHandle>,
162-
index: usize,
171+
event: i32,
163172
}
164173

165174
impl<T> WaitWSAEvent<T> {
166-
pub fn new(socket: SharedFd<T>, ev_object: SharedFd<OwnedHandle>, index: usize) -> Self {
175+
pub fn new(socket: SharedFd<T>, ev_object: SharedFd<OwnedHandle>, event: i32) -> Self {
167176
Self {
168177
socket,
169178
ev_object,
170-
index,
179+
event,
171180
}
172181
}
173182
}
@@ -187,7 +196,6 @@ impl<T: AsRawFd> OpCode for WaitWSAEvent<T> {
187196

188197
unsafe fn operate(self: Pin<&mut Self>, _optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
189198
let mut events: WSANETWORKEVENTS = unsafe { std::mem::zeroed() };
190-
events.lNetworkEvents = 10;
191199
syscall!(
192200
SOCKET,
193201
WSAEnumNetworkEvents(
@@ -196,7 +204,11 @@ impl<T: AsRawFd> OpCode for WaitWSAEvent<T> {
196204
&mut events
197205
)
198206
)?;
199-
let res = events.iErrorCode[self.index + 1];
207+
let res = if (events.lNetworkEvents & self.event) != 0 {
208+
events.iErrorCode[self.event.ilog2() as usize]
209+
} else {
210+
ERROR_IO_PENDING as _
211+
};
200212
if res == 0 {
201213
Poll::Ready(Ok(0))
202214
} else {

0 commit comments

Comments
 (0)