Skip to content

Commit 925460c

Browse files
authored
Merge pull request #230 from Berrysoft/refactor/shared-fd
refactor(driver): add SharedFd & OwnedFd and require them in operations
2 parents b09746d + e0edbac commit 925460c

File tree

36 files changed

+732
-544
lines changed

36 files changed

+732
-544
lines changed

compio-driver/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ compio-log = { workspace = true }
3535
# Utils
3636
cfg-if = { workspace = true }
3737
crossbeam-channel = { workspace = true }
38+
futures-util = { workspace = true }
3839
slab = { workspace = true }
3940
socket2 = { workspace = true }
4041

compio-driver/src/fd.rs

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
#[cfg(unix)]
2+
use std::os::fd::FromRawFd;
3+
#[cfg(windows)]
4+
use std::os::windows::io::{
5+
FromRawHandle, FromRawSocket, OwnedHandle, OwnedSocket, RawHandle, RawSocket,
6+
};
7+
use std::{
8+
future::{poll_fn, Future},
9+
mem::ManuallyDrop,
10+
panic::RefUnwindSafe,
11+
sync::{
12+
atomic::{AtomicBool, Ordering},
13+
Arc,
14+
},
15+
task::Poll,
16+
};
17+
18+
use futures_util::task::AtomicWaker;
19+
20+
use crate::{AsRawFd, OwnedFd, RawFd};
21+
22+
#[derive(Debug)]
23+
struct Inner {
24+
fd: OwnedFd,
25+
// whether there is a future waiting
26+
waits: AtomicBool,
27+
waker: AtomicWaker,
28+
}
29+
30+
impl RefUnwindSafe for Inner {}
31+
32+
/// A shared fd. It is passed to the operations to make sure the fd won't be
33+
/// closed before the operations complete.
34+
#[derive(Debug, Clone)]
35+
pub struct SharedFd(Arc<Inner>);
36+
37+
impl SharedFd {
38+
/// Create the shared fd from an owned fd.
39+
pub fn new(fd: impl Into<OwnedFd>) -> Self {
40+
Self(Arc::new(Inner {
41+
fd: fd.into(),
42+
waits: AtomicBool::new(false),
43+
waker: AtomicWaker::new(),
44+
}))
45+
}
46+
47+
/// Try to take the inner owned fd.
48+
pub fn try_unwrap(self) -> Result<OwnedFd, Self> {
49+
let this = ManuallyDrop::new(self);
50+
if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
51+
Ok(fd)
52+
} else {
53+
Err(ManuallyDrop::into_inner(this))
54+
}
55+
}
56+
57+
// SAFETY: if `Some` is returned, the method should not be called again.
58+
unsafe fn try_unwrap_inner(this: &ManuallyDrop<Self>) -> Option<OwnedFd> {
59+
let ptr = ManuallyDrop::new(std::ptr::read(&this.0));
60+
// The ptr is duplicated without increasing the strong count, should forget.
61+
match Arc::try_unwrap(ManuallyDrop::into_inner(ptr)) {
62+
Ok(inner) => Some(inner.fd),
63+
Err(ptr) => {
64+
std::mem::forget(ptr);
65+
None
66+
}
67+
}
68+
}
69+
70+
/// Wait and take the inner owned fd.
71+
pub fn take(self) -> impl Future<Output = Option<OwnedFd>> {
72+
let this = ManuallyDrop::new(self);
73+
async move {
74+
if !this.0.waits.swap(true, Ordering::AcqRel) {
75+
poll_fn(move |cx| {
76+
if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
77+
return Poll::Ready(Some(fd));
78+
}
79+
80+
this.0.waker.register(cx.waker());
81+
82+
if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
83+
Poll::Ready(Some(fd))
84+
} else {
85+
Poll::Pending
86+
}
87+
})
88+
.await
89+
} else {
90+
None
91+
}
92+
}
93+
}
94+
}
95+
96+
impl Drop for SharedFd {
97+
fn drop(&mut self) {
98+
// It's OK to wake multiple times.
99+
if Arc::strong_count(&self.0) == 2 {
100+
self.0.waker.wake()
101+
}
102+
}
103+
}
104+
105+
#[cfg(windows)]
106+
#[doc(hidden)]
107+
impl SharedFd {
108+
pub unsafe fn to_file(&self) -> ManuallyDrop<std::fs::File> {
109+
ManuallyDrop::new(std::fs::File::from_raw_handle(self.as_raw_fd() as _))
110+
}
111+
112+
pub unsafe fn to_socket(&self) -> ManuallyDrop<socket2::Socket> {
113+
ManuallyDrop::new(socket2::Socket::from_raw_socket(self.as_raw_fd() as _))
114+
}
115+
}
116+
117+
#[cfg(unix)]
118+
#[doc(hidden)]
119+
impl SharedFd {
120+
pub unsafe fn to_file(&self) -> ManuallyDrop<std::fs::File> {
121+
ManuallyDrop::new(std::fs::File::from_raw_fd(self.as_raw_fd() as _))
122+
}
123+
124+
pub unsafe fn to_socket(&self) -> ManuallyDrop<socket2::Socket> {
125+
ManuallyDrop::new(socket2::Socket::from_raw_fd(self.as_raw_fd() as _))
126+
}
127+
}
128+
129+
impl AsRawFd for SharedFd {
130+
fn as_raw_fd(&self) -> RawFd {
131+
self.0.fd.as_raw_fd()
132+
}
133+
}
134+
135+
#[cfg(windows)]
136+
impl FromRawHandle for SharedFd {
137+
unsafe fn from_raw_handle(handle: RawHandle) -> Self {
138+
Self::new(OwnedFd::File(OwnedHandle::from_raw_handle(handle)))
139+
}
140+
}
141+
142+
#[cfg(windows)]
143+
impl FromRawSocket for SharedFd {
144+
unsafe fn from_raw_socket(sock: RawSocket) -> Self {
145+
Self::new(OwnedFd::Socket(OwnedSocket::from_raw_socket(sock)))
146+
}
147+
}
148+
149+
#[cfg(unix)]
150+
impl FromRawFd for SharedFd {
151+
unsafe fn from_raw_fd(fd: RawFd) -> Self {
152+
Self::new(OwnedFd::from_raw_fd(fd))
153+
}
154+
}
155+
156+
impl From<OwnedFd> for SharedFd {
157+
fn from(value: OwnedFd) -> Self {
158+
Self::new(value)
159+
}
160+
}
161+
162+
/// Get a clone of [`SharedFd`].
163+
pub trait ToSharedFd {
164+
/// Return a cloned [`SharedFd`].
165+
fn to_shared_fd(&self) -> SharedFd;
166+
}
167+
168+
impl ToSharedFd for SharedFd {
169+
fn to_shared_fd(&self) -> SharedFd {
170+
self.clone()
171+
}
172+
}

