28
28
execute ,
29
29
)
30
30
from .....lib .aio import alru_cache
31
+ from .....lib .ordered_set import OrderedSet
31
32
from .....resource import Resource
32
33
from .....serialization import serialize , deserialize
33
34
from .....typing import BandType
@@ -96,7 +97,7 @@ def _optimize_subtask_graph(subtask_graph):
96
97
97
98
98
99
async def _cancel_ray_task (obj_ref , kill_timeout : int = 3 ):
99
- ray .cancel ( obj_ref , force = False )
100
+ await asyncio . to_thread ( ray .cancel , obj_ref , force = False )
100
101
try :
101
102
await asyncio .to_thread (ray .get , obj_ref , timeout = kill_timeout )
102
103
except ray .exceptions .TaskCancelledError : # pragma: no cover
@@ -108,7 +109,7 @@ async def _cancel_ray_task(obj_ref, kill_timeout: int = 3):
108
109
e ,
109
110
obj_ref ,
110
111
)
111
- ray .cancel ( obj_ref , force = True )
112
+ await asyncio . to_thread ( ray .cancel , obj_ref , force = True )
112
113
113
114
114
115
def execute_subtask (
@@ -183,17 +184,12 @@ def __init__(
183
184
184
185
self ._available_band_resources = None
185
186
186
- # For progress
187
+ # For progress and task cancel
187
188
self ._pre_all_stages_progress = 0.0
188
- self ._pre_all_stages_tile_progress = 0
189
- self ._cur_stage_tile_progress = 0
190
- self ._cur_stage_output_object_refs = []
191
- # This list records the output object ref number of subtasks, so with
192
- # `self._cur_stage_output_object_refs` we can just call `ray.cancel`
193
- # with one object ref to cancel a subtask instead of cancel all object
194
- # refs. In this way we can reduce a lot of unnecessary calls of ray.
195
- self ._output_object_refs_nums = []
196
- # For meta and data gc
189
+ self ._pre_all_stages_tile_progress = 0.0
190
+ self ._cur_stage_progress = 0.0
191
+ self ._cur_stage_tile_progress = 0.0
192
+ self ._cur_stage_first_output_object_ref_to_subtask = dict ()
197
193
self ._execute_subtask_graph_aiotask = None
198
194
self ._cancelled = False
199
195
@@ -258,12 +254,12 @@ def destroy(self):
258
254
259
255
self ._available_band_resources = None
260
256
261
- # For progress
262
- self ._pre_all_stages_progress = 1
263
- self ._pre_all_stages_tile_progress = 1
264
- self ._cur_stage_tile_progress = 1
265
- self ._cur_stage_output_object_refs = []
266
- self ._output_object_refs_nums = []
257
+ # For progress and task cancel
258
+ self ._pre_all_stages_progress = 1.0
259
+ self ._pre_all_stages_tile_progress = 1.0
260
+ self ._cur_stage_progress = 1.0
261
+ self ._cur_stage_tile_progress = 1.0
262
+ self ._cur_stage_first_output_object_ref_to_subtask = dict ()
267
263
self ._execute_subtask_graph_aiotask = None
268
264
self ._cancelled = None
269
265
@@ -318,7 +314,33 @@ async def execute_subtask_graph(
318
314
) -> Dict [Chunk , ExecutionChunkResult ]:
319
315
if self ._cancelled is True : # pragma: no cover
320
316
raise asyncio .CancelledError ()
317
+
318
+ def _on_monitor_task_done (fut ):
319
+ # Print the error of monitor task.
320
+ try :
321
+ fut .result ()
322
+ except asyncio .CancelledError :
323
+ pass
324
+
325
+ # Create a monitor task to update progress and collect garbage.
326
+ monitor_task = asyncio .create_task (
327
+ self ._update_progress_and_collect_garbage (
328
+ subtask_graph , self ._config .get_subtask_monitor_interval ()
329
+ )
330
+ )
331
+ monitor_task .add_done_callback (_on_monitor_task_done )
332
+
333
+ def _on_execute_task_done (fut ):
334
+ # Make sure the monitor task is cancelled.
335
+ monitor_task .cancel ()
336
+ # Just use `self._cur_stage_tile_progress` as current stage progress
337
+ # because current stage is completed, its progress is 1.0.
338
+ self ._cur_stage_progress = 1.0
339
+ self ._pre_all_stages_progress += self ._cur_stage_tile_progress
340
+ self ._cur_stage_first_output_object_ref_to_subtask .clear ()
341
+
321
342
self ._execute_subtask_graph_aiotask = asyncio .current_task ()
343
+ self ._execute_subtask_graph_aiotask .add_done_callback (_on_execute_task_done )
322
344
323
345
logger .info ("Stage %s start." , stage_id )
324
346
task_context = self ._task_context
@@ -359,8 +381,9 @@ async def execute_subtask_graph(
359
381
continue
360
382
elif output_count == 1 :
361
383
output_object_refs = [output_object_refs ]
362
- self ._cur_stage_output_object_refs .extend (output_object_refs )
363
- self ._output_object_refs_nums .append (len (output_object_refs ))
384
+ self ._cur_stage_first_output_object_ref_to_subtask [
385
+ output_object_refs [0 ]
386
+ ] = subtask
364
387
if output_meta_keys :
365
388
meta_object_ref , * output_object_refs = output_object_refs
366
389
# TODO(fyrestone): Fetch(not get) meta object here.
@@ -395,16 +418,16 @@ async def execute_subtask_graph(
395
418
logger .info ("Waiting for stage %s complete." , stage_id )
396
419
# Patched the asyncio.to_thread for Python < 3.9 at mars/lib/aio/__init__.py
397
420
await asyncio .to_thread (ray .wait , list (output_object_refs ), fetch_local = False )
398
- # Just use `self._cur_stage_tile_progress` as current stage progress
399
- # because current stage is finished, its progress is 1.
400
- self ._pre_all_stages_progress += self ._cur_stage_tile_progress
401
- self ._cur_stage_output_object_refs .clear ()
402
- self ._output_object_refs_nums .clear ()
421
+
403
422
logger .info ("Stage %s is complete." , stage_id )
404
423
return chunk_to_meta
405
424
406
425
async def __aexit__ (self , exc_type , exc_val , exc_tb ):
407
426
if exc_type is not None :
427
+ try :
428
+ await self .cancel ()
429
+ except BaseException : # noqa: E722 # nosec # pylint: disable=bare-except
430
+ pass
408
431
return
409
432
410
433
# Update info if no exception occurs.
@@ -458,19 +481,7 @@ async def get_available_band_resources(self) -> Dict[BandType, Resource]:
458
481
459
482
async def get_progress (self ) -> float :
460
483
"""Get the execution progress."""
461
- stage_progress = 0.0
462
- total = len (self ._cur_stage_output_object_refs )
463
- if total > 0 :
464
- finished_objects , _ = ray .wait (
465
- self ._cur_stage_output_object_refs ,
466
- num_returns = total ,
467
- timeout = 0 , # Avoid blocking the asyncio loop.
468
- fetch_local = False ,
469
- )
470
- stage_progress = (
471
- len (finished_objects ) / total * self ._cur_stage_tile_progress
472
- )
473
- return self ._pre_all_stages_progress + stage_progress
484
+ return self ._cur_stage_progress
474
485
475
486
async def cancel (self ):
476
487
"""
@@ -480,26 +491,17 @@ async def cancel(self):
480
491
2. Try to cancel the submitted subtasks by `ray.cancel`
481
492
"""
482
493
logger .info ("Start to cancel task %s." , self ._task )
483
- if self ._task is None :
494
+ if self ._task is None or self . _cancelled is True :
484
495
return
485
496
self ._cancelled = True
486
- if (
487
- self ._execute_subtask_graph_aiotask is not None
488
- and not self ._execute_subtask_graph_aiotask .cancelled ()
489
- ):
497
+ if self ._execute_subtask_graph_aiotask is not None :
490
498
self ._execute_subtask_graph_aiotask .cancel ()
491
499
timeout = self ._config .get_subtask_cancel_timeout ()
492
- subtask_num = len (self ._output_object_refs_nums )
493
- if subtask_num > 0 :
494
- pos = 0
495
- obj_refs_to_be_cancelled_ = []
496
- for i in range (0 , subtask_num ):
497
- if i > 0 :
498
- pos += self ._output_object_refs_nums [i - 1 ]
499
- obj_refs_to_be_cancelled_ .append (
500
- _cancel_ray_task (self ._cur_stage_output_object_refs [pos ], timeout )
501
- )
502
- await asyncio .gather (* obj_refs_to_be_cancelled_ )
500
+ to_be_cancelled_coros = [
501
+ _cancel_ray_task (object_ref , timeout )
502
+ for object_ref in self ._cur_stage_first_output_object_ref_to_subtask .keys ()
503
+ ]
504
+ await asyncio .gather (* to_be_cancelled_coros )
503
505
504
506
async def _load_subtask_inputs (
505
507
self , stage_id : str , subtask : Subtask , chunk_graph : ChunkGraph , context : Dict
@@ -551,3 +553,81 @@ def _get_subtask_output_keys(chunk_graph: ChunkGraph):
551
553
else :
552
554
output_keys [chunk .key ] = 1
553
555
return output_keys .keys ()
556
+
557
+ async def _update_progress_and_collect_garbage (
558
+ self , subtask_graph : SubtaskGraph , interval_seconds : float
559
+ ):
560
+ object_ref_to_subtask = self ._cur_stage_first_output_object_ref_to_subtask
561
+ total = len (subtask_graph )
562
+ completed_subtasks = OrderedSet ()
563
+
564
+ def gc ():
565
+ """
566
+ Consume the completed subtasks and collect garbage.
567
+
568
+ GC the output object refs of the subtask which successors are submitted
569
+ (not completed as above) can reduce the memory peaks, but we can't cancel
570
+ and rerun slow subtasks because the input object refs of running subtasks
571
+ may be deleted.
572
+ """
573
+ i = 0
574
+ gc_subtasks = set ()
575
+
576
+ while i < total :
577
+ while i >= len (completed_subtasks ):
578
+ yield
579
+ # Iterate the completed subtasks once.
580
+ subtask = completed_subtasks [i ]
581
+ i += 1
582
+ logger .debug ("GC: %s" , subtask )
583
+
584
+ # Note: There may be a scenario in which delayed gc occurs.
585
+ # When a subtask has more than one predecessor, like A, B,
586
+ # and in the `for ... in ...` loop we get A firstly while
587
+ # B's successors are completed, A's not. Then we cannot remove
588
+ # B's results chunks before A's.
589
+ for pred in subtask_graph .iter_predecessors (subtask ):
590
+ if pred in gc_subtasks :
591
+ continue
592
+ while not all (
593
+ succ in completed_subtasks
594
+ for succ in subtask_graph .iter_successors (pred )
595
+ ):
596
+ yield
597
+ for chunk in pred .chunk_graph .results :
598
+ self ._task_context .pop (chunk .key , None )
599
+ gc_subtasks .add (pred )
600
+
601
+ # TODO(fyrestone): Check the remaining self._task_context.keys()
602
+ # in the result subtasks
603
+
604
+ collect_garbage = gc ()
605
+
606
+ while len (completed_subtasks ) != total :
607
+ if len (object_ref_to_subtask ) <= 0 : # pragma: no cover
608
+ await asyncio .sleep (interval_seconds )
609
+
610
+ # Only wait for unready subtask object refs.
611
+ ready_objects , _ = await asyncio .to_thread (
612
+ ray .wait ,
613
+ list (object_ref_to_subtask .keys ()),
614
+ num_returns = len (object_ref_to_subtask ),
615
+ timeout = 0 ,
616
+ fetch_local = False ,
617
+ )
618
+ if len (ready_objects ) == 0 :
619
+ await asyncio .sleep (interval_seconds )
620
+ continue
621
+
622
+ # Pop the completed subtasks from object_ref_to_subtask.
623
+ completed_subtasks .update (map (object_ref_to_subtask .pop , ready_objects ))
624
+ # Update progress.
625
+ stage_progress = (
626
+ len (completed_subtasks ) / total * self ._cur_stage_tile_progress
627
+ )
628
+ self ._cur_stage_progress = self ._pre_all_stages_progress + stage_progress
629
+ # Collect garbage, use `for ... in ...` to avoid raising StopIteration.
630
+ for _ in collect_garbage :
631
+ break
632
+ # Fast to next loop and give it a chance to update object_ref_to_subtask.
633
+ await asyncio .sleep (0 )
0 commit comments