Skip to content
This repository was archived by the owner on May 28, 2025. It is now read-only.

Commit a0d104d

Browse files
committed
refactor scheduling of TLS dtors
1 parent f8fbc6d commit a0d104d

File tree

9 files changed

+268
-265
lines changed

9 files changed

+268
-265
lines changed

src/tools/miri/src/concurrency/thread.rs

Lines changed: 80 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
use std::cell::RefCell;
44
use std::collections::hash_map::Entry;
55
use std::num::TryFromIntError;
6+
use std::task::Poll;
67
use std::time::{Duration, SystemTime};
78

89
use log::trace;
@@ -16,6 +17,7 @@ use rustc_target::spec::abi::Abi;
1617

1718
use crate::concurrency::data_race;
1819
use crate::concurrency::sync::SynchronizationState;
20+
use crate::shims::tls;
1921
use crate::*;
2022

2123
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
@@ -24,10 +26,8 @@ pub enum SchedulingAction {
2426
ExecuteStep,
2527
/// Execute a timeout callback.
2628
ExecuteTimeoutCallback,
27-
/// Execute destructors of the active thread.
28-
ExecuteDtors,
29-
/// Stop the program.
30-
Stop,
29+
/// Wait for a bit, until there is a timeout to be called.
30+
Sleep(Duration),
3131
}
3232

3333
/// Trait for callbacks that can be executed when some event happens, such as after a timeout.
@@ -41,9 +41,6 @@ type TimeoutCallback<'mir, 'tcx> = Box<dyn MachineCallback<'mir, 'tcx> + 'tcx>;
4141
#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)]
4242
pub struct ThreadId(u32);
4343

