Skip to content

Commit 18a3af8

Browse files
authored
[Ray] fix auto scale-in hang (#3043)
1 parent 6ffc7b9 commit 18a3af8

File tree

3 files changed

+82
-24
lines changed

3 files changed

+82
-24
lines changed

mars/deploy/oscar/tests/test_ray_scheduling.py

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import asyncio
15+
import logging
16+
import os
1517

1618
import numpy as np
1719
import pytest
@@ -35,6 +37,8 @@
3537

3638
ray = lazy_import("ray")
3739

40+
logger = logging.getLogger(__name__)
41+
3842

3943
@pytest.fixture
4044
async def speculative_cluster():
@@ -224,14 +228,13 @@ async def test_auto_scale_in(ray_large_cluster):
224228
)
225229
while await autoscaler_ref.get_dynamic_worker_nums() > 2:
226230
dynamic_workers = await autoscaler_ref.get_dynamic_workers()
227-
print(f"Waiting workers {dynamic_workers} to be released.")
231+
logger.info(f"Waiting %s workers to be released.", dynamic_workers)
228232
await asyncio.sleep(1)
229233
await asyncio.sleep(1)
230234
assert await autoscaler_ref.get_dynamic_worker_nums() == 2
231235

232236

233-
@pytest.mark.skip("Enable it when ray ownership bug is fixed")
234-
@pytest.mark.timeout(timeout=200)
237+
@pytest.mark.timeout(timeout=150)
235238
@pytest.mark.parametrize("ray_large_cluster", [{"num_nodes": 4}], indirect=True)
236239
@require_ray
237240
@pytest.mark.asyncio
@@ -255,23 +258,62 @@ async def test_ownership_when_scale_in(ray_large_cluster):
255258
uid=AutoscalerActor.default_uid(),
256259
address=client._cluster.supervisor_address,
257260
)
258-
await asyncio.gather(*[autoscaler_ref.request_worker() for _ in range(2)])
259-
df = md.DataFrame(mt.random.rand(100, 4, chunk_size=2), columns=list("abcd"))
260-
print(df.execute())
261-
assert await autoscaler_ref.get_dynamic_worker_nums() > 1
261+
num_chunks, chunk_size = 20, 4
262+
df = md.DataFrame(
263+
mt.random.rand(num_chunks * chunk_size, 4, chunk_size=chunk_size),
264+
columns=list("abcd"),
265+
)
266+
latch_actor = ray.remote(CountDownLatch).remote(1)
267+
pid = os.getpid()
268+
269+
def f(pdf, latch):
270+
if os.getpid() != pid:
271+
# type inference will call this function too
272+
ray.get(latch.wait.remote())
273+
return pdf
274+
275+
df = df.map_chunk(
276+
f,
277+
args=(latch_actor,),
278+
)
279+
info = df.execute(wait=False)
280+
while await autoscaler_ref.get_dynamic_worker_nums() <= 1:
281+
logger.info("Waiting workers to be created.")
282+
await asyncio.sleep(1)
283+
await latch_actor.count_down.remote()
284+
await info
285+
assert info.exception() is None
286+
assert info.progress() == 1
287+
logger.info("df execute succeed.")
288+
262289
while await autoscaler_ref.get_dynamic_worker_nums() > 1:
263290
dynamic_workers = await autoscaler_ref.get_dynamic_workers()
264-
print(f"Waiting workers {dynamic_workers} to be released.")
291+
logger.info("Waiting workers %s to be released.", dynamic_workers)
265292
await asyncio.sleep(1)
266293
# Test data on node of released worker can still be fetched
267-
pd_df = df.to_pandas()
268-
groupby_sum_df = df.rechunk(40).groupby("a").sum()
269-
print(groupby_sum_df.execute())
294+
pd_df = df.fetch()
295+
groupby_sum_df = df.rechunk(chunk_size * 2).groupby("a").sum()
296+
logger.info(groupby_sum_df.execute())
270297
while await autoscaler_ref.get_dynamic_worker_nums() > 1:
271298
dynamic_workers = await autoscaler_ref.get_dynamic_workers()
272-
print(f"Waiting workers {dynamic_workers} to be released.")
299+
logger.info(f"Waiting workers %s to be released.", dynamic_workers)
273300
await asyncio.sleep(1)
274301
assert df.to_pandas().to_dict() == pd_df.to_dict()
275302
assert (
276303
groupby_sum_df.to_pandas().to_dict() == pd_df.groupby("a").sum().to_dict()
277304
)
305+
306+
307+
class CountDownLatch:
308+
def __init__(self, cnt):
309+
self.cnt = cnt
310+
311+
def count_down(self):
312+
self.cnt -= 1
313+
314+
def get_count(self):
315+
return self.cnt
316+
317+
async def wait(self):
318+
while self.cnt != 0:
319+
await asyncio.sleep(0.01)

