Skip to content

Commit fbd0eed

Browse files
authored
Merge pull request #14 from realratchet/main
Re-raise child exceptions
2 parents 873e0d9 + 84815c6 commit fbd0eed

File tree

3 files changed

+136
-34
lines changed

3 files changed

+136
-34
lines changed

mplite/__init__.py

Lines changed: 113 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,24 @@
66
from tqdm import tqdm as _tqdm
77
import queue
88
from itertools import count
9-
from typing import Callable, Any, Union, Tuple
9+
from typing import Callable, Any, Union, Tuple, Literal
1010
from multiprocessing.context import BaseContext
11+
import tblib.pickling_support as pklex
1112

12-
major, minor, patch = 1, 2, 7
13+
major, minor, patch = 1, 3, 0
1314
__version_info__ = (major, minor, patch)
1415
__version__ = '.'.join(str(i) for i in __version_info__)
1516
default_context = "spawn"
1617

18+
ERR_MODE_STR = "str"
19+
ERR_MODE_EXCEPTION = "exception"
20+
21+
1722
class Task(object):
1823
task_id_counter = count(start=1)
1924

2025
def __init__(self, f, *args, **kwargs) -> None:
21-
26+
2227
if not callable(f):
2328
raise TypeError(f"{f} is not callable")
2429
self.f = f
@@ -37,15 +42,8 @@ def __repr__(self) -> str:
3742
return f"Task(f={self.f.__name__}, *{self.args}, **{self.kwargs})"
3843

3944
def execute(self):
40-
try:
41-
return self.f(*self.args, **self.kwargs)
42-
except Exception as e:
43-
f = io.StringIO()
44-
traceback.print_exc(limit=3, file=f)
45-
f.seek(0)
46-
error = f.read()
47-
f.close()
48-
return error
45+
return self.f(*self.args, **self.kwargs)
46+
4947

5048
class TaskChain(object):
5149
def __init__(self, task: Task, next_task: Callable[[Task, Any], Union[Task, "TaskChain"]] = None) -> None:
@@ -75,13 +73,13 @@ def resolve(self, result):
7573

7674
return task
7775
raise StopIteration()
78-
76+
7977
def __str__(self) -> str:
8078
return repr(self)
8179

8280
def __repr__(self) -> str:
8381
return f"TaskChain(f={self.task.f.__name__}, *{self.task.args}, **{self.task.kwargs}, is_last={self.next is None})"
84-
82+
8583
def execute(self):
8684
""" execute task chain synchronously """
8785
t = self
@@ -96,7 +94,7 @@ def execute(self):
9694

9795

9896
class Worker(object):
99-
def __init__(self, ctx: BaseContext, name: str, tq: multiprocessing.Queue, rq: multiprocessing.Queue, init: Task):
97+
def __init__(self, ctx: BaseContext, name: str, tq: multiprocessing.Queue, rq: multiprocessing.Queue, init: Task, error_mode: Literal["str", "exception"]):
10098
"""
10199
Worker class responsible for executing tasks in parallel, created by TaskManager.
102100
@@ -112,21 +110,25 @@ def __init__(self, ctx: BaseContext, name: str, tq: multiprocessing.Queue, rq: m
112110
Result queue
113111
init: Task
114112
Task executed when worker starts.
113+
error_mode: 'str' | 'exception'
114+
Which error mode to use, 'str' for legacy where exception is returned as string or 'exception' where exception is returned as pickled object.
115115
"""
116+
assert error_mode in (ERR_MODE_STR, ERR_MODE_EXCEPTION), f"Error mode must be in ('{ERR_MODE_STR}', '{ERR_MODE_EXCEPTION}'), got '{error_mode}'"
116117
self.ctx = ctx
117118
self.exit = ctx.Event()
118119
self.tq = tq # workers task queue
119120
self.rq = rq # workers result queue
120121
self.init = init
121122

123+
self.err_mode = error_mode
122124
self.process = ctx.Process(group=None, target=self.update, name=name, daemon=False)
123125

124126
def start(self):
125127
self.process.start()
126128

127129
def is_alive(self):
128130
return self.process.is_alive()
129-
131+
130132
@property
131133
def exitcode(self):
132134
return self.process.exitcode
@@ -135,6 +137,8 @@ def update(self):
135137
if self.init:
136138
self.init.f(*self.init.args, **self.init.kwargs)
137139

