Skip to content

Commit 505361b

Browse files
authored
[Ray] Implement gc for ray task executor context (#3061)
1 parent 0263954 commit 505361b

File tree

5 files changed

+324
-62
lines changed

5 files changed

+324
-62
lines changed

mars/deploy/oscar/tests/test_ray_dag.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,16 @@
1414

1515
import copy
1616
import os
17+
import time
1718

1819
import pytest
1920

21+
from .... import get_context
22+
from .... import tensor as mt
2023
from ....tests.core import DICT_NOT_EMPTY, require_ray
2124
from ....utils import lazy_import
2225
from ..local import new_cluster
23-
from ..session import new_session
26+
from ..session import new_session, get_default_async_session
2427
from ..tests import test_local
2528
from ..tests.session import new_test_session
2629
from ..tests.test_local import _cancel_when_tile, _cancel_when_execute
@@ -129,3 +132,38 @@ async def test_session_get_progress(ray_start_regular_shared2, create_cluster):
129132
@pytest.mark.parametrize("test_func", [_cancel_when_execute, _cancel_when_tile])
130133
def test_cancel(ray_start_regular_shared2, create_cluster, test_func):
131134
test_local.test_cancel(create_cluster, test_func)
135+
136+
137+
@require_ray
138+
@pytest.mark.parametrize("config", [{"backend": "ray"}])
139+
def test_executor_context_gc(config):
140+
session = new_session(
141+
backend=config["backend"],
142+
n_cpu=2,
143+
web=False,
144+
use_uvloop=False,
145+
config={"task.execution_config.ray.subtask_monitor_interval": 0},
146+
)
147+
148+
assert session._session.client.web_address is None
149+
assert session.get_web_endpoint() is None
150+
151+
def f1(c):
152+
time.sleep(0.5)
153+
return c
154+
155+
with session:
156+
t1 = mt.random.randint(10, size=(100, 10), chunk_size=100)
157+
t2 = mt.random.randint(10, size=(100, 10), chunk_size=50)
158+
t3 = t2 + t1
159+
t4 = t3.sum(0)
160+
t5 = t4.map_chunk(f1)
161+
r = t5.execute()
162+
result = r.fetch()
163+
assert result is not None
164+
assert len(result) == 10
165+
context = get_context()
166+
assert len(context._task_context) < 5
167+
168+
session.stop_server()
169+
assert get_default_async_session() is None

mars/lib/ordered_set.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444

4545
# SetLike[T] is either a set of elements of type T, or a sequence, which
4646
# we will convert to an OrderedSet by adding its elements in order.
47-
SetLike = Union[AbstractSet[T], Sequence[T]]
47+
SetLike = Union[AbstractSet[T], Sequence[T], Iterable[T]]
4848
OrderedSetInitializer = Union[AbstractSet[T], Sequence[T], Iterable[T]]
4949

5050

mars/services/task/execution/ray/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from ..api import ExecutionConfig, register_config_cls
1818
from ..utils import get_band_resources_from_config
1919

20+
# The default interval seconds to update progress and collect garbage.
21+
DEFAULT_SUBTASK_MONITOR_INTERVAL = 1
22+
2023

2124
@register_config_cls
2225
class RayExecutionConfig(ExecutionConfig):
@@ -55,3 +58,12 @@ def create_task_state_actor_as_needed(self):
5558
# - False:
5659
# Create RayTaskState actor in advance when the RayTaskExecutor is created.
5760
return self._ray_execution_config.get("create_task_state_actor_as_needed", True)
61+
62+
def get_subtask_monitor_interval(self):
63+
"""
64+
The interval seconds for the monitor task to update progress and
65+
collect garbage.
66+
"""
67+
return self._ray_execution_config.get(
68+
"subtask_monitor_interval", DEFAULT_SUBTASK_MONITOR_INTERVAL
69+
)

mars/services/task/execution/ray/executor.py

Lines changed: 134 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
execute,
2929
)
3030
from .....lib.aio import alru_cache
31+
from .....lib.ordered_set import OrderedSet
3132
from .....resource import Resource
3233
from .....serialization import serialize, deserialize
3334
from .....typing import BandType
@@ -96,7 +97,7 @@ def _optimize_subtask_graph(subtask_graph):
9697

