Skip to content

Commit a46e1f8

Browse files
committed
Use pointers instead of &self in Latch::set
`Latch::set` can invalidate its own `&self`, because it releases the owning thread to continue execution, which may then invalidate the latch by deallocation, reuse, etc. We've known about this problem when it comes to accessing latch fields too late, but the possibly dangling reference was still a problem, like rust-lang/rust#55005. The result of that was rust-lang/rust#98017, omitting the LLVM attribute `dereferenceable` on references to `!Freeze` types -- those containing `UnsafeCell`. However, miri's Stacked Borrows implementation is finer- grained than that, only relaxing for the cell itself in the `!Freeze` type. For rayon, that solves the dangling reference in atomic calls, but remains a problem for other fields of a `Latch`. This easiest fix for rayon is to use a raw pointer instead of `&self`. We still end up with some temporary references for stuff like atomics, but those should be fine with the rules above.
1 parent ed98853 commit a46e1f8

File tree

4 files changed

+60
-51
lines changed

4 files changed

+60
-51
lines changed

rayon-core/src/job.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ where
112112
let abort = unwind::AbortIfPanic;
113113
let func = (*this.func.get()).take().unwrap();
114114
(*this.result.get()) = JobResult::call(func);
115-
this.latch.set();
115+
Latch::set(&this.latch);
116116
mem::forget(abort);
117117
}
118118
}

rayon-core/src/latch.rs

+33-24
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,15 @@ pub(super) trait Latch {
3737
///
3838
/// Setting a latch triggers other threads to wake up and (in some
3939
/// cases) complete. This may, in turn, cause memory to be
40-
/// allocated and so forth. One must be very careful about this,
40+
/// deallocated and so forth. One must be very careful about this,
4141
/// and it's typically better to read all the fields you will need
4242
/// to access *before* a latch is set!
43-
fn set(&self);
43+
///
44+
/// This function operates on `*const Self` instead of `&self` to allow it
45+
/// to become dangling during this call. The caller must ensure that the
46+
/// pointer is valid upon entry, and not invalidated during the call by any
47+
/// actions other than `set` itself.
48+
unsafe fn set(this: *const Self);
4449
}
4550

