|
33 | 33 | BlockMetadata,
|
34 | 34 | BlockType,
|
35 | 35 | U,
|
| 36 | + BlockColumnAccessor, |
36 | 37 | )
|
37 | 38 | from ray.data.context import DataContext
|
38 | 39 |
|
@@ -329,87 +330,6 @@ def _sample(self, n_samples: int, sort_key: "SortKey") -> "pyarrow.Table":
|
329 | 330 | table = self._table.select(sort_key.get_columns())
|
330 | 331 | return transform_pyarrow.take_table(table, indices)
|
331 | 332 |
|
332 |
| - def count(self, on: str, ignore_nulls: bool = False) -> Optional[U]: |
333 |
| - """Count the number of non-null values in the provided column.""" |
334 |
| - import pyarrow.compute as pac |
335 |
| - |
336 |
| - if not isinstance(on, str): |
337 |
| - raise ValueError( |
338 |
| - "on must be a string when aggregating on Arrow blocks, but got:" |
339 |
| - f"{type(on)}." |
340 |
| - ) |
341 |
| - |
342 |
| - if self.num_rows() == 0: |
343 |
| - return None |
344 |
| - |
345 |
| - mode = "only_valid" if ignore_nulls else "all" |
346 |
| - |
347 |
| - col = self._table[on] |
348 |
| - return pac.count(col, mode=mode).as_py() |
349 |
| - |
350 |
| - def _apply_arrow_compute( |
351 |
| - self, compute_fn: Callable, on: str, ignore_nulls: bool |
352 |
| - ) -> Optional[U]: |
353 |
| - """Helper providing null handling around applying an aggregation to a column.""" |
354 |
| - import pyarrow as pa |
355 |
| - |
356 |
| - if not isinstance(on, str): |
357 |
| - raise ValueError( |
358 |
| - "on must be a string when aggregating on Arrow blocks, but got:" |
359 |
| - f"{type(on)}." |
360 |
| - ) |
361 |
| - |
362 |
| - if self.num_rows() == 0: |
363 |
| - return None |
364 |
| - |
365 |
| - col = self._table[on] |
366 |
| - if pa.types.is_null(col.type): |
367 |
| - return None |
368 |
| - else: |
369 |
| - return compute_fn(col, skip_nulls=ignore_nulls).as_py() |
370 |
| - |
371 |
| - def sum(self, on: str, ignore_nulls: bool) -> Optional[U]: |
372 |
| - import pyarrow.compute as pac |
373 |
| - |
374 |
| - return self._apply_arrow_compute(pac.sum, on, ignore_nulls) |
375 |
| - |
376 |
| - def min(self, on: str, ignore_nulls: bool) -> Optional[U]: |
377 |
| - import pyarrow.compute as pac |
378 |
| - |
379 |
| - return self._apply_arrow_compute(pac.min, on, ignore_nulls) |
380 |
| - |
381 |
| - def max(self, on: str, ignore_nulls: bool) -> Optional[U]: |
382 |
| - import pyarrow.compute as pac |
383 |
| - |
384 |
| - return self._apply_arrow_compute(pac.max, on, ignore_nulls) |
385 |
| - |
386 |
| - def mean(self, on: str, ignore_nulls: bool) -> Optional[U]: |
387 |
| - import pyarrow.compute as pac |
388 |
| - |
389 |
| - return self._apply_arrow_compute(pac.mean, on, ignore_nulls) |
390 |
| - |
391 |
| - def sum_of_squared_diffs_from_mean( |
392 |
| - self, |
393 |
| - on: str, |
394 |
| - ignore_nulls: bool, |
395 |
| - mean: Optional[U] = None, |
396 |
| - ) -> Optional[U]: |
397 |
| - import pyarrow.compute as pac |
398 |
| - |
399 |
| - if mean is None: |
400 |
| - # If precomputed mean not given, we compute it ourselves. |
401 |
| - mean = self.mean(on, ignore_nulls) |
402 |
| - if mean is None: |
403 |
| - return None |
404 |
| - return self._apply_arrow_compute( |
405 |
| - lambda col, skip_nulls: pac.sum( |
406 |
| - pac.power(pac.subtract(col, mean), 2), |
407 |
| - skip_nulls=skip_nulls, |
408 |
| - ), |
409 |
| - on, |
410 |
| - ignore_nulls, |
411 |
| - ) |
412 |
| - |
413 | 333 | def sort(self, sort_key: "SortKey") -> Block:
|
414 | 334 | assert (
|
415 | 335 | sort_key.get_columns()
|
@@ -449,8 +369,63 @@ def merge_sorted_blocks(
|
449 | 369 | # Handle blocks of different types.
|
450 | 370 | blocks = TableBlockAccessor.normalize_block_types(blocks, BlockType.ARROW)
|
451 | 371 | concat_and_sort = get_concat_and_sort_transform(DataContext.get_current())
|
452 |
| - ret = concat_and_sort(blocks, sort_key) |
| 372 | + ret = concat_and_sort(blocks, sort_key, promote_types=True) |
453 | 373 | return ret, ArrowBlockAccessor(ret).get_metadata(exec_stats=stats.build())
|
454 | 374 |
|
455 | 375 | def block_type(self) -> BlockType:
|
456 | 376 | return BlockType.ARROW
|
| 377 | + |
| 378 | + |
| 379 | +class ArrowBlockColumnAccessor(BlockColumnAccessor): |
| 380 | + def __init__(self, col: Union["pyarrow.Array", "pyarrow.ChunkedArray"]): |
| 381 | + super().__init__(col) |
| 382 | + |
| 383 | + def count(self, *, ignore_nulls: bool, as_py: bool = True) -> Optional[U]: |
| 384 | + import pyarrow.compute as pac |
| 385 | + |
| 386 | + res = pac.count(self._column, mode="only_valid" if ignore_nulls else "all") |
| 387 | + return res.as_py() if as_py else res |
| 388 | + |
| 389 | + def sum(self, *, ignore_nulls: bool, as_py: bool = True) -> Optional[U]: |
| 390 | + import pyarrow.compute as pac |
| 391 | + |
| 392 | + res = pac.sum(self._column, skip_nulls=ignore_nulls) |
| 393 | + return res.as_py() if as_py else res |
| 394 | + |
| 395 | + def min(self, *, ignore_nulls: bool, as_py: bool = True) -> Optional[U]: |
| 396 | + import pyarrow.compute as pac |
| 397 | + |
| 398 | + res = pac.min(self._column, skip_nulls=ignore_nulls) |
| 399 | + return res.as_py() if as_py else res |
| 400 | + |
| 401 | + def max(self, *, ignore_nulls: bool, as_py: bool = True) -> Optional[U]: |
| 402 | + import pyarrow.compute as pac |
| 403 | + |
| 404 | + res = pac.max(self._column, skip_nulls=ignore_nulls) |
| 405 | + return res.as_py() if as_py else res |
| 406 | + |
| 407 | + def mean(self, *, ignore_nulls: bool, as_py: bool = True) -> Optional[U]: |
| 408 | + import pyarrow.compute as pac |
| 409 | + |
| 410 | + res = pac.mean(self._column, skip_nulls=ignore_nulls) |
| 411 | + return res.as_py() if as_py else res |
| 412 | + |
| 413 | + def sum_of_squared_diffs_from_mean( |
| 414 | + self, ignore_nulls: bool, mean: Optional[U] = None, as_py: bool = True |
| 415 | + ) -> Optional[U]: |
| 416 | + import pyarrow.compute as pac |
| 417 | + |
| 418 | + # Calculate mean if not provided |
| 419 | + if mean is None: |
| 420 | + mean = self.mean(ignore_nulls=ignore_nulls) |
| 421 | + |
| 422 | + if mean is None: |
| 423 | + return None |
| 424 | + |
| 425 | + res = pac.sum( |
| 426 | + pac.power(pac.subtract(self._column, mean), 2), skip_nulls=ignore_nulls |
| 427 | + ) |
| 428 | + return res.as_py() if as_py else res |
| 429 | + |
| 430 | + def to_pylist(self): |
| 431 | + return self._column.to_pylist() |
0 commit comments