Skip to content

Commit 25cf7d1

Browse files
committed
Auto merge of #143035 - ywxt:less-work-steal, r=oli-obk
Only work-steal in the main loop for rustc_thread_pool This PR is a replica of <rust-lang/rustc-rayon#12> that only retained work-steal in the main loop for rustc_thread_pool. r? `@oli-obk` cc `@SparrowLii` `@Zoxc` `@cuviper` Updates #113349
2 parents 8df4a58 + 36462f9 commit 25cf7d1

File tree

16 files changed

+309
-112
lines changed

16 files changed

+309
-112
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4623,6 +4623,7 @@ dependencies = [
46234623
"rand 0.9.1",
46244624
"rand_xorshift",
46254625
"scoped-tls",
4626+
"smallvec",
46264627
]
46274628

46284629
[[package]]

compiler/rustc_thread_pool/Cargo.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
[package]
22
name = "rustc_thread_pool"
33
version = "0.0.0"
4-
authors = ["Niko Matsakis <niko@alum.mit.edu>",
5-
"Josh Stone <cuviper@gmail.com>"]
4+
authors = [
5+
"Niko Matsakis <niko@alum.mit.edu>",
6+
"Josh Stone <cuviper@gmail.com>",
7+
]
68
description = "Core APIs for Rayon - fork for rustc"
79
license = "MIT OR Apache-2.0"
810
rust-version = "1.63"
@@ -14,6 +16,7 @@ categories = ["concurrency"]
1416
[dependencies]
1517
crossbeam-deque = "0.8"
1618
crossbeam-utils = "0.8"
19+
smallvec = "1.8.1"
1720

1821
[dev-dependencies]
1922
rand = "0.9"

compiler/rustc_thread_pool/src/broadcast/mod.rs

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::fmt;
22
use std::marker::PhantomData;
33
use std::sync::Arc;
4+
use std::sync::atomic::{AtomicBool, Ordering};
45

