Skip to content

Commit d85185b

Browse files
dhruvak001DHRUVA KUMAR KAUSHALpre-commit-ci[bot]
authored
Raise if Index.create_variables returns more variables than passed in through set_xindex (#10503)
Co-authored-by: DHRUVA KUMAR KAUSHAL <sanjay@MacBook-Air.local> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 3679a5d commit d85185b

File tree

3 files changed

+71
-0
lines changed

3 files changed

+71
-0
lines changed

doc/whats-new.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ Deprecations
2626
Bug fixes
2727
~~~~~~~~~
2828

29+
- :py:meth:`Dataset.set_xindex` now raises a helpful error when a custom index
30+
creates extra variables that don't match the provided coordinate names, instead
31+
of silently ignoring them. The error message suggests using the factory method
32+
pattern with :py:meth:`xarray.Coordinates.from_xindex` and
33+
:py:meth:`Dataset.assign_coords` for advanced use cases (:issue:`10499`).
34+
By `Dhruva Kumar Kaushal <https://github.com/dhruvak001>`_.
2935

3036
Documentation
3137
~~~~~~~~~~~~~

xarray/core/dataset.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4940,6 +4940,20 @@ def set_xindex(
49404940
if isinstance(index, PandasMultiIndex):
49414941
coord_names = [index.dim] + list(coord_names)
49424942

4943+
# Check for extra variables that don't match the coordinate names
4944+
extra_vars = set(new_coord_vars) - set(coord_names)
4945+
if extra_vars:
4946+
extra_vars_str = ", ".join(f"'{name}'" for name in extra_vars)
4947+
coord_names_str = ", ".join(f"'{name}'" for name in coord_names)
4948+
raise ValueError(
4949+
f"The index created extra variables {extra_vars_str} that are not "
4950+
f"in the list of coordinates {coord_names_str}. "
4951+
f"Use a factory method pattern instead:\n"
4952+
f" index = {index_cls.__name__}.from_variables(ds, {list(coord_names)!r})\n"
4953+
f" coords = xr.Coordinates.from_xindex(index)\n"
4954+
f" ds = ds.assign_coords(coords)"
4955+
)
4956+
49434957
variables: dict[Hashable, Variable]
49444958
indexes: dict[Hashable, Index]
49454959

xarray/tests/test_indexes.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,3 +729,54 @@ def test_restore_dtype_on_multiindexes(dtype: str) -> None:
729729
foo = xr.Dataset(coords={"bar": ("bar", np.array([0, 1], dtype=dtype))})
730730
foo = foo.stack(baz=("bar",))
731731
assert str(foo["bar"].values.dtype) == dtype
732+
733+
734+
class IndexWithExtraVariables(Index):
735+
@classmethod
736+
def from_variables(cls, variables, *, options=None):
737+
return cls()
738+
739+
def create_variables(self, variables=None):
740+
if variables is None:
741+
# For Coordinates.from_xindex(), return all variables the index can create
742+
return {
743+
"time": Variable(dims=("time",), data=[1, 2, 3]),
744+
"valid_time": Variable(
745+
dims=("time",),
746+
data=[2, 3, 4], # time + 1
747+
attrs={"description": "time + 1"},
748+
),
749+
}
750+
751+
result = dict(variables)
752+
if "time" in variables:
753+
result["valid_time"] = Variable(
754+
dims=("time",),
755+
data=variables["time"].data + 1,
756+
attrs={"description": "time + 1"},
757+
)
758+
return result
759+
760+
761+
def test_set_xindex_with_extra_variables() -> None:
762+
"""Test that set_xindex raises an error when custom index creates extra variables."""
763+
764+
ds = xr.Dataset(coords={"time": [1, 2, 3]}).reset_index("time")
765+
766+
# Test that set_xindex raises error for extra variables
767+
with pytest.raises(ValueError, match="extra variables 'valid_time'"):
768+
ds.set_xindex("time", IndexWithExtraVariables)
769+
770+
771+
def test_set_xindex_factory_method_pattern() -> None:
772+
ds = xr.Dataset(coords={"time": [1, 2, 3]}).reset_index("time")
773+
774+
# Test the recommended factory method pattern
775+
coord_vars = {"time": ds._variables["time"]}
776+
index = IndexWithExtraVariables.from_variables(coord_vars)
777+
coords = xr.Coordinates.from_xindex(index)
778+
result = ds.assign_coords(coords)
779+
780+
assert "time" in result.variables
781+
assert "valid_time" in result.variables
782+
assert_array_equal(result.valid_time.data, result.time.data + 1)

0 commit comments

Comments
 (0)