Skip to content

Commit 9e240c5

Browse files
authored
(fix): equality check against singleton PandasExtensionArray (#9032)
1 parent 12123be commit 9e240c5

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

xarray/core/extension_array.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,6 @@ def __setitem__(self, key, val):
123123
self.array[key] = val
124124

125125
def __eq__(self, other):
126-
if np.isscalar(other):
127-
other = type(self)(type(self.array)([other]))
128126
if isinstance(other, PandasExtensionArray):
129127
return self.array == other.array
130128
return self.array == other

xarray/tests/test_duck_array_ops.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def test_concatenate_extension_duck_array(self, categorical1, categorical2):
186186
).all()
187187

188188
@requires_pyarrow
189-
def test_duck_extension_array_pyarrow_concatenate(self, arrow1, arrow2):
189+
def test_extension_array_pyarrow_concatenate(self, arrow1, arrow2):
190190
concatenated = concatenate(
191191
(PandasExtensionArray(arrow1), PandasExtensionArray(arrow2))
192192
)
@@ -1024,19 +1024,24 @@ def test_push_dask():
10241024
np.testing.assert_equal(actual, expected)
10251025

10261026

1027-
def test_duck_extension_array_equality(categorical1, int1):
1027+
def test_extension_array_equality(categorical1, int1):
10281028
int_duck_array = PandasExtensionArray(int1)
10291029
categorical_duck_array = PandasExtensionArray(categorical1)
10301030
assert (int_duck_array != categorical_duck_array).all()
10311031
assert (categorical_duck_array == categorical1).all()
10321032
assert (int_duck_array[0:2] == int1[0:2]).all()
10331033

10341034

1035-
def test_duck_extension_array_repr(int1):
1035+
def test_extension_array_singleton_equality(categorical1):
1036+
categorical_duck_array = PandasExtensionArray(categorical1)
1037+
assert (categorical_duck_array != "cat3").all()
1038+
1039+
1040+
def test_extension_array_repr(int1):
10361041
int_duck_array = PandasExtensionArray(int1)
10371042
assert repr(int1) in repr(int_duck_array)
10381043

10391044

1040-
def test_duck_extension_array_attr(int1):
1045+
def test_extension_array_attr(int1):
10411046
int_duck_array = PandasExtensionArray(int1)
10421047
assert (~int_duck_array.fillna(10)).all()

0 commit comments

Comments
 (0)