diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 994fc70339c..615be8e019f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -34,6 +34,9 @@ New Features By `Benoit Bovy `_. - Support reading to `GPU memory with Zarr `_ (:pull:`10078`). By `Deepak Cherian `_. +- Allow assigning index coordinates with non-array dimension(s) in a :py:class:`DataArray`, enabling + support for CF boundaries coordinate (e.g., ``time(time)`` and ``time_bnds(time, nbnd)``) in a DataArray (:pull:`10116`). + By `Benoit Bovy `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 47773ddfbb6..4eb089314a7 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -486,7 +486,7 @@ def identical(self, other: Self) -> bool: return self.to_dataset().identical(other.to_dataset()) def _update_coords( - self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index] + self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index] ) -> None: # redirect to DatasetCoordinates._update_coords self._data.coords._update_coords(coords, indexes) @@ -780,7 +780,7 @@ def to_dataset(self) -> Dataset: return self._data._copy_listed(names) def _update_coords( - self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index] + self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index] ) -> None: variables = self._data._variables.copy() variables.update(coords) @@ -880,7 +880,7 @@ def to_dataset(self) -> Dataset: return self._data.dataset._copy_listed(self._names) def _update_coords( - self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index] + self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index] ) -> None: from xarray.core.datatree import check_alignment @@ -964,22 +964,23 @@ def __getitem__(self, key: Hashable) -> T_DataArray: return self._data._getitem_coord(key) def _update_coords( - self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index] + self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index] ) -> None: + from xarray.core.dataarray import check_dataarray_coords + coords_plus_data = coords.copy() coords_plus_data[_THIS_ARRAY] = self._data.variable - dims = calculate_dimensions(coords_plus_data) - if not set(dims) <= set(self.dims): - raise ValueError( - "cannot add coordinates with new dimensions to a DataArray" + coords_dims = set(calculate_dimensions(coords_plus_data)) + obj_dims = set(self.dims) + + if coords_dims > obj_dims: + # need more checks + check_dataarray_coords( + self._data.shape, Coordinates(coords, indexes), self.dims ) - self._data._coords = coords - # TODO(shoyer): once ._indexes is always populated by a dict, modify - # it to update inplace instead. - original_indexes = dict(self._data.xindexes) - original_indexes.update(indexes) - self._data._indexes = original_indexes + self._data._coords = coords + self._data._indexes = indexes def _drop_coords(self, coord_names): # should drop indexed coordinates only diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 4324a4587b3..578ce63e789 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -130,22 +130,59 @@ T_XarrayOther = TypeVar("T_XarrayOther", bound="DataArray" | Dataset) -def _check_coords_dims(shape, coords, dim): - sizes = dict(zip(dim, shape, strict=True)) - for k, v in coords.items(): - if any(d not in dim for d in v.dims): +def check_dataarray_coords( + shape: tuple[int, ...], coords: Coordinates, dims: tuple[Hashable, ...] +): + """Check that ``coords`` dimension names and sizes do not conflict with + array ``shape`` and dimensions ``dims``. + + If a coordinate is associated with an index, the coordinate may have any + arbitrary dimension(s) as long as the index's dimensions (i.e., the union of + the dimensions of all coordinates associated with this index) intersects the + array dimensions. + + If a coordinate has no index, then its dimensions much match (or be a subset + of) the array dimensions. Scalar coordinates are also allowed. + + """ + indexes = coords.xindexes + skip_check_coord_name: set[Hashable] = set() + skip_check_dim_size: set[Hashable] = set() + + # check dimension names + for name, var in coords.items(): + if name in skip_check_coord_name: + continue + elif name in indexes: + index_dims = indexes.get_all_dims(name) + if any(d in dims for d in index_dims): + # can safely skip checking index's non-array dimensions + # and index's other coordinates since those must be all + # included in the dataarray so the index is not corrupted + skip_check_coord_name.update(indexes.get_all_coords(name)) + skip_check_dim_size.update(d for d in index_dims if d not in dims) + raise_error = False + else: + raise_error = True + else: + raise_error = any(d not in dims for d in var.dims) + if raise_error: raise ValueError( - f"coordinate {k} has dimensions {v.dims}, but these " + f"coordinate {name} has dimensions {var.dims}, but these " "are not a subset of the DataArray " - f"dimensions {dim}" + f"dimensions {dims}" ) - for d, s in v.sizes.items(): - if s != sizes[d]: + # check dimension sizes + sizes = dict(zip(dims, shape, strict=True)) + + for name, var in coords.items(): + for d, s in var.sizes.items(): + if d not in skip_check_dim_size and s != sizes[d]: raise ValueError( f"conflicting sizes for dimension {d!r}: " f"length {sizes[d]} on the data but length {s} on " - f"coordinate {k!r}" + f"coordinate {name!r}" ) @@ -212,8 +249,6 @@ def _infer_coords_and_dims( var.dims = (dim,) new_coords[dim] = var.to_index_variable() - _check_coords_dims(shape, new_coords, dims_tuple) - return new_coords, dims_tuple @@ -487,6 +522,7 @@ def __init__( if not isinstance(coords, Coordinates): coords = create_coords_with_default_indexes(coords) + check_dataarray_coords(data.shape, coords, dims) indexes = dict(coords.xindexes) coords = {k: v.copy() for k, v in coords.variables.items()} diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 79a2dde3444..0c1decbf3a9 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1210,10 +1210,20 @@ def _construct_dataarray(self, name: Hashable) -> DataArray: needed_dims = set(variable.dims) coords: dict[Hashable, Variable] = {} + temp_indexes = self.xindexes # preserve ordering for k in self._variables: - if k in self._coord_names and set(self._variables[k].dims) <= needed_dims: - coords[k] = self._variables[k] + if k in self._coord_names: + if ( + k not in coords + and k in temp_indexes + and set(temp_indexes.get_all_dims(k)) & needed_dims + ): + # add all coordinates of each index that shares at least one dimension + # with the dimensions of the extracted variable + coords.update(temp_indexes.get_all_coords(k)) + elif set(self._variables[k].dims) <= needed_dims: + coords[k] = self._variables[k] indexes = filter_indexes_from_coords(self._indexes, set(coords)) diff --git a/xarray/core/dataset_variables.py b/xarray/core/dataset_variables.py index 6521da61444..8f5b7442c7a 100644 --- a/xarray/core/dataset_variables.py +++ b/xarray/core/dataset_variables.py @@ -40,7 +40,7 @@ def __getitem__(self, key: Hashable) -> "DataArray": raise KeyError(key) def __repr__(self) -> str: - return formatting.data_vars_repr(self) + return formatting.data_vars_repr(self.variables) @property def variables(self) -> Mapping[Hashable, Variable]: diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 993cddf2b57..6302ee00210 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -429,7 +429,7 @@ def coords_repr(coords: AbstractCoordinates, col_width=None, max_rows=None): if col_width is None: col_width = _calculate_col_width(coords) return _mapping_repr( - coords, + coords.variables, title="Coordinates", summarizer=summarize_variable, expand_option_name="display_expand_coords", @@ -743,7 +743,9 @@ def dataset_repr(ds): if unindexed_dims_str: summary.append(unindexed_dims_str) - summary.append(data_vars_repr(ds.data_vars, col_width=col_width, max_rows=max_rows)) + summary.append( + data_vars_repr(ds.data_vars.variables, col_width=col_width, max_rows=max_rows) + ) display_default_indexes = _get_boolean_with_default( "display_default_indexes", False diff --git a/xarray/core/formatting_html.py b/xarray/core/formatting_html.py index eb9073cd869..69c128ae5fd 100644 --- a/xarray/core/formatting_html.py +++ b/xarray/core/formatting_html.py @@ -113,10 +113,11 @@ def summarize_variable(name, var, is_index=False, dtype=None) -> str: ) -def summarize_coords(variables) -> str: +def summarize_coords(coords) -> str: li_items = [] - for k, v in variables.items(): - li_content = summarize_variable(k, v, is_index=k in variables.xindexes) + indexes = coords.xindexes + for k, v in coords.variables.items(): + li_content = summarize_variable(k, v, is_index=k in indexes) li_items.append(f"
  • {li_content}
  • ") vars_li = "".join(li_items) @@ -339,7 +340,7 @@ def dataset_repr(ds) -> str: sections = [ dim_section(ds), coord_section(ds.coords), - datavar_section(ds.data_vars), + datavar_section(ds.data_vars.variables), index_section(_get_indexes_dict(ds.xindexes)), attr_section(ds.attrs), ] @@ -415,7 +416,7 @@ def datatree_node_repr(group_title: str, node: DataTree, show_inherited=False) - sections.append(inherited_coord_section(inherited_coords)) sections += [ - datavar_section(ds.data_vars), + datavar_section(ds.data_vars.variables), attr_section(ds.attrs), ] diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index c2bc8b94f3f..09d685f296f 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1568,6 +1568,7 @@ class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]): """ _index_type: type[Index] | type[pd.Index] + _index_dims: dict[Hashable, Mapping[Hashable, int]] _indexes: dict[Any, T_PandasOrXarrayIndex] _variables: dict[Any, Variable] @@ -1576,6 +1577,7 @@ class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]): "__id_coord_names", "__id_index", "_dims", + "_index_dims", "_index_type", "_indexes", "_variables", @@ -1619,6 +1621,7 @@ def __init__( ) self._index_type = index_type + self._index_dims = {} self._indexes = dict(**indexes) self._variables = dict(**variables) @@ -1737,7 +1740,13 @@ def get_all_dims( """ from xarray.core.variable import calculate_dimensions - return calculate_dimensions(self.get_all_coords(key, errors=errors)) + if key in self._index_dims: + return self._index_dims[key] + else: + dims = calculate_dimensions(self.get_all_coords(key, errors=errors)) + if dims: + self._index_dims[key] = dims + return dims def group_by_index( self, diff --git a/xarray/testing/assertions.py b/xarray/testing/assertions.py index 8a2dba9261f..ec7b4fdd410 100644 --- a/xarray/testing/assertions.py +++ b/xarray/testing/assertions.py @@ -330,10 +330,13 @@ def _assert_indexes_invariants_checks( k: type(v) for k, v in indexes.items() } - index_vars = { - k for k, v in possible_coord_variables.items() if isinstance(v, IndexVariable) - } - assert indexes.keys() <= index_vars, (set(indexes), index_vars) + if check_default: + index_vars = { + k + for k, v in possible_coord_variables.items() + if isinstance(v, IndexVariable) + } + assert indexes.keys() <= index_vars, (set(indexes), index_vars) # check pandas index wrappers vs. coordinate data adapters for k, index in indexes.items(): @@ -399,9 +402,14 @@ def _assert_dataarray_invariants(da: DataArray, check_default_indexes: bool): da.dims, {k: v.dims for k, v in da._coords.items()}, ) - assert all( - isinstance(v, IndexVariable) for (k, v) in da._coords.items() if v.dims == (k,) - ), {k: type(v) for k, v in da._coords.items()} + + if check_default_indexes: + assert all( + isinstance(v, IndexVariable) + for (k, v) in da._coords.items() + if v.dims == (k,) + ), {k: type(v) for k, v in da._coords.items()} + for k, v in da._coords.items(): _assert_variable_invariants(v, k) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 75d6d919e19..1883b9eb407 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -529,6 +529,26 @@ class CustomIndex(Index): ... # test coordinate variables copied assert da.coords["x"] is not coords.variables["x"] + def test_constructor_extra_dim_index_coord(self) -> None: + class AnyIndex(Index): + # This test only requires that the coordinates to assign have an + # index, whatever its type. + pass + + idx = AnyIndex() + coords = Coordinates( + coords={ + "x": ("x", [1, 2]), + "x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]), + }, + indexes={"x": idx, "x_bounds": idx}, + ) + + actual = DataArray([1.0, 2.0], coords=coords, dims="x") + + assert_identical(actual.coords, coords, check_default_indexes=False) + assert "x_bnds" not in actual.dims + def test_equals_and_identical(self) -> None: orig = DataArray(np.arange(5.0), {"a": 42}, dims="x") @@ -1634,6 +1654,27 @@ def test_assign_coords_no_default_index(self) -> None: assert_identical(actual.coords, coords, check_default_indexes=False) assert "y" not in actual.xindexes + def test_assign_coords_extra_dim_index_coord(self) -> None: + class AnyIndex(Index): + # This test only requires that the coordinates to assign have an + # index, whatever its type. + pass + + idx = AnyIndex() + coords = Coordinates( + coords={ + "x": ("x", [1, 2]), + "x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]), + }, + indexes={"x": idx, "x_bounds": idx}, + ) + + da = DataArray([1.0, 2.0], dims="x") + actual = da.assign_coords(coords) + + assert_identical(actual.coords, coords, check_default_indexes=False) + assert "x_bnds" not in actual.dims + def test_coords_alignment(self) -> None: lhs = DataArray([1, 2, 3], [("x", [0, 1, 2])]) rhs = DataArray([2, 3, 4], [("x", [1, 2, 3])]) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index bdae9daf758..b47599c7cd0 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4206,6 +4206,27 @@ def test_getitem_multiple_dtype(self) -> None: dataset = Dataset({key: ("dim0", range(1)) for key in keys}) assert_identical(dataset, dataset[keys]) + def test_getitem_extra_dim_index_coord(self) -> None: + class AnyIndex(Index): + # This test only requires that the coordinates to assign have an + # index, whatever its type. + pass + + idx = AnyIndex() + coords = Coordinates( + coords={ + "x": ("x", [1, 2]), + "x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]), + }, + indexes={"x": idx, "x_bounds": idx}, + ) + + ds = Dataset({"foo": (("x"), [1.0, 2.0])}, coords=coords) + actual = ds["foo"] + + assert_identical(actual.coords, coords, check_default_indexes=False) + assert "x_bnds" not in actual.dims + def test_virtual_variables_default_coords(self) -> None: dataset = Dataset({"foo": ("x", range(10))}) expected1 = DataArray(range(10), dims="x", name="x")