Skip to content

DataArray: propagate index coordinates with non-array dimensions #10116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ New Features
By `Benoit Bovy <https://github.com/benbovy>`_.
- Support reading to `GPU memory with Zarr <https://zarr.readthedocs.io/en/stable/user-guide/gpu.html>`_ (:pull:`10078`).
By `Deepak Cherian <https://github.com/dcherian>`_.
- Allow assigning index coordinates with non-array dimension(s) in a :py:class:`DataArray`, enabling
support for CF boundaries coordinate (e.g., ``time(time)`` and ``time_bnds(time, nbnd)``) in a DataArray (:pull:`10116`).
By `Benoit Bovy <https://github.com/benbovy>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
38 changes: 24 additions & 14 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def identical(self, other: Self) -> bool:
return self.to_dataset().identical(other.to_dataset())

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
# redirect to DatasetCoordinates._update_coords
self._data.coords._update_coords(coords, indexes)
Expand Down Expand Up @@ -780,7 +780,7 @@ def to_dataset(self) -> Dataset:
return self._data._copy_listed(names)

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
variables = self._data._variables.copy()
variables.update(coords)
Expand Down Expand Up @@ -880,7 +880,7 @@ def to_dataset(self) -> Dataset:
return self._data.dataset._copy_listed(self._names)

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
from xarray.core.datatree import check_alignment

Expand Down Expand Up @@ -964,22 +964,32 @@ def __getitem__(self, key: Hashable) -> T_DataArray:
return self._data._getitem_coord(key)

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
coords_plus_data = coords.copy()
coords_plus_data[_THIS_ARRAY] = self._data.variable
dims = calculate_dimensions(coords_plus_data)
if not set(dims) <= set(self.dims):
raise ValueError(
"cannot add coordinates with new dimensions to a DataArray"
)
self._data._coords = coords

# TODO(shoyer): once ._indexes is always populated by a dict, modify
# it to update inplace instead.
original_indexes = dict(self._data.xindexes)
original_indexes.update(indexes)
self._data._indexes = original_indexes
if set(dims) > set(self.dims):
for k, v in coords.items():
if any(d not in self.dims for d in v.dims):
# allow any coordinate associated with an index that shares at least
# one of dataarray's dimensions
temp_indexes = Indexes(
indexes, {k: v for k, v in coords.items() if k in indexes}
)
if k in indexes:
index_dims = temp_indexes.get_all_dims(k)
if any(d in self.dims for d in index_dims):
continue
raise ValueError(
f"coordinate {k} has dimensions {v.dims}, but these "
"are not a subset of the DataArray "
f"dimensions {self.dims}"
)

self._data._coords = coords
self._data._indexes = indexes

def _drop_coords(self, coord_names):
# should drop indexed coordinates only
Expand Down
19 changes: 15 additions & 4 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,30 @@
T_XarrayOther = TypeVar("T_XarrayOther", bound="DataArray" | Dataset)


def _check_coords_dims(shape, coords, dim):
def _check_coords_dims(
shape: tuple[int, ...], coords: Coordinates, dim: tuple[Hashable, ...]
):
sizes = dict(zip(dim, shape, strict=True))
extra_index_dims: set[str] = set()

for k, v in coords.items():
if any(d not in dim for d in v.dims):
# allow any coordinate associated with an index that shares at least
# one of dataarray's dimensions
indexes = coords.xindexes
if k in indexes:
index_dims = indexes.get_all_dims(k)
if any(d in dim for d in index_dims):
extra_index_dims.update(d for d in v.dims if d not in dim)
continue
raise ValueError(
f"coordinate {k} has dimensions {v.dims}, but these "
"are not a subset of the DataArray "
f"dimensions {dim}"
)

for d, s in v.sizes.items():
if s != sizes[d]:
if d not in extra_index_dims and s != sizes[d]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if d not in extra_index_dims and s != sizes[d]:
if d not in extra_index_dims or s != sizes[d]:

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be and.

extra_index_dims corresponds to all non-array dimensions of index coordinates to include in the dataarray (size[d] would return a KeyError).

I renamed it and added comment in 695fb86.

raise ValueError(
f"conflicting sizes for dimension {d!r}: "
f"length {sizes[d]} on the data but length {s} on "
Expand Down Expand Up @@ -212,8 +224,6 @@ def _infer_coords_and_dims(
var.dims = (dim,)
new_coords[dim] = var.to_index_variable()

_check_coords_dims(shape, new_coords, dims_tuple)

return new_coords, dims_tuple


Expand Down Expand Up @@ -487,6 +497,7 @@ def __init__(

if not isinstance(coords, Coordinates):
coords = create_coords_with_default_indexes(coords)
_check_coords_dims(data.shape, coords, dims)
indexes = dict(coords.xindexes)
coords = {k: v.copy() for k, v in coords.variables.items()}

Expand Down
14 changes: 12 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,10 +1210,20 @@ def _construct_dataarray(self, name: Hashable) -> DataArray:
needed_dims = set(variable.dims)

coords: dict[Hashable, Variable] = {}
temp_indexes = self.xindexes
# preserve ordering
for k in self._variables:
if k in self._coord_names and set(self._variables[k].dims) <= needed_dims:
coords[k] = self._variables[k]
if k in self._coord_names:
if (
k not in coords
and k in temp_indexes
and set(temp_indexes.get_all_dims(k)) & needed_dims
):
# add all coordinates of each index that shares at least one dimension
# with the dimensions of the extracted variable
coords.update(temp_indexes.get_all_coords(k))
Comment on lines +1217 to +1224
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coud this use a separate loop after the existing loop instead? e.g.,

for k in self._indexes:
    if k in coords:
        coords.update(self.xindexes.get_all_coords(k))

Or if we allow indexes without a coordinate of the same name:

for k in self._indexes:
    if set(self.xindexes.get_all_dims(k)) & needed_dims:
        coords.update(self.xindexes.get_all_coords(k))

Ideally, I would like the logic here to be just as simple as the words describing how it works, so a comment is not necessary!

elif set(self._variables[k].dims) <= needed_dims:
coords[k] = self._variables[k]

indexes = filter_indexes_from_coords(self._indexes, set(coords))

Expand Down
11 changes: 10 additions & 1 deletion xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,6 +1568,7 @@ class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]):
"""

_index_type: type[Index] | type[pd.Index]
_index_dims: dict[Hashable, Mapping[Hashable, int]]
_indexes: dict[Any, T_PandasOrXarrayIndex]
_variables: dict[Any, Variable]

Expand All @@ -1576,6 +1577,7 @@ class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]):
"__id_coord_names",
"__id_index",
"_dims",
"_index_dims",
"_index_type",
"_indexes",
"_variables",
Expand Down Expand Up @@ -1619,6 +1621,7 @@ def __init__(
)

self._index_type = index_type
self._index_dims = {}
self._indexes = dict(**indexes)
self._variables = dict(**variables)

Expand Down Expand Up @@ -1737,7 +1740,13 @@ def get_all_dims(
"""
from xarray.core.variable import calculate_dimensions

return calculate_dimensions(self.get_all_coords(key, errors=errors))
if key in self._index_dims:
return self._index_dims[key]
else:
dims = calculate_dimensions(self.get_all_coords(key, errors=errors))
if dims:
self._index_dims[key] = dims
return dims

def group_by_index(
self,
Expand Down
49 changes: 49 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,30 @@ class CustomIndex(Index): ...
# test coordinate variables copied
assert da.coords["x"] is not coords.variables["x"]

def test_constructor_extra_dim_index_coord(self) -> None:
class AnyIndex(Index):
# This test only requires that the coordinates to assign have an
# index, whatever its type.
pass

idx = AnyIndex()
coords = Coordinates(
coords={
"x": ("x", [1, 2]),
"x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]),
},
indexes={"x": idx, "x_bounds": idx},
)

actual = DataArray([1.0, 2.0], coords=coords, dims="x")

# cannot use `assert_identical()` test utility function here yet
# (indexes invariant check is still based on IndexVariable, which
# doesn't work with AnyIndex coordinate variables here)
assert actual.coords.to_dataset().equals(coords.to_dataset())
assert list(actual.coords.xindexes) == list(coords.xindexes)
assert "x_bnds" not in actual.dims

def test_equals_and_identical(self) -> None:
orig = DataArray(np.arange(5.0), {"a": 42}, dims="x")

Expand Down Expand Up @@ -1634,6 +1658,31 @@ def test_assign_coords_no_default_index(self) -> None:
assert_identical(actual.coords, coords, check_default_indexes=False)
assert "y" not in actual.xindexes

def test_assign_coords_extra_dim_index_coord(self) -> None:
class AnyIndex(Index):
# This test only requires that the coordinates to assign have an
# index, whatever its type.
pass

idx = AnyIndex()
coords = Coordinates(
coords={
"x": ("x", [1, 2]),
"x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]),
},
indexes={"x": idx, "x_bounds": idx},
)

da = DataArray([1.0, 2.0], dims="x")
actual = da.assign_coords(coords)

# cannot use `assert_identical()` test utility function here yet
# (indexes invariant check is still based on IndexVariable, which
# doesn't work with AnyIndex coordinate variables here)
assert actual.coords.to_dataset().equals(coords.to_dataset())
assert list(actual.coords.xindexes) == list(coords.xindexes)
assert "x_bnds" not in actual.dims

def test_coords_alignment(self) -> None:
lhs = DataArray([1, 2, 3], [("x", [0, 1, 2])])
rhs = DataArray([2, 3, 4], [("x", [1, 2, 3])])
Expand Down
25 changes: 25 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4206,6 +4206,31 @@ def test_getitem_multiple_dtype(self) -> None:
dataset = Dataset({key: ("dim0", range(1)) for key in keys})
assert_identical(dataset, dataset[keys])

def test_getitem_extra_dim_index_coord(self) -> None:
class AnyIndex(Index):
# This test only requires that the coordinates to assign have an
# index, whatever its type.
pass

idx = AnyIndex()
coords = Coordinates(
coords={
"x": ("x", [1, 2]),
"x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]),
},
indexes={"x": idx, "x_bounds": idx},
)

ds = Dataset({"foo": (("x"), [1.0, 2.0])}, coords=coords)
actual = ds["foo"]

# cannot use `assert_identical()` test utility function here yet
# (indexes invariant check is still based on IndexVariable, which
# doesn't work with AnyIndex coordinate variables here)
assert actual.coords.to_dataset().equals(coords.to_dataset())
assert list(actual.coords.xindexes) == list(coords.xindexes)
assert "x_bnds" not in actual.dims

def test_virtual_variables_default_coords(self) -> None:
dataset = Dataset({"foo": ("x", range(10))})
expected1 = DataArray(range(10), dims="x", name="x")
Expand Down
Loading