mars/oscar/backends/ray/pool.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ async def start(self):
307307
)
308308
self._set_ray_server(self._actor_pool)
309309
self._state = RayPoolState.POOL_READY
310+
logger.info("Started main pool %s with %s processes.", address, n_process)
310311

311312
async def mark_service_ready(self):
312313
results = []
@@ -336,20 +337,25 @@ def set_actor_pool_config(self, actor_pool_config):
336337
self._actor_pool_config = actor_pool_config
337338

338339
async def start(self):
339-
logger.info("Start to init sub pool.")
340340
# create mars pool outside the constructor is to avoid ray actor creation failed.
341341
# ray can't get the creation exception.
342342
main_pool_address, process_index = self._args
343+
logger.info(
344+
"Start to init sub pool %s for main pool %s.",
345+
process_index,
346+
main_pool_address,
347+
)
343348
main_pool = ray.get_actor(main_pool_address)
344349
self._check_alive_task = asyncio.create_task(
345350
self.check_main_pool_alive(main_pool)
346351
)
347352
if self._actor_pool_config is None:
348353
self._actor_pool_config = await main_pool.actor_pool.remote("_config")
349354
pool_config = self._actor_pool_config.get_pool_config(process_index)
355+
sub_pool_address = pool_config["external_address"]
350356
assert (
351357
self._state == RayPoolState.INIT
352-
), f"The pool {pool_config['external_address']} is already started, current state is {self._state}"
358+
), f"The pool {sub_pool_address} is already started, current state is {self._state}"
353359
env = pool_config["env"]
354360
if env: # pragma: no cover
355361
os.environ.update(env)
@@ -363,6 +369,7 @@ async def start(self):
363369
await self._actor_pool.start()
364370
asyncio.create_task(self._actor_pool.join())
365371
self._state = RayPoolState.POOL_READY
372+
logger.info("Started sub pool %s.", sub_pool_address)
366373

367374
def mark_service_ready(self):
368375
self._state = RayPoolState.SERVICE_READY

mars/services/scheduling/supervisor/autoscale.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ....lib.aio import alru_cache
2727
from ...cluster.api import ClusterAPI
2828
from ...cluster.core import NodeRole, NodeStatus
29+
from ..errors import NoAvailableBand
2930

3031
logger = logging.getLogger(__name__)
3132

@@ -131,7 +132,10 @@ async def release_worker(address):
131132
self._dynamic_workers.remove(address)
132133
logger.info("Released worker %s.", address)
133134

134-
await asyncio.gather(*[release_worker(address) for address in addresses])
135+
# Release workers one by one to ensure others workers which the current is moving data to
136+
# is not being releasing.
137+
for address in addresses:
138+
await release_worker(address)
135139

136140
def get_dynamic_workers(self) -> Set[str]:
137141
return self._dynamic_workers
@@ -214,7 +218,7 @@ async def _select_target_band(
214218
if (b[1] == band[1] and b != band and b not in excluded_bands)
215219
)
216220
if not bands: # pragma: no cover
217-
raise RuntimeError(
221+
raise NoAvailableBand(
218222
f"No bands to migrate data to, "
219223
f"all available bands is {all_bands}, "
220224
f"current band is {band}, "
@@ -400,14 +404,19 @@ async def _scale_in(self):
400404
worker_addresses,
401405
idle_bands,
402406
)
403-
# Release workers one by one to ensure others workers which the current is moving data to
404-
# is not being releasing.
405-
await self._autoscaler.release_workers(worker_addresses)
406-
logger.info(
407-
"Finished offline workers %s in %.4f seconds",
408-
worker_addresses,
409-
time.time() - start_time,
410-
)
407+
try:
408+
await self._autoscaler.release_workers(worker_addresses)
409+
logger.info(
410+
"Finished offline workers %s in %.4f seconds",
411+
worker_addresses,
412+
time.time() - start_time,
413+
)
414+
except NoAvailableBand as e: # pragma: no cover
415+
logger.warning(
416+
"No enough bands, offline workers %s failed with exception %s.",
417+
worker_addresses,
418+
e,
419+
)
411420

412421
async def stop(self):
413422
self._task.cancel()

0 commit comments

Comments
 (0)