140+
do_task = _do_task_exception_mode if self.err_mode == ERR_MODE_EXCEPTION else _do_task_str_mode
141+
138142
while True:
139143
try:
140144
task = self.tq.get_nowait()
@@ -147,23 +151,43 @@ def update(self):
147151
break
148152

149153
elif isinstance(task, Task):
150-
result = task.execute()
151-
self.rq.put((task.id, result))
154+
self.rq.put((task.id, do_task(task)))
152155
else:
153156
time.sleep(0.01)
154157

155158

156159
class TaskManager(object):
157-
def __init__(self, cpu_count: int = None, context=default_context, worker_init: Task = None) -> None:
160+
def __init__(self, cpu_count: int = None, context=default_context, worker_init: Task = None, error_mode: Literal["str", "exception"] = ERR_MODE_STR) -> None:
161+
"""
162+
Class responsible for managing worker processes and tasks.
163+
164+
OPTIONAL
165+
--------
166+
cpu_count: int
167+
Number of worker processes to use.
168+
Default: {cpu core count}.
169+
ctx: BaseContext
170+
Process spawning context ForkContext/SpawnContext. Note: Windows cannot fork.
171+
Default: "spawn"
172+
worker_init: Task | None
173+
Task executed when worker starts.
174+
Default: None
175+
error_mode: 'str' | 'exception'
176+
Which error mode to use, 'str' for legacy where exception is returned as string or 'exception' where exception is returned as pickled object.
177+
Default: 'str'
178+
"""
179+
180+
assert error_mode in (ERR_MODE_STR, ERR_MODE_EXCEPTION), f"Error mode must be in ('{ERR_MODE_STR}', '{ERR_MODE_EXCEPTION}'), got '{error_mode}'"
181+
assert worker_init is None or isinstance(worker_init, Task), "Init is not (None, type[Task])"
182+
158183
self._ctx = multiprocessing.get_context(context)
159184
self._cpus = multiprocessing.cpu_count() if cpu_count is None else cpu_count
160185
self.tq = self._ctx.Queue()
161186
self.rq = self._ctx.Queue()
162187
self.pool: list[Worker] = []
163-
self._open_tasks = 0
164-
165-
assert worker_init is None or isinstance(worker_init, Task)
188+
self._open_tasks: list[int] = []
166189

190+
self.error_mode = error_mode
167191
self.worker_init = worker_init
168192

169193
def __enter__(self):
@@ -175,13 +199,13 @@ def __exit__(self, exc_type, exc_val, exc_tb): # signature requires these, thou
175199

176200
def start(self):
177201
for i in range(self._cpus): # create workers
178-
worker = Worker(self._ctx, name=str(i), tq=self.tq, rq=self.rq, init=self.worker_init)
202+
worker = Worker(self._ctx, name=str(i), tq=self.tq, rq=self.rq, init=self.worker_init, error_mode=self.error_mode)
179203
self.pool.append(worker)
180204
worker.start()
181205
while not all(p.is_alive() for p in self.pool):
182206
time.sleep(0.01)
183207

184-
def execute(self, tasks: "list[Union[Task, TaskChain]]", tqdm=_tqdm, pbar: _tqdm=None):
208+
def execute(self, tasks: "list[Union[Task, TaskChain]]", tqdm=_tqdm, pbar: _tqdm = None):
185209
"""
186210
Execute tasks using mplite
187211
@@ -207,7 +231,8 @@ def execute(self, tasks: "list[Union[Task, TaskChain]]", tqdm=_tqdm, pbar: _tqdm
207231
if None is provided, progress bar will be created using tqdm callable provided by tqdm parameter.
208232
"""
209233
task_count = len(tasks)
210-
self._open_tasks += task_count
234+
tasks_running = [t.id for t in tasks]
235+
self._open_tasks.extend(tasks_running)
211236
task_indices: dict[int, Tuple[int, Union[Task, TaskChain]]] = {}
212237

