Skip to content

Commit ae5923e

Browse files
authored
fix: support revertable for concatenate in pyarrow logic (#2889)
* wip: fix revertable for concatenate * test: ensure behavior! * test: fix importorskip arrow
1 parent bf7e37f commit ae5923e

File tree

2 files changed

+81
-2
lines changed

2 files changed

+81
-2
lines changed

src/awkward/_connect/pyarrow.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -919,8 +919,15 @@ def direct_Content_subclass_name(node):
919919
return out.__name__
920920

921921

922+
def is_revertable(akarray):
923+
return hasattr(akarray, "__pyarrow_original")
924+
925+
922926
def remove_optiontype(akarray):
923-
return akarray.__pyarrow_original
927+
if callable(akarray.__pyarrow_original):
928+
return akarray.__pyarrow_original()
929+
else:
930+
return akarray.__pyarrow_original
924931

925932

926933
def form_remove_optiontype(akform):
@@ -944,6 +951,17 @@ def handle_arrow(obj, generate_bitmasks=False, pass_empty_field=False):
944951

945952
if len(layouts) == 1:
946953
return layouts[0]
954+
elif any(is_revertable(arr) for arr in layouts):
955+
assert all(is_revertable(arr) for arr in layouts)
956+
# TODO: the callable argument to revertable is a premature(?) optimisation.
957+
# it would be better to obviate the need to compute both revertable and non revertable branches
958+
# e.g. by requesting a particular layout kind from the next `frombuffers` operation
959+
return revertable(
960+
ak.operations.concatenate(layouts, highlevel=False),
961+
lambda: ak.operations.concatenate(
962+
[remove_optiontype(x) for x in layouts], highlevel=False
963+
),
964+
)
947965
else:
948966
return ak.operations.concatenate(layouts, highlevel=False)
949967

@@ -1044,7 +1062,19 @@ def handle_arrow(obj, generate_bitmasks=False, pass_empty_field=False):
10441062
for batch in batches
10451063
if len(batch) > 0
10461064
]
1047-
return ak.operations.concatenate(arrays, highlevel=False)
1065+
if any(is_revertable(arr) for arr in arrays):
1066+
assert all(is_revertable(arr) for arr in arrays)
1067+
# TODO: the callable argument to revertable is a premature(?) optimisation.
1068+
# it would be better to obviate the need to compute both revertable and non revertable branches
1069+
# e.g. by requesting a particular layout kind from the next `frombuffers` operation
1070+
return revertable(
1071+
ak.operations.concatenate(arrays, highlevel=False),
1072+
lambda: ak.operations.concatenate(
1073+
[remove_optiontype(x) for x in arrays], highlevel=False
1074+
),
1075+
)
1076+
else:
1077+
return ak.operations.concatenate(arrays, highlevel=False)
10481078

10491079
elif (
10501080
isinstance(obj, Iterable)

tests/test_2889_test_chunked_array.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
2+
3+
from __future__ import annotations
4+
5+
import pytest
6+
7+
import awkward as ak
8+
9+
pa = pytest.importorskip("pyarrow")
10+
11+
12+
def test_strings():
13+
array = pa.chunked_array([["foo", "bar"], ["blah", "bleh"]])
14+
ak_array = ak.from_arrow(array)
15+
assert ak_array.type == ak.types.ArrayType(
16+
ak.types.ListType(
17+
ak.types.NumpyType("uint8", parameters={"__array__": "char"}),
18+
parameters={"__array__": "string"},
19+
),
20+
4,
21+
)
22+
23+
24+
def test_strings_option():
25+
array = pa.chunked_array([["foo", "bar"], ["blah", "bleh", None]])
26+
ak_array = ak.from_arrow(array)
27+
assert ak_array.type == ak.types.ArrayType(
28+
ak.types.OptionType(
29+
ak.types.ListType(
30+
ak.types.NumpyType("uint8", parameters={"__array__": "char"}),
31+
parameters={"__array__": "string"},
32+
)
33+
),
34+
5,
35+
)
36+
37+
38+
def test_numbers():
39+
array = pa.chunked_array([[1, 2, 3], [4, 5]])
40+
ak_array = ak.from_arrow(array)
41+
assert ak_array.type == ak.types.ArrayType(ak.types.NumpyType("int64"), 5)
42+
43+
44+
def test_numbers_option():
45+
array = pa.chunked_array([[1, 2, 3], [4, 5, None]])
46+
ak_array = ak.from_arrow(array)
47+
assert ak_array.type == ak.types.ArrayType(
48+
ak.types.OptionType(ak.types.NumpyType("int64")), 6
49+
)

0 commit comments

Comments
 (0)