6
6
from tqdm import tqdm as _tqdm
7
7
import queue
8
8
from itertools import count
9
- from typing import Callable , Any , Union , Tuple
9
+ from typing import Callable , Any , Union , Tuple , Literal
10
10
from multiprocessing .context import BaseContext
11
+ import tblib .pickling_support as pklex
11
12
12
- major , minor , patch = 1 , 2 , 7
13
+ major , minor , patch = 1 , 3 , 0
13
14
__version_info__ = (major , minor , patch )
14
15
__version__ = '.' .join (str (i ) for i in __version_info__ )
15
16
default_context = "spawn"
16
17
18
+ ERR_MODE_STR = "str"
19
+ ERR_MODE_EXCEPTION = "exception"
20
+
21
+
17
22
class Task (object ):
18
23
task_id_counter = count (start = 1 )
19
24
20
25
def __init__ (self , f , * args , ** kwargs ) -> None :
21
-
26
+
22
27
if not callable (f ):
23
28
raise TypeError (f"{ f } is not callable" )
24
29
self .f = f
@@ -37,15 +42,8 @@ def __repr__(self) -> str:
37
42
return f"Task(f={ self .f .__name__ } , *{ self .args } , **{ self .kwargs } )"
38
43
39
44
def execute (self ):
40
- try :
41
- return self .f (* self .args , ** self .kwargs )
42
- except Exception as e :
43
- f = io .StringIO ()
44
- traceback .print_exc (limit = 3 , file = f )
45
- f .seek (0 )
46
- error = f .read ()
47
- f .close ()
48
- return error
45
+ return self .f (* self .args , ** self .kwargs )
46
+
49
47
50
48
class TaskChain (object ):
51
49
def __init__ (self , task : Task , next_task : Callable [[Task , Any ], Union [Task , "TaskChain" ]] = None ) -> None :
@@ -75,13 +73,13 @@ def resolve(self, result):
75
73
76
74
return task
77
75
raise StopIteration ()
78
-
76
+
79
77
def __str__ (self ) -> str :
80
78
return repr (self )
81
79
82
80
def __repr__ (self ) -> str :
83
81
return f"TaskChain(f={ self .task .f .__name__ } , *{ self .task .args } , **{ self .task .kwargs } , is_last={ self .next is None } )"
84
-
82
+
85
83
def execute (self ):
86
84
""" execute task chain synchronously """
87
85
t = self
@@ -96,7 +94,7 @@ def execute(self):
96
94
97
95
98
96
class Worker (object ):
99
- def __init__ (self , ctx : BaseContext , name : str , tq : multiprocessing .Queue , rq : multiprocessing .Queue , init : Task ):
97
+ def __init__ (self , ctx : BaseContext , name : str , tq : multiprocessing .Queue , rq : multiprocessing .Queue , init : Task , error_mode : Literal [ "str" , "exception" ] ):
100
98
"""
101
99
Worker class responsible for executing tasks in parallel, created by TaskManager.
102
100
@@ -112,21 +110,25 @@ def __init__(self, ctx: BaseContext, name: str, tq: multiprocessing.Queue, rq: m
112
110
Result queue
113
111
init: Task
114
112
Task executed when worker starts.
113
+ error_mode: 'str' | 'exception'
114
+ Which error mode to use, 'str' for legacy where exception is returned as string or 'exception' where exception is returned as pickled object.
115
115
"""
116
+ assert error_mode in (ERR_MODE_STR , ERR_MODE_EXCEPTION ), f"Error mode must be in ('{ ERR_MODE_STR } ', '{ ERR_MODE_EXCEPTION } '), got '{ error_mode } '"
116
117
self .ctx = ctx
117
118
self .exit = ctx .Event ()
118
119
self .tq = tq # workers task queue
119
120
self .rq = rq # workers result queue
120
121
self .init = init
121
122
123
+ self .err_mode = error_mode
122
124
self .process = ctx .Process (group = None , target = self .update , name = name , daemon = False )
123
125
124
126
def start (self ):
125
127
self .process .start ()
126
128
127
129
def is_alive (self ):
128
130
return self .process .is_alive ()
129
-
131
+
130
132
@property
131
133
def exitcode (self ):
132
134
return self .process .exitcode
@@ -135,6 +137,8 @@ def update(self):
135
137
if self .init :
136
138
self .init .f (* self .init .args , ** self .init .kwargs )
137
139
140
+ do_task = _do_task_exception_mode if self .err_mode == ERR_MODE_EXCEPTION else _do_task_str_mode
141
+
138
142
while True :
139
143
try :
140
144
task = self .tq .get_nowait ()
@@ -147,23 +151,43 @@ def update(self):
147
151
break
148
152
149
153
elif isinstance (task , Task ):
150
- result = task .execute ()
151
- self .rq .put ((task .id , result ))
154
+ self .rq .put ((task .id , do_task (task )))
152
155
else :
153
156
time .sleep (0.01 )
154
157
155
158
156
159
class TaskManager (object ):
157
- def __init__ (self , cpu_count : int = None , context = default_context , worker_init : Task = None ) -> None :
160
+ def __init__ (self , cpu_count : int = None , context = default_context , worker_init : Task = None , error_mode : Literal ["str" , "exception" ] = ERR_MODE_STR ) -> None :
161
+ """
162
+ Class responsible for managing worker processes and tasks.
163
+
164
+ OPTIONAL
165
+ --------
166
+ cpu_count: int
167
+ Number of worker processes to use.
168
+ Default: {cpu core count}.
169
+ ctx: BaseContext
170
+ Process spawning context ForkContext/SpawnContext. Note: Windows cannot fork.
171
+ Default: "spawn"
172
+ worker_init: Task | None
173
+ Task executed when worker starts.
174
+ Default: None
175
+ error_mode: 'str' | 'exception'
176
+ Which error mode to use, 'str' for legacy where exception is returned as string or 'exception' where exception is returned as pickled object.
177
+ Default: 'str'
178
+ """
179
+
180
+ assert error_mode in (ERR_MODE_STR , ERR_MODE_EXCEPTION ), f"Error mode must be in ('{ ERR_MODE_STR } ', '{ ERR_MODE_EXCEPTION } '), got '{ error_mode } '"
181
+ assert worker_init is None or isinstance (worker_init , Task ), "Init is not (None, type[Task])"
182
+
158
183
self ._ctx = multiprocessing .get_context (context )
159
184
self ._cpus = multiprocessing .cpu_count () if cpu_count is None else cpu_count
160
185
self .tq = self ._ctx .Queue ()
161
186
self .rq = self ._ctx .Queue ()
162
187
self .pool : list [Worker ] = []
163
- self ._open_tasks = 0
164
-
165
- assert worker_init is None or isinstance (worker_init , Task )
188
+ self ._open_tasks : list [int ] = []
166
189
190
+ self .error_mode = error_mode
167
191
self .worker_init = worker_init
168
192
169
193
def __enter__ (self ):
@@ -175,13 +199,13 @@ def __exit__(self, exc_type, exc_val, exc_tb): # signature requires these, thou
175
199
176
200
def start (self ):
177
201
for i in range (self ._cpus ): # create workers
178
- worker = Worker (self ._ctx , name = str (i ), tq = self .tq , rq = self .rq , init = self .worker_init )
202
+ worker = Worker (self ._ctx , name = str (i ), tq = self .tq , rq = self .rq , init = self .worker_init , error_mode = self . error_mode )
179
203
self .pool .append (worker )
180
204
worker .start ()
181
205
while not all (p .is_alive () for p in self .pool ):
182
206
time .sleep (0.01 )
183
207
184
- def execute (self , tasks : "list[Union[Task, TaskChain]]" , tqdm = _tqdm , pbar : _tqdm = None ):
208
+ def execute (self , tasks : "list[Union[Task, TaskChain]]" , tqdm = _tqdm , pbar : _tqdm = None ):
185
209
"""
186
210
Execute tasks using mplite
187
211
@@ -207,7 +231,8 @@ def execute(self, tasks: "list[Union[Task, TaskChain]]", tqdm=_tqdm, pbar: _tqdm
207
231
if None is provided, progress bar will be created using tqdm callable provided by tqdm parameter.
208
232
"""
209
233
task_count = len (tasks )
210
- self ._open_tasks += task_count
234
+ tasks_running = [t .id for t in tasks ]
235
+ self ._open_tasks .extend (tasks_running )
211
236
task_indices : dict [int , Tuple [int , Union [Task , TaskChain ]]] = {}
212
237
213
238
for i , t in enumerate (tasks ):
@@ -217,14 +242,20 @@ def execute(self, tasks: "list[Union[Task, TaskChain]]", tqdm=_tqdm, pbar: _tqdm
217
242
218
243
if pbar is None :
219
244
""" if pbar object was not passed, create a new tqdm compatible object """
220
- pbar = tqdm (total = self . _open_tasks , unit = 'tasks' )
245
+ pbar = tqdm (total = task_count , unit = 'tasks' )
221
246
222
- while self . _open_tasks != 0 :
247
+ while len ( tasks_running ) > 0 :
223
248
try :
224
- task_key , res = self .rq .get_nowait ()
249
+ task_key , (success , res ) = self .rq .get_nowait ()
250
+
251
+ if not success and self .error_mode == ERR_MODE_EXCEPTION :
252
+ [self ._open_tasks .remove (idx ) for idx in tasks_running ]
253
+ raise unpickle_exception (res )
254
+
225
255
idx , t = task_indices [task_key ]
226
256
if isinstance (t , Task ) or t .next is None :
227
- self ._open_tasks -= 1
257
+ self ._open_tasks .remove (t .id )
258
+ tasks_running .remove (t .id )
228
259
results [idx ] = res
229
260
pbar .update (1 )
230
261
else :
@@ -248,21 +279,26 @@ def submit(self, task: Task):
248
279
""" permits asynchronous submission of tasks. """
249
280
if not isinstance (task , Task ):
250
281
raise TypeError (f"expected mplite.Task, not { type (task )} " )
251
- self ._open_tasks += 1
282
+ self ._open_tasks . append ( task . id )
252
283
self .tq .put (task )
253
284
254
285
def take (self ):
255
286
""" permits asynchronous retrieval of results """
256
287
try :
257
- _ , result = self .rq .get_nowait ()
258
- self ._open_tasks -= 1
288
+ task_id , (success , result ) = self .rq .get_nowait ()
289
+
290
+ self ._open_tasks .remove (task_id )
291
+
292
+ if not success and self .error_mode == ERR_MODE_EXCEPTION :
293
+ raise unpickle_exception (result )
294
+
259
295
except queue .Empty :
260
296
result = None
261
297
return result
262
298
263
299
@property
264
300
def open_tasks (self ):
265
- return self ._open_tasks
301
+ return len ( self ._open_tasks )
266
302
267
303
def stop (self ):
268
304
for _ in range (self ._cpus ):
@@ -274,3 +310,47 @@ def stop(self):
274
310
_ = self .tq .get_nowait ()
275
311
while not self .rq .empty :
276
312
_ = self .rq .get_nowait ()
313
+
314
+
315
+ def pickle_exception (e : Exception ):
316
+ if e .__traceback__ is not None :
317
+ tback = pklex .pickle_traceback (e .__traceback__ )
318
+ e .__traceback__ = None
319
+ else :
320
+ tback = None
321
+
322
+ fn_ex , (ex_cls , ex_txt , ex_rsn , _ , * others ) = pklex .pickle_exception (e )
323
+
324
+ return fn_ex , (ex_cls , ex_txt , ex_rsn , tback , * others )
325
+
326
+
327
+ def unpickle_exception (e ):
328
+ fn_ex , (ex_cls , ex_txt , ex_rsn , tback , * others ) = e
329
+
330
+ if tback is not None :
331
+ fn_tback , args_tback = tback
332
+ tback = fn_tback (* args_tback )
333
+
334
+ return fn_ex (ex_cls , ex_txt , ex_rsn , tback , * others )
335
+
336
+
337
+ def _do_task_exception_mode (task : Task ):
338
+ """ execute task in exception mode"""
339
+ try :
340
+ return True , task .execute ()
341
+ except Exception as e :
342
+ return False , pickle_exception (e )
343
+
344
+
345
+ def _do_task_str_mode (task : Task ):
346
+ """ execute task in legacy string mode """
347
+ try :
348
+ return True , task .execute ()
349
+ except Exception :
350
+ f = io .StringIO ()
351
+ traceback .print_exc (limit = 3 , file = f )
352
+ f .seek (0 )
353
+ error = f .read ()
354
+ f .close ()
355
+
356
+ return False , error
0 commit comments