Skip to content

Commit e1272e7

Browse files
authored
Merge pull request #261 from Berrysoft/faster-send-wrapper
feat(runtime): faster SendWrapper with cached thread id
2 parents 456b47b + 947af15 commit e1272e7

File tree

3 files changed

+119
-7
lines changed

3 files changed

+119
-7
lines changed

compio-runtime/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ crossbeam-queue = { workspace = true }
4040
futures-util = { workspace = true }
4141
once_cell = { workspace = true }
4242
scoped-tls = "1.0.1"
43-
send_wrapper = "0.6.0"
4443
slab = { workspace = true, optional = true }
4544
smallvec = "1.11.1"
4645
socket2 = { workspace = true }

compio-runtime/src/runtime/mod.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@ use compio_driver::{
2020
use compio_log::{debug, instrument};
2121
use crossbeam_queue::SegQueue;
2222
use futures_util::{future::Either, FutureExt};
23-
use send_wrapper::SendWrapper;
2423
use smallvec::SmallVec;
2524

2625
pub(crate) mod op;
2726
#[cfg(feature = "time")]
2827
pub(crate) mod time;
2928

29+
mod send_wrapper;
30+
use send_wrapper::SendWrapper;
31+
3032
#[cfg(feature = "time")]
3133
use crate::runtime::time::{TimerFuture, TimerRuntime};
3234
use crate::{runtime::op::OpFuture, BufResult};
@@ -120,8 +122,8 @@ impl Runtime {
120122
.handle()
121123
.expect("cannot create notify handle of the proactor");
122124
let schedule = move |runnable| {
123-
if local_runnables.valid() {
124-
local_runnables.borrow_mut().push_back(runnable);
125+
if let Some(runnables) = local_runnables.get() {
126+
runnables.borrow_mut().push_back(runnable);
125127
} else {
126128
sync_runnables.push(runnable);
127129
handle.notify().ok();
@@ -136,9 +138,8 @@ impl Runtime {
136138
///
137139
/// Run the scheduled tasks.
138140
pub fn run(&self) {
139-
use std::ops::Deref;
140-
141-
let local_runnables = self.local_runnables.deref().deref();
141+
// SAFETY: self is !Send + !Sync.
142+
let local_runnables = unsafe { self.local_runnables.get_unchecked() };
142143
loop {
143144
let next_task = local_runnables.borrow_mut().pop_front();
144145
let has_local_task = next_task.is_some();
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
// Copyright 2017 Thomas Keh.
2+
// Copyright 2024 compio-rs
3+
//
4+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5+
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6+
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
7+
// option. This file may not be copied, modified, or distributed
8+
// except according to those terms.
9+
10+
use std::{
11+
cell::Cell,
12+
mem::{self, ManuallyDrop},
13+
thread::{self, ThreadId},
14+
};
15+
16+
thread_local! {
17+
static THREAD_ID: Cell<ThreadId> = Cell::new(thread::current().id());
18+
}
19+
20+
/// A wrapper that copied from `send_wrapper` crate, with our own optimizations.
21+
pub struct SendWrapper<T> {
22+
data: ManuallyDrop<T>,
23+
thread_id: ThreadId,
24+
}
25+
26+
impl<T> SendWrapper<T> {
27+
/// Create a `SendWrapper<T>` wrapper around a value of type `T`.
28+
/// The wrapper takes ownership of the value.
29+
#[inline]
30+
pub fn new(data: T) -> SendWrapper<T> {
31+
SendWrapper {
32+
data: ManuallyDrop::new(data),
33+
thread_id: THREAD_ID.get(),
34+
}
35+
}
36+
37+
/// Returns `true` if the value can be safely accessed from within the
38+
/// current thread.
39+
#[inline]
40+
pub fn valid(&self) -> bool {
41+
self.thread_id == THREAD_ID.get()
42+
}
43+
44+
/// Returns a reference to the contained value.
45+
///
46+
/// # Safety
47+
///
48+
/// The caller should be in the same thread as the creator.
49+
#[inline]
50+
pub unsafe fn get_unchecked(&self) -> &T {
51+
&self.data
52+
}
53+
54+
/// Returns a reference to the contained value, if valid.
55+
#[inline]
56+
pub fn get(&self) -> Option<&T> {
57+
if self.valid() { Some(&self.data) } else { None }
58+
}
59+
}
60+
61+
unsafe impl<T> Send for SendWrapper<T> {}
62+
unsafe impl<T> Sync for SendWrapper<T> {}
63+
64+
impl<T> Drop for SendWrapper<T> {
65+
/// Drops the contained value.
66+
///
67+
/// # Panics
68+
///
69+
/// Dropping panics if it is done from a different thread than the one the
70+
/// `SendWrapper<T>` instance has been created with.
71+
///
72+
/// Exceptions:
73+
/// - There is no extra panic if the thread is already panicking/unwinding.
74+
/// This is because otherwise there would be double panics (usually
75+
/// resulting in an abort) when dereferencing from a wrong thread.
76+
/// - If `T` has a trivial drop ([`needs_drop::<T>()`] is false) then this
77+
/// method never panics.
78+
///
79+
/// [`needs_drop::<T>()`]: std::mem::needs_drop
80+
#[track_caller]
81+
fn drop(&mut self) {
82+
// If the drop is trivial (`needs_drop` = false), then dropping `T` can't access
83+
// it and so it can be safely dropped on any thread.
84+
if !mem::needs_drop::<T>() || self.valid() {
85+
unsafe {
86+
// Drop the inner value
87+
//
88+
// Safety:
89+
// - We've just checked that it's valid to drop `T` on this thread
90+
// - We only move out from `self.data` here and in drop, so `self.data` is
91+
// present
92+
ManuallyDrop::drop(&mut self.data);
93+
}
94+
} else {
95+
invalid_drop()
96+
}
97+
}
98+
}
99+
100+
#[cold]
101+
#[inline(never)]
102+
#[track_caller]
103+
fn invalid_drop() {
104+
const DROP_ERROR: &str = "Dropped SendWrapper<T> variable from a thread different to the one \
105+
it has been created with.";
106+
107+
if !thread::panicking() {
108+
// panic because of dropping from wrong thread
109+
// only do this while not unwinding (could be caused by deref from wrong thread)
110+
panic!("{}", DROP_ERROR)
111+
}
112+
}

0 commit comments

Comments
 (0)