diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 1de857032d0..f97f12ffc9f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -26,6 +26,12 @@ Deprecations Bug fixes ~~~~~~~~~ +- :py:meth:`Dataset.set_xindex` now raises a helpful error when a custom index + creates extra variables that don't match the provided coordinate names, instead + of silently ignoring them. The error message suggests using the factory method + pattern with :py:meth:`xarray.Coordinates.from_xindex` and + :py:meth:`Dataset.assign_coords` for advanced use cases (:issue:`10499`). + By `Dhruva Kumar Kaushal `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 6de626a159b..ac4bfc32df5 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4940,6 +4940,20 @@ def set_xindex( if isinstance(index, PandasMultiIndex): coord_names = [index.dim] + list(coord_names) + # Check for extra variables that don't match the coordinate names + extra_vars = set(new_coord_vars) - set(coord_names) + if extra_vars: + extra_vars_str = ", ".join(f"'{name}'" for name in extra_vars) + coord_names_str = ", ".join(f"'{name}'" for name in coord_names) + raise ValueError( + f"The index created extra variables {extra_vars_str} that are not " + f"in the list of coordinates {coord_names_str}. " + f"Use a factory method pattern instead:\n" + f" index = {index_cls.__name__}.from_variables(ds, {list(coord_names)!r})\n" + f" coords = xr.Coordinates.from_xindex(index)\n" + f" ds = ds.assign_coords(coords)" + ) + variables: dict[Hashable, Variable] indexes: dict[Hashable, Index] diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 2b7900d9c89..9f2eea48260 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -729,3 +729,54 @@ def test_restore_dtype_on_multiindexes(dtype: str) -> None: foo = xr.Dataset(coords={"bar": ("bar", np.array([0, 1], dtype=dtype))}) foo = foo.stack(baz=("bar",)) assert str(foo["bar"].values.dtype) == dtype + + +class IndexWithExtraVariables(Index): + @classmethod + def from_variables(cls, variables, *, options=None): + return cls() + + def create_variables(self, variables=None): + if variables is None: + # For Coordinates.from_xindex(), return all variables the index can create + return { + "time": Variable(dims=("time",), data=[1, 2, 3]), + "valid_time": Variable( + dims=("time",), + data=[2, 3, 4], # time + 1 + attrs={"description": "time + 1"}, + ), + } + + result = dict(variables) + if "time" in variables: + result["valid_time"] = Variable( + dims=("time",), + data=variables["time"].data + 1, + attrs={"description": "time + 1"}, + ) + return result + + +def test_set_xindex_with_extra_variables() -> None: + """Test that set_xindex raises an error when custom index creates extra variables.""" + + ds = xr.Dataset(coords={"time": [1, 2, 3]}).reset_index("time") + + # Test that set_xindex raises error for extra variables + with pytest.raises(ValueError, match="extra variables 'valid_time'"): + ds.set_xindex("time", IndexWithExtraVariables) + + +def test_set_xindex_factory_method_pattern() -> None: + ds = xr.Dataset(coords={"time": [1, 2, 3]}).reset_index("time") + + # Test the recommended factory method pattern + coord_vars = {"time": ds._variables["time"]} + index = IndexWithExtraVariables.from_variables(coord_vars) + coords = xr.Coordinates.from_xindex(index) + result = ds.assign_coords(coords) + + assert "time" in result.variables + assert "valid_time" in result.variables + assert_array_equal(result.valid_time.data, result.time.data + 1)