Skip to content

Commit 3457723

Browse files
committed
switch direct to thread_local!, fix wasm CI errors
1 parent 14cf95e commit 3457723

File tree

7 files changed

+177
-169
lines changed

7 files changed

+177
-169
lines changed

crates/bevy_app/src/app.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,8 @@ impl App {
180180
// this loop never exits because multiple copies of sender exist
181181
let event = recv.recv().unwrap();
182182
match event {
183-
AppEvent::Task(f) => {
184-
f(&mut tls.lock());
183+
AppEvent::Task(task) => {
184+
task();
185185
}
186186
AppEvent::Exit(x) => {
187187
sub_apps = x;
@@ -519,14 +519,14 @@ impl App {
519519
/// .insert_non_send_resource(MyCounter { counter: 0 });
520520
/// ```
521521
pub fn insert_non_send_resource<R: ThreadLocalResource>(&mut self, resource: R) -> &mut Self {
522-
self.tls.lock().insert_resource(resource);
522+
self.tls.insert_resource(resource);
523523
self
524524
}
525525

526526
/// Inserts the [`!Send`](Send) resource into the app, initialized with its default value,
527527
/// if there is no existing instance of `R`.
528528
pub fn init_non_send_resource<R: ThreadLocalResource + Default>(&mut self) -> &mut Self {
529-
self.tls.lock().init_resource::<R>();
529+
self.tls.init_resource::<R>();
530530
self
531531
}
532532

@@ -918,7 +918,7 @@ fn run_once(mut app: App) {
918918
}
919919

920920
// disassemble
921-
let (mut sub_apps, tls, _) = app.into_parts();
921+
let (mut sub_apps, _, _) = app.into_parts();
922922

923923
#[cfg(not(target_arch = "wasm32"))]
924924
{
@@ -937,8 +937,8 @@ fn run_once(mut app: App) {
937937
loop {
938938
let event = recv.recv().unwrap();
939939
match event {
940-
AppEvent::Task(f) => {
941-
f(&mut tls.lock());
940+
AppEvent::Task(task) => {
941+
task();
942942
}
943943
AppEvent::Exit(_) => {
944944
thread.join().unwrap();

crates/bevy_app/src/schedule_runner.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ impl Plugin for ScheduleRunnerPlugin {
116116
};
117117

118118
// disassemble
119-
let (mut sub_apps, tls, _) = app.into_parts();
119+
let (mut sub_apps, _, _) = app.into_parts();
120120

121121
#[cfg(not(target_arch = "wasm32"))]
122122
{
@@ -140,8 +140,8 @@ impl Plugin for ScheduleRunnerPlugin {
140140
loop {
141141
let event = recv.recv().unwrap();
142142
match event {
143-
AppEvent::Task(f) => {
144-
f(&mut tls.lock());
143+
AppEvent::Task(task) => {
144+
task();
145145
}
146146
AppEvent::Exit(_) => {
147147
thread.join().unwrap();

crates/bevy_ecs/src/storage/resource_non_send.rs

Lines changed: 77 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::any::TypeId;
2+
use std::cell::RefCell;
23
use std::marker::PhantomData;
3-
use std::sync::{Arc, Mutex, Weak};
44

55
use bevy_ptr::{OwningPtr, Ptr};
66

@@ -12,6 +12,10 @@ use crate::storage::{ResourceData, Resources};
1212
use crate::system::{Resource, SystemParam};
1313
use crate::world::{unsafe_world_cell::UnsafeWorldCell, World};
1414

15+
thread_local! {
16+
static TLS: RefCell<ThreadLocals> = RefCell::new(ThreadLocals::new());
17+
}
18+
1519
/// A type that can be inserted into [`ThreadLocals`]. Unlike [`Resource`], this does not require
1620
/// [`Send`] or [`Sync`].
1721
pub trait ThreadLocalResource: 'static {}
@@ -322,32 +326,8 @@ impl ThreadLocals {
322326
}
323327
}
324328

325-
/// A "scoped lock" on [`ThreadLocals`], which is protected by a mutex and can be accessed
326-
/// through this guard via its [`Deref`] and [`DerefMut`] implementations.
327-
///
328-
/// When this guard is dropped, the lock will be unlocked.
329-
#[doc(hidden)]
330-
pub struct ThreadLocalsGuard<'a> {
331-
guard: std::sync::MutexGuard<'a, ThreadLocals>,
332-
// needed to decrement the strong reference count once dropped
333-
_arc: Arc<Mutex<ThreadLocals>>,
334-
}
335-
336-
impl std::ops::Deref for ThreadLocalsGuard<'_> {
337-
type Target = ThreadLocals;
338-
fn deref(&self) -> &Self::Target {
339-
&self.guard
340-
}
341-
}
342-
343-
impl std::ops::DerefMut for ThreadLocalsGuard<'_> {
344-
fn deref_mut(&mut self) -> &mut Self::Target {
345-
&mut self.guard
346-
}
347-
}
348-
349329
/// Type alias for tasks that access thread-local data.
350-
pub type ThreadLocalTask = Box<dyn FnOnce(&mut ThreadLocals) + Send + 'static>;
330+
pub type ThreadLocalTask = Box<dyn FnOnce() + Send + 'static>;
351331

352332
/// An error returned from the [`ThreadLocalTaskSender::send_task`] function.
353333
///
@@ -369,9 +349,8 @@ pub trait ThreadLocalTaskSender: Send + 'static {
369349
/// A [`Resource`] that enables the use of the [`ThreadLocal`] system parameter.
370350
#[derive(Resource)]
371351
struct ThreadLocalChannel {
372-
owning_thread: std::thread::ThreadId,
373-
direct: Weak<Mutex<ThreadLocals>>,
374-
indirect: Box<dyn ThreadLocalTaskSender>,
352+
thread: std::thread::ThreadId,
353+
sender: Box<dyn ThreadLocalTaskSender>,
375354
}
376355

377356
// SAFETY: The pointer to the thread-local storage is only dereferenced in its owning thread.
@@ -381,28 +360,18 @@ unsafe impl Send for ThreadLocalChannel {}
381360
// Likewise, all operations require an exclusive reference, so there can be no races.
382361
unsafe impl Sync for ThreadLocalChannel {}
383362

384-
/// A mutex-guarded instance of [`ThreadLocals`].
363+
/// A guard to access [`ThreadLocals`].
385364
pub struct ThreadLocalStorage {
386365
thread: std::thread::ThreadId,
387-
locals: Arc<Mutex<ThreadLocals>>,
366+
// !Send + !Sync
367+
_marker: PhantomData<*const ()>,
388368
}
389369

390370
impl Default for ThreadLocalStorage {
391371
fn default() -> Self {
392372
Self {
393373
thread: std::thread::current().id(),
394-
// Need a reference-counted pointer that can exist on multiple threads. Only the local
395-
// thread is able to deference it.
396-
#[allow(clippy::arc_with_non_send_sync)]
397-
locals: Arc::new(Mutex::new(ThreadLocals::new())),
398-
}
399-
}
400-
}
401-
402-
impl Drop for ThreadLocalStorage {
403-
fn drop(&mut self) {
404-
if Arc::strong_count(&self.locals) > 1 {
405-
panic!("`ThreadLocalStorage` was dropped while there was an active borrow of `ThreadLocals`.")
374+
_marker: PhantomData,
406375
}
407376
}
408377
}
@@ -413,16 +382,39 @@ impl ThreadLocalStorage {
413382
Self::default()
414383
}
415384

416-
/// Returns an exclusive reference to the underlying [`ThreadLocals`].
385+
/// Inserts a new resource with its default value.
417386
///
418-
/// # Panics
387+
/// If the resource already exists, nothing happens.
388+
#[inline]
389+
pub fn init_resource<R: ThreadLocalResource + Default>(&mut self) {
390+
TLS.with_borrow_mut(|tls| {
391+
tls.init_resource::<R>();
392+
});
393+
}
394+
395+
/// Inserts a new resource with the given `value`.
419396
///
420-
/// This function will panic if an exclusive reference cannot be acquired.
421-
pub fn lock(&self) -> ThreadLocalsGuard<'_> {
422-
ThreadLocalsGuard {
423-
guard: self.locals.try_lock().unwrap(),
424-
_arc: self.locals.clone(),
425-
}
397+
/// Resources are "unique" data of a given type. If you insert a resource of a type that already
398+
/// exists, you will overwrite any existing data.
399+
#[inline]
400+
pub fn insert_resource<R: ThreadLocalResource>(&mut self, value: R) {
401+
TLS.with_borrow_mut(|tls| {
402+
tls.insert_resource(value);
403+
});
404+
}
405+
406+
/// Removes the resource of a given type and returns it, if it exists.
407+
#[inline]
408+
pub fn remove_resource<R: ThreadLocalResource>(&mut self) -> Option<R> {
409+
TLS.with_borrow_mut(|tls| tls.remove_resource::<R>())
410+
}
411+
412+
/// Temporarily removes `R` from the [`ThreadLocals`], then re-inserts it before returning.
413+
pub fn resource_scope<R: ThreadLocalResource, T>(
414+
&mut self,
415+
f: impl FnOnce(&mut ThreadLocals, Mut<R>) -> T,
416+
) -> T {
417+
TLS.with_borrow_mut(|tls| tls.resource_scope(f))
426418
}
427419

428420
/// Inserts a channel into `world` that systems in `world` (via [`ThreadLocal`]) can use to
@@ -432,9 +424,8 @@ impl ThreadLocalStorage {
432424
S: ThreadLocalTaskSender,
433425
{
434426
let channel = ThreadLocalChannel {
435-
owning_thread: self.thread,
436-
direct: Arc::downgrade(&self.locals),
437-
indirect: Box::new(sender),
427+
thread: self.thread,
428+
sender: Box::new(sender),
438429
};
439430

440431
world.insert_resource(channel);
@@ -448,7 +439,7 @@ impl ThreadLocalStorage {
448439
}
449440

450441
enum ThreadLocalAccess<'a> {
451-
Direct(ThreadLocalsGuard<'a>),
442+
Direct,
452443
Indirect(&'a mut dyn ThreadLocalTaskSender),
453444
}
454445

@@ -472,7 +463,7 @@ impl ThreadLocal<'_, '_> {
472463
T: Send + 'static,
473464
{
474465
match self.access {
475-
ThreadLocalAccess::Direct(_) => self.run_direct(f),
466+
ThreadLocalAccess::Direct => self.run_direct(f),
476467
ThreadLocalAccess::Indirect(_) => self.run_indirect(f),
477468
}
478469
}
@@ -482,18 +473,16 @@ impl ThreadLocal<'_, '_> {
482473
F: FnOnce(&mut ThreadLocals) -> T + Send,
483474
T: Send + 'static,
484475
{
485-
let ThreadLocalAccess::Direct(ref mut tls) = self.access else {
486-
unreachable!()
487-
};
488-
489-
tls.update_change_tick();
490-
let saved = std::mem::replace(&mut tls.last_tick, *self.last_run);
491-
let result = f(&mut *tls);
492-
tls.last_tick = saved;
476+
debug_assert!(matches!(self.access, ThreadLocalAccess::Direct));
493477

494-
*self.last_run = tls.last_tick;
495-
496-
result
478+
TLS.with_borrow_mut(|tls| {
479+
tls.update_change_tick();
480+
let saved = std::mem::replace(&mut tls.last_tick, *self.last_run);
481+
let result = f(tls);
482+
tls.last_tick = saved;
483+
*self.last_run = tls.curr_tick;
484+
result
485+
})
497486
}
498487

499488
fn run_indirect<F, T>(&mut self, f: F) -> T
@@ -507,22 +496,23 @@ impl ThreadLocal<'_, '_> {
507496

508497
let system_tick = *self.last_run;
509498
let (result_tx, result_rx) = std::sync::mpsc::sync_channel(1);
510-
let task = move |tls: &mut ThreadLocals| {
511-
tls.update_change_tick();
512-
let saved = std::mem::replace(&mut tls.last_tick, system_tick);
513-
// we want to propagate to caller instead of panicking in the main thread
514-
let result =
515-
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| (f(tls), tls.last_tick)));
516-
tls.last_tick = saved;
517-
518-
result_tx.send(result).unwrap();
499+
let task = move || {
500+
TLS.with_borrow_mut(|tls| {
501+
tls.update_change_tick();
502+
let saved = std::mem::replace(&mut tls.last_tick, system_tick);
503+
// we want to propagate to caller instead of panicking in the main thread
504+
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
505+
(f(tls), tls.curr_tick)
506+
}));
507+
tls.last_tick = saved;
508+
result_tx.send(result).unwrap();
509+
});
519510
};
520511

521-
let task: Box<dyn FnOnce(&mut ThreadLocals) + Send> = Box::new(task);
522-
let task: Box<dyn FnOnce(&mut ThreadLocals) + Send + 'static> =
523-
// SAFETY: This function will block the calling thread until `f` completes,
524-
// so any captured references in `f` will remain valid until then.
525-
unsafe { std::mem::transmute(task) };
512+
let task: Box<dyn FnOnce() + Send> = Box::new(task);
513+
// SAFETY: This function will block the calling thread until `f` completes,
514+
// so any captured references in `f` will remain valid until then.
515+
let task: Box<dyn FnOnce() + Send + 'static> = unsafe { std::mem::transmute(task) };
526516

527517
// Send task to the main thread.
528518
sender
@@ -531,8 +521,8 @@ impl ThreadLocal<'_, '_> {
531521

532522
// Wait to receive result back from the main thread.
533523
match result_rx.recv().unwrap() {
534-
Ok((result, last_run)) => {
535-
*self.last_run = last_run;
524+
Ok((result, tls_tick)) => {
525+
*self.last_run = tls_tick;
536526
result
537527
}
538528
Err(payload) => {
@@ -568,33 +558,17 @@ unsafe impl SystemParam for ThreadLocal<'_, '_> {
568558
world: UnsafeWorldCell<'world>,
569559
curr_tick: Tick,
570560
) -> Self::Item<'world, 'state> {
571-
let mut accessor = crate::system::ResMut::<ThreadLocalChannel>::get_param(
561+
let accessor = crate::system::ResMut::<ThreadLocalChannel>::get_param(
572562
&mut state.component_id,
573563
system_meta,
574564
world,
575565
curr_tick,
576566
);
577567

578-
let access = if std::thread::current().id() == accessor.owning_thread {
579-
let arc = accessor
580-
.direct
581-
.upgrade()
582-
.expect("pointer to `ThreadLocals` should be valid");
583-
584-
// Use a raw pointer so we can hold onto the `Arc` and satisfy the borrow checker.
585-
let ptr = Arc::as_ptr(&arc);
586-
587-
// SAFETY: This pointer is valid since we're still holding the `Arc`.
588-
let mutex = unsafe { &*ptr };
589-
let guard = mutex
590-
.try_lock()
591-
.expect("lock on `ThreadLocals` should be available");
592-
593-
ThreadLocalAccess::Direct(ThreadLocalsGuard { guard, _arc: arc })
568+
let access = if std::thread::current().id() == accessor.thread {
569+
ThreadLocalAccess::Direct
594570
} else {
595-
let ptr: *mut dyn ThreadLocalTaskSender = &mut *accessor.indirect;
596-
// SAFETY: The pointer is valid. We have to do this to satisfy the borrow checker.
597-
ThreadLocalAccess::Indirect(unsafe { &mut *ptr })
571+
ThreadLocalAccess::Indirect(&mut *accessor.into_inner().sender)
598572
};
599573

600574
ThreadLocal {

0 commit comments

Comments
 (0)