56
use crate::job::{ArcJob, StackJob};
67
use crate::latch::{CountLatch, LatchRef};
@@ -97,13 +98,22 @@ where
9798
OP: Fn(BroadcastContext<'_>) -> R + Sync,
9899
R: Send,
99100
{
101+
let current_thread = WorkerThread::current();
102+
let current_thread_addr = current_thread.expose_provenance();
103+
let started = &AtomicBool::new(false);
100104
let f = move |injected: bool| {
101105
debug_assert!(injected);
106+
107+
// Mark as started if we are the thread that initiated that broadcast.
108+
if current_thread_addr == WorkerThread::current().expose_provenance() {
109+
started.store(true, Ordering::Relaxed);
110+
}
111+
102112
BroadcastContext::with(&op)
103113
};
104114

105115
let n_threads = registry.num_threads();
106-
let current_thread = unsafe { WorkerThread::current().as_ref() };
116+
let current_thread = unsafe { current_thread.as_ref() };
107117
let tlv = crate::tlv::get();
108118
let latch = CountLatch::with_count(n_threads, current_thread);
109119
let jobs: Vec<_> =
@@ -112,8 +122,16 @@ where
112122

113123
registry.inject_broadcast(job_refs);
114124

125+
let current_thread_job_id = current_thread
126+
.and_then(|worker| (registry.id() == worker.registry.id()).then(|| worker))
127+
.map(|worker| unsafe { jobs[worker.index()].as_job_ref() }.id());
128+
115129
// Wait for all jobs to complete, then collect the results, maybe propagating a panic.
116-
latch.wait(current_thread);
130+
latch.wait(
131+
current_thread,
132+
|| started.load(Ordering::Relaxed),
133+
|job| Some(job.id()) == current_thread_job_id,
134+
);
117135
jobs.into_iter().map(|job| unsafe { job.into_result() }).collect()
118136
}
119137

@@ -129,7 +147,7 @@ where
129147
{
130148
let job = ArcJob::new({
131149
let registry = Arc::clone(registry);
132-
move || {
150+
move |_| {
133151
registry.catch_unwind(|| BroadcastContext::with(&op));
134152
registry.terminate(); // (*) permit registry to terminate now
135153
}

compiler/rustc_thread_pool/src/broadcast/tests.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ fn spawn_broadcast_self() {
6464
assert!(v.into_iter().eq(0..7));
6565
}
6666

67+
// FIXME: We should fix or remove this ignored test.
6768
#[test]
69+
#[ignore]
6870
#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
6971
fn broadcast_mutual() {
7072
let count = AtomicUsize::new(0);
@@ -98,7 +100,9 @@ fn spawn_broadcast_mutual() {
98100
assert_eq!(rx.into_iter().count(), 3 * 7);
99101
}
100102

103+
// FIXME: We should fix or remove this ignored test.
101104
#[test]
105+
#[ignore]
102106
#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
103107
fn broadcast_mutual_sleepy() {
104108
let count = AtomicUsize::new(0);

compiler/rustc_thread_pool/src/job.rs

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ pub(super) trait Job {
2727
unsafe fn execute(this: *const ());
2828
}
2929

30+
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
31+
pub(super) struct JobRefId {
32+
pointer: usize,
33+
}
34+
3035
/// Effectively a Job trait object. Each JobRef **must** be executed
3136
/// exactly once, or else data may leak.
3237
///
@@ -52,11 +57,9 @@ impl JobRef {
5257
JobRef { pointer: data as *const (), execute_fn: <T as Job>::execute }
5358
}
5459

55-
/// Returns an opaque handle that can be saved and compared,
56-
/// without making `JobRef` itself `Copy + Eq`.
5760
#[inline]
58-
pub(super) fn id(&self) -> impl Eq {
59-
(self.pointer, self.execute_fn)
61+
pub(super) fn id(&self) -> JobRefId {
62+
JobRefId { pointer: self.pointer.expose_provenance() }
6063
}
6164

6265
#[inline]
@@ -100,8 +103,15 @@ where
100103
unsafe { JobRef::new(self) }
101104
}
102105

103-
pub(super) unsafe fn run_inline(self, stolen: bool) -> R {
104-
self.func.into_inner().unwrap()(stolen)
106+
pub(super) unsafe fn run_inline(&self, stolen: bool) {
107+
unsafe {
108+
let func = (*self.func.get()).take().unwrap();
109+
*(self.result.get()) = match unwind::halt_unwinding(|| func(stolen)) {
110+
Ok(x) => JobResult::Ok(x),
111+
Err(x) => JobResult::Panic(x),
112+
};
113+
Latch::set(&self.latch);
114+
}
105115
}
106116

107117
pub(super) unsafe fn into_result(self) -> R {
@@ -138,15 +148,15 @@ where
138148
/// (Probably `StackJob` should be refactored in a similar fashion.)
139149
pub(super) struct HeapJob<BODY>
140150
where
141-
BODY: FnOnce() + Send,
151+
BODY: FnOnce(JobRefId) + Send,
142152
{
143153
job: BODY,
144154
tlv: Tlv,
145155
}
146156

147157
impl<BODY> HeapJob<BODY>
148158
where
149-
BODY: FnOnce() + Send,
159+
BODY: FnOnce(JobRefId) + Send,
150160
{
151161
pub(super) fn new(tlv: Tlv, job: BODY) -> Box<Self> {
152162
Box::new(HeapJob { job, tlv })
@@ -170,27 +180,28 @@ where
170180

171181
impl<BODY> Job for HeapJob<BODY>
172182
where
173-
BODY: FnOnce() + Send,
183+
BODY: FnOnce(JobRefId) + Send,
174184
{
175185
unsafe fn execute(this: *const ()) {
186+
let pointer = this.expose_provenance();
176187
let this = unsafe { Box::from_raw(this as *mut Self) };
177188
tlv::set(this.tlv);
178-
(this.job)();
189+
(this.job)(JobRefId { pointer });
179190
}
180191
}
181192

182193
/// Represents a job stored in an `Arc` -- like `HeapJob`, but may
183194
/// be turned into multiple `JobRef`s and called multiple times.
184195
pub(super) struct ArcJob<BODY>
185196
where
186-
BODY: Fn() + Send + Sync,
197+
BODY: Fn(JobRefId) + Send + Sync,
187198
{
188199
job: BODY,
189200
}
190201

191202
impl<BODY> ArcJob<BODY>
192203
where
193-
BODY: Fn() + Send + Sync,
204+
BODY: Fn(JobRefId) + Send + Sync,
194205
{
195206
pub(super) fn new(job: BODY) -> Arc<Self> {
196207
Arc::new(ArcJob { job })
@@ -214,11 +225,12 @@ where
214225

215226
impl<BODY> Job for ArcJob<BODY>
216227
where
217-
BODY: Fn() + Send + Sync,
228+
BODY: Fn(JobRefId) + Send + Sync,
218229
{
219230
unsafe fn execute(this: *const ()) {
231+
let pointer = this.expose_provenance();
220232
let this = unsafe { Arc::from_raw(this as *mut Self) };
221-
(this.job)();
233+
(this.job)(JobRefId { pointer });
222234
}
223235
}
224236

compiler/rustc_thread_pool/src/join/mod.rs

Lines changed: 24 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
use std::any::Any;
1+
use std::sync::atomic::{AtomicBool, Ordering};
22

33
use crate::job::StackJob;
44
use crate::latch::SpinLatch;
5-
use crate::registry::{self, WorkerThread};
6-
use crate::tlv::{self, Tlv};
7-
use crate::{FnContext, unwind};
5+
use crate::{FnContext, registry, tlv, unwind};
86

97
#[cfg(test)]
108
mod tests;
@@ -134,68 +132,38 @@ where
134132
// Create virtual wrapper for task b; this all has to be
135133
// done here so that the stack frame can keep it all live
136134
// long enough.
137-
let job_b = StackJob::new(tlv, call_b(oper_b), SpinLatch::new(worker_thread));
135+
let job_b_started = AtomicBool::new(false);
136+
let job_b = StackJob::new(
137+
tlv,
138+
|migrated| {
139+
job_b_started.store(true, Ordering::Relaxed);
140+
call_b(oper_b)(migrated)
141+
},
142+
SpinLatch::new(worker_thread),
143+
);
138144
let job_b_ref = job_b.as_job_ref();
139145
let job_b_id = job_b_ref.id();
140146
worker_thread.push(job_b_ref);
141147

142148
// Execute task a; hopefully b gets stolen in the meantime.
143149
let status_a = unwind::halt_unwinding(call_a(oper_a, injected));
144-
let result_a = match status_a {
145-
Ok(v) => v,
146-
Err(err) => join_recover_from_panic(worker_thread, &job_b.latch, err, tlv),
147-
};
148-
149-
// Now that task A has finished, try to pop job B from the
150-
// local stack. It may already have been popped by job A; it
151-
// may also have been stolen. There may also be some tasks
152-
// pushed on top of it in the stack, and we will have to pop
153-
// those off to get to it.
154-
while !job_b.latch.probe() {
155-
if let Some(job) = worker_thread.take_local_job() {
156-
if job_b_id == job.id() {
157-
// Found it! Let's run it.
158-
//
159-
// Note that this could panic, but it's ok if we unwind here.
160-
161-
// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
162-
tlv::set(tlv);
163-
164-
let result_b = job_b.run_inline(injected);
165-
return (result_a, result_b);
166-
} else {
167-
worker_thread.execute(job);
168-
}
169-
} else {
170-
// Local deque is empty. Time to steal from other
171-
// threads.
172-
worker_thread.wait_until(&job_b.latch);
173-
debug_assert!(job_b.latch.probe());
174-
break;
175-
}
176-
}
150+
worker_thread.wait_for_jobs::<_, false>(
151+
&job_b.latch,
152+
|| job_b_started.load(Ordering::Relaxed),
153+
|job| job.id() == job_b_id,
154+
|job| {
155+
debug_assert_eq!(job.id(), job_b_id);
156+
job_b.run_inline(injected);
157+
},
158+
);
177159

178160
// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
179161
tlv::set(tlv);
180162

163+
let result_a = match status_a {
164+
Ok(v) => v,
165+
Err(err) => unwind::resume_unwinding(err),
166+
};
181167
(result_a, job_b.into_result())
182168
})
183169
}
184-
185-
/// If job A panics, we still cannot return until we are sure that job
186-
/// B is complete. This is because it may contain references into the
187-
/// enclosing stack frame(s).
188-
#[cold] // cold path
189-
unsafe fn join_recover_from_panic(
190-
worker_thread: &WorkerThread,
191-
job_b_latch: &SpinLatch<'_>,
192-
err: Box<dyn Any + Send>,
193-
tlv: Tlv,
194-
) -> ! {
195-
unsafe { worker_thread.wait_until(job_b_latch) };
196-
197-
// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
198-
tlv::set(tlv);
199-
200-
unwind::resume_unwinding(err)
201-
}

compiler/rustc_thread_pool/src/join/tests.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,9 @@ fn join_context_both() {
9696
assert!(b_migrated);
9797
}
9898

99+
// FIXME: We should fix or remove this ignored test.
99100
#[test]
101+
#[ignore]
100102
#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
101103
fn join_context_neither() {
102104
// If we're already in a 1-thread pool, neither job should be stolen.

compiler/rustc_thread_pool/src/latch.rs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::ops::Deref;
33
use std::sync::atomic::{AtomicUsize, Ordering};
44
use std::sync::{Arc, Condvar, Mutex};
55

6+
use crate::job::JobRef;
67
use crate::registry::{Registry, WorkerThread};
78

89
/// We define various kinds of latches, which are all a primitive signaling
@@ -166,11 +167,6 @@ impl<'r> SpinLatch<'r> {
166167
pub(super) fn cross(thread: &'r WorkerThread) -> SpinLatch<'r> {
167168
SpinLatch { cross: true, ..SpinLatch::new(thread) }
168169
}
169-
170-
#[inline]
171-
pub(super) fn probe(&self) -> bool {
172-
self.core_latch.probe()
173-
}
174170
}
175171

176172
impl<'r> AsCoreLatch for SpinLatch<'r> {
@@ -368,13 +364,20 @@ impl CountLatch {
368364
debug_assert!(old_counter != 0);
369365
}
370366

371-
pub(super) fn wait(&self, owner: Option<&WorkerThread>) {
367+
pub(super) fn wait(
368+
&self,
369+
owner: Option<&WorkerThread>,
370+
all_jobs_started: impl FnMut() -> bool,
371+
is_job: impl FnMut(&JobRef) -> bool,
372+
) {
372373
match &self.kind {
373374
CountLatchKind::Stealing { latch, registry, worker_index } => unsafe {
374375
let owner = owner.expect("owner thread");
375376
debug_assert_eq!(registry.id(), owner.registry().id());
376377
debug_assert_eq!(*worker_index, owner.index());
377-
owner.wait_until(latch);
378+
owner.wait_for_jobs::<_, true>(latch, all_jobs_started, is_job, |job| {
379+
owner.execute(job);
380+
});
378381
},
379382
CountLatchKind::Blocking { latch } => latch.wait(),
380383
}

0 commit comments

Comments
 (0)