Skip to content

Commit 7538dbd

Browse files
committed
(refactor): clean reindexing test
1 parent fbcc5a4 commit 7538dbd

File tree

1 file changed

+23
-31
lines changed

1 file changed

+23
-31
lines changed

xarray/tests/test_dataarray.py

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1849,42 +1849,34 @@ def test_reindex_empty_array_dtype(self) -> None:
18491849
"Dtype of reindexed DataArray should remain float32"
18501850
)
18511851

1852-
def test_reindex_extension_array(self) -> None:
1853-
index1 = np.array([1, 2, 3])
1854-
index2 = np.array([1, 2, 4])
1855-
srs = pd.Series(index=index1, data=1).convert_dtypes()
1856-
x = srs.to_xarray()
1857-
y = x.reindex(index=index2) # used to fail (GH #10301)
1858-
assert_array_equal(x, pd.array([1, 1, 1]))
1859-
assert_array_equal(y, pd.array([1, 1, pd.NA]))
1860-
assert x.dtype == y.dtype == pd.Int64Dtype()
1861-
assert x.index.dtype == y.index.dtype == np.dtype("int64")
1862-
1863-
def test_reindex_categorical_index(self) -> None:
1864-
index1 = pd.Categorical(["a", "b", "c"])
1865-
index2 = pd.Categorical(["a", "b", "d"])
1866-
srs = pd.Series(index=index1, data=1).convert_dtypes()
1867-
x = srs.to_xarray()
1868-
y = x.reindex(index=index2)
1869-
assert_array_equal(x, pd.array([1, 1, 1]))
1870-
assert_array_equal(y, pd.array([1, 1, pd.NA]))
1871-
assert x.dtype == y.dtype == pd.Int64Dtype()
1872-
assert isinstance(x.index.dtype, pd.CategoricalDtype)
1873-
assert isinstance(y.index.dtype, pd.CategoricalDtype)
1874-
assert_array_equal(x.index.dtype.categories, np.array(["a", "b", "c"]))
1875-
assert_array_equal(y.index.dtype.categories, np.array(["a", "b", "d"]))
1876-
1877-
def test_reindex_categorical(self) -> None:
1878-
data = pd.Categorical(["a", "b", "c"])
1879-
srs = pd.Series(index=["e", "f", "g"], data=data).convert_dtypes()
1852+
@pytest.mark.parametrize(
1853+
"extension_array",
1854+
[
1855+
pytest.param(pd.Categorical(["a", "b", "c"]), id="categorical"),
1856+
]
1857+
+ [
1858+
pytest.param(
1859+
pd.array([1, 2, 3], dtype="int64[pyarrow]"),
1860+
id="int64[pyarrow]",
1861+
)
1862+
]
1863+
if has_pyarrow
1864+
else [],
1865+
)
1866+
def test_reindex_extension_array(self, extension_array) -> None:
1867+
srs = pd.Series(index=["e", "f", "g"], data=extension_array)
18801868
x = srs.to_xarray()
18811869
y = x.reindex(index=["f", "g", "z"])
1882-
assert_array_equal(x, data)
1870+
assert_array_equal(x, extension_array)
18831871
# TODO: remove .array once the branch is updated with main
18841872
pd.testing.assert_extension_array_equal(
1885-
y.data, pd.Categorical(["b", "c", pd.NA], dtype=data.dtype)
1873+
y.data,
1874+
extension_array._from_sequence(
1875+
[extension_array[1], extension_array[2], pd.NA],
1876+
dtype=extension_array.dtype,
1877+
),
18861878
)
1887-
assert x.dtype == y.dtype == data.dtype
1879+
assert x.dtype == y.dtype == extension_array.dtype
18881880

18891881
@pytest.mark.parametrize(
18901882
"fill_value,extension_array",

0 commit comments

Comments
 (0)