Skip to content

Commit fb2dad7

Browse files
authored
[Ray] Create RayTaskState actor as needed by default (#3081)
1 parent 992c077 commit fb2dad7

File tree

7 files changed

+154
-106
lines changed

7 files changed

+154
-106
lines changed

.github/workflows/platform-ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ jobs:
9494
rm -fr /tmp/etcd-$ETCD_VER-linux-amd64.tar.gz /tmp/etcd-download-test
9595
fi
9696
if [ -n "$WITH_RAY" ] || [ -n "$WITH_RAY_DAG" ]; then
97-
pip install ray[default]==1.9.2
97+
pip install ray[default]==1.9.2 "protobuf<4"
9898
pip install "xgboost_ray==0.1.5" "xgboost<1.6.0"
9999
fi
100100
if [ -n "$RUN_DASK" ]; then

mars/services/task/execution/ray/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,11 @@ def get_n_worker(self):
4747

4848
def get_subtask_cancel_timeout(self):
4949
return self._ray_execution_config.get("subtask_cancel_timeout")
50+
51+
def create_task_state_actor_as_needed(self):
52+
# Whether create RayTaskState actor as needed.
53+
# - True (default):
54+
# Create RayTaskState actor only when create_remote_object is called.
55+
# - False:
56+
# Create RayTaskState actor in advance when the RayTaskExecutor is created.
57+
return self._ray_execution_config.get("create_task_state_actor_as_needed", True)

mars/services/task/execution/ray/context.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import inspect
1717
import logging
1818
from dataclasses import asdict
19-
from typing import Union, Dict, List
19+
from typing import Dict, List, Callable
2020

2121
from .....core.context import Context
2222
from .....storage.base import StorageLevel
@@ -77,18 +77,21 @@ def wrap(*args, **kwargs):
7777

7878
class _RayRemoteObjectContext:
7979
def __init__(
80-
self, actor_name_or_handle: Union[str, "ray.actor.ActorHandle"], *args, **kwargs
80+
self,
81+
get_or_create_actor: Callable[[], "ray.actor.ActorHandle"],
82+
*args,
83+
**kwargs
8184
):
8285
super().__init__(*args, **kwargs)
83-
self._actor_name_or_handle = actor_name_or_handle
86+
self._get_or_create_actor = get_or_create_actor
8487
self._task_state_actor = None
8588

8689
def _get_task_state_actor(self) -> "ray.actor.ActorHandle":
90+
# Get the RayTaskState actor, this is more clear and faster than wraps
91+
# the `get_or_create_actor` by lru_cache in __init__ because this method
92+
# is called as needed.
8793
if self._task_state_actor is None:
88-
if isinstance(self._actor_name_or_handle, ray.actor.ActorHandle):
89-
self._task_state_actor = self._actor_name_or_handle
90-
else:
91-
self._task_state_actor = ray.get_actor(self._actor_name_or_handle)
94+
self._task_state_actor = self._get_or_create_actor()
9295
return self._task_state_actor
9396

9497
@implements(Context.create_remote_object)
@@ -124,13 +127,15 @@ def __init__(
124127
config: RayExecutionConfig,
125128
task_context: Dict,
126129
task_chunks_meta: Dict,
130+
worker_addresses: List[str],
127131
*args,
128132
**kwargs
129133
):
130134
super().__init__(*args, **kwargs)
131135
self._config = config
132136
self._task_context = task_context
133137
self._task_chunks_meta = task_chunks_meta
138+
self._worker_addresses = worker_addresses
134139

135140
@implements(Context.get_chunks_result)
136141
def get_chunks_result(self, data_keys: List[str]) -> List:
@@ -158,6 +163,11 @@ def get_total_n_cpu(self) -> int:
158163
# TODO(fyrestone): Support auto scaling.
159164
return self._config.get_n_cpu() * self._config.get_n_worker()
160165

166+
@implements(Context.get_worker_addresses)
167+
def get_worker_addresses(self) -> List[str]:
168+
# Returns virtual worker addresses.
169+
return self._worker_addresses
170+
161171

162172
# TODO(fyrestone): Implement more APIs for Ray.
163173
class RayExecutionWorkerContext(_RayRemoteObjectContext, dict):

mars/services/task/execution/ray/executor.py

Lines changed: 48 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
import asyncio
1616
import functools
1717
import logging
18+
import operator
1819
from dataclasses import dataclass
19-
from typing import List, Dict, Any, Set
20+
from typing import List, Dict, Any, Set, Callable
2021
from .....core import ChunkGraph, Chunk, TileContext
2122
from .....core.context import set_context
2223
from .....core.operand import (
@@ -65,9 +66,23 @@ class _RayChunkMeta:
6566

6667
class RayTaskState(RayRemoteObjectManager):
6768
@classmethod
68-
def gen_name(cls, task_id: str):
69+
def _gen_name(cls, task_id: str):
6970
return f"{cls.__name__}_{task_id}"
7071

72+
@classmethod
73+
def get_handle(cls, task_id: str):
74+
"""Get the RayTaskState actor handle."""
75+
name = cls._gen_name(task_id)
76+
logger.info("Getting %s handle.", name)
77+
return ray.get_actor(name)
78+
79+
@classmethod
80+
def create(cls, task_id: str):
81+
"""Create a RayTaskState actor."""
82+
name = cls._gen_name(task_id)
83+
logger.info("Creating %s.", name)
84+
return ray.remote(cls).options(name=name).remote()
85+
7186

7287
_optimize_physical = None
7388

@@ -109,7 +124,7 @@ def execute_subtask(
109124
subtask_chunk_graph = deserialize(*subtask_chunk_graph)
110125
# inputs = [i[1] for i in inputs]
111126
context = RayExecutionWorkerContext(
112-
RayTaskState.gen_name(task_id), zip(input_keys, inputs)
127+
lambda: RayTaskState.get_handle(task_id), zip(input_keys, inputs)
113128
)
114129
# optimize chunk graph.
115130
subtask_chunk_graph = _optimize_subtask_graph(subtask_chunk_graph)
@@ -152,7 +167,6 @@ def __init__(
152167
tile_context: TileContext,
153168
task_context: Dict[str, "ray.ObjectRef"],
154169
task_chunks_meta: Dict[str, _RayChunkMeta],
155-
task_state_actor: "ray.actor.ActorHandle",
156170
lifecycle_api: LifecycleAPI,
157171
meta_api: MetaAPI,
158172
):
@@ -161,7 +175,6 @@ def __init__(
161175
self._tile_context = tile_context
162176
self._task_context = task_context
163177
self._task_chunks_meta = task_chunks_meta
164-
self._task_state_actor = task_state_actor
165178
self._ray_executor = self._get_ray_executor()
166179

167180
# api
@@ -196,31 +209,39 @@ async def create(
196209
**kwargs,
197210
) -> "RayTaskExecutor":
198211
lifecycle_api, meta_api = await cls._get_apis(session_id, address)
199-
task_state_actor = (
200-
ray.remote(RayTaskState)
201-
.options(name=RayTaskState.gen_name(task.task_id))
202-
.remote()
203-
)
204212
task_context = {}
205213
task_chunks_meta = {}
206-
await cls._init_context(
207-
config,
208-
task_context,
209-
task_chunks_meta,
210-
task_state_actor,
211-
session_id,
212-
address,
213-
)
214-
return cls(
214+
215+
executor = cls(
215216
config,
216217
task,
217218
tile_context,
218219
task_context,
219220
task_chunks_meta,
220-
task_state_actor,
221221
lifecycle_api,
222222
meta_api,
223223
)
224+
available_band_resources = await executor.get_available_band_resources()
225+
worker_addresses = list(
226+
map(operator.itemgetter(0), available_band_resources.keys())
227+
)
228+
if config.create_task_state_actor_as_needed():
229+
create_task_state_actor = lambda: RayTaskState.create( # noqa: E731
230+
task_id=task.task_id
231+
)
232+
else:
233+
actor_handle = RayTaskState.create(task_id=task.task_id)
234+
create_task_state_actor = lambda: actor_handle # noqa: E731
235+
await cls._init_context(
236+
config,
237+
task_context,
238+
task_chunks_meta,
239+
create_task_state_actor,
240+
worker_addresses,
241+
session_id,
242+
address,
243+
)
244+
return executor
224245

225246
# noinspection DuplicatedCode
226247
def destroy(self):
@@ -229,7 +250,6 @@ def destroy(self):
229250
self._tile_context = None
230251
self._task_context = None
231252
self._task_chunks_meta = None
232-
self._task_state_actor = None
233253
self._ray_executor = None
234254

235255
# api
@@ -267,7 +287,8 @@ async def _init_context(
267287
config: RayExecutionConfig,
268288
task_context: Dict[str, "ray.ObjectRef"],
269289
task_chunks_meta: Dict[str, _RayChunkMeta],
270-
task_state_actor: "ray.actor.ActorHandle",
290+
create_task_state_actor: Callable[[], "ray.actor.ActorHandle"],
291+
worker_addresses: List[str],
271292
session_id: str,
272293
address: str,
273294
):
@@ -276,7 +297,8 @@ async def _init_context(
276297
config,
277298
task_context,
278299
task_chunks_meta,
279-
task_state_actor,
300+
worker_addresses,
301+
create_task_state_actor,
280302
session_id,
281303
address,
282304
address,
@@ -426,7 +448,9 @@ async def get_available_band_resources(self) -> Dict[BandType, Resource]:
426448
idx = 0
427449
for band_resource in band_resources:
428450
for band, resource in band_resource.items():
429-
virtual_band_resources[(f"ray_virtual://{idx}", band)] = resource
451+
virtual_band_resources[
452+
(f"ray_virtual_address_{idx}:0", band)
453+
] = resource
430454
idx += 1
431455
self._available_band_resources = virtual_band_resources
432456

mars/services/task/execution/ray/tests/test_ray_execution_backend.py

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ...... import tensor as mt
2222

2323
from ......core import TileContext
24+
from ......core.context import get_context
2425
from ......core.graph import TileableGraph, TileableGraphBuilder, ChunkGraphBuilder
2526
from ......serialization import serialize
2627
from ......tests.core import require_ray, mock
@@ -34,7 +35,7 @@
3435
RayRemoteObjectManager,
3536
_RayRemoteObjectContext,
3637
)
37-
from ..executor import execute_subtask, RayTaskExecutor
38+
from ..executor import execute_subtask, RayTaskExecutor, RayTaskState
3839
from ..fetcher import RayFetcher
3940

4041
ray = lazy_import("ray")
@@ -51,11 +52,18 @@ def __init__(self, *args, **kwargs):
5152
self._set_attrs = Counter()
5253
super().__init__(*args, **kwargs)
5354

55+
@classmethod
56+
async def _get_apis(cls, session_id: str, address: str):
57+
return None, None
58+
5459
@staticmethod
5560
def _get_ray_executor():
5661
# Export remote function once.
5762
return None
5863

64+
async def get_available_band_resources(self):
65+
return {}
66+
5967
def set_attr_counter(self):
6068
return self._set_attrs
6169

@@ -64,6 +72,54 @@ def __setattr__(self, key, value):
6472
self._set_attrs[key] += 1
6573

6674

75+
@require_ray
76+
@pytest.mark.asyncio
77+
@mock.patch("mars.services.task.execution.ray.executor.RayTaskState.create")
78+
@mock.patch("mars.services.task.execution.ray.context.RayExecutionContext.init")
79+
@mock.patch("ray.get")
80+
async def test_ray_executor_create(
81+
mock_ray_get, mock_execution_context_init, mock_task_state_actor_create
82+
):
83+
task = Task("mock_task", "mock_session")
84+
85+
# Create RayTaskState actor as needed by default.
86+
mock_config = RayExecutionConfig.from_execution_config({"backend": "ray"})
87+
executor = await MockRayTaskExecutor.create(
88+
mock_config,
89+
session_id="mock_session_id",
90+
address="mock_address",
91+
task=task,
92+
tile_context=TileContext(),
93+
)
94+
assert isinstance(executor, MockRayTaskExecutor)
95+
assert mock_task_state_actor_create.call_count == 0
96+
ctx = get_context()
97+
assert isinstance(ctx, RayExecutionContext)
98+
ctx.create_remote_object("abc", lambda: None)
99+
assert mock_ray_get.call_count == 1
100+
assert mock_task_state_actor_create.call_count == 1
101+
102+
# Create RayTaskState actor in advance if create_task_state_actor_as_needed is False
103+
mock_config = RayExecutionConfig.from_execution_config(
104+
{"backend": "ray", "ray": {"create_task_state_actor_as_needed": False}}
105+
)
106+
executor = await MockRayTaskExecutor.create(
107+
mock_config,
108+
session_id="mock_session_id",
109+
address="mock_address",
110+
task=task,
111+
tile_context=TileContext(),
112+
)
113+
assert isinstance(executor, MockRayTaskExecutor)
114+
assert mock_ray_get.call_count == 1
115+
assert mock_task_state_actor_create.call_count == 2
116+
ctx = get_context()
117+
assert isinstance(ctx, RayExecutionContext)
118+
ctx.create_remote_object("abc", lambda: None)
119+
assert mock_ray_get.call_count == 2
120+
assert mock_task_state_actor_create.call_count == 2
121+
122+
67123
def test_ray_executor_destroy():
68124
task = Task("mock_task", "mock_session")
69125
mock_config = RayExecutionConfig.from_execution_config({"backend": "ray"})
@@ -73,7 +129,6 @@ def test_ray_executor_destroy():
73129
tile_context=TileContext(),
74130
task_context={},
75131
task_chunks_meta={},
76-
task_state_actor=None,
77132
lifecycle_api=None,
78133
meta_api=None,
79134
)
@@ -163,8 +218,8 @@ async def bar(self, a, b):
163218
await manager.call_remote_object(name, "foo", 3, 4)
164219

165220
# Test _RayRemoteObjectContext
166-
remote_manager = ray.remote(RayRemoteObjectManager).remote()
167-
context = _RayRemoteObjectContext(remote_manager)
221+
test_task_id = "test_task_id"
222+
context = _RayRemoteObjectContext(lambda: RayTaskState.create(test_task_id))
168223
context.create_remote_object(name, _TestRemoteObject, 2)
169224
remote_object = context.get_remote_object(name)
170225
r = remote_object.foo(3, 4)
@@ -185,9 +240,12 @@ def __init__(self):
185240
with pytest.raises(MyException):
186241
context.create_remote_object(name, _ErrorRemoteObject)
187242

243+
handle = RayTaskState.get_handle(test_task_id)
244+
assert handle is not None
245+
188246

189247
@require_ray
190-
def test_get_chunks_result(ray_start_regular_shared2):
248+
def test_ray_execution_context(ray_start_regular_shared2):
191249
value = 123
192250
o = ray.put(value)
193251

@@ -196,13 +254,19 @@ def fake_init(self):
196254

197255
with mock.patch.object(ThreadedServiceContext, "__init__", new=fake_init):
198256
mock_config = RayExecutionConfig.from_execution_config({"backend": "ray"})
199-
context = RayExecutionContext(mock_config, {"abc": o}, {}, None)
257+
mock_worker_addresses = ["mock_worker_address"]
258+
context = RayExecutionContext(
259+
mock_config, {"abc": o}, {}, mock_worker_addresses, lambda: None
260+
)
200261
r = context.get_chunks_result(["abc"])
201262
assert r == [value]
202263

264+
r = context.get_worker_addresses()
265+
assert r == mock_worker_addresses
266+
203267

204268
def test_ray_execution_worker_context():
205-
context = RayExecutionWorkerContext(None)
269+
context = RayExecutionWorkerContext(lambda: None)
206270
with pytest.raises(NotImplementedError):
207271
context.set_running_operand_key("mock_session_id", "mock_op_key")
208272
with pytest.raises(NotImplementedError):

0 commit comments

Comments
 (0)