Skip to content

Commit 7a0a6d0

Browse files
committed
feat(net): add PollFd for waiting readiness
1 parent 9939253 commit 7a0a6d0

File tree

6 files changed

+243
-4
lines changed

6 files changed

+243
-4
lines changed

compio-net/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
66
#![warn(missing_docs)]
77

8+
mod poll_fd;
89
mod resolve;
910
mod socket;
1011
pub(crate) mod split;
1112
mod tcp;
1213
mod udp;
1314
mod unix;
1415

16+
pub use poll_fd::*;
1517
pub use resolve::ToSocketAddrsAsync;
1618
pub(crate) use resolve::{each_addr, first_addr_buf};
1719
pub(crate) use socket::*;

compio-net/src/poll_fd.rs

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
#[cfg(unix)]
2+
use std::os::fd::FromRawFd;
3+
#[cfg(windows)]
4+
use std::os::windows::io::{AsRawSocket, FromRawSocket, RawSocket};
5+
use std::{io, ops::Deref};
6+
7+
use compio_buf::{BufResult, IntoInner};
8+
#[cfg(unix)]
9+
use compio_driver::op::{Interest, PollOnce};
10+
use compio_driver::{AsRawFd, RawFd, SharedFd, ToSharedFd};
11+
12+
/// A wrapper for socket, providing functionalities to wait for readiness.
13+
#[derive(Debug)]
14+
pub struct PollFd<T: AsRawFd> {
15+
inner: SharedFd<T>,
16+
}
17+
18+
impl<T: AsRawFd> PollFd<T> {
19+
/// Create [`PollFd`] without attaching the source. Ready-based sources need
20+
/// not to be attached.
21+
pub fn new(source: T) -> Self {
22+
Self {
23+
inner: SharedFd::new(source),
24+
}
25+
}
26+
27+
pub(crate) fn from_shared_fd(inner: SharedFd<T>) -> Self {
28+
Self { inner }
29+
}
30+
}
31+
32+
#[cfg(unix)]
33+
impl<T: AsRawFd + 'static> PollFd<T> {
34+
/// Wait for accept readiness, before calling `accept`, or after `accept`
35+
/// returns `WouldBlock`.
36+
pub async fn accept_ready(&self) -> io::Result<()> {
37+
self.read_ready().await
38+
}
39+
40+
/// Wait for connect readiness.
41+
pub async fn connect_ready(&self) -> io::Result<()> {
42+
self.write_ready().await
43+
}
44+
45+
/// Wait for read readiness.
46+
pub async fn read_ready(&self) -> io::Result<()> {
47+
let op = PollOnce::new(self.to_shared_fd(), Interest::Readable);
48+
let BufResult(res, _) = compio_runtime::submit(op).await;
49+
res?;
50+
Ok(())
51+
}
52+
53+
/// Wait for write readiness.
54+
pub async fn write_ready(&self) -> io::Result<()> {
55+
let op = PollOnce::new(self.to_shared_fd(), Interest::Writable);
56+
let BufResult(res, _) = compio_runtime::submit(op).await;
57+
res?;
58+
Ok(())
59+
}
60+
}
61+
62+
impl<T: AsRawFd> IntoInner for PollFd<T> {
63+
type Inner = SharedFd<T>;
64+
65+
fn into_inner(self) -> Self::Inner {
66+
self.inner
67+
}
68+
}
69+
70+
impl<T: AsRawFd> ToSharedFd<T> for PollFd<T> {
71+
fn to_shared_fd(&self) -> SharedFd<T> {
72+
self.inner.clone()
73+
}
74+
}
75+
76+
impl<T: AsRawFd> Clone for PollFd<T> {
77+
fn clone(&self) -> Self {
78+
Self {
79+
inner: self.inner.clone(),
80+
}
81+
}
82+
}
83+
84+
impl<T: AsRawFd> AsRawFd for PollFd<T> {
85+
fn as_raw_fd(&self) -> RawFd {
86+
self.inner.as_raw_fd()
87+
}
88+
}
89+
90+
#[cfg(windows)]
91+
impl<T: AsRawFd + AsRawSocket> AsRawSocket for PollFd<T> {
92+
fn as_raw_socket(&self) -> RawSocket {
93+
self.inner.as_raw_socket()
94+
}
95+
}
96+
97+
#[cfg(unix)]
98+
impl<T: AsRawFd + FromRawFd> FromRawFd for PollFd<T> {
99+
unsafe fn from_raw_fd(fd: RawFd) -> Self {
100+
Self::new(FromRawFd::from_raw_fd(fd))
101+
}
102+
}
103+
104+
#[cfg(windows)]
105+
impl<T: AsRawFd + FromRawSocket> FromRawSocket for PollFd<T> {
106+
unsafe fn from_raw_socket(sock: RawSocket) -> Self {
107+
Self::new(FromRawSocket::from_raw_socket(sock))
108+
}
109+
}
110+
111+
impl<T: AsRawFd> Deref for PollFd<T> {
112+
type Target = T;
113+
114+
fn deref(&self) -> &Self::Target {
115+
&self.inner
116+
}
117+
}

