Skip to content

Commit 26a8a77

Browse files
[Data] Abstracting common shuffle utility (#51237)
1. Abstracted common `shuffle` util 2. Rebased `ArrowBlockAccessor.random_shuffle` to use it 3. Cleaning up `RetryingPyFileSystem` 4. Added `RowToBlockMapTransformFn` --------- Signed-off-by: Alexey Kudinkin <ak@anyscale.com>
1 parent 3ee2ff0 commit 26a8a77

File tree

11 files changed

+91
-35
lines changed

11 files changed

+91
-35
lines changed

python/ray/_private/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@
5151
if TYPE_CHECKING:
5252
from ray.runtime_env import RuntimeEnv
5353

54+
55+
INT32_MAX = (2**31) - 1
56+
57+
5458
pwd = None
5559
if sys.platform != "win32":
5660
import pwd

python/ray/data/_internal/arrow_block.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
pyarrow_table_from_pydict,
2323
)
2424
from ray.data._internal.arrow_ops import transform_polars, transform_pyarrow
25+
from ray.data._internal.arrow_ops.transform_pyarrow import shuffle
2526
from ray.data._internal.row import TableRow
2627
from ray.data._internal.table_block import TableBlockAccessor, TableBlockBuilder
2728
from ray.data._internal.util import find_partitions
@@ -215,14 +216,7 @@ def slice(self, start: int, end: int, copy: bool = False) -> "pyarrow.Table":
215216
return view
216217

217218
def random_shuffle(self, random_seed: Optional[int]) -> "pyarrow.Table":
218-
num_rows = self.num_rows()
219-
if num_rows == 0:
220-
return pyarrow.table([])
221-
random = np.random.RandomState(random_seed)
222-
shuffled_indices = np.arange(num_rows)
223-
# Shuffle all rows in-place
224-
random.shuffle(shuffled_indices)
225-
return self.take(pyarrow.array(shuffled_indices))
219+
return shuffle(self._table, random_seed)
226220

227221
def schema(self) -> "pyarrow.lib.Schema":
228222
return self._table.schema

python/ray/data/_internal/arrow_ops/transform_pyarrow.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import TYPE_CHECKING, List, Union, Dict
2+
from typing import TYPE_CHECKING, List, Union, Dict, Optional
33

44
import numpy as np
55
from packaging.version import parse as parse_version
@@ -59,7 +59,7 @@ def sort(table: "pyarrow.Table", sort_key: "SortKey") -> "pyarrow.Table":
5959