213238
for i, t in enumerate(tasks):
@@ -217,14 +242,20 @@ def execute(self, tasks: "list[Union[Task, TaskChain]]", tqdm=_tqdm, pbar: _tqdm
217242

218243
if pbar is None:
219244
""" if pbar object was not passed, create a new tqdm compatible object """
220-
pbar = tqdm(total=self._open_tasks, unit='tasks')
245+
pbar = tqdm(total=task_count, unit='tasks')
221246

222-
while self._open_tasks != 0:
247+
while len(tasks_running) > 0:
223248
try:
224-
task_key, res = self.rq.get_nowait()
249+
task_key, (success, res) = self.rq.get_nowait()
250+
251+
if not success and self.error_mode == ERR_MODE_EXCEPTION:
252+
[self._open_tasks.remove(idx) for idx in tasks_running]
253+
raise unpickle_exception(res)
254+
225255
idx, t = task_indices[task_key]
226256
if isinstance(t, Task) or t.next is None:
227-
self._open_tasks -= 1
257+
self._open_tasks.remove(t.id)
258+
tasks_running.remove(t.id)
228259
results[idx] = res
229260
pbar.update(1)
230261
else:
@@ -248,21 +279,26 @@ def submit(self, task: Task):
248279
""" permits asynchronous submission of tasks. """
249280
if not isinstance(task, Task):
250281
raise TypeError(f"expected mplite.Task, not {type(task)}")
251-
self._open_tasks += 1
282+
self._open_tasks.append(task.id)
252283
self.tq.put(task)
253284

254285
def take(self):
255286
""" permits asynchronous retrieval of results """
256287
try:
257-
_, result = self.rq.get_nowait()
258-
self._open_tasks -= 1
288+
task_id, (success, result) = self.rq.get_nowait()
289+
290+
self._open_tasks.remove(task_id)
291+
292+
if not success and self.error_mode == ERR_MODE_EXCEPTION:
293+
raise unpickle_exception(result)
294+
259295
except queue.Empty:
260296
result = None
261297
return result
262298

263299
@property
264300
def open_tasks(self):
265-
return self._open_tasks
301+
return len(self._open_tasks)
266302

267303
def stop(self):
268304
for _ in range(self._cpus):
@@ -274,3 +310,47 @@ def stop(self):
274310
_ = self.tq.get_nowait()
275311
while not self.rq.empty:
276312
_ = self.rq.get_nowait()
313+
314+
315+
def pickle_exception(e: Exception):
316+
if e.__traceback__ is not None:
317+
tback = pklex.pickle_traceback(e.__traceback__)
318+
e.__traceback__ = None
319+
else:
320+
tback = None
321+
322+
fn_ex, (ex_cls, ex_txt, ex_rsn, _, *others) = pklex.pickle_exception(e)
323+
324+
return fn_ex, (ex_cls, ex_txt, ex_rsn, tback, *others)
325+
326+
327+
def unpickle_exception(e):
328+
fn_ex, (ex_cls, ex_txt, ex_rsn, tback, *others) = e
329+
330+
if tback is not None:
331+
fn_tback, args_tback = tback
332+
tback = fn_tback(*args_tback)
333+
334+
return fn_ex(ex_cls, ex_txt, ex_rsn, tback, *others)
335+
336+
337+
def _do_task_exception_mode(task: Task):
338+
""" execute task in exception mode"""
339+
try:
340+
return True, task.execute()
341+
except Exception as e:
342+
return False, pickle_exception(e)
343+
344+
345+
def _do_task_str_mode(task: Task):
346+
""" execute task in legacy string mode """
347+
try:
348+
return True, task.execute()
349+
except Exception:
350+
f = io.StringIO()
351+
traceback.print_exc(limit=3, file=f)
352+
f.seek(0)
353+
error = f.read()
354+
f.close()
355+
356+
return False, error

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
tqdm>=4.63.0
1+
tqdm>=4.63.0
2+
tblib

tests/test_basics.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import signal
44
from mplite import TaskManager, Task, TaskChain
55
import time
6+
import traceback
67
import random
78

89
def test_alpha():
@@ -204,5 +205,25 @@ def post_1(prev, res):
204205

205206
assert res == [3, 3, 3, 3, 3]
206207

208+
def task_exception(i):
209+
if i == 4:
210+
raise ValueError(f"my exception: {i}")
211+
212+
return i
213+
214+
def test_exception_mode():
215+
tasks = [Task(task_exception, i) for i in range(10)]
216+
217+
with TaskManager(10, error_mode="exception") as tm:
218+
try:
219+
[k for k, *_ in tm.execute(tasks)]
220+
assert False
221+
except Exception as e:
222+
assert tm.open_tasks == 0, "there should be no left-over tasks"
223+
assert str(e) == "my exception: 4", "wrong exception"
224+
assert isinstance(e, ValueError), "wrong exception type"
225+
assert type(e.__traceback__).__name__ == "traceback", "not a traceback"
226+
assert 'in task_exception\n raise ValueError(f"my exception: {i}")\n' in traceback.format_tb(e.__traceback__)[-1], "wrong callstack"
227+
207228
if __name__ == "__main__":
208229
test_task_order()

0 commit comments

Comments
 (0)