Skip to content

Commit c2dfd9a

Browse files
committed
Refactor out test Index classes
1 parent 7979079 commit c2dfd9a

File tree

2 files changed

+54
-57
lines changed

2 files changed

+54
-57
lines changed

xarray/tests/indexes.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from xarray.core.indexes import Index, PandasIndex
2+
3+
4+
class ScalarIndex(Index):
5+
def __init__(self, value: int):
6+
self.value = value
7+
8+
@classmethod
9+
def from_variables(cls, variables, *, options):
10+
var = next(iter(variables.values()))
11+
return cls(int(var.values))
12+
13+
def equals(self, other, *, exclude=None):
14+
return isinstance(other, ScalarIndex) and other.value == self.value
15+
16+
17+
class XYIndex(Index):
18+
def __init__(self, x: PandasIndex, y: PandasIndex):
19+
self.x: PandasIndex = x
20+
self.y: PandasIndex = y
21+
22+
@classmethod
23+
def from_variables(cls, variables, *, options):
24+
return cls(
25+
x=PandasIndex.from_variables({"x": variables["x"]}, options=options),
26+
y=PandasIndex.from_variables({"y": variables["y"]}, options=options),
27+
)
28+
29+
def equals(self, other, exclude=None):
30+
x_eq = True if self.x.dim in exclude else self.x.equals(other.x)
31+
y_eq = True if self.y.dim in exclude else self.y.equals(other.y)
32+
return x_eq and y_eq
33+
34+
35+
class MultiCoordIndex(Index):
36+
def __init__(self, idx1, idx2):
37+
self.idx1 = idx1
38+
self.idx2 = idx2
39+
40+
@classmethod
41+
def from_variables(cls, variables, *, options=None):
42+
idx1 = PandasIndex.from_variables({"x": variables["x"]}, options=options)
43+
idx2 = PandasIndex.from_variables({"y": variables["y"]}, options=options)
44+
45+
return cls(idx1, idx2)
46+
47+
def create_variables(self, variables=None):
48+
return {**self.idx1.create_variables(), **self.idx2.create_variables()}
49+
50+
def isel(self, indexers):
51+
idx1 = self.idx1.isel({"x": indexers.get("x", slice(None))})
52+
idx2 = self.idx2.isel({"y": indexers.get("y", slice(None))})
53+
return MultiCoordIndex(idx1, idx2)

xarray/tests/test_dataset.py

Lines changed: 1 addition & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
requires_sparse,
7070
source_ndarray,
7171
)
72+
from xarray.tests.indexes import MultiCoordIndex, ScalarIndex, XYIndex
7273

7374
try:
7475
from pandas.errors import UndefinedVariableError
@@ -1599,30 +1600,6 @@ def test_isel_multicoord_index(self) -> None:
15991600
# regression test https://github.com/pydata/xarray/issues/10063
16001601
# isel on a multi-coordinate index should return a unique index associated
16011602
# to each coordinate
1602-
class MultiCoordIndex(xr.Index):
1603-
def __init__(self, idx1, idx2):
1604-
self.idx1 = idx1
1605-
self.idx2 = idx2
1606-
1607-
@classmethod
1608-
def from_variables(cls, variables, *, options=None):
1609-
idx1 = PandasIndex.from_variables(
1610-
{"x": variables["x"]}, options=options
1611-
)
1612-
idx2 = PandasIndex.from_variables(
1613-
{"y": variables["y"]}, options=options
1614-
)
1615-
1616-
return cls(idx1, idx2)
1617-
1618-
def create_variables(self, variables=None):
1619-
return {**self.idx1.create_variables(), **self.idx2.create_variables()}
1620-
1621-
def isel(self, indexers):
1622-
idx1 = self.idx1.isel({"x": indexers.get("x", slice(None))})
1623-
idx2 = self.idx2.isel({"y": indexers.get("y", slice(None))})
1624-
return MultiCoordIndex(idx1, idx2)
1625-
16261603
coords = xr.Coordinates(coords={"x": [0, 1], "y": [1, 2]}, indexes={})
16271604
ds = xr.Dataset(coords=coords).set_xindex(["x", "y"], MultiCoordIndex)
16281605

@@ -2639,18 +2616,6 @@ def test_align_index_var_attrs(self, join) -> None:
26392616
def test_align_scalar_index(self) -> None:
26402617
# ensure that indexes associated with scalar coordinates are not ignored
26412618
# during alignment
2642-
class ScalarIndex(Index):
2643-
def __init__(self, value: int):
2644-
self.value = value
2645-
2646-
@classmethod
2647-
def from_variables(cls, variables, *, options):
2648-
var = next(iter(variables.values()))
2649-
return cls(int(var.values))
2650-
2651-
def equals(self, other, *, exclude=None):
2652-
return isinstance(other, ScalarIndex) and other.value == self.value
2653-
26542619
ds1 = Dataset(coords={"x": 0}).set_xindex("x", ScalarIndex)
26552620
ds2 = Dataset(coords={"x": 0}).set_xindex("x", ScalarIndex)
26562621

@@ -2664,27 +2629,6 @@ def equals(self, other, *, exclude=None):
26642629
xr.align(ds1, ds3, join="exact")
26652630

26662631
def test_align_multi_dim_index_exclude_dims(self) -> None:
2667-
class XYIndex(Index):
2668-
def __init__(self, x: PandasIndex, y: PandasIndex):
2669-
self.x: PandasIndex = x
2670-
self.y: PandasIndex = y
2671-
2672-
@classmethod
2673-
def from_variables(cls, variables, *, options):
2674-
return cls(
2675-
x=PandasIndex.from_variables(
2676-
{"x": variables["x"]}, options=options
2677-
),
2678-
y=PandasIndex.from_variables(
2679-
{"y": variables["y"]}, options=options
2680-
),
2681-
)
2682-
2683-
def equals(self, other, exclude=None):
2684-
x_eq = True if self.x.dim in exclude else self.x.equals(other.x)
2685-
y_eq = True if self.y.dim in exclude else self.y.equals(other.y)
2686-
return x_eq and y_eq
2687-
26882632
ds1 = (
26892633
Dataset(coords={"x": [1, 2], "y": [3, 4]})
26902634
.drop_indexes(["x", "y"])

0 commit comments

Comments
 (0)