16
16
import functools
17
17
import logging
18
18
import operator
19
+ import sys
19
20
from dataclasses import dataclass
20
21
from typing import List , Dict , Any , Set , Callable
21
22
from .....core import ChunkGraph , Chunk , TileContext
49
50
ExecutionChunkResult ,
50
51
register_executor_cls ,
51
52
)
52
- from .config import RayExecutionConfig
53
+ from .config import RayExecutionConfig , IN_RAY_CI
53
54
from .context import (
54
55
RayExecutionContext ,
55
56
RayExecutionWorkerContext ,
@@ -314,35 +315,54 @@ async def execute_subtask_graph(
314
315
) -> Dict [Chunk , ExecutionChunkResult ]:
315
316
if self ._cancelled is True : # pragma: no cover
316
317
raise asyncio .CancelledError ()
318
+ logger .info ("Stage %s start." , stage_id )
319
+ # Make sure each stage use a clean dict.
320
+ self ._cur_stage_first_output_object_ref_to_subtask = dict ()
317
321
318
- def _on_monitor_task_done (fut ):
322
+ def _on_monitor_aiotask_done (fut ):
319
323
# Print the error of monitor task.
320
324
try :
321
325
fut .result ()
322
326
except asyncio .CancelledError :
323
327
pass
328
+ except Exception : # pragma: no cover
329
+ logger .exception (
330
+ "The monitor task of stage %s is done with exception." , stage_id
331
+ )
332
+ if IN_RAY_CI : # pragma: no cover
333
+ logger .warning (
334
+ "The process will be exit due to the monitor task exception "
335
+ "when MARS_CI_BACKEND=ray."
336
+ )
337
+ sys .exit (- 1 )
324
338
339
+ result_meta_keys = {
340
+ chunk .key
341
+ for chunk in chunk_graph .result_chunks
342
+ if not isinstance (chunk .op , Fetch )
343
+ }
325
344
# Create a monitor task to update progress and collect garbage.
326
- monitor_task = asyncio .create_task (
345
+ monitor_aiotask = asyncio .create_task (
327
346
self ._update_progress_and_collect_garbage (
328
- subtask_graph , self ._config .get_subtask_monitor_interval ()
347
+ stage_id ,
348
+ subtask_graph ,
349
+ result_meta_keys ,
350
+ self ._config .get_subtask_monitor_interval (),
329
351
)
330
352
)
331
- monitor_task .add_done_callback (_on_monitor_task_done )
353
+ monitor_aiotask .add_done_callback (_on_monitor_aiotask_done )
332
354
333
- def _on_execute_task_done ( fut ):
355
+ def _on_execute_aiotask_done ( _ ):
334
356
# Make sure the monitor task is cancelled.
335
- monitor_task .cancel ()
357
+ monitor_aiotask .cancel ()
336
358
# Just use `self._cur_stage_tile_progress` as current stage progress
337
359
# because current stage is completed, its progress is 1.0.
338
360
self ._cur_stage_progress = 1.0
339
361
self ._pre_all_stages_progress += self ._cur_stage_tile_progress
340
- self ._cur_stage_first_output_object_ref_to_subtask .clear ()
341
362
342
363
self ._execute_subtask_graph_aiotask = asyncio .current_task ()
343
- self ._execute_subtask_graph_aiotask .add_done_callback (_on_execute_task_done )
364
+ self ._execute_subtask_graph_aiotask .add_done_callback (_on_execute_aiotask_done )
344
365
345
- logger .info ("Stage %s start." , stage_id )
346
366
task_context = self ._task_context
347
367
output_meta_object_refs = []
348
368
self ._pre_all_stages_tile_progress = (
@@ -352,11 +372,6 @@ def _on_execute_task_done(fut):
352
372
self ._tile_context .get_all_progress () - self ._pre_all_stages_tile_progress
353
373
)
354
374
logger .info ("Submitting %s subtasks of stage %s." , len (subtask_graph ), stage_id )
355
- result_meta_keys = {
356
- chunk .key
357
- for chunk in chunk_graph .result_chunks
358
- if not isinstance (chunk .op , Fetch )
359
- }
360
375
subtask_max_retries = self ._config .get_subtask_max_retries ()
361
376
for subtask in subtask_graph .topological_iter ():
362
377
subtask_chunk_graph = subtask .chunk_graph
@@ -555,7 +570,11 @@ def _get_subtask_output_keys(chunk_graph: ChunkGraph):
555
570
return output_keys .keys ()
556
571
557
572
async def _update_progress_and_collect_garbage (
558
- self , subtask_graph : SubtaskGraph , interval_seconds : float
573
+ self ,
574
+ stage_id : str ,
575
+ subtask_graph : SubtaskGraph ,
576
+ result_meta_keys : Set [str ],
577
+ interval_seconds : float ,
559
578
):
560
579
object_ref_to_subtask = self ._cur_stage_first_output_object_ref_to_subtask
561
580
total = len (subtask_graph )
@@ -579,7 +598,7 @@ def gc():
579
598
# Iterate the completed subtasks once.
580
599
subtask = completed_subtasks [i ]
581
600
i += 1
582
- logger .debug ("GC: %s" , subtask )
601
+ logger .debug ("GC[stage=%s] : %s" , stage_id , subtask )
583
602
584
603
# Note: There may be a scenario in which delayed gc occurs.
585
604
# When a subtask has more than one predecessor, like A, B,
@@ -595,15 +614,23 @@ def gc():
595
614
):
596
615
yield
597
616
for chunk in pred .chunk_graph .results :
598
- self ._task_context .pop (chunk .key , None )
617
+ chunk_key = chunk .key
618
+ # We need to check the GC chunk key is not in the
619
+ # result meta keys, because there are some special
620
+ # cases that the result meta keys are not the leaves.
621
+ #
622
+ # example: test_cut_execution
623
+ if chunk_key not in result_meta_keys :
624
+ logger .debug ("GC[stage=%s]: %s" , stage_id , chunk )
625
+ self ._task_context .pop (chunk_key , None )
599
626
gc_subtasks .add (pred )
600
627
601
628
# TODO(fyrestone): Check the remaining self._task_context.keys()
602
629
# in the result subtasks
603
630
604
631
collect_garbage = gc ()
605
632
606
- while len (completed_subtasks ) != total :
633
+ while len (completed_subtasks ) < total :
607
634
if len (object_ref_to_subtask ) <= 0 : # pragma: no cover
608
635
await asyncio .sleep (interval_seconds )
609
636
0 commit comments