Skip to content

Commit 0263954

Browse files
authored
[Ray] Fix ray worker failover (#3080)
* make failover work with laster ray master * fix max_task_retries * fix _get_actor * fix compatibility * fix retry actor state task * fix subppol restart * skip test_ownership_when_scale_in * revert alive check interval * lint * lint
1 parent fb2dad7 commit 0263954

File tree

6 files changed

+105
-12
lines changed

6 files changed

+105
-12
lines changed

.github/workflows/platform-ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ jobs:
144144
coverage combine build/ && coverage report
145145
fi
146146
if [ -n "$WITH_RAY" ]; then
147-
pytest $PYTEST_CONFIG --durations=0 --timeout=600 -v -s -m ray
147+
pytest $PYTEST_CONFIG --durations=0 --timeout=200 -v -s -m ray
148148
coverage report
149149
fi
150150
if [ -n "$WITH_RAY_DAG" ]; then

mars/deploy/oscar/ray.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
AbstractClusterBackend,
3838
)
3939
from ...services import NodeRole
40-
from ...utils import lazy_import
40+
from ...utils import lazy_import, retry_callable
4141
from ..utils import (
4242
load_config,
4343
get_third_party_modules_from_config,
@@ -274,7 +274,10 @@ async def reconstruct_worker(self, address: str):
274274
async def _reconstruct_worker():
275275
logger.info("Reconstruct worker %s", address)
276276
actor = ray.get_actor(address)
277-
state = await actor.state.remote()
277+
# ray call will error when actor is restarting
278+
state = await retry_callable(
279+
actor.state.remote, ex_type=ray.exceptions.RayActorError, sync=False
280+
)()
278281
if state == RayPoolState.SERVICE_READY:
279282
logger.info("Worker %s is service ready.")
280283
return

mars/deploy/oscar/tests/test_ray.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ async def remote(self):
578578
class FakeActor:
579579
state = FakeActorMethod()
580580

581-
def _get_actor(*args):
581+
def _get_actor(*args, **kwargs):
582582
return FakeActor
583583

584584
async def _stop_worker(*args):
@@ -677,7 +677,8 @@ async def test_auto_scale_in(ray_large_cluster):
677677
assert await autoscaler_ref.get_dynamic_worker_nums() == 2
678678

679679

680-
@pytest.mark.timeout(timeout=1000)
680+
@pytest.mark.skip("Enable it when ray ownership bug is fixed")
681+
@pytest.mark.timeout(timeout=200)
681682
@pytest.mark.parametrize("ray_large_cluster", [{"num_nodes": 4}], indirect=True)
682683
@require_ray
683684
@pytest.mark.asyncio

mars/oscar/backends/ray/pool.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
from ... import ServerClosed
3030
from ....serialization.ray import register_ray_serializers
31-
from ....utils import lazy_import, ensure_coverage
31+
from ....utils import lazy_import, ensure_coverage, retry_callable
3232
from ..config import ActorPoolConfig
3333
from ..message import CreateActorMessage
3434
from ..pool import (
@@ -130,14 +130,27 @@ async def start_sub_pool(
130130
f"process_index {process_index} is not consistent with index {_process_index} "
131131
f"in external_address {external_address}"
132132
)
133+
actor_handle = config["kwargs"]["sub_pool_handles"][external_address]
134+
state = await retry_callable(
135+
actor_handle.state.remote, ex_type=ray.exceptions.RayActorError, sync=False
136+
)()
137+
if state is RayPoolState.SERVICE_READY: # pragma: no cover
138+
logger.info("Ray sub pool %s is alive, kill it first.", external_address)
139+
await kill_and_wait(actor_handle, no_restart=False)
140+
# Wait sub pool process restarted.
141+
await retry_callable(
142+
actor_handle.state.remote,
143+
ex_type=ray.exceptions.RayActorError,
144+
sync=False,
145+
)()
133146
logger.info("Start to start ray sub pool %s.", external_address)
134147
create_sub_pool_timeout = 120
135-
actor_handle = config["kwargs"]["sub_pool_handles"][external_address]
136-
done, _ = await asyncio.wait(
137-
[actor_handle.set_actor_pool_config.remote(actor_pool_config)],
138-
timeout=create_sub_pool_timeout,
139-
)
140-
if not done: # pragma: no cover
148+
try:
149+
await asyncio.wait_for(
150+
actor_handle.set_actor_pool_config.remote(actor_pool_config),
151+
timeout=create_sub_pool_timeout,
152+
)
153+
except asyncio.TimeoutError: # pragma: no cover
141154
msg = (
142155
f"Can not start ray sub pool {external_address} in {create_sub_pool_timeout} seconds.",
143156
)
@@ -153,6 +166,10 @@ async def wait_sub_pools_ready(cls, create_pool_tasks: List[asyncio.Task]):
153166

154167
async def recover_sub_pool(self, address: str):
155168
process = self.sub_processes[address]
169+
# ray call will error when actor is restarting
170+
await retry_callable(
171+
process.state.remote, ex_type=ray.exceptions.RayActorError, sync=False
172+
)()
156173
await process.start.remote()
157174

158175
if self._auto_recover == "actor":

mars/tests/test_utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,3 +616,37 @@ def __call__(self, *args, **kwargs):
616616
def test_gen_random_id(id_length):
617617
rnd_id = utils.new_random_id(id_length)
618618
assert len(rnd_id) == id_length
619+
620+
621+
@pytest.mark.asyncio
622+
async def test_retry_callable():
623+
assert utils.retry_callable(lambda x: x)(1) == 1
624+
assert utils.retry_callable(lambda x: 0)(1) == 0
625+
626+
class CustomException(BaseException):
627+
pass
628+
629+
def f1(x):
630+
nonlocal num_retried
631+
num_retried += 1
632+
if num_retried == 3:
633+
return x
634+
raise CustomException
635+
636+
num_retried = 0
637+
with pytest.raises(CustomException):
638+
utils.retry_callable(f1)(1)
639+
assert utils.retry_callable(f1, ex_type=CustomException)(1) == 1
640+
num_retried = 0
641+
with pytest.raises(CustomException):
642+
utils.retry_callable(f1, max_retries=2, ex_type=CustomException)(1)
643+
num_retried = 0
644+
assert utils.retry_callable(f1, max_retries=3, ex_type=CustomException)(1) == 1
645+
646+
async def f2(x):
647+
return f1(x)
648+
649+
num_retried = 0
650+
with pytest.raises(CustomException):
651+
await utils.retry_callable(f2)(1)
652+
assert await utils.retry_callable(f2, ex_type=CustomException)(1) == 1

mars/utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1698,3 +1698,41 @@ def ensure_coverage():
16981698
pass
16991699
else:
17001700
cleanup_on_sigterm()
1701+
1702+
1703+
def retry_callable(
1704+
callable_,
1705+
ex_type: type = Exception,
1706+
wait_interval=1,
1707+
max_retries=-1,
1708+
sync: bool = None,
1709+
):
1710+
if inspect.iscoroutinefunction(callable_) or sync is False:
1711+
1712+
@functools.wraps(callable)
1713+
async def retry_call(*args, **kwargs):
1714+
num_retried = 0
1715+
while max_retries < 0 or num_retried < max_retries:
1716+
num_retried += 1
1717+
try:
1718+
return await callable_(*args, **kwargs)
1719+
except ex_type:
1720+
await asyncio.sleep(wait_interval)
1721+
1722+
else:
1723+
1724+
@functools.wraps(callable)
1725+
def retry_call(*args, **kwargs):
1726+
num_retried = 0
1727+
ex = None
1728+
while max_retries < 0 or num_retried < max_retries:
1729+
num_retried += 1
1730+
try:
1731+
return callable_(*args, **kwargs)
1732+
except ex_type as e:
1733+
ex = e
1734+
time.sleep(wait_interval)
1735+
assert ex is not None
1736+
raise ex # pylint: disable-msg=E0702
1737+
1738+
return retry_call

0 commit comments

Comments
 (0)