Skip to content

Commit 3780d33

Browse files
authored
Merge pull request #265 from Berrysoft/dev/poll-fd
2 parents dec8b68 + f7e5923 commit 3780d33

File tree

8 files changed

+504
-4
lines changed

8 files changed

+504
-4
lines changed

compio-net/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
66
#![warn(missing_docs)]
77

8+
mod poll_fd;
89
mod resolve;
910
mod socket;
1011
pub(crate) mod split;
1112
mod tcp;
1213
mod udp;
1314
mod unix;
1415

16+
pub use poll_fd::*;
1517
pub use resolve::ToSocketAddrsAsync;
1618
pub(crate) use resolve::{each_addr, first_addr_buf};
1719
pub(crate) use socket::*;

compio-net/src/poll_fd/mod.rs

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
cfg_if::cfg_if! {
2+
if #[cfg(windows)] {
3+
#[path = "windows.rs"]
4+
mod sys;
5+
} else if #[cfg(unix)] {
6+
#[path = "unix.rs"]
7+
mod sys;
8+
}
9+
}
10+
11+
#[cfg(windows)]
12+
use std::os::windows::io::{AsRawSocket, RawSocket};
13+
use std::{io, ops::Deref};
14+
15+
use compio_buf::IntoInner;
16+
use compio_driver::{AsRawFd, RawFd, SharedFd, ToSharedFd};
17+
18+
/// A wrapper for socket, providing functionalities to wait for readiness.
19+
#[derive(Debug)]
20+
pub struct PollFd<T: AsRawFd>(sys::PollFd<T>);
21+
22+
impl<T: AsRawFd> PollFd<T> {
23+
/// Create [`PollFd`] without attaching the source. Ready-based sources need
24+
/// not to be attached.
25+
pub fn new(source: T) -> io::Result<Self> {
26+
Self::from_shared_fd(SharedFd::new(source))
27+
}
28+
29+
pub(crate) fn from_shared_fd(inner: SharedFd<T>) -> io::Result<Self> {
30+
Ok(Self(sys::PollFd::new(inner)?))
31+
}
32+
}
33+
34+
impl<T: AsRawFd + 'static> PollFd<T> {
35+
/// Wait for accept readiness, before calling `accept`, or after `accept`
36+
/// returns `WouldBlock`.
37+
pub async fn accept_ready(&self) -> io::Result<()> {
38+
self.0.accept_ready().await
39+
}
40+
41+
/// Wait for connect readiness.
42+
pub async fn connect_ready(&self) -> io::Result<()> {
43+
self.0.connect_ready().await
44+
}
45+
46+
/// Wait for read readiness.
47+
pub async fn read_ready(&self) -> io::Result<()> {
48+
self.0.read_ready().await
49+
}
50+
51+
/// Wait for write readiness.
52+
pub async fn write_ready(&self) -> io::Result<()> {
53+
self.0.write_ready().await
54+
}
55+
}
56+
57+
impl<T: AsRawFd> IntoInner for PollFd<T> {
58+
type Inner = SharedFd<T>;
59+
60+
fn into_inner(self) -> Self::Inner {
61+
self.0.into_inner()
62+
}
63+
}
64+
65+
impl<T: AsRawFd> ToSharedFd<T> for PollFd<T> {
66+
fn to_shared_fd(&self) -> SharedFd<T> {
67+
self.0.to_shared_fd()
68+
}
69+
}
70+
71+
impl<T: AsRawFd> AsRawFd for PollFd<T> {
72+
fn as_raw_fd(&self) -> RawFd {
73+
self.0.as_raw_fd()
74+
}
75+
}
76+
77+
#[cfg(windows)]
78+
impl<T: AsRawFd + AsRawSocket> AsRawSocket for PollFd<T> {
79+
fn as_raw_socket(&self) -> RawSocket {
80+
self.0.as_raw_socket()
81+
}
82+
}
83+
84+
impl<T: AsRawFd> Deref for PollFd<T> {
85+
type Target = T;
86+
87+
fn deref(&self) -> &Self::Target {
88+
&self.0
89+
}
90+
}

