Skip to content

Commit 3f44633

Browse files
authored
[Data] Support async callable classes in flat_map() (#51180)
<!-- Thank you for your contribution! Please review https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? This PR adds async generator support to `flat_map`. The implementation is similar to how #46129 handled async callable classes for map_batches(), changes include: * Generalize the logic in `_generate_transform_fn_for_async_flat_map` so it can process both batches and rows * Add test case for async `flat_map` <!-- Please give a short summary of the change and the problem this solves. --> ## Related issue number #50329 <!-- For example: "Closes #1234" --> ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run `scripts/format.sh` to lint the changes in this PR. - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [ ] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [x] Unit tests - [ ] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: Drice1999 <chenxh267@gmail.com>
1 parent 647b74a commit 3f44633

File tree

2 files changed

+66
-32
lines changed

2 files changed

+66
-32
lines changed

python/ray/data/_internal/planner/plan_udf_map_op.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def _generate_transform_fn_for_map_batches(
376376
) -> MapTransformCallable[DataBatch, DataBatch]:
377377
if inspect.iscoroutinefunction(fn):
378378
# UDF is a callable class with async generator `__call__` method.
379-
transform_fn = _generate_transform_fn_for_async_map_batches(fn)
379+
transform_fn = _generate_transform_fn_for_async_map(fn, _validate_batch_output)
380380

381381
else:
382382

@@ -423,64 +423,66 @@ def transform_fn(
423423
return transform_fn
424424

425425

426-
def _generate_transform_fn_for_async_map_batches(
426+
def _generate_transform_fn_for_async_map(
427427
fn: UserDefinedFunction,
428-
) -> MapTransformCallable[DataBatch, DataBatch]:
429-
def transform_fn(
430-
input_iterable: Iterable[DataBatch], _: TaskContext
431-
) -> Iterable[DataBatch]:
428+
validate_fn,
429+
) -> MapTransformCallable:
430+
# Generates a transform function for asynchronous mapping of items (either batches or rows)
431+
# using a user-defined function (UDF). This consolidated function handles both asynchronous
432+
# batch processing and asynchronous flat mapping (e.g., rows) based on the provided UDF.
433+
def transform_fn(input_iterable: Iterable, _: TaskContext) -> Iterable:
432434
# Use a queue to store outputs from async generator calls.
433-
# We will put output batches into this queue from async
435+
# We will put output items into this queue from async
434436
# generators, and in the main event loop, yield them from
435437
# the queue as they become available.
436-
output_batch_queue = queue.Queue()
438+
output_item_queue = queue.Queue()
437439
# Sentinel object to signal the end of the async generator.
438440
sentinel = object()
439441

440-
async def process_batch(batch: DataBatch):
442+
async def process_item(item):
441443
try:
442-
output_batch_iterator = await fn(batch)
444+
output_item_iterator = await fn(item)
443445
# As soon as results become available from the async generator,
444446
# put them into the result queue so they can be yielded.
445-
async for output_batch in output_batch_iterator:
446-
output_batch_queue.put(output_batch)
447+
async for output_item in output_item_iterator:
448+
output_item_queue.put(output_item)
447449
except Exception as e:
448-
output_batch_queue.put(
450+
output_item_queue.put(
449451
e
450452
) # Put the exception into the queue to signal an error
451453

452-
async def process_all_batches():
454+
async def process_all_items():
453455
try:
454456
loop = ray.data._map_actor_context.udf_map_asyncio_loop
455-
tasks = [loop.create_task(process_batch(x)) for x in input_iterable]
457+
tasks = [loop.create_task(process_item(x)) for x in input_iterable]
456458

457459
ctx = ray.data.DataContext.get_current()
458460
if ctx.execution_options.preserve_order:
459461
for task in tasks:
460-
await task()
462+
await task
461463
else:
462464
for task in asyncio.as_completed(tasks):
463465
await task
464466
finally:
465-
output_batch_queue.put(sentinel)
467+
output_item_queue.put(sentinel)
466468

467-
# Use the existing event loop to create and run Tasks to process each batch
469+
# Use the existing event loop to create and run Tasks to process each item
468470
loop = ray.data._map_actor_context.udf_map_asyncio_loop
469-
asyncio.run_coroutine_threadsafe(process_all_batches(), loop)
471+
asyncio.run_coroutine_threadsafe(process_all_items(), loop)
470472

471473
# Yield results as they become available.
472474
while True:
473-
# Here, `out_batch` is a one-row output batch
475+
# Here, `out_item` is a one-row output item
474476
# from the async generator, corresponding to a
475-
# single row from the input batch.
476-
out_batch = output_batch_queue.get()
477-
if out_batch is sentinel:
477+
# single row from the input item.
478+
out_item = output_item_queue.get()
479+
if out_item is sentinel:
478480
# Break out of the loop when the sentinel is received.
479481
break
480-
if isinstance(out_batch, Exception):
481-
raise out_batch
482-
_validate_batch_output(out_batch)
483-
yield out_batch
482+
if isinstance(out_item, Exception):
483+
raise out_item
484+
validate_fn(out_item)
485+
yield out_item
484486

485487
return transform_fn
486488

@@ -511,11 +513,17 @@ def transform_fn(rows: Iterable[Row], _: TaskContext) -> Iterable[Row]:
511513
def _generate_transform_fn_for_flat_map(
512514
fn: UserDefinedFunction,
513515
) -> MapTransformCallable[Row, Row]:
514-
def transform_fn(rows: Iterable[Row], _: TaskContext) -> Iterable[Row]:
515-
for row in rows:
516-
for out_row in fn(row):
517-
_validate_row_output(out_row)
518-
yield out_row
516+
if inspect.iscoroutinefunction(fn):
517+
# UDF is a callable class with async generator `__call__` method.
518+
transform_fn = _generate_transform_fn_for_async_map(fn, _validate_row_output)
519+
520+
else:
521+
522+
def transform_fn(rows: Iterable[Row], _: TaskContext) -> Iterable[Row]:
523+
for row in rows:
524+
for out_row in fn(row):
525+
_validate_row_output(out_row)
526+
yield out_row
519527

520528
return transform_fn
521529

python/ray/data/tests/test_map.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1687,6 +1687,32 @@ async def __call__(self, batch):
16871687
)
16881688

16891689

1690+
def test_flat_map_async_generator(shutdown_only):
1691+
async def fetch_data(id):
1692+
return {"id": id}
1693+
1694+
class AsyncActor:
1695+
def __init__(self):
1696+
pass
1697+
1698+
async def __call__(self, row):
1699+
id = row["id"]
1700+
task1 = asyncio.create_task(fetch_data(id))
1701+
task2 = asyncio.create_task(fetch_data(id + 1))
1702+
print(f"yield task1: {id}")
1703+
yield await task1
1704+
print(f"sleep: {id}")
1705+
await asyncio.sleep(id % 5)
1706+
print(f"yield task2: {id}")
1707+
yield await task2
1708+
1709+
n = 10
1710+
ds = ray.data.from_items([{"id": i} for i in range(0, n, 2)])
1711+
ds = ds.flat_map(AsyncActor, concurrency=1, max_concurrency=2)
1712+
output = ds.take_all()
1713+
assert sorted(extract_values("id", output)) == list(range(0, n)), output
1714+
1715+
16901716
def test_map_batches_async_exception_propagation(shutdown_only):
16911717
ray.shutdown()
16921718
ray.init(num_cpus=2)

0 commit comments

Comments
 (0)