Skip to content

Commit 3081fd2

Browse files
jorisvandenbosscherhshadrachmroeschke
authored
[backport 2.3.x] API (string dtype): implement hierarchy (NA > NaN, pyarrow > python) for consistent comparisons between different string dtypes (#61138) (#61649)
Co-authored-by: Richard Shadrach <rhshadrach@gmail.com> Co-authored-by: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com>
1 parent 47645fe commit 3081fd2

File tree

6 files changed

+131
-22
lines changed

6 files changed

+131
-22
lines changed

doc/source/whatsnew/v2.3.0.rst

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,20 @@ Notable bug fixes
3838

3939
These are bug fixes that might have notable behavior changes.
4040

41-
.. _whatsnew_230.notable_bug_fixes.notable_bug_fix1:
41+
.. _whatsnew_230.notable_bug_fixes.string_comparisons:
4242

43-
notable_bug_fix1
44-
^^^^^^^^^^^^^^^^
43+
Comparisons between different string dtypes
44+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
45+
46+
In previous versions, comparing Series of different string dtypes (e.g. ``pd.StringDtype("pyarrow", na_value=pd.NA)`` against ``pd.StringDtype("python", na_value=np.nan)``) would result in inconsistent resulting dtype or incorrectly raise. pandas will now use the hierarchy
47+
48+
object < (python, NaN) < (pyarrow, NaN) < (python, NA) < (pyarrow, NA)
49+
50+
in determining the result dtype when there are different string dtypes compared. Some examples:
51+
52+
- When ``pd.StringDtype("pyarrow", na_value=pd.NA)`` is compared against any other string dtype, the result will always be ``boolean[pyarrow]``.
53+
- When ``pd.StringDtype("python", na_value=pd.NA)`` is compared against ``pd.StringDtype("pyarrow", na_value=np.nan)``, the result will be ``boolean``, the NumPy-backed nullable extension array.
54+
- When ``pd.StringDtype("python", na_value=pd.NA)`` is compared against ``pd.StringDtype("python", na_value=np.nan)``, the result will be ``boolean``, the NumPy-backed nullable extension array.
4555

4656
In previous versions, comparing :class:`Series` of different string dtypes (e.g. ``pd.StringDtype("pyarrow", na_value=pd.NA)`` against ``pd.StringDtype("python", na_value=np.nan)``) would result in inconsistent resulting dtype or incorrectly raise. pandas will now use the hierarchy
4757

pandas/core/arrays/arrow/array.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
infer_dtype_from_scalar,
3838
)
3939
from pandas.core.dtypes.common import (
40-
CategoricalDtype,
4140
is_array_like,
4241
is_bool_dtype,
4342
is_float_dtype,
@@ -725,9 +724,7 @@ def __setstate__(self, state) -> None:
725724

726725
def _cmp_method(self, other, op):
727726
pc_func = ARROW_CMP_FUNCS[op.__name__]
728-
if isinstance(
729-
other, (ArrowExtensionArray, np.ndarray, list, BaseMaskedArray)
730-
) or isinstance(getattr(other, "dtype", None), CategoricalDtype):
727+
if isinstance(other, (ExtensionArray, np.ndarray, list)):
731728
try:
732729
result = pc_func(self._pa_array, self._box_pa(other))
733730
except pa.ArrowNotImplementedError:

pandas/core/arrays/string_.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1014,7 +1014,30 @@ def searchsorted(
10141014
return super().searchsorted(value=value, side=side, sorter=sorter)
10151015

10161016
def _cmp_method(self, other, op):
1017-
from pandas.arrays import BooleanArray
1017+
from pandas.arrays import (
1018+
ArrowExtensionArray,
1019+
BooleanArray,
1020+
)
1021+
1022+
if (
1023+
isinstance(other, BaseStringArray)
1024+
and self.dtype.na_value is not libmissing.NA
1025+
and other.dtype.na_value is libmissing.NA
1026+
):
1027+
# NA has priority of NaN semantics
1028+
return NotImplemented
1029+
1030+
if isinstance(other, ArrowExtensionArray):
1031+
if isinstance(other, BaseStringArray):
1032+
# pyarrow storage has priority over python storage
1033+
# (except if we have NA semantics and other not)
1034+
if not (
1035+
self.dtype.na_value is libmissing.NA
1036+
and other.dtype.na_value is not libmissing.NA
1037+
):
1038+
return NotImplemented
1039+
else:
1040+
return NotImplemented
10181041

10191042
if isinstance(other, StringArray):
10201043
other = other._ndarray

pandas/core/arrays/string_arrow.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,14 @@ def value_counts(self, dropna: bool = True) -> Series:
469469
return result
470470

471471
def _cmp_method(self, other, op):
472+
if (
473+
isinstance(other, (BaseStringArray, ArrowExtensionArray))
474+
and self.dtype.na_value is not libmissing.NA
475+
and other.dtype.na_value is libmissing.NA
476+
):
477+
# NA has priority of NaN semantics
478+
return NotImplemented
479+
472480
result = super()._cmp_method(other, op)
473481
if self.dtype.na_value is np.nan:
474482
if op == operator.ne:

pandas/tests/arrays/string_/test_string.py

Lines changed: 77 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99

1010
from pandas._config import using_string_dtype
1111

12+
from pandas.compat import HAS_PYARROW
1213
from pandas.compat.pyarrow import (
1314
pa_version_under12p0,
1415
pa_version_under19p0,
1516
)
17+
import pandas.util._test_decorators as td
1618

1719
from pandas.core.dtypes.common import is_dtype_equal
1820

@@ -44,6 +46,25 @@ def cls(dtype):
4446
return dtype.construct_array_type()
4547

4648

49+
def string_dtype_highest_priority(dtype1, dtype2):
50+
if HAS_PYARROW:
51+
DTYPE_HIERARCHY = [
52+
pd.StringDtype("python", na_value=np.nan),
53+
pd.StringDtype("pyarrow", na_value=np.nan),
54+
pd.StringDtype("python", na_value=pd.NA),
55+
pd.StringDtype("pyarrow", na_value=pd.NA),
56+
]
57+
else:
58+
DTYPE_HIERARCHY = [
59+
pd.StringDtype("python", na_value=np.nan),
60+
pd.StringDtype("python", na_value=pd.NA),
61+
]
62+
63+
h1 = DTYPE_HIERARCHY.index(dtype1)
64+
h2 = DTYPE_HIERARCHY.index(dtype2)
65+
return DTYPE_HIERARCHY[max(h1, h2)]
66+
67+
4768
def test_dtype_constructor():
4869
pytest.importorskip("pyarrow")
4970

@@ -318,25 +339,75 @@ def test_comparison_methods_scalar_not_string(comparison_op, dtype):
318339
tm.assert_extension_array_equal(result, expected)
319340

320341

321-
def test_comparison_methods_array(comparison_op, dtype):
342+
def test_comparison_methods_array(comparison_op, dtype, dtype2):
322343
op_name = f"__{comparison_op.__name__}__"
323344

324345
a = pd.array(["a", None, "c"], dtype=dtype)
325-
other = [None, None, "c"]
326-
result = getattr(a, op_name)(other)
327-
if dtype.na_value is np.nan:
346+
other = pd.array([None, None, "c"], dtype=dtype2)
347+
result = comparison_op(a, other)
348+
349+
# ensure operation is commutative
350+
result2 = comparison_op(other, a)
351+
tm.assert_equal(result, result2)
352+
353+
if dtype.na_value is np.nan and dtype2.na_value is np.nan:
328354
if operator.ne == comparison_op:
329355
expected = np.array([True, True, False])
330356
else:
331357
expected = np.array([False, False, False])
332358
expected[-1] = getattr(other[-1], op_name)(a[-1])
333359
tm.assert_numpy_array_equal(result, expected)
334360

335-
result = getattr(a, op_name)(pd.NA)
361+
else:
362+
max_dtype = string_dtype_highest_priority(dtype, dtype2)
363+
if max_dtype.storage == "python":
364+
expected_dtype = "boolean"
365+
else:
366+
expected_dtype = "bool[pyarrow]"
367+
368+
expected = np.full(len(a), fill_value=None, dtype="object")
369+
expected[-1] = getattr(other[-1], op_name)(a[-1])
370+
expected = pd.array(expected, dtype=expected_dtype)
371+
tm.assert_extension_array_equal(result, expected)
372+
373+
374+
@td.skip_if_no("pyarrow")
375+
def test_comparison_methods_array_arrow_extension(comparison_op, dtype2):
376+
# Test pd.ArrowDtype(pa.string()) against other string arrays
377+
import pyarrow as pa
378+
379+
op_name = f"__{comparison_op.__name__}__"
380+
dtype = pd.ArrowDtype(pa.string())
381+
a = pd.array(["a", None, "c"], dtype=dtype)
382+
other = pd.array([None, None, "c"], dtype=dtype2)
383+
result = comparison_op(a, other)
384+
385+
# ensure operation is commutative
386+
result2 = comparison_op(other, a)
387+
tm.assert_equal(result, result2)
388+
389+
expected = pd.array([None, None, True], dtype="bool[pyarrow]")
390+
expected[-1] = getattr(other[-1], op_name)(a[-1])
391+
tm.assert_extension_array_equal(result, expected)
392+
393+
394+
def test_comparison_methods_list(comparison_op, dtype):
395+
op_name = f"__{comparison_op.__name__}__"
396+
397+
a = pd.array(["a", None, "c"], dtype=dtype)
398+
other = [None, None, "c"]
399+
result = comparison_op(a, other)
400+
401+
# ensure operation is commutative
402+
result2 = comparison_op(other, a)
403+
tm.assert_equal(result, result2)
404+
405+
if dtype.na_value is np.nan:
336406
if operator.ne == comparison_op:
337-
expected = np.array([True, True, True])
407+
expected = np.array([True, True, False])
338408
else:
339409
expected = np.array([False, False, False])
410+
expected[-1] = getattr(other[-1], op_name)(a[-1])
340411
tm.assert_numpy_array_equal(result, expected)
341412

342413
else:
@@ -346,10 +417,6 @@ def test_comparison_methods_array(comparison_op, dtype):
346417
expected = pd.array(expected, dtype=expected_dtype)
347418
tm.assert_extension_array_equal(result, expected)
348419

349-
result = getattr(a, op_name)(pd.NA)
350-
expected = pd.array([None, None, None], dtype=expected_dtype)
351-
tm.assert_extension_array_equal(result, expected)
352-
353420

354421
def test_constructor_raises(cls):
355422
if cls is pd.arrays.StringArray:

pandas/tests/extension/test_string.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pandas.api.types import is_string_dtype
3131
from pandas.core.arrays import ArrowStringArray
3232
from pandas.core.arrays.string_ import StringDtype
33+
from pandas.tests.arrays.string_.test_string import string_dtype_highest_priority
3334
from pandas.tests.extension import base
3435

3536

@@ -206,10 +207,13 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
206207
dtype = cast(StringDtype, tm.get_dtype(obj))
207208
if op_name in ["__add__", "__radd__"]:
208209
cast_to = dtype
210+
dtype_other = tm.get_dtype(other) if not isinstance(other, str) else None
211+
if isinstance(dtype_other, StringDtype):
212+
cast_to = string_dtype_highest_priority(dtype, dtype_other)
209213
elif dtype.na_value is np.nan:
210214
cast_to = np.bool_ # type: ignore[assignment]
211215
elif dtype.storage == "pyarrow":
212-
cast_to = "boolean[pyarrow]" # type: ignore[assignment]
216+
cast_to = "bool[pyarrow]" # type: ignore[assignment]
213217
else:
214218
cast_to = "boolean" # type: ignore[assignment]
215219
return pointwise_result.astype(cast_to)
@@ -237,10 +241,10 @@ def test_arith_series_with_array(
237241
if (
238242
using_infer_string
239243
and all_arithmetic_operators == "__radd__"
240-
and (
241-
(dtype.na_value is pd.NA) or (dtype.storage == "python" and HAS_PYARROW)
242-
)
244+
and dtype.na_value is pd.NA
245+
and (HAS_PYARROW or dtype.storage == "pyarrow")
243246
):
247+
# TODO(infer_string)
244248
mark = pytest.mark.xfail(
245249
reason="The pointwise operation result will be inferred to "
246250
"string[nan, pyarrow], which does not match the input dtype"

0 commit comments

Comments
 (0)