Skip to content

Commit 62a3a9e

Browse files
authored
Merge pull request #247 from Berrysoft/refactor/shared-fd-generic
refactor(driver): generic SharedFd
2 parents c0164a9 + 73824eb commit 62a3a9e

File tree

18 files changed

+374
-344
lines changed

18 files changed

+374
-344
lines changed

compio-driver/src/fd.rs

Lines changed: 42 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
#[cfg(unix)]
22
use std::os::fd::FromRawFd;
33
#[cfg(windows)]
4-
use std::os::windows::io::{
5-
FromRawHandle, FromRawSocket, OwnedHandle, OwnedSocket, RawHandle, RawSocket,
6-
};
4+
use std::os::windows::io::{FromRawHandle, FromRawSocket, RawHandle, RawSocket};
75
use std::{
86
future::{poll_fn, Future},
97
mem::ManuallyDrop,
8+
ops::Deref,
109
panic::RefUnwindSafe,
1110
sync::{
1211
atomic::{AtomicBool, Ordering},
@@ -17,35 +16,35 @@ use std::{
1716

1817
use futures_util::task::AtomicWaker;
1918

20-
use crate::{AsRawFd, OwnedFd, RawFd};
19+
use crate::{AsRawFd, RawFd};
2120

2221
#[derive(Debug)]
23-
struct Inner {
24-
fd: OwnedFd,
22+
struct Inner<T> {
23+
fd: T,
2524
// whether there is a future waiting
2625
waits: AtomicBool,
2726
waker: AtomicWaker,
2827
}
2928

30-
impl RefUnwindSafe for Inner {}
29+
impl<T> RefUnwindSafe for Inner<T> {}
3130

3231
/// A shared fd. It is passed to the operations to make sure the fd won't be
3332
/// closed before the operations complete.
34-
#[derive(Debug, Clone)]
35-
pub struct SharedFd(Arc<Inner>);
33+
#[derive(Debug)]
34+
pub struct SharedFd<T>(Arc<Inner<T>>);
3635

37-
impl SharedFd {
36+
impl<T> SharedFd<T> {
3837
/// Create the shared fd from an owned fd.
39-
pub fn new(fd: impl Into<OwnedFd>) -> Self {
38+
pub fn new(fd: T) -> Self {
4039
Self(Arc::new(Inner {
41-
fd: fd.into(),
40+
fd,
4241
waits: AtomicBool::new(false),
4342
waker: AtomicWaker::new(),
4443
}))
4544
}
4645

4746
/// Try to take the inner owned fd.
48-
pub fn try_unwrap(self) -> Result<OwnedFd, Self> {
47+
pub fn try_unwrap(self) -> Result<T, Self> {
4948
let this = ManuallyDrop::new(self);
5049
if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
5150
Ok(fd)
@@ -55,7 +54,7 @@ impl SharedFd {
5554
}
5655

5756
// SAFETY: if `Some` is returned, the method should not be called again.
58-
unsafe fn try_unwrap_inner(this: &ManuallyDrop<Self>) -> Option<OwnedFd> {
57+
unsafe fn try_unwrap_inner(this: &ManuallyDrop<Self>) -> Option<T> {
5958
let ptr = ManuallyDrop::new(std::ptr::read(&this.0));
6059
// The ptr is duplicated without increasing the strong count, should forget.
6160
match Arc::try_unwrap(ManuallyDrop::into_inner(ptr)) {
@@ -68,7 +67,7 @@ impl SharedFd {
6867
}
6968

7069
/// Wait and take the inner owned fd.
71-
pub fn take(self) -> impl Future<Output = Option<OwnedFd>> {
70+
pub fn take(self) -> impl Future<Output = Option<T>> {
7271
let this = ManuallyDrop::new(self);
7372
async move {
7473
if !this.0.waits.swap(true, Ordering::AcqRel) {
@@ -93,7 +92,7 @@ impl SharedFd {
9392
}
9493
}
9594

96-
impl Drop for SharedFd {
95+
impl<T> Drop for SharedFd<T> {
9796
fn drop(&mut self) {
9897
// It's OK to wake multiple times.
9998
if Arc::strong_count(&self.0) == 2 {
@@ -102,71 +101,61 @@ impl Drop for SharedFd {
102101
}
103102
}
104103

105-
#[cfg(windows)]
106-
#[doc(hidden)]
107-
impl SharedFd {
108-
pub unsafe fn to_file(&self) -> ManuallyDrop<std::fs::File> {
109-
ManuallyDrop::new(std::fs::File::from_raw_handle(self.as_raw_fd() as _))
110-
}
111-
112-
pub unsafe fn to_socket(&self) -> ManuallyDrop<socket2::Socket> {
113-
ManuallyDrop::new(socket2::Socket::from_raw_socket(self.as_raw_fd() as _))
114-
}
115-
}
116-
117-
#[cfg(unix)]
118-
#[doc(hidden)]
119-
impl SharedFd {
120-
pub unsafe fn to_file(&self) -> ManuallyDrop<std::fs::File> {
121-
ManuallyDrop::new(std::fs::File::from_raw_fd(self.as_raw_fd() as _))
122-
}
123-
124-
pub unsafe fn to_socket(&self) -> ManuallyDrop<socket2::Socket> {
125-
ManuallyDrop::new(socket2::Socket::from_raw_fd(self.as_raw_fd() as _))
126-
}
127-
}
128-
129-
impl AsRawFd for SharedFd {
104+
impl<T: AsRawFd> AsRawFd for SharedFd<T> {
130105
fn as_raw_fd(&self) -> RawFd {
131106
self.0.fd.as_raw_fd()
132107
}
133108
}
134109

135110
#[cfg(windows)]
136-
impl FromRawHandle for SharedFd {
111+
impl<T: FromRawHandle> FromRawHandle for SharedFd<T> {
137112
unsafe fn from_raw_handle(handle: RawHandle) -> Self {
138-
Self::new(OwnedFd::File(OwnedHandle::from_raw_handle(handle)))
113+
Self::new(T::from_raw_handle(handle))
139114
}
140115
}
141116

142117
#[cfg(windows)]
143-
impl FromRawSocket for SharedFd {
118+
impl<T: FromRawSocket> FromRawSocket for SharedFd<T> {
144119
unsafe fn from_raw_socket(sock: RawSocket) -> Self {
145-
Self::new(OwnedFd::Socket(OwnedSocket::from_raw_socket(sock)))
120+
Self::new(T::from_raw_socket(sock))
146121
}
147122
}
148123

149124
#[cfg(unix)]
150-
impl FromRawFd for SharedFd {
125+
impl<T: FromRawFd> FromRawFd for SharedFd<T> {
151126
unsafe fn from_raw_fd(fd: RawFd) -> Self {
152-
Self::new(OwnedFd::from_raw_fd(fd))
127+
Self::new(T::from_raw_fd(fd))
153128
}
154129
}
155130

156-
impl From<OwnedFd> for SharedFd {
157-
fn from(value: OwnedFd) -> Self {
131+
impl<T> From<T> for SharedFd<T> {
132+
fn from(value: T) -> Self {
158133
Self::new(value)
159134
}
160135
}
161136

137+
impl<T> Clone for SharedFd<T> {
138+
fn clone(&self) -> Self {
139+
Self(self.0.clone())
140+
}
141+
}
142+
143+
impl<T> Deref for SharedFd<T> {
144+
type Target = T;
145+
146+
fn deref(&self) -> &Self::Target {
147+
&self.0.fd
148+
}
149+
}
150+
162151
/// Get a clone of [`SharedFd`].
163-
pub trait ToSharedFd {
152+
pub trait ToSharedFd<T> {
164153
/// Return a cloned [`SharedFd`].
165-
fn to_shared_fd(&self) -> SharedFd;
154+
fn to_shared_fd(&self) -> SharedFd<T>;
166155
}
167156

168-
impl ToSharedFd for SharedFd {
169-
fn to_shared_fd(&self) -> SharedFd {
157+
impl<T> ToSharedFd<T> for SharedFd<T> {
158+
fn to_shared_fd(&self) -> SharedFd<T> {
170159
self.clone()
171160
}
172161
}

compio-driver/src/fusion/op.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ pub use crate::unix::op::*;
88
use crate::SharedFd;
99

1010
macro_rules! op {
11-
(<$($ty:ident: $trait:ident),* $(,)?> $name:ident( $($arg:ident: $arg_t:ident),* $(,)? )) => {
11+
(<$($ty:ident: $trait:ident),* $(,)?> $name:ident( $($arg:ident: $arg_t:ty),* $(,)? )) => {
1212
::paste::paste!{
1313
enum [< $name Inner >] <$($ty: $trait),*> {
1414
Poll(poll::$name<$($ty),*>),
@@ -92,9 +92,9 @@ mod iour { pub use crate::sys::iour::{op::*, OpCode}; }
9292
#[rustfmt::skip]
9393
mod poll { pub use crate::sys::poll::{op::*, OpCode}; }
9494

95-
op!(<T: IoBufMut> RecvFrom(fd: SharedFd, buffer: T));
96-
op!(<T: IoBuf> SendTo(fd: SharedFd, buffer: T, addr: SockAddr));
97-
op!(<T: IoVectoredBufMut> RecvFromVectored(fd: SharedFd, buffer: T));
98-
op!(<T: IoVectoredBuf> SendToVectored(fd: SharedFd, buffer: T, addr: SockAddr));
99-
op!(<> FileStat(fd: SharedFd));
95+
op!(<T: IoBufMut, S: AsRawFd> RecvFrom(fd: SharedFd<S>, buffer: T));
96+
op!(<T: IoBuf, S: AsRawFd> SendTo(fd: SharedFd<S>, buffer: T, addr: SockAddr));
97+
op!(<T: IoVectoredBufMut, S: AsRawFd> RecvFromVectored(fd: SharedFd<S>, buffer: T));
98+
op!(<T: IoVectoredBuf, S: AsRawFd> SendToVectored(fd: SharedFd<S>, buffer: T, addr: SockAddr));
99+
op!(<S: AsRawFd> FileStat(fd: SharedFd<S>));
100100
op!(<> PathStat(path: CString, follow_symlink: bool));

compio-driver/src/iocp/mod.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,36 @@ impl AsRawFd for OwnedFd {
7070
}
7171
}
7272

73+
impl AsRawFd for RawFd {
74+
fn as_raw_fd(&self) -> RawFd {
75+
*self
76+
}
77+
}
78+
79+
impl AsRawFd for std::fs::File {
80+
fn as_raw_fd(&self) -> RawFd {
81+
self.as_raw_handle() as _
82+
}
83+
}
84+
85+
impl AsRawFd for OwnedHandle {
86+
fn as_raw_fd(&self) -> RawFd {
87+
self.as_raw_handle() as _
88+
}
89+
}
90+
91+
impl AsRawFd for socket2::Socket {
92+
fn as_raw_fd(&self) -> RawFd {
93+
self.as_raw_socket() as _
94+
}
95+
}
96+
97+
impl AsRawFd for OwnedSocket {
98+
fn as_raw_fd(&self) -> RawFd {
99+
self.as_raw_socket() as _
100+
}
101+
}
102+
73103
impl From<OwnedHandle> for OwnedFd {
74104
fn from(value: OwnedHandle) -> Self {
75105
Self::File(value)

0 commit comments

Comments
 (0)