Skip to content

Commit f6347c0

Browse files
[Data] Avoid unnecessary conversion to Numpy when creating Arrow/Pandas blocks (#51238)
Context --- This change skips unnecessary blanket conversion to Numpy (applied to every chunk of data) before converting to Pyarrow. That creates challenges when batches contain Arrow native `Scalars` which because of that are ultimately being serialized as `ArrowPythonObjectType` extension. Changes --- We revisit following conversion aspects and convert to Numpy passed in column values only in following cases: - Column name is `TENSOR_COLUMN_NAME` (for compatibility) - Provided column values are already represented by a tensor (either numpy, torch, etc) - Provided column values is a list of ndarrays (we do this for compatibility with previously existing behavior where all column values were blindly converted to Numpy leading to list of ndarrays being converted a tensor) --------- Signed-off-by: Alexey Kudinkin <ak@anyscale.com>
1 parent 5fd8632 commit f6347c0

File tree

15 files changed

+392
-172
lines changed

15 files changed

+392
-172
lines changed

python/ray/air/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ py_test(
4949
py_test(
5050
name = "test_arrow",
5151
size = "small",
52-
srcs = ["tests/test_arrow.py"],
52+
srcs = ["tests/test_arrow.py", "conftest"],
5353
tags = ["team:ml", "team:data", "ray_data", "exclusive"],
5454
deps = [":ml_lib"]
5555
)

python/ray/air/tests/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,11 @@ def restore_data_context(request):
1313
original = copy.deepcopy(ray.data.context.DataContext.get_current())
1414
yield
1515
ray.data.context.DataContext._set_current(original)
16+
17+
18+
@pytest.fixture
19+
def disable_fallback_to_object_extension(request, restore_data_context):
20+
"""Disables fallback to ArrowPythonObjectType"""
21+
ray.data.context.DataContext.get_current().enable_fallback_to_arrow_object_ext_type = (
22+
False
23+
)

python/ray/air/tests/test_arrow.py

Lines changed: 93 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,18 @@
44
import numpy as np
55
import pyarrow as pa
66
import pytest
7+
from packaging.version import parse as parse_version
78

9+
from ray._private.utils import _get_pyarrow_version
810
from ray.air.util.tensor_extensions.arrow import (
911
ArrowConversionError,
1012
_convert_to_pyarrow_native_array,
1113
_infer_pyarrow_type,
1214
convert_to_pyarrow_array,
15+
ArrowTensorArray,
1316
)
1417
from ray.air.util.tensor_extensions.utils import create_ragged_ndarray
18+
from ray.data import DataContext
1519
from ray.tests.conftest import * # noqa
1620

1721
import psutil
@@ -23,16 +27,56 @@ class UserObj:
2327

2428

2529
@pytest.mark.parametrize(
26-
"numpy_precision, expected_arrow_type",
30+
"input",
31+
[
32+
# Python native lists
33+
[
34+
[1, 2],
35+
[3, 4],
36+
],
37+
# Python native tuples
38+
[
39+
(1, 2),
40+
(3, 4),
41+
],
42+
# Lists as PA scalars
43+
[
44+
pa.scalar([1, 2]),
45+
pa.scalar([3, 4]),
46+
],
47+
],
48+
)
49+
def test_arrow_native_list_conversion(input, disable_fallback_to_object_extension):
50+
"""Test asserts that nested lists are represented as native Arrow lists
51+
upon serialization into Arrow format (and are NOT converted to numpy
52+
tensor using extension)"""
53+
54+
if isinstance(input[0], pa.Scalar) and parse_version(
55+
_get_pyarrow_version()
56+
) <= parse_version("13.0.0"):
57+
pytest.skip(
58+
"Pyarrow < 13.0 not able to properly infer native types from its own Scalars"
59+
)
60+
61+
pa_arr = convert_to_pyarrow_array(input, "a")
62+
63+
# Should be able to natively convert back to Pyarrow array,
64+
# not using any extensions
65+
assert pa_arr.type == pa.list_(pa.int64()), pa_arr.type
66+
assert pa.array(input) == pa_arr, pa_arr
67+
68+
69+
@pytest.mark.parametrize("arg_type", ["list", "ndarray"])
70+
@pytest.mark.parametrize(
71+
"numpy_precision, expected_arrow_timestamp_type",
2772
[
2873
("ms", pa.timestamp("ms")),
2974
("us", pa.timestamp("us")),
3075
("ns", pa.timestamp("ns")),
31-
# Arrow has a special date32 type for dates.
32-
("D", pa.date32()),
3376
# The coarsest resolution Arrow supports is seconds.
3477
("Y", pa.timestamp("s")),
3578
("M", pa.timestamp("s")),
79+
("D", pa.timestamp("s")),
3680
("h", pa.timestamp("s")),
3781
("m", pa.timestamp("s")),
3882
("s", pa.timestamp("s")),
@@ -44,26 +88,61 @@ class UserObj:
4488
)
4589
def test_convert_datetime_array(
4690
numpy_precision: str,
47-
expected_arrow_type: pa.DataType,
91+
expected_arrow_timestamp_type: pa.TimestampType,
92+
arg_type: str,
93+
restore_data_context,
4894
):
49-
numpy_array = np.zeros(1, dtype=f"datetime64[{numpy_precision}]")
50-
51-
pyarrow_array = _convert_to_pyarrow_native_array(numpy_array, "")
52-
53-
assert pyarrow_array.type == expected_arrow_type
54-
assert len(numpy_array) == len(pyarrow_array)
55-
56-
95+
DataContext.get_current().enable_fallback_to_arrow_object_ext_type = False
96+
97+
ndarray = np.ones(1, dtype=f"datetime64[{numpy_precision}]")
98+
99+
if arg_type == "ndarray":
100+
column_values = ndarray
101+
elif arg_type == "list":
102+
column_values = [ndarray]
103+
else:
104+
pytest.fail(f"Unknown type: {arg_type}")
105+
106+
# Step 1: Convert to PA array
107+
converted = convert_to_pyarrow_array(column_values, "")
108+
109+
if arg_type == "ndarray":
110+
expected = pa.array(
111+
column_values.astype(f"datetime64[{expected_arrow_timestamp_type.unit}]")
112+
)
113+
elif arg_type == "list":
114+
expected = ArrowTensorArray.from_numpy(
115+
[
116+
column_values[0].astype(
117+
f"datetime64[{expected_arrow_timestamp_type.unit}]"
118+
)
119+
]
120+
)
121+
else:
122+
pytest.fail(f"Unknown type: {arg_type}")
123+
124+
assert expected.type == converted.type
125+
assert expected == converted
126+
127+
128+
@pytest.mark.parametrize("arg_type", ["list", "ndarray"])
57129
@pytest.mark.parametrize("dtype", ["int64", "float64", "datetime64[ns]"])
58-
def test_infer_type_does_not_leak_memory(dtype):
130+
def test_infer_type_does_not_leak_memory(arg_type, dtype):
59131
# Test for https://github.com/apache/arrow/issues/45493.
60-
column_values = np.zeros(923040, dtype=dtype) # A ~7 MiB column
132+
ndarray = np.zeros(923040, dtype=dtype) # A ~7 MiB column
61133

62134
process = psutil.Process()
63135
gc.collect()
64136
pa.default_memory_pool().release_unused()
65137
before = process.memory_info().rss
66138

139+
if arg_type == "ndarray":
140+
column_values = ndarray
141+
elif arg_type == "list":
142+
column_values = [ndarray]
143+
else:
144+
pytest.fail(f"Unknown type: {arg_type}")
145+
67146
_infer_pyarrow_type(column_values)
68147

69148
gc.collect()

0 commit comments

Comments
 (0)