Skip to content

Commit f927e0f

Browse files
[Data] Abstracted BlockColumnAccessor (#51326)
## Why are these changes needed? This component is generalizing and abstracting common columnar operations. Additionally - Added tests - Cleaned up some minor issues, - Enable `concat` and `concat_and_sort` to do type promotions --------- Signed-off-by: Alexey Kudinkin <ak@anyscale.com>
1 parent 146590b commit f927e0f

File tree

11 files changed

+572
-195
lines changed

11 files changed

+572
-195
lines changed

python/ray/data/_internal/arrow_block.py

Lines changed: 57 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
BlockMetadata,
3434
BlockType,
3535
U,
36+
BlockColumnAccessor,
3637
)
3738
from ray.data.context import DataContext
3839

@@ -329,87 +330,6 @@ def _sample(self, n_samples: int, sort_key: "SortKey") -> "pyarrow.Table":
329330
table = self._table.select(sort_key.get_columns())
330331
return transform_pyarrow.take_table(table, indices)
331332

332-
def count(self, on: str, ignore_nulls: bool = False) -> Optional[U]:
333-
"""Count the number of non-null values in the provided column."""
334-
import pyarrow.compute as pac
335-
336-
if not isinstance(on, str):
337-
raise ValueError(
338-
"on must be a string when aggregating on Arrow blocks, but got:"
339-
f"{type(on)}."
340-
)
341-
342-
if self.num_rows() == 0:
343-
return None
344-
345-
mode = "only_valid" if ignore_nulls else "all"
346-
347-
col = self._table[on]
348-
return pac.count(col, mode=mode).as_py()
349-
350-
def _apply_arrow_compute(
351-
self, compute_fn: Callable, on: str, ignore_nulls: bool
352-
) -> Optional[U]:
353-
"""Helper providing null handling around applying an aggregation to a column."""
354-
import pyarrow as pa
355-
356-
if not isinstance(on, str):
357-
raise ValueError(
358-
"on must be a string when aggregating on Arrow blocks, but got:"
359-
f"{type(on)}."
360-
)
361-
362-
if self.num_rows() == 0:
363-
return None
364-
365-
col = self._table[on]
366-
if pa.types.is_null(col.type):
367-
return None
368-
else:
369-
return compute_fn(col, skip_nulls=ignore_nulls).as_py()
370-
371-
def sum(self, on: str, ignore_nulls: bool) -> Optional[U]:
372-
import pyarrow.compute as pac
373-
374-
return self._apply_arrow_compute(pac.sum, on, ignore_nulls)
375-
376-
def min(self, on: str, ignore_nulls: bool) -> Optional[U]:
377-
import pyarrow.compute as pac
378-
379-
return self._apply_arrow_compute(pac.min, on, ignore_nulls)
380-
381-
def max(self, on: str, ignore_nulls: bool) -> Optional[U]:
382-
import pyarrow.compute as pac
383-
384-
return self._apply_arrow_compute(pac.max, on, ignore_nulls)
385-
386-
def mean(self, on: str, ignore_nulls: bool) -> Optional[U]:
387-
import pyarrow.compute as pac
388-
389-
return self._apply_arrow_compute(pac.mean, on, ignore_nulls)
390-
391-
def sum_of_squared_diffs_from_mean(
392-
self,
393-
on: str,
394-
ignore_nulls: bool,
395-
mean: Optional[U] = None,
396-
) -> Optional[U]:
397-
import pyarrow.compute as pac
398-
399-
if mean is None:
400-
# If precomputed mean not given, we compute it ourselves.
401-
mean = self.mean(on, ignore_nulls)
402-
if mean is None:
403-
return None
404-
return self._apply_arrow_compute(
405-
lambda col, skip_nulls: pac.sum(
406-
pac.power(pac.subtract(col, mean), 2),
407-
skip_nulls=skip_nulls,
408-
),
409-
on,
410-
ignore_nulls,
411-
)
412-
413333
def sort(self, sort_key: "SortKey") -> Block:
414334
assert (
415335
sort_key.get_columns()
@@ -449,8 +369,63 @@ def merge_sorted_blocks(
449369
# Handle blocks of different types.
450370
blocks = TableBlockAccessor.normalize_block_types(blocks, BlockType.ARROW)
451371
concat_and_sort = get_concat_and_sort_transform(DataContext.get_current())
452-
ret = concat_and_sort(blocks, sort_key)
372+
ret = concat_and_sort(blocks, sort_key, promote_types=True)
453373
return ret, ArrowBlockAccessor(ret).get_metadata(exec_stats=stats.build())
454374

455375
def block_type(self) -> BlockType:
456376
return BlockType.ARROW
377+
378+
379+
class ArrowBlockColumnAccessor(BlockColumnAccessor):
380+
def __init__(self, col: Union["pyarrow.Array", "pyarrow.ChunkedArray"]):
381+
super().__init__(col)
382+
383+
def count(self, *, ignore_nulls: bool, as_py: bool = True) -> Optional[U]:
384+
import pyarrow.compute as pac
385+
386+
res = pac.count(self._column, mode="only_valid" if ignore_nulls else "all")
387+
return res.as_py() if as_py else res
388+
389+
def sum(self, *, ignore_nulls: bool, as_py: bool = True) -> Optional[U]:
390+
import pyarrow.compute as pac
391+
392+
res = pac.sum(self._column, skip_nulls=ignore_nulls)
393+
return res.as_py() if as_py else res
394+
395+
def min(self, *, ignore_nulls: bool, as_py: bool = True) -> Optional[U]:
396+
import pyarrow.compute as pac
397+
398+
res = pac.min(self._column, skip_nulls=ignore_nulls)
399+
return res.as_py() if as_py else res
400+
401+
def max(self, *, ignore_nulls: bool, as_py: bool = True) -> Optional[U]:
402+
import pyarrow.compute as pac
403+
404+
res = pac.max(self._column, skip_nulls=ignore_nulls)
405+
return res.as_py() if as_py else res
406+
407+
def mean(self, *, ignore_nulls: bool, as_py: bool = True) -> Optional[U]:
408+
import pyarrow.compute as pac
409+
410+
res = pac.mean(self._column, skip_nulls=ignore_nulls)
411+
return res.as_py() if as_py else res
412+
413+
def sum_of_squared_diffs_from_mean(
414+
self, ignore_nulls: bool, mean: Optional[U] = None, as_py: bool = True
415+
) -> Optional[U]:
416+
import pyarrow.compute as pac
417+
418+
# Calculate mean if not provided
419+
if mean is None:
420+
mean = self.mean(ignore_nulls=ignore_nulls)
421+
422+
if mean is None:
423+
return None
424+
425+
res = pac.sum(
426+
pac.power(pac.subtract(self._column, mean), 2), skip_nulls=ignore_nulls
427+
)
428+
return res.as_py() if as_py else res
429+
430+
def to_pylist(self):
431+
return self._column.to_pylist()

python/ray/data/_internal/arrow_ops/transform_polars.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def sort(table: "pyarrow.Table", sort_key: "SortKey") -> "pyarrow.Table":
3030

3131

3232
def concat_and_sort(
33-
blocks: List["pyarrow.Table"], sort_key: "SortKey"
33+
blocks: List["pyarrow.Table"], sort_key: "SortKey", *, promote_types: bool = False
3434
) -> "pyarrow.Table":
3535
check_polars_installed()
3636
blocks = [pl.from_arrow(block) for block in blocks]

python/ray/data/_internal/arrow_ops/transform_pyarrow.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ def concat(
454454

455455
if not blocks:
456456
# Short-circuit on empty list of blocks.
457-
return blocks
457+
return pa.table([])
458458

459459
if len(blocks) == 1:
460460
return blocks[0]
@@ -566,12 +566,20 @@ def concat(
566566

567567

568568
def concat_and_sort(
569-
blocks: List["pyarrow.Table"], sort_key: "SortKey"
569+
blocks: List["pyarrow.Table"],
570+
sort_key: "SortKey",
571+
*,
572+
promote_types: bool = False,
570573
) -> "pyarrow.Table":
574+
import pyarrow as pa
571575
import pyarrow.compute as pac
572576

573-
ret = concat(blocks, promote_types=True)
577+
if len(blocks) == 0:
578+
return pa.table([])
579+
580+
ret = concat(blocks, promote_types=promote_types)
574581
indices = pac.sort_indices(ret, sort_keys=sort_key.to_arrow_sort_args())
582+
575583
return take_table(ret, indices)
576584

577585

0 commit comments

Comments
 (0)