Skip to content

Commit 29319e8

Browse files
committed
Add test
1 parent bb61abd commit 29319e8

File tree

2 files changed

+73
-1
lines changed

2 files changed

+73
-1
lines changed

xarray/tests/indexes.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
from collections.abc import Hashable, Iterable, Sequence
2+
13
from xarray.core.indexes import Index, PandasIndex
4+
from xarray.core.types import Self
25

36

47
class ScalarIndex(Index):
@@ -26,11 +29,34 @@ def from_variables(cls, variables, *, options):
2629
y=PandasIndex.from_variables({"y": variables["y"]}, options=options),
2730
)
2831

32+
def create_variables(self, variables):
33+
return self.x.create_variables() | self.y.create_variables()
34+
2935
def equals(self, other, exclude=None):
3036
x_eq = True if self.x.dim in exclude else self.x.equals(other.x)
3137
y_eq = True if self.y.dim in exclude else self.y.equals(other.y)
3238
return x_eq and y_eq
3339

40+
@classmethod
41+
def concat(
42+
cls,
43+
indexes: Sequence[Self],
44+
dim: Hashable,
45+
positions: Iterable[Iterable[int]] | None = None,
46+
) -> Self:
47+
first = next(iter(indexes))
48+
if dim == "x":
49+
newx = PandasIndex.concat(
50+
tuple(i.x for i in indexes), dim=dim, positions=positions
51+
)
52+
newy = first.y
53+
elif dim == "y":
54+
newx = first.x
55+
newy = PandasIndex.concat(
56+
tuple(i.y for i in indexes), dim=dim, positions=positions
57+
)
58+
return cls(x=newx, y=newy)
59+
3460

3561
class MultiCoordIndex(Index):
3662
def __init__(self, idx1, idx2):

xarray/tests/test_concat.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pandas as pd
99
import pytest
1010

11-
from xarray import DataArray, Dataset, Variable, concat
11+
from xarray import AlignmentError, DataArray, Dataset, Variable, concat
1212
from xarray.core import dtypes
1313
from xarray.core.coordinates import Coordinates
1414
from xarray.core.indexes import PandasIndex
@@ -22,6 +22,7 @@
2222
assert_identical,
2323
requires_dask,
2424
)
25+
from xarray.tests.indexes import XYIndex
2526
from xarray.tests.test_dataset import create_test_data
2627

2728
if TYPE_CHECKING:
@@ -1379,3 +1380,48 @@ def test_concat_index_not_same_dim() -> None:
13791380
match=r"Cannot concatenate along dimension 'x' indexes with dimensions.*",
13801381
):
13811382
concat([ds1, ds2], dim="x")
1383+
1384+
1385+
def test_concat_multi_dim_index() -> None:
1386+
ds1 = (
1387+
Dataset(
1388+
{"foo": (("x", "y"), np.random.randn(2, 2))},
1389+
coords={"x": [1, 2], "y": [3, 4]},
1390+
)
1391+
.drop_indexes(["x", "y"])
1392+
.set_xindex(["x", "y"], XYIndex)
1393+
)
1394+
ds2 = (
1395+
Dataset(
1396+
{"foo": (("x", "y"), np.random.randn(2, 2))},
1397+
coords={"x": [1, 2], "y": [5, 6]},
1398+
)
1399+
.drop_indexes(["x", "y"])
1400+
.set_xindex(["x", "y"], XYIndex)
1401+
)
1402+
1403+
expected = (
1404+
Dataset(
1405+
{
1406+
"foo": (
1407+
("x", "y"),
1408+
np.concatenate([ds1.foo.data, ds2.foo.data], axis=-1),
1409+
)
1410+
},
1411+
coords={"x": [1, 2], "y": [3, 4, 5, 6]},
1412+
)
1413+
.drop_indexes(["x", "y"])
1414+
.set_xindex(["x", "y"], XYIndex)
1415+
)
1416+
# note: missing 'override'
1417+
for join in ["inner", "outer", "exact", "left", "right"]:
1418+
actual = concat([ds1, ds2], dim="y", join=join)
1419+
assert_identical(actual, expected, check_default_indexes=False)
1420+
1421+
with pytest.raises(AlignmentError):
1422+
actual = concat([ds1, ds2], dim="x", join="exact")
1423+
1424+
# TODO: fix these, or raise better error message
1425+
with pytest.raises(AssertionError):
1426+
for join in ["left", "right"]:
1427+
actual = concat([ds1, ds2], dim="x", join=join)

0 commit comments

Comments
 (0)