Skip to content

Commit 568a837

Browse files
authored
Fix non-abort on slurm tests (#925)
1 parent aefe26d commit 568a837

File tree

1 file changed

+39
-9
lines changed

1 file changed

+39
-9
lines changed

toolchain/mfc/sched.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
11
import time, typing, threading, dataclasses
22
import rich, rich.progress
3+
import traceback
34

45
from .printer import cons
56

6-
7-
8-
97
class WorkerThread(threading.Thread):
108
def __init__(self, *args, **kwargs):
119
self.exc = None
10+
self.exc_info = None # Store full exception information for better debugging
11+
self.completed_successfully = False # Track if the target function completed
1212

1313
threading.Thread.__init__(self, *args, **kwargs)
1414

1515
def run(self):
1616
try:
1717
if self._target:
1818
self._target(*self._args, **self._kwargs)
19+
self.completed_successfully = True # Mark as completed successfully
1920
except Exception as exc:
2021
self.exc = exc
22+
# Store the full traceback for better error reporting
23+
self.exc_info = traceback.format_exc()
2124

2225

2326
@dataclasses.dataclass
@@ -35,7 +38,6 @@ class Task:
3538
args: typing.List[typing.Any]
3639
load: float
3740

38-
3941
def sched(tasks: typing.List[Task], nThreads: int, devices: typing.Set[int] = None) -> None:
4042
nAvailable: int = nThreads
4143
threads: typing.List[WorkerThreadHolder] = []
@@ -46,9 +48,39 @@ def join_first_dead_thread(progress, complete_tracker) -> None:
4648
nonlocal threads, nAvailable
4749

4850
for threadID, threadHolder in enumerate(threads):
49-
if not threadHolder.thread.is_alive():
51+
# Check if thread is not alive OR if it's been running for too long
52+
thread_not_alive = not threadHolder.thread.is_alive()
53+
54+
if thread_not_alive:
55+
# Properly join the thread with timeout to prevent infinite hangs
56+
try:
57+
threadHolder.thread.join(timeout=30.0) # 30 second timeout
58+
59+
# Double-check that thread actually finished joining
60+
if threadHolder.thread.is_alive():
61+
# Thread didn't finish within timeout - this is a serious issue
62+
raise RuntimeError(f"Thread {threadID} failed to join within 30 seconds timeout. "
63+
f"Thread may be hung or in an inconsistent state.")
64+
65+
except Exception as join_exc:
66+
# Handle join-specific exceptions with more context
67+
raise RuntimeError(f"Failed to join thread {threadID}: {join_exc}. "
68+
f"This may indicate a system threading issue or hung test case.") from join_exc
69+
70+
# Check for and propagate any exceptions that occurred in the worker thread
71+
# But only if the worker function didn't complete successfully
72+
# (This allows test failures to be handled gracefully by handle_case)
5073
if threadHolder.thread.exc is not None:
51-
raise threadHolder.thread.exc
74+
if threadHolder.thread.completed_successfully:
75+
# Test framework handled the exception gracefully (e.g., test failure)
76+
# Don't re-raise - this is expected behavior
77+
pass
78+
# Unhandled exception - this indicates a real problem
79+
elif hasattr(threadHolder.thread, 'exc_info') and threadHolder.thread.exc_info:
80+
error_msg = f"Worker thread {threadID} failed with unhandled exception:\n{threadHolder.thread.exc_info}"
81+
raise RuntimeError(error_msg) from threadHolder.thread.exc
82+
else:
83+
raise threadHolder.thread.exc
5284

5385
nAvailable += threadHolder.ppn
5486
for device in threadHolder.devices or set():
@@ -60,7 +92,6 @@ def join_first_dead_thread(progress, complete_tracker) -> None:
6092

6193
break
6294

63-
6495
with rich.progress.Progress(console=cons.raw, transient=True) as progress:
6596
queue_tracker = progress.add_task("Queued ", total=len(tasks))
6697
complete_tracker = progress.add_task("Completed", total=len(tasks))
@@ -99,8 +130,7 @@ def join_first_dead_thread(progress, complete_tracker) -> None:
99130

100131
threads.append(WorkerThreadHolder(thread, task.ppn, task.load, use_devices))
101132

102-
103-
# Wait for the lasts tests to complete
133+
# Wait for the last tests to complete (MOVED INSIDE CONTEXT)
104134
while len(threads) != 0:
105135
# Keep track of threads that are done
106136
join_first_dead_thread(progress, complete_tracker)

0 commit comments

Comments
 (0)