Skip to content

Commit 6ffc7b9

Browse files
authored
[Ray] Fix Ray context GC (#3118)
1 parent 78628b6 commit 6ffc7b9

File tree

3 files changed

+62
-26
lines changed

3 files changed

+62
-26
lines changed

mars/services/task/execution/ray/config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,19 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
16+
import logging
1517
from typing import Dict, List
1618
from .....resource import Resource
1719
from ..api import ExecutionConfig, register_config_cls
1820
from ..utils import get_band_resources_from_config
1921

22+
23+
logger = logging.getLogger(__name__)
24+
25+
IN_RAY_CI = os.environ.get("MARS_CI_BACKEND", "mars") == "ray"
2026
# The default interval seconds to update progress and collect garbage.
21-
DEFAULT_SUBTASK_MONITOR_INTERVAL = 1
27+
DEFAULT_SUBTASK_MONITOR_INTERVAL = 0 if IN_RAY_CI else 1
2228

2329

2430
@register_config_cls

mars/services/task/execution/ray/executor.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import functools
1717
import logging
1818
import operator
19+
import sys
1920
from dataclasses import dataclass
2021
from typing import List, Dict, Any, Set, Callable
2122
from .....core import ChunkGraph, Chunk, TileContext
@@ -49,7 +50,7 @@
4950
ExecutionChunkResult,
5051
register_executor_cls,
5152
)
52-
from .config import RayExecutionConfig
53+
from .config import RayExecutionConfig, IN_RAY_CI
5354
from .context import (
5455
RayExecutionContext,
5556
RayExecutionWorkerContext,
@@ -314,35 +315,54 @@ async def execute_subtask_graph(
314315
) -> Dict[Chunk, ExecutionChunkResult]:
315316
if self._cancelled is True: # pragma: no cover
316317
raise asyncio.CancelledError()
318+
logger.info("Stage %s start.", stage_id)
319+
# Make sure each stage use a clean dict.
320+
self._cur_stage_first_output_object_ref_to_subtask = dict()
317321

318-
def _on_monitor_task_done(fut):
322+
def _on_monitor_aiotask_done(fut):
319323
# Print the error of monitor task.
320324
try:
321325
fut.result()
322326
except asyncio.CancelledError:
323327
pass
328+
except Exception: # pragma: no cover
329+
logger.exception(
330+
"The monitor task of stage %s is done with exception.", stage_id
331+
)
332+
if IN_RAY_CI: # pragma: no cover
333+
logger.warning(
334+
"The process will be exit due to the monitor task exception "
335+
"when MARS_CI_BACKEND=ray."
336+
)
337+
sys.exit(-1)
324338

339+
result_meta_keys = {
340+
chunk.key
341+
for chunk in chunk_graph.result_chunks
342+
if not isinstance(chunk.op, Fetch)
343+
}
325344
# Create a monitor task to update progress and collect garbage.
326-
monitor_task = asyncio.create_task(
345+
monitor_aiotask = asyncio.create_task(
327346
self._update_progress_and_collect_garbage(
328-
subtask_graph, self._config.get_subtask_monitor_interval()
347+
stage_id,
348+
subtask_graph,
349+
result_meta_keys,
350+
self._config.get_subtask_monitor_interval(),
329351
)
330352
)
331-
monitor_task.add_done_callback(_on_monitor_task_done)
353+
monitor_aiotask.add_done_callback(_on_monitor_aiotask_done)
332354

333-
def _on_execute_task_done(fut):
355+
def _on_execute_aiotask_done(_):
334356
# Make sure the monitor task is cancelled.
335-
monitor_task.cancel()
357+
monitor_aiotask.cancel()
336358
# Just use `self._cur_stage_tile_progress` as current stage progress
337359
# because current stage is completed, its progress is 1.0.
338360
self._cur_stage_progress = 1.0
339361
self._pre_all_stages_progress += self._cur_stage_tile_progress
340-
self._cur_stage_first_output_object_ref_to_subtask.clear()
341362

342363
self._execute_subtask_graph_aiotask = asyncio.current_task()
343-
self._execute_subtask_graph_aiotask.add_done_callback(_on_execute_task_done)
364+
self._execute_subtask_graph_aiotask.add_done_callback(_on_execute_aiotask_done)
344365

345-
logger.info("Stage %s start.", stage_id)
346366
task_context = self._task_context
347367
output_meta_object_refs = []
348368
self._pre_all_stages_tile_progress = (
@@ -352,11 +372,6 @@ def _on_execute_task_done(fut):
352372
self._tile_context.get_all_progress() - self._pre_all_stages_tile_progress
353373
)
354374
logger.info("Submitting %s subtasks of stage %s.", len(subtask_graph), stage_id)
355-
result_meta_keys = {
356-
chunk.key
357-
for chunk in chunk_graph.result_chunks
358-
if not isinstance(chunk.op, Fetch)
359-
}
360375
subtask_max_retries = self._config.get_subtask_max_retries()
361376
for subtask in subtask_graph.topological_iter():
362377
subtask_chunk_graph = subtask.chunk_graph
@@ -555,7 +570,11 @@ def _get_subtask_output_keys(chunk_graph: ChunkGraph):
555570
return output_keys.keys()
556571

557572
async def _update_progress_and_collect_garbage(
558-
self, subtask_graph: SubtaskGraph, interval_seconds: float
573+
self,
574+
stage_id: str,
575+
subtask_graph: SubtaskGraph,
576+
result_meta_keys: Set[str],
577+
interval_seconds: float,
559578
):
560579
object_ref_to_subtask = self._cur_stage_first_output_object_ref_to_subtask
561580
total = len(subtask_graph)
@@ -579,7 +598,7 @@ def gc():
579598
# Iterate the completed subtasks once.
580599
subtask = completed_subtasks[i]
581600
i += 1
582-
logger.debug("GC: %s", subtask)
601+
logger.debug("GC[stage=%s]: %s", stage_id, subtask)
583602

584603
# Note: There may be a scenario in which delayed gc occurs.
585604
# When a subtask has more than one predecessor, like A, B,
@@ -595,15 +614,23 @@ def gc():
595614
):
596615
yield
597616
for chunk in pred.chunk_graph.results:
598-
self._task_context.pop(chunk.key, None)
617+
chunk_key = chunk.key
618+
# We need to check the GC chunk key is not in the
619+
# result meta keys, because there are some special
620+
# cases that the result meta keys are not the leaves.
621+
#
622+
# example: test_cut_execution
623+
if chunk_key not in result_meta_keys:
624+
logger.debug("GC[stage=%s]: %s", stage_id, chunk)
625+
self._task_context.pop(chunk_key, None)
599626
gc_subtasks.add(pred)
600627

601628
# TODO(fyrestone): Check the remaining self._task_context.keys()
602629
# in the result subtasks
603630

604631
collect_garbage = gc()
605632

606-
while len(completed_subtasks) != total:
633+
while len(completed_subtasks) < total:
607634
if len(object_ref_to_subtask) <= 0: # pragma: no cover
608635
await asyncio.sleep(interval_seconds)
609636

mars/services/task/execution/ray/tests/test_ray_execution_backend.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,18 +93,21 @@ async def submit_subtask_graph(
9393
subtask_graph: SubtaskGraph,
9494
chunk_graph: ChunkGraph,
9595
):
96-
monitor_task = asyncio.create_task(
97-
self._update_progress_and_collect_garbage(
98-
subtask_graph, self._config.get_subtask_monitor_interval()
99-
)
100-
)
101-
10296
result_meta_keys = {
10397
chunk.key
10498
for chunk in chunk_graph.result_chunks
10599
if not isinstance(chunk.op, Fetch)
106100
}
107101

102+
monitor_task = asyncio.create_task(
103+
self._update_progress_and_collect_garbage(
104+
stage_id,
105+
subtask_graph,
106+
result_meta_keys,
107+
self._config.get_subtask_monitor_interval(),
108+
)
109+
)
110+
108111
for subtask in subtask_graph.topological_iter():
109112
subtask_chunk_graph = subtask.chunk_graph
110113
task_context = self._task_context

0 commit comments

Comments
 (0)