compio-net/src/socket.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ use compio_driver::{
1414
use compio_runtime::Attacher;
1515
use socket2::{Domain, Protocol, SockAddr, Socket as Socket2, Type};
1616

17+
use crate::PollFd;
18+
1719
#[derive(Debug, Clone)]
1820
pub struct Socket {
1921
socket: Attacher<Socket2>,
@@ -34,6 +36,14 @@ impl Socket {
3436
self.socket.local_addr()
3537
}
3638

39+
pub fn to_poll_fd(&self) -> PollFd<Socket2> {
40+
PollFd::from_shared_fd(self.to_shared_fd())
41+
}
42+
43+
pub fn into_poll_fd(self) -> PollFd<Socket2> {
44+
PollFd::from_shared_fd(self.socket.into_inner())
45+
}
46+
3747
#[cfg(windows)]
3848
pub async fn new(domain: Domain, ty: Type, protocol: Option<Protocol>) -> io::Result<Self> {
3949
use std::panic::resume_unwind;

compio-net/src/tcp.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@ use std::{future::Future, io, net::SocketAddr};
33
use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
44
use compio_driver::impl_raw_fd;
55
use compio_io::{AsyncRead, AsyncWrite};
6-
use socket2::{Protocol, SockAddr, Type};
6+
use socket2::{Protocol, SockAddr, Socket as Socket2, Type};
77

8-
use crate::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, ToSocketAddrsAsync, WriteHalf};
8+
use crate::{
9+
OwnedReadHalf, OwnedWriteHalf, PollFd, ReadHalf, Socket, ToSocketAddrsAsync, WriteHalf,
10+
};
911

1012
/// A TCP socket server, listening for connections.
1113
///
@@ -203,6 +205,16 @@ impl TcpStream {
203205
pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
204206
crate::into_split(self)
205207
}
208+
209+
/// Create [`PollFd`] from inner socket.
210+
pub fn to_poll_fd(&self) -> PollFd<Socket2> {
211+
self.inner.to_poll_fd()
212+
}
213+
214+
/// Create [`PollFd`] from inner socket.
215+
pub fn into_poll_fd(self) -> PollFd<Socket2> {
216+
self.inner.into_poll_fd()
217+
}
206218
}
207219

208220
impl AsyncRead for TcpStream {

compio-net/src/unix.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ use std::{future::Future, io, path::Path};
33
use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
44
use compio_driver::impl_raw_fd;
55
use compio_io::{AsyncRead, AsyncWrite};
6-
use socket2::{SockAddr, Type};
6+
use socket2::{SockAddr, Socket as Socket2, Type};
77

8-
use crate::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, WriteHalf};
8+
use crate::{OwnedReadHalf, OwnedWriteHalf, PollFd, ReadHalf, Socket, WriteHalf};
99

1010
/// A Unix socket server, listening for connections.
1111
///
@@ -187,6 +187,16 @@ impl UnixStream {
187187
pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
188188
crate::into_split(self)
189189
}
190+
191+
/// Create [`PollFd`] from inner socket.
192+
pub fn to_poll_fd(&self) -> PollFd<Socket2> {
193+
self.inner.to_poll_fd()
194+
}
195+
196+
/// Create [`PollFd`] from inner socket.
197+
pub fn into_poll_fd(self) -> PollFd<Socket2> {
198+
self.inner.into_poll_fd()
199+
}
190200
}
191201

192202
impl AsyncRead for UnixStream {

compio-net/tests/poll.rs

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
use std::{
2+
io,
3+
net::{Ipv4Addr, SocketAddrV4},
4+
};
5+
6+
use compio_net::PollFd;
7+
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
8+
9+
fn is_would_block(e: &io::Error) -> bool {
10+
#[cfg(unix)]
11+
{
12+
e.kind() == io::ErrorKind::WouldBlock || e.raw_os_error() == Some(libc::EINPROGRESS)
13+
}
14+
#[cfg(not(unix))]
15+
{
16+
e.kind() == io::ErrorKind::WouldBlock
17+
}
18+
}
19+
20+
#[compio_macros::test]
21+
async fn poll_connect() {
22+
let listener = Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP)).unwrap();
23+
listener.set_nonblocking(true).unwrap();
24+
listener
25+
.bind(&SockAddr::from(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)))
26+
.unwrap();
27+
listener.listen(4).unwrap();
28+
let addr = listener.local_addr().unwrap();
29+
let listener = PollFd::new(listener);
30+
let accept_task = async {
31+
loop {
32+
listener.accept_ready().await.unwrap();
33+
match listener.accept() {
34+
Ok(res) => break res,
35+
Err(e) if e.kind() == io::ErrorKind::WouldBlock => continue,
36+
Err(e) => Err(e).unwrap(),
37+
}
38+
}
39+
};
40+
41+
let client = Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP)).unwrap();
42+
client.set_nonblocking(true).unwrap();
43+
let client = PollFd::new(client);
44+
let res = client.connect(&addr);
45+
let tx = if let Err(e) = res {
46+
assert!(is_would_block(&e));
47+
let (tx, _) = accept_task.await;
48+
tx
49+
} else {
50+
let ((tx, _), res) = futures_util::join!(accept_task, client.connect_ready());
51+
res.unwrap();
52+
tx
53+
};
54+
55+
tx.set_nonblocking(true).unwrap();
56+
let tx = PollFd::new(tx);
57+
58+
let send_task = async {
59+
loop {
60+
match tx.send(b"Hello world!") {
61+
Ok(res) => break res,
62+
Err(e) if is_would_block(&e) => {}
63+
Err(e) => Err(e).unwrap(),
64+
}
65+
tx.write_ready().await.unwrap();
66+
}
67+
};
68+
69+
let mut buffer = Vec::with_capacity(12);
70+
let recv_task = async {
71+
loop {
72+
match client.recv(buffer.spare_capacity_mut()) {
73+
Ok(res) => {
74+
unsafe { buffer.set_len(res) };
75+
break res;
76+
}
77+
Err(e) if is_would_block(&e) => {}
78+
Err(e) => Err(e).unwrap(),
79+
}
80+
client.read_ready().await.unwrap();
81+
}
82+
};
83+
84+
let (write, read) = futures_util::join!(send_task, recv_task);
85+
assert_eq!(write, 12);
86+
assert_eq!(read, 12);
87+
assert_eq!(buffer, b"Hello world!");
88+
}

0 commit comments

Comments
 (0)