Skip to content

Commit 833253e

Browse files
committed
(chore): remove non-reindex fixes
1 parent 1db86e8 commit 833253e

File tree

3 files changed

+16
-104
lines changed

3 files changed

+16
-104
lines changed

xarray/core/extension_array.py

Lines changed: 10 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __extension_duck_array__astype(
6262
casting: str = "unsafe",
6363
subok: bool = True,
6464
copy: bool = True,
65-
device: str = None,
65+
device: str | None = None,
6666
) -> T_ExtensionArray:
6767
if (
6868
not (
@@ -209,55 +209,6 @@ def __init__(self, array: T_ExtensionArray):
209209
raise TypeError(f"{array} is not an pandas ExtensionArray.")
210210
self.array = array
211211

212-
self._add_ops_dunders()
213-
214-
def _add_ops_dunders(self):
215-
"""Delegate all operators to pd.Series"""
216-
217-
def create_dunder(name: str) -> Callable:
218-
def binary_dunder(self, other):
219-
self, other = replace_duck_with_series((self, other))
220-
res = getattr(pd.Series, name)(self, other)
221-
if isinstance(res, pd.Series):
222-
res = PandasExtensionArray(res.array)
223-
return res
224-
225-
return binary_dunder
226-
227-
# see pandas.core.arraylike.OpsMixin
228-
binary_operators = [
229-
"__eq__",
230-
"__ne__",
231-
"__lt__",
232-
"__le__",
233-
"__gt__",
234-
"__ge__",
235-
"__and__",
236-
"__rand__",
237-
"__or__",
238-
"__ror__",
239-
"__xor__",
240-
"__rxor__",
241-
"__add__",
242-
"__radd__",
243-
"__sub__",
244-
"__rsub__",
245-
"__mul__",
246-
"__rmul__",
247-
"__truediv__",
248-
"__rtruediv__",
249-
"__floordiv__",
250-
"__rfloordiv__",
251-
"__mod__",
252-
"__rmod__",
253-
"__divmod__",
254-
"__rdivmod__",
255-
"__pow__",
256-
"__rpow__",
257-
]
258-
for method_name in binary_operators:
259-
setattr(self.__class__, method_name, create_dunder(method_name))
260-
261212
def __array_function__(self, func, types, args, kwargs):
262213
args = replace_duck_with_extension_array(args)
263214
if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS:
@@ -268,17 +219,7 @@ def __array_function__(self, func, types, args, kwargs):
268219
return res
269220

270221
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
271-
if first_ea := next(
272-
(x for x in inputs if isinstance(x, PandasExtensionArray)), None
273-
):
274-
inputs = replace_duck_with_series(inputs)
275-
res = first_ea.__array_ufunc__(ufunc, method, *inputs, **kwargs)
276-
if isinstance(res, pd.Series):
277-
arr = res.array
278-
return type(self)[type(arr)](arr)
279-
return res
280-
281-
return getattr(ufunc, method)(*inputs, **kwargs)
222+
return ufunc(*inputs, **kwargs)
282223

283224
def __repr__(self):
284225
return f"PandasExtensionArray(array={self.array!r})"
@@ -299,3 +240,11 @@ def __setitem__(self, key, val):
299240

300241
def __len__(self):
301242
return len(self.array)
243+
244+
def __eq__(self, other):
245+
if isinstance(other, PandasExtensionArray):
246+
return self.array == other.array
247+
return self.array == other
248+
249+
def __ne__(self, other):
250+
return ~(self == other)

xarray/tests/test_dataarray.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from xarray.core import dtypes
3535
from xarray.core.common import full_like
3636
from xarray.core.coordinates import Coordinates
37-
from xarray.core.extension_array import PandasExtensionArray
3837
from xarray.core.indexes import Index, PandasIndex, filter_indexes_from_coords
3938
from xarray.core.types import QueryEngineOptions, QueryParserOptions
4039
from xarray.core.utils import is_scalar
@@ -1793,12 +1792,12 @@ def test_reindex_empty_array_dtype(self) -> None:
17931792
x = xr.DataArray([], dims=("x",), coords={"x": []}).astype("float32")
17941793
y = x.reindex(x=[1.0, 2.0])
17951794

1796-
assert (
1797-
x.dtype == y.dtype
1798-
), "Dtype of reindexed DataArray should match dtype of the original DataArray"
1799-
assert (
1800-
y.dtype == np.float32
1801-
), "Dtype of reindexed DataArray should remain float32"
1795+
assert x.dtype == y.dtype, (
1796+
"Dtype of reindexed DataArray should match dtype of the original DataArray"
1797+
)
1798+
assert y.dtype == np.float32, (
1799+
"Dtype of reindexed DataArray should remain float32"
1800+
)
18021801

18031802
def test_reindex_extension_array(self) -> None:
18041803
index1 = np.array([1, 2, 3])
@@ -7289,24 +7288,6 @@ def test_from_series_regression() -> None:
72897288
srs = pd.Series(index=[1, 2, 3], data=pd.array([1, 1, pd.NA]))
72907289
arr = srs.to_xarray()
72917290

7292-
# binary operator
7293-
res = arr * 5
7294-
assert_array_equal(res, np.array([5, 5, np.nan]))
7295-
assert res.dtype == pd.Int64Dtype()
7296-
assert isinstance(res, xr.DataArray)
7297-
7298-
# NEP-13 ufunc
7299-
res = np.add(3, arr)
7300-
assert_array_equal(np.add(2, arr), np.array([3, 3, np.nan]))
7301-
assert res.dtype == pd.Int64Dtype()
7302-
assert isinstance(res, xr.DataArray)
7303-
7304-
# NEP-18 array_function
7305-
res = np.astype(arr.data, pd.Int32Dtype())
7306-
assert_array_equal(res, arr)
7307-
assert res.dtype == pd.Int32Dtype()
7308-
assert isinstance(res, PandasExtensionArray)
7309-
73107291
# xarray ufunc
73117292
res = arr.fillna(0)
73127293
assert_array_equal(res, np.array([1, 1, 0]))

xarray/tests/test_duck_array_ops.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,21 +1143,3 @@ def test_extension_array_result_type_mixed(int1, categorical1):
11431143
assert np.dtype("object") == np.result_type(
11441144
PandasExtensionArray(int1), dt.datetime.now()
11451145
)
1146-
1147-
1148-
def test_extension_array_astype(int1):
1149-
res = np.astype(PandasExtensionArray(int1), float)
1150-
assert res.dtype == np.dtype("float64")
1151-
assert_array_equal(res, np.array([np.nan, 2, 3, np.nan, np.nan], dtype="float32"))
1152-
1153-
res = np.astype(PandasExtensionArray(int1), pd.Float64Dtype())
1154-
assert res.dtype == pd.Float64Dtype()
1155-
assert_array_equal(
1156-
res, pd.array([pd.NA, np.float64(2), np.float64(3), pd.NA, pd.NA])
1157-
)
1158-
1159-
res = np.astype(
1160-
PandasExtensionArray(pd.array([1, 2], dtype="int8")), pd.Int16Dtype()
1161-
)
1162-
assert res.dtype == pd.Int16Dtype()
1163-
assert_array_equal(res, pd.array([1, 2], dtype=pd.Int16Dtype()))

0 commit comments

Comments
 (0)