9798

9899
async def _cancel_ray_task(obj_ref, kill_timeout: int = 3):
99-
ray.cancel(obj_ref, force=False)
100+
await asyncio.to_thread(ray.cancel, obj_ref, force=False)
100101
try:
101102
await asyncio.to_thread(ray.get, obj_ref, timeout=kill_timeout)
102103
except ray.exceptions.TaskCancelledError: # pragma: no cover
@@ -108,7 +109,7 @@ async def _cancel_ray_task(obj_ref, kill_timeout: int = 3):
108109
e,
109110
obj_ref,
110111
)
111-
ray.cancel(obj_ref, force=True)
112+
await asyncio.to_thread(ray.cancel, obj_ref, force=True)
112113

113114

114115
def execute_subtask(
@@ -183,17 +184,12 @@ def __init__(
183184

184185
self._available_band_resources = None
185186

186-
# For progress
187+
# For progress and task cancel
187188
self._pre_all_stages_progress = 0.0
188-
self._pre_all_stages_tile_progress = 0
189-
self._cur_stage_tile_progress = 0
190-
self._cur_stage_output_object_refs = []
191-
# This list records the output object ref number of subtasks, so with
192-
# `self._cur_stage_output_object_refs` we can just call `ray.cancel`
193-
# with one object ref to cancel a subtask instead of cancel all object
194-
# refs. In this way we can reduce a lot of unnecessary calls of ray.
195-
self._output_object_refs_nums = []
196-
# For meta and data gc
189+
self._pre_all_stages_tile_progress = 0.0
190+
self._cur_stage_progress = 0.0
191+
self._cur_stage_tile_progress = 0.0
192+
self._cur_stage_first_output_object_ref_to_subtask = dict()
197193
self._execute_subtask_graph_aiotask = None
198194
self._cancelled = False
199195

@@ -258,12 +254,12 @@ def destroy(self):
258254

259255
self._available_band_resources = None
260256

261-
# For progress
262-
self._pre_all_stages_progress = 1
263-
self._pre_all_stages_tile_progress = 1
264-
self._cur_stage_tile_progress = 1
265-
self._cur_stage_output_object_refs = []
266-
self._output_object_refs_nums = []
257+
# For progress and task cancel
258+
self._pre_all_stages_progress = 1.0
259+
self._pre_all_stages_tile_progress = 1.0
260+
self._cur_stage_progress = 1.0
261+
self._cur_stage_tile_progress = 1.0
262+
self._cur_stage_first_output_object_ref_to_subtask = dict()
267263
self._execute_subtask_graph_aiotask = None
268264
self._cancelled = None
269265

@@ -318,7 +314,33 @@ async def execute_subtask_graph(
318314
) -> Dict[Chunk, ExecutionChunkResult]:
319315
if self._cancelled is True: # pragma: no cover
320316
raise asyncio.CancelledError()
317+
318+
def _on_monitor_task_done(fut):
319+
# Print the error of monitor task.
320+
try:
321+
fut.result()
322+
except asyncio.CancelledError:
323+
pass
324+
325+
# Create a monitor task to update progress and collect garbage.
326+
monitor_task = asyncio.create_task(
327+
self._update_progress_and_collect_garbage(
328+
subtask_graph, self._config.get_subtask_monitor_interval()
329+
)
330+
)
331+
monitor_task.add_done_callback(_on_monitor_task_done)
332+
333+
def _on_execute_task_done(fut):
334+
# Make sure the monitor task is cancelled.
335+
monitor_task.cancel()
336+
# Just use `self._cur_stage_tile_progress` as current stage progress
337+
# because current stage is completed, its progress is 1.0.
338+
self._cur_stage_progress = 1.0
339+
self._pre_all_stages_progress += self._cur_stage_tile_progress
340+
self._cur_stage_first_output_object_ref_to_subtask.clear()
341+
321342
self._execute_subtask_graph_aiotask = asyncio.current_task()
343+
self._execute_subtask_graph_aiotask.add_done_callback(_on_execute_task_done)
322344

323345
logger.info("Stage %s start.", stage_id)
324346
task_context = self._task_context
@@ -359,8 +381,9 @@ async def execute_subtask_graph(
359381
continue
360382
elif output_count == 1:
361383
output_object_refs = [output_object_refs]
362-
self._cur_stage_output_object_refs.extend(output_object_refs)
363-
self._output_object_refs_nums.append(len(output_object_refs))
384+
self._cur_stage_first_output_object_ref_to_subtask[
385+
output_object_refs[0]
386+
] = subtask
364387
if output_meta_keys:
365388
meta_object_ref, *output_object_refs = output_object_refs
366389
# TODO(fyrestone): Fetch(not get) meta object here.
@@ -395,16 +418,16 @@ async def execute_subtask_graph(
395418
logger.info("Waiting for stage %s complete.", stage_id)
396419
# Patched the asyncio.to_thread for Python < 3.9 at mars/lib/aio/__init__.py
397420
await asyncio.to_thread(ray.wait, list(output_object_refs), fetch_local=False)
398-
# Just use `self._cur_stage_tile_progress` as current stage progress
399-
# because current stage is finished, its progress is 1.
400-
self._pre_all_stages_progress += self._cur_stage_tile_progress
401-
self._cur_stage_output_object_refs.clear()
402-
self._output_object_refs_nums.clear()
421+
403422
logger.info("Stage %s is complete.", stage_id)
404423
return chunk_to_meta
405424

406425
async def __aexit__(self, exc_type, exc_val, exc_tb):
407426
if exc_type is not None:
427+
try:
428+
await self.cancel()
429+
except BaseException: # noqa: E722 # nosec # pylint: disable=bare-except
430+
pass
408431
return
409432

410433
# Update info if no exception occurs.
@@ -458,19 +481,7 @@ async def get_available_band_resources(self) -> Dict[BandType, Resource]:
458481

459482
async def get_progress(self) -> float:
460483
"""Get the execution progress."""
461-
stage_progress = 0.0
462-
total = len(self._cur_stage_output_object_refs)
463-
if total > 0:
464-
finished_objects, _ = ray.wait(
465-
self._cur_stage_output_object_refs,
466-
num_returns=total,
467-
timeout=0, # Avoid blocking the asyncio loop.
468-
fetch_local=False,
469-
)
470-
stage_progress = (
471-
len(finished_objects) / total * self._cur_stage_tile_progress
472-
)
473-
return self._pre_all_stages_progress + stage_progress
484+
return self._cur_stage_progress
474485

475486
async def cancel(self):
476487
"""
@@ -480,26 +491,17 @@ async def cancel(self):
480491
2. Try to cancel the submitted subtasks by `ray.cancel`
481492
"""
482493
logger.info("Start to cancel task %s.", self._task)
483-
if self._task is None:
494+
if self._task is None or self._cancelled is True:
484495
return
485496
self._cancelled = True
486-
if (
487-
self._execute_subtask_graph_aiotask is not None
488-
and not self._execute_subtask_graph_aiotask.cancelled()
489-
):
497+
if self._execute_subtask_graph_aiotask is not None:
490498
self._execute_subtask_graph_aiotask.cancel()
491499
timeout = self._config.get_subtask_cancel_timeout()
492-
subtask_num = len(self._output_object_refs_nums)
493-
if subtask_num > 0:
494-
pos = 0
495-
obj_refs_to_be_cancelled_ = []
496-
for i in range(0, subtask_num):
497-
if i > 0:
498-
pos += self._output_object_refs_nums[i - 1]
499-
obj_refs_to_be_cancelled_.append(
500-
_cancel_ray_task(self._cur_stage_output_object_refs[pos], timeout)
501-
)
502-
await asyncio.gather(*obj_refs_to_be_cancelled_)
500+
to_be_cancelled_coros = [
501+
_cancel_ray_task(object_ref, timeout)
502+
for object_ref in self._cur_stage_first_output_object_ref_to_subtask.keys()
503+
]
504+
await asyncio.gather(*to_be_cancelled_coros)
503505

504506
async def _load_subtask_inputs(
505507
self, stage_id: str, subtask: Subtask, chunk_graph: ChunkGraph, context: Dict
@@ -551,3 +553,81 @@ def _get_subtask_output_keys(chunk_graph: ChunkGraph):
551553
else:
552554
output_keys[chunk.key] = 1
553555
return output_keys.keys()
556+
557+
async def _update_progress_and_collect_garbage(
558+
self, subtask_graph: SubtaskGraph, interval_seconds: float
559+
):
560+
object_ref_to_subtask = self._cur_stage_first_output_object_ref_to_subtask
561+
total = len(subtask_graph)
562+
completed_subtasks = OrderedSet()
563+
564+
def gc():
565+
"""
566+
Consume the completed subtasks and collect garbage.
567+
568+
GC the output object refs of the subtask which successors are submitted
569+
(not completed as above) can reduce the memory peaks, but we can't cancel
570+
and rerun slow subtasks because the input object refs of running subtasks
571+
may be deleted.
572+
"""
573+
i = 0
574+
gc_subtasks = set()
575+
576+
while i < total:
577+
while i >= len(completed_subtasks):
578+
yield
579+
# Iterate the completed subtasks once.
580+
subtask = completed_subtasks[i]
581+
i += 1
582+
logger.debug("GC: %s", subtask)
583+
584+
# Note: There may be a scenario in which delayed gc occurs.
585+
# When a subtask has more than one predecessor, like A, B,
586+
# and in the `for ... in ...` loop we get A firstly while
587+
# B's successors are completed, A's not. Then we cannot remove
588+
# B's results chunks before A's.
589+
for pred in subtask_graph.iter_predecessors(subtask):
590+
if pred in gc_subtasks:
591+
continue
592+
while not all(
593+
succ in completed_subtasks
594+
for succ in subtask_graph.iter_successors(pred)
595+
):
596+
yield
597+
for chunk in pred.chunk_graph.results:
598+
self._task_context.pop(chunk.key, None)
599+
gc_subtasks.add(pred)
600+
601+
# TODO(fyrestone): Check the remaining self._task_context.keys()
602+
# in the result subtasks
603+
604+
collect_garbage = gc()
605+
606+
while len(completed_subtasks) != total:
607+
if len(object_ref_to_subtask) <= 0: # pragma: no cover
608+
await asyncio.sleep(interval_seconds)
609+
610+
# Only wait for unready subtask object refs.
611+
ready_objects, _ = await asyncio.to_thread(
612+
ray.wait,
613+
list(object_ref_to_subtask.keys()),
614+
num_returns=len(object_ref_to_subtask),
615+
timeout=0,
616+
fetch_local=False,
617+
)
618+
if len(ready_objects) == 0:
619+
await asyncio.sleep(interval_seconds)
620+
continue
621+
622+
# Pop the completed subtasks from object_ref_to_subtask.
623+
completed_subtasks.update(map(object_ref_to_subtask.pop, ready_objects))
624+
# Update progress.
625+
stage_progress = (
626+
len(completed_subtasks) / total * self._cur_stage_tile_progress
627+
)
628+
self._cur_stage_progress = self._pre_all_stages_progress + stage_progress
629+
# Collect garbage, use `for ... in ...` to avoid raising StopIteration.
630+
for _ in collect_garbage:
631+
break
632+
# Fast to next loop and give it a chance to update object_ref_to_subtask.
633+
await asyncio.sleep(0)

0 commit comments

Comments
 (0)