12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
import asyncio
15
+ import logging
16
+ import os
15
17
16
18
import numpy as np
17
19
import pytest
35
37
36
38
ray = lazy_import ("ray" )
37
39
40
+ logger = logging .getLogger (__name__ )
41
+
38
42
39
43
@pytest .fixture
40
44
async def speculative_cluster ():
@@ -224,14 +228,13 @@ async def test_auto_scale_in(ray_large_cluster):
224
228
)
225
229
while await autoscaler_ref .get_dynamic_worker_nums () > 2 :
226
230
dynamic_workers = await autoscaler_ref .get_dynamic_workers ()
227
- print (f"Waiting workers { dynamic_workers } to be released." )
231
+ logger . info (f"Waiting %s workers to be released." , dynamic_workers )
228
232
await asyncio .sleep (1 )
229
233
await asyncio .sleep (1 )
230
234
assert await autoscaler_ref .get_dynamic_worker_nums () == 2
231
235
232
236
233
- @pytest .mark .skip ("Enable it when ray ownership bug is fixed" )
234
- @pytest .mark .timeout (timeout = 200 )
237
+ @pytest .mark .timeout (timeout = 150 )
235
238
@pytest .mark .parametrize ("ray_large_cluster" , [{"num_nodes" : 4 }], indirect = True )
236
239
@require_ray
237
240
@pytest .mark .asyncio
@@ -255,23 +258,62 @@ async def test_ownership_when_scale_in(ray_large_cluster):
255
258
uid = AutoscalerActor .default_uid (),
256
259
address = client ._cluster .supervisor_address ,
257
260
)
258
- await asyncio .gather (* [autoscaler_ref .request_worker () for _ in range (2 )])
259
- df = md .DataFrame (mt .random .rand (100 , 4 , chunk_size = 2 ), columns = list ("abcd" ))
260
- print (df .execute ())
261
- assert await autoscaler_ref .get_dynamic_worker_nums () > 1
261
+ num_chunks , chunk_size = 20 , 4
262
+ df = md .DataFrame (
263
+ mt .random .rand (num_chunks * chunk_size , 4 , chunk_size = chunk_size ),
264
+ columns = list ("abcd" ),
265
+ )
266
+ latch_actor = ray .remote (CountDownLatch ).remote (1 )
267
+ pid = os .getpid ()
268
+
269
+ def f (pdf , latch ):
270
+ if os .getpid () != pid :
271
+ # type inference will call this function too
272
+ ray .get (latch .wait .remote ())
273
+ return pdf
274
+
275
+ df = df .map_chunk (
276
+ f ,
277
+ args = (latch_actor ,),
278
+ )
279
+ info = df .execute (wait = False )
280
+ while await autoscaler_ref .get_dynamic_worker_nums () <= 1 :
281
+ logger .info ("Waiting workers to be created." )
282
+ await asyncio .sleep (1 )
283
+ await latch_actor .count_down .remote ()
284
+ await info
285
+ assert info .exception () is None
286
+ assert info .progress () == 1
287
+ logger .info ("df execute succeed." )
288
+
262
289
while await autoscaler_ref .get_dynamic_worker_nums () > 1 :
263
290
dynamic_workers = await autoscaler_ref .get_dynamic_workers ()
264
- print ( f "Waiting workers { dynamic_workers } to be released." )
291
+ logger . info ( "Waiting workers %s to be released." , dynamic_workers )
265
292
await asyncio .sleep (1 )
266
293
# Test data on node of released worker can still be fetched
267
- pd_df = df .to_pandas ()
268
- groupby_sum_df = df .rechunk (40 ).groupby ("a" ).sum ()
269
- print (groupby_sum_df .execute ())
294
+ pd_df = df .fetch ()
295
+ groupby_sum_df = df .rechunk (chunk_size * 2 ).groupby ("a" ).sum ()
296
+ logger . info (groupby_sum_df .execute ())
270
297
while await autoscaler_ref .get_dynamic_worker_nums () > 1 :
271
298
dynamic_workers = await autoscaler_ref .get_dynamic_workers ()
272
- print (f"Waiting workers { dynamic_workers } to be released." )
299
+ logger . info (f"Waiting workers %s to be released." , dynamic_workers )
273
300
await asyncio .sleep (1 )
274
301
assert df .to_pandas ().to_dict () == pd_df .to_dict ()
275
302
assert (
276
303
groupby_sum_df .to_pandas ().to_dict () == pd_df .groupby ("a" ).sum ().to_dict ()
277
304
)
305
+
306
+
307
+ class CountDownLatch :
308
+ def __init__ (self , cnt ):
309
+ self .cnt = cnt
310
+
311
+ def count_down (self ):
312
+ self .cnt -= 1
313
+
314
+ def get_count (self ):
315
+ return self .cnt
316
+
317
+ async def wait (self ):
318
+ while self .cnt != 0 :
319
+ await asyncio .sleep (0.01 )
0 commit comments