Skip to content

Commit f2d678f

Browse files
committed
fix(driver): safety of try_unwrap_impl
1 parent f958731 commit f2d678f

File tree

4 files changed

+24
-20
lines changed

4 files changed

+24
-20
lines changed

compio-driver/src/fd.rs

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use std::{
2-
future::{poll_fn, ready, Future},
2+
future::{poll_fn, Future},
33
mem::ManuallyDrop,
44
panic::RefUnwindSafe,
55
sync::{
@@ -9,7 +9,7 @@ use std::{
99
task::Poll,
1010
};
1111

12-
use futures_util::{future::Either, task::AtomicWaker};
12+
use futures_util::task::AtomicWaker;
1313

1414
use crate::{AsRawFd, OwnedFd, RawFd};
1515

@@ -38,18 +38,18 @@ impl SharedFd {
3838
}
3939

4040
/// Try to take the inner owned fd.
41-
pub fn try_owned(self) -> Result<OwnedFd, Self> {
41+
pub fn try_unwrap(self) -> Result<OwnedFd, Self> {
4242
let this = ManuallyDrop::new(self);
43-
if let Some(fd) = Self::try_owned_inner(&this) {
43+
if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
4444
Ok(fd)
4545
} else {
4646
Err(ManuallyDrop::into_inner(this))
4747
}
4848
}
4949

50-
fn try_owned_inner(this: &ManuallyDrop<Self>) -> Option<OwnedFd> {
51-
// SAFETY: see ManuallyDrop::take
52-
let ptr = ManuallyDrop::new(unsafe { std::ptr::read(&this.0) });
50+
// SAFETY: if `Some` is returned, the method should not be called again.
51+
unsafe fn try_unwrap_inner(this: &ManuallyDrop<Self>) -> Option<OwnedFd> {
52+
let ptr = ManuallyDrop::new(std::ptr::read(&this.0));
5353
// The ptr is duplicated without increasing the strong count, should forget.
5454
match Arc::try_unwrap(ManuallyDrop::into_inner(ptr)) {
5555
Ok(inner) => Some(inner.fd),
@@ -62,26 +62,26 @@ impl SharedFd {
6262

6363
/// Wait and take the inner owned fd.
6464
pub fn take(self) -> impl Future<Output = Option<OwnedFd>> {
65-
if self.0.waits.fetch_add(1, Ordering::AcqRel) == 0 {
66-
let this = ManuallyDrop::new(self);
67-
Either::Left(async move {
68-
poll_fn(|cx| {
69-
if let Some(fd) = Self::try_owned_inner(&this) {
65+
let this = ManuallyDrop::new(self);
66+
async move {
67+
if this.0.waits.fetch_add(1, Ordering::AcqRel) == 0 {
68+
poll_fn(move |cx| {
69+
if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
7070
return Poll::Ready(Some(fd));
7171
}
7272

7373
this.0.waker.register(cx.waker());
7474

75-
if let Some(fd) = Self::try_owned_inner(&this) {
75+
if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
7676
Poll::Ready(Some(fd))
7777
} else {
7878
Poll::Pending
7979
}
8080
})
8181
.await
82-
})
83-
} else {
84-
Either::Right(ready(None))
82+
} else {
83+
None
84+
}
8585
}
8686
}
8787
}

compio-driver/tests/file.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ fn cancel_before_poll() {
7373

7474
assert!(res.is_ok() || res.unwrap_err().kind() == io::ErrorKind::TimedOut);
7575

76-
let op = CloseFile::new(fd.try_owned().unwrap());
76+
let op = CloseFile::new(fd.try_unwrap().unwrap());
7777
push_and_wait(&mut driver, op).unwrap();
7878
}
7979

@@ -119,7 +119,7 @@ fn register_multiple() {
119119
driver.cancel(entry);
120120
}
121121

122-
let op = CloseFile::new(fd.try_owned().unwrap());
122+
let op = CloseFile::new(fd.try_unwrap().unwrap());
123123
push_and_wait(&mut driver, op).unwrap();
124124
}
125125

compio-net/src/socket.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,17 @@ impl Socket {
156156

157157
#[cfg(windows)]
158158
pub async fn accept(&self) -> io::Result<(Self, SockAddr)> {
159+
use std::panic::resume_unwind;
160+
159161
let domain = self.local_addr()?.domain();
160162
// We should allow users sending this accepted socket to a new thread.
161163
let this_socket = unsafe { self.socket.to_socket() };
162164
let ty = this_socket.r#type()?;
163165
let protocol = this_socket.protocol()?;
164166
let accept_sock =
165-
compio_runtime::spawn_blocking(move || Socket2::new(domain, ty, protocol)).await?;
167+
compio_runtime::spawn_blocking(move || Socket2::new(domain, ty, protocol))
168+
.await
169+
.unwrap_or_else(|e| resume_unwind(e))?;
166170
let op = Accept::new(self.to_shared_fd(), accept_sock);
167171
let BufResult(res, op) = Runtime::current().submit(op).await;
168172
res?;

compio/examples/driver.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,6 @@ fn main() {
7575
}
7676
println!("{}", String::from_utf8(buffer).unwrap());
7777

78-
let op = CloseFile::new(fd.try_owned().unwrap());
78+
let op = CloseFile::new(fd.try_unwrap().unwrap());
7979
push_and_wait(&mut driver, op);
8080
}

0 commit comments

Comments
 (0)