6060
def take_table(
6161
table: "pyarrow.Table",
62-
indices: Union[List[int], "pyarrow.Array", "pyarrow.ChunkedArray"],
62+
indices: Union[List[int], np.ndarray, "pyarrow.Array", "pyarrow.ChunkedArray"],
6363
) -> "pyarrow.Table":
6464
"""Select rows from the table.
6565
@@ -421,6 +421,19 @@ def _align_struct_fields(
421421
return aligned_blocks
422422

423423

424+
def shuffle(block: "pyarrow.Table", seed: Optional[int] = None) -> "pyarrow.Table":
425+
"""Shuffles provided Arrow table"""
426+
427+
if len(block) == 0:
428+
return block
429+
430+
indices = np.arange(block.num_rows)
431+
# Shuffle indices
432+
np.random.RandomState(seed).shuffle(indices)
433+
434+
return take_table(block, indices)
435+
436+
424437
def concat(
425438
blocks: List["pyarrow.Table"], *, promote_types: bool = False
426439
) -> "pyarrow.Table":

python/ray/data/_internal/datasource/parquet_datasource.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,8 @@ def __init__(
202202
self._unresolved_paths = paths
203203
paths, self._filesystem = _resolve_paths_and_filesystem(paths, filesystem)
204204
filesystem = RetryingPyFileSystem.wrap(
205-
self._filesystem, context=DataContext.get_current()
205+
self._filesystem,
206+
retryable_errors=DataContext.get_current().retried_io_errors,
206207
)
207208

208209
# HACK: PyArrow's `ParquetDataset` errors if input paths contain non-parquet

python/ray/data/_internal/execution/operators/map_transformer.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,30 @@ def __eq__(self, other):
336336
)
337337

338338

339+
class RowToBlockMapTransformFn(MapTransformFn):
340+
"""A Row-to-Batch MapTransformFn."""
341+
342+
def __init__(
343+
self, transform_fn: MapTransformCallable[Row, Block], is_udf: bool = False
344+
):
345+
self._transform_fn = transform_fn
346+
super().__init__(
347+
MapTransformFnDataType.Row,
348+
MapTransformFnDataType.Block,
349+
category=MapTransformFnCategory.DataProcess,
350+
is_udf=is_udf,
351+
)
352+
353+
def __call__(self, input: Iterable[Row], ctx: TaskContext) -> Iterable[Block]:
354+
yield from self._transform_fn(input, ctx)
355+
356+
def __eq__(self, other):
357+
return (
358+
isinstance(other, RowToBlockMapTransformFn)
359+
and self._transform_fn == other._transform_fn
360+
)
361+
362+
339363
class BlockMapTransformFn(MapTransformFn):
340364
"""A block-to-block MapTransformFn."""
341365

python/ray/data/_internal/util.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,8 +1143,8 @@ def __init__(self, handler: "RetryingPyFileSystemHandler"):
11431143
super().__init__(handler)
11441144

11451145
@property
1146-
def data_context(self):
1147-
return self.handler.data_context
1146+
def retryable_errors(self) -> List[str]:
1147+
return self.handler._retryable_errors
11481148

11491149
def unwrap(self):
11501150
return self.handler.unwrap()
@@ -1153,13 +1153,15 @@ def unwrap(self):
11531153
def wrap(
11541154
cls,
11551155
fs: "pyarrow.fs.FileSystem",
1156-
context: DataContext,
1156+
retryable_errors: List[str],
11571157
max_attempts: int = 10,
11581158
max_backoff_s: int = 32,
11591159
):
11601160
if isinstance(fs, RetryingPyFileSystem):
11611161
return fs
1162-
handler = RetryingPyFileSystemHandler(fs, context, max_attempts, max_backoff_s)
1162+
handler = RetryingPyFileSystemHandler(
1163+
fs, retryable_errors, max_attempts, max_backoff_s
1164+
)
11631165
return cls(handler)
11641166

11651167
def __reduce__(self):
@@ -1182,7 +1184,7 @@ class RetryingPyFileSystemHandler(pyarrow.fs.FileSystemHandler):
11821184
def __init__(
11831185
self,
11841186
fs: "pyarrow.fs.FileSystem",
1185-
context: DataContext,
1187+
retryable_errors: List[str] = tuple(),
11861188
max_attempts: int = 10,
11871189
max_backoff_s: int = 32,
11881190
):
@@ -1198,20 +1200,16 @@ def __init__(
11981200
fs, RetryingPyFileSystem
11991201
), "Cannot wrap a RetryingPyFileSystem"
12001202
self._fs = fs
1201-
self._data_context = context
1203+
self._retryable_errors = retryable_errors
12021204
self._max_attempts = max_attempts
12031205
self._max_backoff_s = max_backoff_s
12041206

1205-
@property
1206-
def data_context(self):
1207-
return self._data_context
1208-
12091207
def _retry_operation(self, operation: Callable, description: str):
12101208
"""Execute an operation with retries."""
12111209
return call_with_retry(
12121210
operation,
12131211
description=description,
1214-
match=self._data_context.retried_io_errors,
1212+
match=self._retryable_errors,
12151213
max_attempts=self._max_attempts,
12161214
max_backoff_s=self._max_backoff_s,
12171215
)

python/ray/data/datasource/file_based_datasource.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def __init__(
142142
self._include_paths = include_paths
143143
paths, self._filesystem = _resolve_paths_and_filesystem(paths, filesystem)
144144
self._filesystem = RetryingPyFileSystem.wrap(
145-
self._filesystem, context=self._data_context
145+
self._filesystem, retryable_errors=self._data_context.retried_io_errors
146146
)
147147
paths, file_sizes = map(
148148
list,
@@ -475,28 +475,31 @@ def _wrap_s3_serialization_workaround(filesystem: "pyarrow.fs.FileSystem"):
475475

476476
wrap_retries = False
477477
fs_to_be_wrapped = filesystem # Only unwrap for S3FileSystemWrapper
478-
context = None
478+
retryable_errors = []
479479
if isinstance(fs_to_be_wrapped, RetryingPyFileSystem):
480480
wrap_retries = True
481-
context = fs_to_be_wrapped.data_context
481+
retryable_errors = fs_to_be_wrapped.retryable_errors
482482
fs_to_be_wrapped = fs_to_be_wrapped.unwrap()
483483
if isinstance(fs_to_be_wrapped, pa.fs.S3FileSystem):
484484
return _S3FileSystemWrapper(
485-
fs_to_be_wrapped, wrap_retries=wrap_retries, context=context
485+
fs_to_be_wrapped,
486+
wrap_retries=wrap_retries,
487+
retryable_errors=retryable_errors,
486488
)
487489
return filesystem
488490

489491

490492
def _unwrap_s3_serialization_workaround(
491493
filesystem: Union["pyarrow.fs.FileSystem", "_S3FileSystemWrapper"],
492-
context: Optional[DataContext] = None,
493494
):
494495
if isinstance(filesystem, _S3FileSystemWrapper):
495496
wrap_retries = filesystem._wrap_retries
496-
context = filesystem._context
497+
retryable_errors = filesystem._retryable_erros
497498
filesystem = filesystem.unwrap()
498499
if wrap_retries:
499-
filesystem = RetryingPyFileSystem.wrap(filesystem, context=context)
500+
filesystem = RetryingPyFileSystem.wrap(
501+
filesystem, retryable_errors=retryable_errors
502+
)
500503
return filesystem
501504

502505

@@ -505,11 +508,11 @@ def __init__(
505508
self,
506509
fs: "pyarrow.fs.S3FileSystem",
507510
wrap_retries: bool = False,
508-
context: Optional[DataContext] = None,
511+
retryable_errors: List[str] = tuple(),
509512
):
510513
self._fs = fs
511514
self._wrap_retries = wrap_retries
512-
self._context = context
515+
self._retryable_erros = retryable_errors
513516

514517
def unwrap(self):
515518
return self._fs
@@ -548,7 +551,9 @@ def _resolve_kwargs(
548551
return kwargs
549552

550553

551-
def _validate_shuffle_arg(shuffle: Optional[str]) -> None:
554+
def _validate_shuffle_arg(
555+
shuffle: Union[Literal["files"], FileShuffleConfig, None]
556+
) -> None:
552557
if not (
553558
shuffle is None or shuffle == "files" or isinstance(shuffle, FileShuffleConfig)
554559
):

python/ray/data/datasource/file_datasink.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666
self.unresolved_path = path
6767
paths, self.filesystem = _resolve_paths_and_filesystem(path, filesystem)
6868
self.filesystem = RetryingPyFileSystem.wrap(
69-
self.filesystem, context=self._data_context
69+
self.filesystem, retryable_errors=self._data_context.retried_io_errors
7070
)
7171
assert len(paths) == 1, len(paths)
7272
self.path = paths[0]

python/ray/data/tests/test_image.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,12 @@ def test_random_shuffle(self, ray_start_regular_shared, restore_data_context):
166166
shuffle="files",
167167
)
168168

169-
# Execute 10 times to get a set of output paths.
169+
# Execute 5 times to get a set of output paths.
170170
output_paths_list = []
171-
for _ in range(10):
171+
for _ in range(5):
172172
paths = [row["path"][-len(file_paths[0]) :] for row in ds.take_all()]
173173
output_paths_list.append(paths)
174+
174175
all_paths_matched = [
175176
file_paths == output_paths for output_paths in output_paths_list
176177
]

python/ray/data/tests/test_transform_pyarrow.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
try_combine_chunked_columns,
1717
unify_schemas,
1818
MIN_PYARROW_VERSION_TYPE_PROMOTION,
19+
shuffle,
1920
)
2021
from ray.data.block import BlockAccessor
2122
from ray.data.extensions import (
@@ -46,6 +47,20 @@ def test_try_defragment_table():
4647
assert dt == t
4748

4849

50+
def test_shuffle():
51+
t = pa.Table.from_pydict(
52+
{
53+
"index": pa.array(list(range(10))),
54+
}
55+
)
56+
57+
shuffled = shuffle(t, seed=0xDEED)
58+
59+
assert shuffled == pa.Table.from_pydict(
60+
{"index": pa.array([4, 3, 6, 8, 7, 1, 5, 2, 9, 0])}
61+
)
62+
63+
4964
def test_arrow_concat_empty():
5065
# Test empty.
5166
assert concat([]) == []

0 commit comments

Comments
 (0)