-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Add CFIntervalIndex #10296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add CFIntervalIndex #10296
Changes from 5 commits
9282fc4
f71f767
f7041fd
9401774
781d33f
48dc0bd
b424b12
8d80e71
e60a1a4
8918fe8
c722a2e
23fb18b
de4f5d8
e1bf896
06a3b92
80f496f
67d8f6c
a8015aa
5b5cbee
fdc1943
3a8fd3c
edfa435
bc20226
3ec2c65
4cabb7c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
from __future__ import annotations | ||
|
||
from collections.abc import Hashable, Iterable, Mapping, Sequence | ||
from typing import TYPE_CHECKING, Any, cast | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from xarray import Variable | ||
from xarray.core.indexes import Index, PandasIndex | ||
from xarray.core.indexing import IndexSelResult | ||
|
||
if TYPE_CHECKING: | ||
from xarray.core.types import Self | ||
|
||
|
||
class IntervalIndex(Index): | ||
"""Xarray index of 1-dimensional intervals. | ||
|
||
This index is built on top of :py:class:`~xarray.indexes.PandasIndex` and | ||
wraps a :py:class:`pandas.IntervalIndex`. It is associated with two | ||
coordinate variables: | ||
|
||
- a 1-dimensional coordinate where each label represents an interval that is | ||
materialized by its midpoint (i.e., the average of its left and right | ||
boundaries) | ||
|
||
- a 2-dimensional coordinate that represents the left and right boundaries | ||
of each interval. One of the two dimensions is shared with the | ||
aforementioned coordinate and the other one has length 2. | ||
|
||
""" | ||
|
||
_index: PandasIndex | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should save the central values too. |
||
_bounds_name: Hashable | ||
_bounds_dim: str | ||
|
||
def __init__(self, index: PandasIndex, bounds_name: Hashable, bounds_dim: str): | ||
assert isinstance(index.index, pd.IntervalIndex) | ||
self._index = index | ||
self._bounds_name = bounds_name | ||
self._bounds_dim = bounds_dim | ||
|
||
@classmethod | ||
def from_variables( | ||
cls, | ||
variables: Mapping[Any, Variable], | ||
*, | ||
options: Mapping[str, Any], | ||
) -> Self: | ||
# TODO: allow set the index from one variable? Guess bounds like cf_xarray's add_bounds | ||
assert len(variables) == 2 | ||
|
||
for k, v in variables.items(): | ||
if v.ndim == 2: | ||
# TODO: be flexible with dimension order? Check which dim has length 2 | ||
bounds_name, bounds = k, v | ||
elif v.ndim == 1: | ||
dim, _ = k, v | ||
|
||
bounds = bounds.transpose(..., dim) | ||
left, right = bounds.data.tolist() | ||
# TODO: support non-dimension coordinates (pass variable name to pd.IntervalIndex.from_arrays) | ||
# TODO: propagate coordinate dtype (pass it to PandasIndex constructor) | ||
# TODO: add "closed" build option (maybe choose "closed='both'" as default here? to be consistent with | ||
# CF conventions: https://cfconventions.org/cf-conventions/cf-conventions.html#bounds-one-d) | ||
index = PandasIndex(pd.IntervalIndex.from_arrays(left, right), dim) | ||
bounds_dim = (set(bounds.dims) - set(dim)).pop() | ||
|
||
return cls(index, bounds_name, str(bounds_dim)) | ||
|
||
@classmethod | ||
def concat( | ||
cls, | ||
indexes: Sequence[IntervalIndex], | ||
dim: Hashable, | ||
positions: Iterable[Iterable[int]] | None = None, | ||
) -> IntervalIndex: | ||
new_index = PandasIndex.concat( | ||
[idx._index for idx in indexes], dim, positions=positions | ||
) | ||
|
||
if indexes: | ||
bounds_name = indexes[0]._bounds_name | ||
bounds_dim = indexes[0]._bounds_dim | ||
if any( | ||
idx._bounds_name != bounds_name or idx._bounds_dim != bounds_dim | ||
for idx in indexes | ||
): | ||
raise ValueError( | ||
f"Cannot concatenate along dimension {dim!r} indexes with different " | ||
"boundary coordinate or dimension names" | ||
) | ||
else: | ||
bounds_name = new_index.index.name + "_bounds" | ||
bounds_dim = "bnd" | ||
|
||
return cls(new_index, bounds_name, bounds_dim) | ||
|
||
@property | ||
def _pd_index(self) -> pd.IntervalIndex: | ||
# For typing purpose only | ||
# TODO: cleaner to make PandasIndex a generic class, i.e., PandasIndex[pd.IntervalIndex] | ||
# will be easier once PEP 696 is fully supported (starting from Python 3.13) | ||
return cast(pd.IntervalIndex, self._index.index) | ||
|
||
def create_variables( | ||
self, variables: Mapping[Any, Variable] | None = None | ||
) -> dict[Any, Variable]: | ||
if variables is None: | ||
variables = {} | ||
empty_var = Variable((), 0) | ||
bounds_attrs = variables.get(self._bounds_name, empty_var).attrs | ||
mid_attrs = variables.get(self._index.dim, empty_var).attrs | ||
|
||
# TODO: create a PandasIndexingAdapter subclass for the boundary variable | ||
# and wrap it here (avoid data copy) | ||
bounds_var = Variable( | ||
dims=(self._bounds_dim, self._index.dim), | ||
data=np.stack([self._pd_index.left, self._pd_index.right], axis=0), | ||
attrs=bounds_attrs, | ||
) | ||
# TODO: use PandasIndexingAdapter directly (avoid data copy) | ||
# and/or maybe add an index build option to preserve original labels? | ||
# (if those differ from interval midpoints as defined by pd.IntervalIndex) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should always save the central value and return that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed. This is sub-optimal in the case where the central values exactly correspond to what is returned |
||
mid_var = Variable( | ||
dims=(self._index.dim,), | ||
data=self._pd_index.mid, | ||
attrs=mid_attrs, | ||
) | ||
|
||
return {self._index.dim: mid_var, self._bounds_name: bounds_var} | ||
|
||
def should_add_coord_to_array( | ||
self, | ||
name: Hashable, | ||
var: Variable, | ||
dims: set[Hashable], | ||
) -> bool: | ||
# add both the mid and boundary coordinates if the index dimension | ||
# is present in the array dimensions | ||
if self._index.dim in dims: | ||
return True | ||
else: | ||
return False | ||
benbovy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def to_pandas_index(self) -> pd.Index: | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return self._pd_index | ||
|
||
def equals(self, other: Index) -> bool: | ||
if not isinstance(other, IntervalIndex): | ||
return False | ||
return self._index.equals(other._index) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again need to check the central value here |
||
|
||
def sel(self, labels: dict[Any, Any], **kwargs) -> IndexSelResult: | ||
return self._index.sel(labels, **kwargs) | ||
|
||
def isel( | ||
self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] | ||
) -> Self | None: | ||
new_index = self._index.isel(indexers) | ||
if new_index is not None: | ||
return type(self)(new_index, self._bounds_name, self._bounds_dim) | ||
else: | ||
return None | ||
|
||
def roll(self, shifts: Mapping[Any, int]) -> Self | None: | ||
new_index = self._index.roll(shifts) | ||
return type(self)(new_index, self._bounds_name, self._bounds_dim) | ||
|
||
def rename( | ||
self, | ||
name_dict: Mapping[Any, Hashable], | ||
dims_dict: Mapping[Any, Hashable], | ||
) -> Self: | ||
new_index = self._index.rename(name_dict, dims_dict) | ||
|
||
bounds_name = name_dict.get(self._bounds_name, self._bounds_name) | ||
bounds_dim = dims_dict.get(self._bounds_dim, self._bounds_dim) | ||
|
||
return type(self)(new_index, bounds_name, str(bounds_dim)) | ||
|
||
def __repr__(self) -> str: | ||
string = f"{self._index!r}" | ||
return string |
Uh oh!
There was an error while loading. Please reload this page.