compio-driver/src/fusion/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ mod iour;
77
pub(crate) mod op;
88

99
#[cfg_attr(all(doc, docsrs), doc(cfg(all())))]
10-
pub use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
10+
pub use std::os::fd::{AsRawFd, OwnedFd, RawFd};
1111
use std::{io, task::Poll, time::Duration};
1212

1313
pub use driver_type::DriverType;

compio-driver/src/fusion/op.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use socket2::SockAddr;
55

66
use super::*;
77
pub use crate::unix::op::*;
8+
use crate::SharedFd;
89

910
macro_rules! op {
1011
(<$($ty:ident: $trait:ident),* $(,)?> $name:ident( $($arg:ident: $arg_t:ident),* $(,)? )) => {
@@ -91,9 +92,9 @@ mod iour { pub use crate::sys::iour::{op::*, OpCode}; }
9192
#[rustfmt::skip]
9293
mod poll { pub use crate::sys::poll::{op::*, OpCode}; }
9394

94-
op!(<T: IoBufMut> RecvFrom(fd: RawFd, buffer: T));
95-
op!(<T: IoBuf> SendTo(fd: RawFd, buffer: T, addr: SockAddr));
96-
op!(<T: IoVectoredBufMut> RecvFromVectored(fd: RawFd, buffer: T));
97-
op!(<T: IoVectoredBuf> SendToVectored(fd: RawFd, buffer: T, addr: SockAddr));
98-
op!(<> FileStat(fd: RawFd));
95+
op!(<T: IoBufMut> RecvFrom(fd: SharedFd, buffer: T));
96+
op!(<T: IoBuf> SendTo(fd: SharedFd, buffer: T, addr: SockAddr));
97+
op!(<T: IoVectoredBufMut> RecvFromVectored(fd: SharedFd, buffer: T));
98+
op!(<T: IoVectoredBuf> SendToVectored(fd: SharedFd, buffer: T, addr: SockAddr));
99+
op!(<> FileStat(fd: SharedFd));
99100
op!(<> PathStat(path: CString, follow_symlink: bool));

compio-driver/src/iocp/cp/global.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ fn iocp_start() -> io::Result<()> {
7474
)
7575
) {
7676
error!(
77-
"fail to dispatch entry ({}, {}, {:p}) to driver {:p}: {:?}",
77+
"fail to dispatch entry ({}, {}, {:p}) to driver {:x}: {:?}",
7878
entry.dwNumberOfBytesTransferred,
7979
entry.lpCompletionKey,
8080
entry.lpOverlapped,

compio-driver/src/iocp/cp/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ impl CompletionPort {
158158
)
159159
) {
160160
error!(
161-
"fail to repost entry ({}, {}, {:p}) to driver {:p}: {:?}",
161+
"fail to repost entry ({}, {}, {:p}) to driver {:x}: {:?}",
162162
entry.dwNumberOfBytesTransferred,
163163
entry.lpCompletionKey,
164164
entry.lpOverlapped,

compio-driver/src/iocp/cp/multi.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ impl Port {
2828
}
2929

3030
pub fn poll(&self, timeout: Option<Duration>) -> io::Result<impl Iterator<Item = Entry> + '_> {
31-
let current_id = self.as_raw_handle();
31+
let current_id = self.as_raw_handle() as _;
3232
self.port.poll(timeout, Some(current_id))
3333
}
3434
}

compio-driver/src/iocp/mod.rs

Lines changed: 29 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ use std::{
44
mem::ManuallyDrop,
55
os::{
66
raw::c_void,
7-
windows::prelude::{
8-
AsRawHandle, AsRawSocket, FromRawHandle, FromRawSocket, IntoRawHandle, IntoRawSocket,
9-
RawHandle,
7+
windows::{
8+
io::{OwnedHandle, OwnedSocket},
9+
prelude::{AsRawHandle, AsRawSocket},
1010
},
1111
},
1212
pin::Pin,
@@ -44,66 +44,53 @@ pub(crate) use windows_sys::Win32::Networking::WinSock::{
4444
/// On windows, handle and socket are in the same size.
4545
/// Both of them could be attached to an IOCP.
4646
/// Therefore, both could be seen as fd.
47-
pub type RawFd = RawHandle;
47+
pub type RawFd = isize;
4848

4949
/// Extracts raw fds.
5050
pub trait AsRawFd {
5151
/// Extracts the raw fd.
5252
fn as_raw_fd(&self) -> RawFd;
5353
}
5454

55-
/// Construct IO objects from raw fds.
56-
pub trait FromRawFd {
57-
/// Constructs an IO object from the specified raw fd.
58-
///
59-
/// # Safety
60-
///
61-
/// The `fd` passed in must:
62-
/// - be a valid open handle or socket,
63-
/// - be opened with `FILE_FLAG_OVERLAPPED` if it's a file handle,
64-
/// - have not been attached to a driver.
65-
unsafe fn from_raw_fd(fd: RawFd) -> Self;
66-
}
67-
68-
/// Consumes an object and acquire ownership of its raw fd.
69-
pub trait IntoRawFd {
70-
/// Consumes this object, returning the raw underlying fd.
71-
fn into_raw_fd(self) -> RawFd;
55+
/// Owned handle or socket on Windows.
56+
#[derive(Debug)]
57+
pub enum OwnedFd {
58+
/// Win32 handle.
59+
File(OwnedHandle),
60+
/// Windows socket handle.
61+
Socket(OwnedSocket),
7262
}
7363

74-
impl AsRawFd for std::fs::File {
64+
impl AsRawFd for OwnedFd {
7565
fn as_raw_fd(&self) -> RawFd {
76-
self.as_raw_handle()
77-
}
78-
}
79-
80-
impl AsRawFd for socket2::Socket {
81-
fn as_raw_fd(&self) -> RawFd {
82-
self.as_raw_socket() as _
66+
match self {
67+
Self::File(fd) => fd.as_raw_handle() as _,
68+
Self::Socket(s) => s.as_raw_socket() as _,
69+
}
8370
}
8471
}
8572

86-
impl FromRawFd for std::fs::File {
87-
unsafe fn from_raw_fd(fd: RawFd) -> Self {
88-
Self::from_raw_handle(fd)
73+
impl From<OwnedHandle> for OwnedFd {
74+
fn from(value: OwnedHandle) -> Self {
75+
Self::File(value)
8976
}
9077
}
9178

92-
impl FromRawFd for socket2::Socket {
93-
unsafe fn from_raw_fd(fd: RawFd) -> Self {
94-
Self::from_raw_socket(fd as _)
79+
impl From<std::fs::File> for OwnedFd {
80+
fn from(value: std::fs::File) -> Self {
81+
Self::File(OwnedHandle::from(value))
9582
}
9683
}
9784

98-
impl IntoRawFd for std::fs::File {
99-
fn into_raw_fd(self) -> RawFd {
100-
self.into_raw_handle()
85+
impl From<OwnedSocket> for OwnedFd {
86+
fn from(value: OwnedSocket) -> Self {
87+
Self::Socket(value)
10188
}
10289
}
10390

104-
impl IntoRawFd for socket2::Socket {
105-
fn into_raw_fd(self) -> RawFd {
106-
self.into_raw_socket() as _
91+
impl From<socket2::Socket> for OwnedFd {
92+
fn from(value: socket2::Socket) -> Self {
93+
Self::Socket(OwnedSocket::from(value))
10794
}
10895
}
10996

@@ -298,7 +285,7 @@ impl Driver {
298285

299286
impl AsRawFd for Driver {
300287
fn as_raw_fd(&self) -> RawFd {
301-
self.port.as_raw_handle()
288+
self.port.as_raw_handle() as _
302289
}
303290
}
304291

0 commit comments

Comments
 (0)