4651
pub(super) trait AsCoreLatch {
@@ -123,8 +128,8 @@ impl CoreLatch {
123128
/// doing some wakeups; those are encapsulated in the surrounding
124129
/// latch code.
125130
#[inline]
126-
fn set(&self) -> bool {
127-
let old_state = self.state.swap(SET, Ordering::AcqRel);
131+
unsafe fn set(this: *const Self) -> bool {
132+
let old_state = (*this).state.swap(SET, Ordering::AcqRel);
128133
old_state == SLEEPING
129134
}
130135

@@ -186,29 +191,29 @@ impl<'r> AsCoreLatch for SpinLatch<'r> {
186191

187192
impl<'r> Latch for SpinLatch<'r> {
188193
#[inline]
189-
fn set(&self) {
194+
unsafe fn set(this: *const Self) {
190195
let cross_registry;
191196

192-
let registry: &Registry = if self.cross {
197+
let registry: &Registry = if (*this).cross {
193198
// Ensure the registry stays alive while we notify it.
194199
// Otherwise, it would be possible that we set the spin
195200
// latch and the other thread sees it and exits, causing
196201
// the registry to be deallocated, all before we get a
197202
// chance to invoke `registry.notify_worker_latch_is_set`.
198-
cross_registry = Arc::clone(self.registry);
203+
cross_registry = Arc::clone((*this).registry);
199204
&cross_registry
200205
} else {
201206
// If this is not a "cross-registry" spin-latch, then the
202207
// thread which is performing `set` is itself ensuring
203208
// that the registry stays alive. However, that doesn't
204209
// include this *particular* `Arc` handle if the waiting
205210
// thread then exits, so we must completely dereference it.
206-
self.registry
211+
(*this).registry
207212
};
208-
let target_worker_index = self.target_worker_index;
213+
let target_worker_index = (*this).target_worker_index;
209214

210-
// NOTE: Once we `set`, the target may proceed and invalidate `&self`!
211-
if self.core_latch.set() {
215+
// NOTE: Once we `set`, the target may proceed and invalidate `this`!
216+
if CoreLatch::set(&(*this).core_latch) {
212217
// Subtle: at this point, we can no longer read from
213218
// `self`, because the thread owning this spin latch may
214219
// have awoken and deallocated the latch. Therefore, we
@@ -255,10 +260,10 @@ impl LockLatch {
255260

256261
impl Latch for LockLatch {
257262
#[inline]
258-
fn set(&self) {
259-
let mut guard = self.m.lock().unwrap();
263+
unsafe fn set(this: *const Self) {
264+
let mut guard = (*this).m.lock().unwrap();
260265
*guard = true;
261-
self.v.notify_all();
266+
(*this).v.notify_all();
262267
}
263268
}
264269

@@ -307,9 +312,9 @@ impl CountLatch {
307312
/// count, then the latch is **set**, and calls to `probe()` will
308313
/// return true. Returns whether the latch was set.
309314
#[inline]
310-
pub(super) fn set(&self) -> bool {
311-
if self.counter.fetch_sub(1, Ordering::SeqCst) == 1 {
312-
self.core_latch.set();
315+
pub(super) unsafe fn set(this: *const Self) -> bool {
316+
if (*this).counter.fetch_sub(1, Ordering::SeqCst) == 1 {
317+
CoreLatch::set(&(*this).core_latch);
313318
true
314319
} else {
315320
false
@@ -320,8 +325,12 @@ impl CountLatch {
320325
/// the latch is set, then the specific worker thread is tickled,
321326
/// which should be the one that owns this latch.
322327
#[inline]
323-
pub(super) fn set_and_tickle_one(&self, registry: &Registry, target_worker_index: usize) {
324-
if self.set() {
328+
pub(super) unsafe fn set_and_tickle_one(
329+
this: *const Self,
330+
registry: &Registry,
331+
target_worker_index: usize,
332+
) {
333+
if Self::set(this) {
325334
registry.notify_worker_latch_is_set(target_worker_index);
326335
}
327336
}
@@ -362,9 +371,9 @@ impl CountLockLatch {
362371

363372
impl Latch for CountLockLatch {
364373
#[inline]
365-
fn set(&self) {
366-
if self.counter.fetch_sub(1, Ordering::SeqCst) == 1 {
367-
self.lock_latch.set();
374+
unsafe fn set(this: *const Self) {
375+
if (*this).counter.fetch_sub(1, Ordering::SeqCst) == 1 {
376+
LockLatch::set(&(*this).lock_latch);
368377
}
369378
}
370379
}
@@ -374,7 +383,7 @@ where
374383
L: Latch,
375384
{
376385
#[inline]
377-
fn set(&self) {
378-
L::set(self);
386+
unsafe fn set(this: *const Self) {
387+
L::set(&**this);
379388
}
380389
}

rayon-core/src/registry.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ impl Registry {
575575
pub(super) fn terminate(&self) {
576576
if self.terminate_count.fetch_sub(1, Ordering::AcqRel) == 1 {
577577
for (i, thread_info) in self.thread_infos.iter().enumerate() {
578-
thread_info.terminate.set_and_tickle_one(self, i);
578+
unsafe { CountLatch::set_and_tickle_one(&thread_info.terminate, self, i) };
579579
}
580580
}
581581
}
@@ -869,7 +869,7 @@ unsafe fn main_loop(
869869
let registry = &*worker_thread.registry;
870870

871871
// let registry know we are ready to do work
872-
registry.thread_infos[index].primed.set();
872+
Latch::set(&registry.thread_infos[index].primed);
873873

874874
// Worker threads should not panic. If they do, just abort, as the
875875
// internal state of the threadpool is corrupted. Note that if
@@ -892,7 +892,7 @@ unsafe fn main_loop(
892892
debug_assert!(worker_thread.take_local_job().is_none());
893893

894894
// let registry know we are done
895-
registry.thread_infos[index].stopped.set();
895+
Latch::set(&registry.thread_infos[index].stopped);
896896

897897
// Normal termination, do not abort.
898898
mem::forget(abort_guard);

rayon-core/src/scope/mod.rs

+23-23
Original file line numberDiff line numberDiff line change
@@ -540,10 +540,10 @@ impl<'scope> Scope<'scope> {
540540
BODY: FnOnce(&Scope<'scope>) + Send + 'scope,
541541
{
542542
let scope_ptr = ScopePtr(self);
543-
let job = HeapJob::new(move || {
543+
let job = HeapJob::new(move || unsafe {
544544
// SAFETY: this job will execute before the scope ends.
545-
let scope = unsafe { scope_ptr.as_ref() };
546-
scope.base.execute_job(move || body(scope))
545+
let scope = scope_ptr.as_ref();
546+
ScopeBase::execute_job(&scope.base, move || body(scope))
547547
});
548548
let job_ref = self.base.heap_job_ref(job);
549549

@@ -562,12 +562,12 @@ impl<'scope> Scope<'scope> {
562562
BODY: Fn(&Scope<'scope>, BroadcastContext<'_>) + Send + Sync + 'scope,
563563
{
564564
let scope_ptr = ScopePtr(self);
565-
let job = ArcJob::new(move || {
565+
let job = ArcJob::new(move || unsafe {
566566
// SAFETY: this job will execute before the scope ends.
567-
let scope = unsafe { scope_ptr.as_ref() };
567+
let scope = scope_ptr.as_ref();
568568
let body = &body;
569569
let func = move || BroadcastContext::with(move |ctx| body(scope, ctx));
570-
scope.base.execute_job(func);
570+
ScopeBase::execute_job(&scope.base, func)
571571
});
572572
self.base.inject_broadcast(job)
573573
}
@@ -600,10 +600,10 @@ impl<'scope> ScopeFifo<'scope> {
600600
BODY: FnOnce(&ScopeFifo<'scope>) + Send + 'scope,
601601
{
602602
let scope_ptr = ScopePtr(self);
603-
let job = HeapJob::new(move || {
603+
let job = HeapJob::new(move || unsafe {
604604
// SAFETY: this job will execute before the scope ends.
605-
let scope = unsafe { scope_ptr.as_ref() };
606-
scope.base.execute_job(move || body(scope))
605+
let scope = scope_ptr.as_ref();
606+
ScopeBase::execute_job(&scope.base, move || body(scope))
607607
});
608608
let job_ref = self.base.heap_job_ref(job);
609609

@@ -628,12 +628,12 @@ impl<'scope> ScopeFifo<'scope> {
628628
BODY: Fn(&ScopeFifo<'scope>, BroadcastContext<'_>) + Send + Sync + 'scope,
629629
{
630630
let scope_ptr = ScopePtr(self);
631-
let job = ArcJob::new(move || {
631+
let job = ArcJob::new(move || unsafe {
632632
// SAFETY: this job will execute before the scope ends.
633-
let scope = unsafe { scope_ptr.as_ref() };
633+
let scope = scope_ptr.as_ref();
634634
let body = &body;
635635
let func = move || BroadcastContext::with(move |ctx| body(scope, ctx));
636-
scope.base.execute_job(func);
636+
ScopeBase::execute_job(&scope.base, func)
637637
});
638638
self.base.inject_broadcast(job)
639639
}
@@ -688,36 +688,36 @@ impl<'scope> ScopeBase<'scope> {
688688
where
689689
FUNC: FnOnce() -> R,
690690
{
691-
let result = self.execute_job_closure(func);
691+
let result = unsafe { Self::execute_job_closure(self, func) };
692692
self.job_completed_latch.wait(owner);
693693
self.maybe_propagate_panic();
694694
result.unwrap() // only None if `op` panicked, and that would have been propagated
695695
}
696696

697697
/// Executes `func` as a job, either aborting or executing as
698698
/// appropriate.
699-
fn execute_job<FUNC>(&self, func: FUNC)
699+
unsafe fn execute_job<FUNC>(this: *const Self, func: FUNC)
700700
where
701701
FUNC: FnOnce(),
702702
{
703-
let _: Option<()> = self.execute_job_closure(func);
703+
let _: Option<()> = Self::execute_job_closure(this, func);
704704
}
705705

706706
/// Executes `func` as a job in scope. Adjusts the "job completed"
707707
/// counters and also catches any panic and stores it into
708708
/// `scope`.
709-
fn execute_job_closure<FUNC, R>(&self, func: FUNC) -> Option<R>
709+
unsafe fn execute_job_closure<FUNC, R>(this: *const Self, func: FUNC) -> Option<R>
710710
where
711711
FUNC: FnOnce() -> R,
712712
{
713713
match unwind::halt_unwinding(func) {
714714
Ok(r) => {
715-
self.job_completed_latch.set();
715+
Latch::set(&(*this).job_completed_latch);
716716
Some(r)
717717
}
718718
Err(err) => {
719-
self.job_panicked(err);
720-
self.job_completed_latch.set();
719+
(*this).job_panicked(err);
720+
Latch::set(&(*this).job_completed_latch);
721721
None
722722
}
723723
}
@@ -797,14 +797,14 @@ impl ScopeLatch {
797797
}
798798

799799
impl Latch for ScopeLatch {
800-
fn set(&self) {
801-
match self {
800+
unsafe fn set(this: *const Self) {
801+
match &*this {
802802
ScopeLatch::Stealing {
803803
latch,
804804
registry,
805805
worker_index,
806-
} => latch.set_and_tickle_one(registry, *worker_index),
807-
ScopeLatch::Blocking { latch } => latch.set(),
806+
} => CountLatch::set_and_tickle_one(latch, registry, *worker_index),
807+
ScopeLatch::Blocking { latch } => Latch::set(latch),
808808
}
809809
}
810810
}

0 commit comments

Comments
 (0)