compio-net/src/poll_fd/unix.rs

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
use std::{io, ops::Deref};
2+
3+
use compio_buf::{BufResult, IntoInner};
4+
use compio_driver::{
5+
op::{Interest, PollOnce},
6+
AsRawFd, RawFd, SharedFd, ToSharedFd,
7+
};
8+
9+
#[derive(Debug)]
10+
pub struct PollFd<T: AsRawFd> {
11+
inner: SharedFd<T>,
12+
}
13+
14+
impl<T: AsRawFd> PollFd<T> {
15+
pub fn new(inner: SharedFd<T>) -> io::Result<Self> {
16+
Ok(Self { inner })
17+
}
18+
}
19+
20+
impl<T: AsRawFd + 'static> PollFd<T> {
21+
pub async fn accept_ready(&self) -> io::Result<()> {
22+
self.read_ready().await
23+
}
24+
25+
pub async fn connect_ready(&self) -> io::Result<()> {
26+
self.write_ready().await
27+
}
28+
29+
pub async fn read_ready(&self) -> io::Result<()> {
30+
let op = PollOnce::new(self.to_shared_fd(), Interest::Readable);
31+
let BufResult(res, _) = compio_runtime::submit(op).await;
32+
res?;
33+
Ok(())
34+
}
35+
36+
pub async fn write_ready(&self) -> io::Result<()> {
37+
let op = PollOnce::new(self.to_shared_fd(), Interest::Writable);
38+
let BufResult(res, _) = compio_runtime::submit(op).await;
39+
res?;
40+
Ok(())
41+
}
42+
}
43+
44+
impl<T: AsRawFd> IntoInner for PollFd<T> {
45+
type Inner = SharedFd<T>;
46+
47+
fn into_inner(self) -> Self::Inner {
48+
self.inner
49+
}
50+
}
51+
52+
impl<T: AsRawFd> ToSharedFd<T> for PollFd<T> {
53+
fn to_shared_fd(&self) -> SharedFd<T> {
54+
self.inner.clone()
55+
}
56+
}
57+
58+
impl<T: AsRawFd> AsRawFd for PollFd<T> {
59+
fn as_raw_fd(&self) -> RawFd {
60+
self.inner.as_raw_fd()
61+
}
62+
}
63+
64+
impl<T: AsRawFd> Deref for PollFd<T> {
65+
type Target = T;
66+
67+
fn deref(&self) -> &Self::Target {
68+
&self.inner
69+
}
70+
}