44-
/// The main thread. When it terminates, the whole application terminates.
45-
const MAIN_THREAD: ThreadId = ThreadId(0);
46-
4744
impl ThreadId {
4845
pub fn to_u32(self) -> u32 {
4946
self.0
@@ -118,6 +115,12 @@ pub struct Thread<'mir, 'tcx> {
118115
/// The virtual call stack.
119116
stack: Vec<Frame<'mir, 'tcx, Provenance, FrameData<'tcx>>>,
120117

118+
/// The function to call when the stack ran empty, to figure out what to do next.
119+
/// Conceptually, this is the interpreter implementation of the things that happen 'after' the
120+
/// Rust language entry point for this thread returns (usually implemented by the C or OS runtime).
121+
/// (`None` is an error, it means the callback has not been set up yet or is actively running.)
122+
pub(crate) on_stack_empty: Option<StackEmptyCallback<'mir, 'tcx>>,
123+
121124
/// The index of the topmost user-relevant frame in `stack`. This field must contain
122125
/// the value produced by `get_top_user_relevant_frame`.
123126
/// The `None` state here represents
@@ -137,19 +140,10 @@ pub struct Thread<'mir, 'tcx> {
137140
pub(crate) last_error: Option<MPlaceTy<'tcx, Provenance>>,
138141
}
139142

140-
impl<'mir, 'tcx> Thread<'mir, 'tcx> {
141-
/// Check if the thread is done executing (no more stack frames). If yes,
142-
/// change the state to terminated and return `true`.
143-
fn check_terminated(&mut self) -> bool {
144-
if self.state == ThreadState::Enabled {
145-
if self.stack.is_empty() {
146-
self.state = ThreadState::Terminated;
147-
return true;
148-
}
149-
}
150-
false
151-
}
143+
pub type StackEmptyCallback<'mir, 'tcx> =
144+
Box<dyn FnMut(&mut MiriInterpCx<'mir, 'tcx>) -> InterpResult<'tcx, Poll<()>>>;
152145

146+
impl<'mir, 'tcx> Thread<'mir, 'tcx> {
153147
/// Get the name of the current thread, or `<unnamed>` if it was not set.
154148
fn thread_name(&self) -> &[u8] {
155149
if let Some(ref thread_name) = self.thread_name { thread_name } else { b"<unnamed>" }
@@ -202,28 +196,21 @@ impl<'mir, 'tcx> std::fmt::Debug for Thread<'mir, 'tcx> {
202196
}
203197
}
204198

205-
impl<'mir, 'tcx> Default for Thread<'mir, 'tcx> {
206-
fn default() -> Self {
199+
impl<'mir, 'tcx> Thread<'mir, 'tcx> {
200+
fn new(name: Option<&str>, on_stack_empty: Option<StackEmptyCallback<'mir, 'tcx>>) -> Self {
207201
Self {
208202
state: ThreadState::Enabled,
209-
thread_name: None,
203+
thread_name: name.map(|name| Vec::from(name.as_bytes())),
210204
stack: Vec::new(),
211205
top_user_relevant_frame: None,
212206
join_status: ThreadJoinStatus::Joinable,
213207
panic_payload: None,
214208
last_error: None,
209+
on_stack_empty,
215210
}
216211
}
217212
}
218213

219-
impl<'mir, 'tcx> Thread<'mir, 'tcx> {
220-
fn new(name: &str) -> Self {
221-
let mut thread = Thread::default();
222-
thread.thread_name = Some(Vec::from(name.as_bytes()));
223-
thread
224-
}
225-
}
226-
227214
impl VisitTags for Thread<'_, '_> {
228215
fn visit_tags(&self, visit: &mut dyn FnMut(SbTag)) {
229216
let Thread {
@@ -234,6 +221,7 @@ impl VisitTags for Thread<'_, '_> {
234221
state: _,
235222
thread_name: _,
236223
join_status: _,
224+
on_stack_empty: _, // we assume the closure captures no GC-relevant state
237225
} = self;
238226

239227
panic_payload.visit_tags(visit);
@@ -327,22 +315,6 @@ pub struct ThreadManager<'mir, 'tcx> {
327315
timeout_callbacks: FxHashMap<ThreadId, TimeoutCallbackInfo<'mir, 'tcx>>,
328316
}
329317

330-
impl<'mir, 'tcx> Default for ThreadManager<'mir, 'tcx> {
331-
fn default() -> Self {
332-
let mut threads = IndexVec::new();
333-
// Create the main thread and add it to the list of threads.
334-
threads.push(Thread::new("main"));
335-
Self {
336-
active_thread: ThreadId::new(0),
337-
threads,
338-
sync: SynchronizationState::default(),
339-
thread_local_alloc_ids: Default::default(),
340-
yield_active_thread: false,
341-
timeout_callbacks: FxHashMap::default(),
342-
}
343-
}
344-
}
345-
346318
impl VisitTags for ThreadManager<'_, '_> {
347319
fn visit_tags(&self, visit: &mut dyn FnMut(SbTag)) {
348320
let ThreadManager {
@@ -367,8 +339,28 @@ impl VisitTags for ThreadManager<'_, '_> {
367339
}
368340
}
369341

342+
impl<'mir, 'tcx> Default for ThreadManager<'mir, 'tcx> {
343+
fn default() -> Self {
344+
let mut threads = IndexVec::new();
345+
// Create the main thread and add it to the list of threads.
346+
threads.push(Thread::new(Some("main"), None));
347+
Self {
348+
active_thread: ThreadId::new(0),
349+
threads,
350+
sync: SynchronizationState::default(),
351+
thread_local_alloc_ids: Default::default(),
352+
yield_active_thread: false,
353+
timeout_callbacks: FxHashMap::default(),
354+
}
355+
}
356+
}
357+
370358
impl<'mir, 'tcx: 'mir> ThreadManager<'mir, 'tcx> {
371-
pub(crate) fn init(ecx: &mut MiriInterpCx<'mir, 'tcx>) {
359+
pub(crate) fn init(
360+
ecx: &mut MiriInterpCx<'mir, 'tcx>,
361+
on_main_stack_empty: StackEmptyCallback<'mir, 'tcx>,
362+
) {
363+
ecx.machine.threads.threads[ThreadId::new(0)].on_stack_empty = Some(on_main_stack_empty);
372364
if ecx.tcx.sess.target.os.as_ref() != "windows" {
373365
// The main thread can *not* be joined on except on windows.
374366
ecx.machine.threads.threads[ThreadId::new(0)].join_status = ThreadJoinStatus::Detached;
@@ -411,9 +403,9 @@ impl<'mir, 'tcx: 'mir> ThreadManager<'mir, 'tcx> {
411403
}
412404

413405
/// Create a new thread and returns its id.
414-
fn create_thread(&mut self) -> ThreadId {
406+
fn create_thread(&mut self, on_stack_empty: StackEmptyCallback<'mir, 'tcx>) -> ThreadId {
415407
let new_thread_id = ThreadId::new(self.threads.len());
416-
self.threads.push(Default::default());
408+
self.threads.push(Thread::new(None, Some(on_stack_empty)));
417409
new_thread_id
418410
}
419411

@@ -458,6 +450,7 @@ impl<'mir, 'tcx: 'mir> ThreadManager<'mir, 'tcx> {
458450
}
459451

460452
/// Get a mutable borrow of the currently active thread.
453+
/// (Private for a bit of protection.)
461454
fn active_thread_mut(&mut self) -> &mut Thread<'mir, 'tcx> {
462455
&mut self.threads[self.active_thread]
463456
}
@@ -669,37 +662,25 @@ impl<'mir, 'tcx: 'mir> ThreadManager<'mir, 'tcx> {
669662
/// long as we can and switch only when we have to (the active thread was
670663
/// blocked, terminated, or has explicitly asked to be preempted).
671664
fn schedule(&mut self, clock: &Clock) -> InterpResult<'tcx, SchedulingAction> {
672-
// Check whether the thread has **just** terminated (`check_terminated`
673-
// checks whether the thread has popped all its stack and if yes, sets
674-
// the thread state to terminated).
675-
if self.threads[self.active_thread].check_terminated() {
676-
return Ok(SchedulingAction::ExecuteDtors);
677-
}
678-
// If we get here again and the thread is *still* terminated, there are no more dtors to run.
679-
if self.threads[MAIN_THREAD].state == ThreadState::Terminated {
680-
// The main thread terminated; stop the program.
681-
// We do *not* run TLS dtors of remaining threads, which seems to match rustc behavior.
682-
return Ok(SchedulingAction::Stop);
683-
}
684665
// This thread and the program can keep going.
685666
if self.threads[self.active_thread].state == ThreadState::Enabled
686667
&& !self.yield_active_thread
687668
{
688669
// The currently active thread is still enabled, just continue with it.
689670
return Ok(SchedulingAction::ExecuteStep);
690671
}
691-
// The active thread yielded. Let's see if there are any timeouts to take care of. We do
692-
// this *before* running any other thread, to ensure that timeouts "in the past" fire before
693-
// any other thread can take an action. This ensures that for `pthread_cond_timedwait`, "an
694-
// error is returned if [...] the absolute time specified by abstime has already been passed
695-
// at the time of the call".
672+
// The active thread yielded or got terminated. Let's see if there are any timeouts to take
673+
// care of. We do this *before* running any other thread, to ensure that timeouts "in the
674+
// past" fire before any other thread can take an action. This ensures that for
675+
// `pthread_cond_timedwait`, "an error is returned if [...] the absolute time specified by
676+
// abstime has already been passed at the time of the call".
696677
// <https://pubs.opengroup.org/onlinepubs/9699919799/functions/pthread_cond_timedwait.html>
697678
let potential_sleep_time =
698679
self.timeout_callbacks.values().map(|info| info.call_time.get_wait_time(clock)).min();
699680
if potential_sleep_time == Some(Duration::new(0, 0)) {
700681
return Ok(SchedulingAction::ExecuteTimeoutCallback);
701682
}
702-
// No callbacks scheduled, pick a regular thread to execute.
683+
// No callbacks immediately scheduled, pick a regular thread to execute.
703684
// The active thread blocked or yielded. So we go search for another enabled thread.
704685
// Crucially, we start searching at the current active thread ID, rather than at 0, since we
705686
// want to avoid always scheduling threads 0 and 1 without ever making progress in thread 2.
@@ -730,9 +711,7 @@ impl<'mir, 'tcx: 'mir> ThreadManager<'mir, 'tcx> {
730711
// All threads are currently blocked, but we have unexecuted
731712
// timeout_callbacks, which may unblock some of the threads. Hence,
732713
// sleep until the first callback.
733-
734-
clock.sleep(sleep_time);
735-
Ok(SchedulingAction::ExecuteTimeoutCallback)
714+
Ok(SchedulingAction::Sleep(sleep_time))
736715
} else {
737716
throw_machine_stop!(TerminationInfo::Deadlock);
738717
}
@@ -773,18 +752,9 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
773752
}
774753
}
775754

755+
/// Start a regular (non-main) thread.
776756
#[inline]
777-
fn create_thread(&mut self) -> ThreadId {
778-
let this = self.eval_context_mut();
779-
let id = this.machine.threads.create_thread();
780-
if let Some(data_race) = &mut this.machine.data_race {
781-
data_race.thread_created(&this.machine.threads, id);
782-
}
783-
id
784-
}
785-
786-
#[inline]
787-
fn start_thread(
757+
fn start_regular_thread(
788758
&mut self,
789759
thread: Option<MPlaceTy<'tcx, Provenance>>,
790760
start_routine: Pointer<Option<Provenance>>,
@@ -795,7 +765,13 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
795765
let this = self.eval_context_mut();
796766

797767
// Create the new thread
798-
let new_thread_id = this.create_thread();
768+
let new_thread_id = this.machine.threads.create_thread({
769+
let mut state = tls::TlsDtorsState::default();
770+
Box::new(move |m| state.on_stack_empty(m))
771+
});
772+
if let Some(data_race) = &mut this.machine.data_race {
773+
data_race.thread_created(&this.machine.threads, new_thread_id);
774+
}
799775

800776
// Write the current thread-id, switch to the next thread later
801777
// to treat this write operation as occuring on the current thread.
@@ -888,12 +864,6 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
888864
this.machine.threads.get_total_thread_count()
889865
}
890866

891-
#[inline]
892-
fn has_terminated(&self, thread_id: ThreadId) -> bool {
893-
let this = self.eval_context_ref();
894-
this.machine.threads.has_terminated(thread_id)
895-
}
896-
897867
#[inline]
898868
fn have_all_terminated(&self) -> bool {
899869
let this = self.eval_context_ref();
@@ -943,26 +913,22 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
943913
where
944914
'mir: 'c,
945915
{
946-
let this = self.eval_context_ref();
947-
this.machine.threads.get_thread_name(thread)
916+
self.eval_context_ref().machine.threads.get_thread_name(thread)
948917
}
949918

950919
#[inline]
951920
fn block_thread(&mut self, thread: ThreadId) {
952-
let this = self.eval_context_mut();
953-
this.machine.threads.block_thread(thread);
921+
self.eval_context_mut().machine.threads.block_thread(thread);
954922
}
955923

956924
#[inline]
957925
fn unblock_thread(&mut self, thread: ThreadId) {
958-
let this = self.eval_context_mut();
959-
this.machine.threads.unblock_thread(thread);
926+
self.eval_context_mut().machine.threads.unblock_thread(thread);
960927
}
961928

962929
#[inline]
963930
fn yield_active_thread(&mut self) {
964-
let this = self.eval_context_mut();
965-
this.machine.threads.yield_active_thread();
931+
self.eval_context_mut().machine.threads.yield_active_thread();
966932
}
967933

968934
#[inline]
@@ -1024,6 +990,19 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
1024990
Ok(())
1025991
}
1026992

993+
#[inline]
994+
fn run_on_stack_empty(&mut self) -> InterpResult<'tcx, Poll<()>> {
995+
let this = self.eval_context_mut();
996+
let mut callback = this
997+
.active_thread_mut()
998+
.on_stack_empty
999+
.take()
1000+
.expect("`on_stack_empty` not set up, or already running");
1001+
let res = callback(this)?;
1002+
this.active_thread_mut().on_stack_empty = Some(callback);
1003+
Ok(res)
1004+
}
1005+
10271006
/// Decide which action to take next and on which thread.
10281007
#[inline]
10291008
fn schedule(&mut self) -> InterpResult<'tcx, SchedulingAction> {
@@ -1034,10 +1013,14 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
10341013
/// Handles thread termination of the active thread: wakes up threads joining on this one,
10351014
/// and deallocated thread-local statics.
10361015
///
1037-
/// This is called from `tls.rs` after handling the TLS dtors.
1016+
/// This is called by the eval loop when a thread's on_stack_empty returns `Ready`.
10381017
#[inline]
10391018
fn thread_terminated(&mut self) -> InterpResult<'tcx> {
10401019
let this = self.eval_context_mut();
1020+
let thread = this.active_thread_mut();
1021+
assert!(thread.stack.is_empty(), "only threads with an empty stack can be terminated");
1022+
thread.state = ThreadState::Terminated;
1023+
10411024
for ptr in this.machine.threads.thread_terminated(this.machine.data_race.as_mut()) {
10421025
this.deallocate_ptr(ptr.into(), None, MiriMemoryKind::Tls.into())?;
10431026
}

0 commit comments

Comments
 (0)