compio-net/src/poll_fd/windows.rs

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
use std::{
2+
io,
3+
ops::Deref,
4+
os::windows::io::{AsRawHandle, AsRawSocket, FromRawHandle, OwnedHandle, RawSocket},
5+
pin::Pin,
6+
ptr::null,
7+
sync::atomic::{AtomicI32, AtomicUsize, Ordering},
8+
task::Poll,
9+
};
10+
11+
use compio_buf::{BufResult, IntoInner};
12+
use compio_driver::{syscall, AsRawFd, OpCode, OpType, RawFd, SharedFd, ToSharedFd};
13+
use windows_sys::Win32::{
14+
Foundation::ERROR_IO_PENDING,
15+
Networking::WinSock::{
16+
WSAEnumNetworkEvents, WSAEventSelect, FD_ACCEPT, FD_CONNECT, FD_MAX_EVENTS, FD_READ,
17+
FD_WRITE, WSANETWORKEVENTS,
18+
},
19+
System::{Threading::CreateEventW, IO::OVERLAPPED},
20+
};
21+
22+
#[derive(Debug)]
23+
pub struct PollFd<T: AsRawFd> {
24+
inner: SharedFd<T>,
25+
event: WSAEvent,
26+
}
27+
28+
impl<T: AsRawFd> PollFd<T> {
29+
pub fn new(inner: SharedFd<T>) -> io::Result<Self> {
30+
Ok(Self {
31+
inner,
32+
event: WSAEvent::new()?,
33+
})
34+
}
35+
}
36+
37+
impl<T: AsRawFd + 'static> PollFd<T> {
38+
pub async fn accept_ready(&self) -> io::Result<()> {
39+
self.event.wait(self.to_shared_fd(), FD_ACCEPT).await
40+
}
41+
42+
pub async fn connect_ready(&self) -> io::Result<()> {
43+
self.event.wait(self.to_shared_fd(), FD_CONNECT).await
44+
}
45+
46+
pub async fn read_ready(&self) -> io::Result<()> {
47+
self.event.wait(self.to_shared_fd(), FD_READ).await
48+
}
49+
50+
pub async fn write_ready(&self) -> io::Result<()> {
51+
self.event.wait(self.to_shared_fd(), FD_WRITE).await
52+
}
53+
}
54+
55+
impl<T: AsRawFd> IntoInner for PollFd<T> {
56+
type Inner = SharedFd<T>;
57+
58+
fn into_inner(self) -> Self::Inner {
59+
self.inner
60+
}
61+
}
62+
63+
impl<T: AsRawFd> ToSharedFd<T> for PollFd<T> {
64+
fn to_shared_fd(&self) -> SharedFd<T> {
65+
self.inner.clone()
66+
}
67+
}
68+
69+
impl<T: AsRawFd> AsRawFd for PollFd<T> {
70+
fn as_raw_fd(&self) -> RawFd {
71+
self.inner.as_raw_fd()
72+
}
73+
}
74+
75+
impl<T: AsRawFd + AsRawSocket> AsRawSocket for PollFd<T> {
76+
fn as_raw_socket(&self) -> RawSocket {
77+
self.inner.as_raw_socket()
78+
}
79+
}
80+
81+
impl<T: AsRawFd> Deref for PollFd<T> {
82+
type Target = T;
83+
84+
fn deref(&self) -> &Self::Target {
85+
&self.inner
86+
}
87+
}
88+
89+
#[derive(Debug)]
90+
pub struct WSAEvent {
91+
ev_object: SharedFd<OwnedHandle>,
92+
ev_record: [AtomicUsize; FD_MAX_EVENTS as usize],
93+
events: AtomicI32,
94+
}
95+
96+
impl WSAEvent {
97+
pub fn new() -> io::Result<Self> {
98+
Ok(Self {
99+
ev_object: SharedFd::new(unsafe {
100+
OwnedHandle::from_raw_handle(
101+
syscall!(HANDLE, CreateEventW(null(), 1, 0, null()))? as _
102+
)
103+
}),
104+
ev_record: Default::default(),
105+
events: AtomicI32::new(0),
106+
})
107+
}
108+
109+
pub async fn wait<T: AsRawFd + 'static>(
110+
&self,
111+
mut socket: SharedFd<T>,
112+
event: u32,
113+
) -> io::Result<()> {
114+
struct EventGuard<'a> {
115+
wsa_event: &'a WSAEvent,
116+
event: i32,
117+
}
118+
119+
impl Drop for EventGuard<'_> {
120+
fn drop(&mut self) {
121+
let index = self.event.ilog2() as usize;
122+
if self.wsa_event.ev_record[index].fetch_sub(1, Ordering::Relaxed) == 1 {
123+
self.wsa_event
124+
.events
125+
.fetch_add(!self.event, Ordering::Relaxed);
126+
}
127+
}
128+
}
129+
130+
let event = event as i32;
131+
let mut ev_object = self.ev_object.clone();
132+
133+
let index = event.ilog2() as usize;
134+
let events = if self.ev_record[index].fetch_add(1, Ordering::Relaxed) == 0 {
135+
self.events.fetch_or(event, Ordering::Relaxed) | event
136+
} else {
137+
self.events.load(Ordering::Relaxed)
138+
};
139+
syscall!(
140+
SOCKET,
141+
WSAEventSelect(
142+
socket.as_raw_fd() as _,
143+
ev_object.as_raw_handle() as _,
144+
events
145+
)
146+
)?;
147+
let _guard = EventGuard {
148+
wsa_event: self,
149+
event,
150+
};
151+
loop {
152+
let op = WaitWSAEvent::new(socket, ev_object, event);
153+
let BufResult(res, op) = compio_runtime::submit(op).await;
154+
WaitWSAEvent {
155+
socket,
156+
ev_object,
157+
..
158+
} = op;
159+
match res {
160+
Ok(_) => break Ok(()),
161+
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
162+
Err(e) => break Err(e),
163+
}
164+
}
165+
}
166+
}
167+
168+
struct WaitWSAEvent<T> {
169+
socket: SharedFd<T>,
170+
ev_object: SharedFd<OwnedHandle>,
171+
event: i32,
172+
}
173+
174+
impl<T> WaitWSAEvent<T> {
175+
pub fn new(socket: SharedFd<T>, ev_object: SharedFd<OwnedHandle>, event: i32) -> Self {
176+
Self {
177+
socket,
178+
ev_object,
179+
event,
180+
}
181+
}
182+
}
183+
184+
impl<T> IntoInner for WaitWSAEvent<T> {
185+
type Inner = SharedFd<OwnedHandle>;
186+
187+
fn into_inner(self) -> Self::Inner {
188+
self.ev_object
189+
}
190+
}
191+
192+
impl<T: AsRawFd> OpCode for WaitWSAEvent<T> {
193+
fn op_type(&self) -> OpType {
194+
OpType::Event(self.ev_object.as_raw_fd())
195+
}
196+
197+
unsafe fn operate(self: Pin<&mut Self>, _optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
198+
let mut events: WSANETWORKEVENTS = unsafe { std::mem::zeroed() };
199+
syscall!(
200+
SOCKET,
201+
WSAEnumNetworkEvents(
202+
self.socket.as_raw_fd() as _,
203+
self.ev_object.as_raw_handle() as _,
204+
&mut events
205+
)
206+
)?;
207+
let res = if (events.lNetworkEvents & self.event) != 0 {
208+
events.iErrorCode[self.event.ilog2() as usize]
209+
} else {
210+
ERROR_IO_PENDING as _
211+
};
212+
if res == 0 {
213+
Poll::Ready(Ok(0))
214+
} else {
215+
Poll::Ready(Err(io::Error::from_raw_os_error(res)))
216+
}
217+
}
218+
}

0 commit comments

Comments
 (0)