Skip to content

Commit ddd4cdb

Browse files
andersy005dcherian
andcommitted
add .oindex and .vindex to BackendArray (#8885)
* add .oindex and .vindex to BackendArray * Add support for .oindex and .vindex in H5NetCDFArrayWrapper * Add support for .oindex and .vindex in NetCDF4ArrayWrapper, PydapArrayWrapper, NioArrayWrapper, and ZarrArrayWrapper * add deprecation warning * Fix deprecation warning message formatting * add tests * Update xarray/core/indexing.py Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com> * Update ZarrArrayWrapper class in xarray/backends/zarr.py Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com> --------- Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
1 parent 3029943 commit ddd4cdb

File tree

8 files changed

+182
-36
lines changed

8 files changed

+182
-36
lines changed

xarray/backends/common.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,24 @@ def get_duck_array(self, dtype: np.typing.DTypeLike = None):
210210
key = indexing.BasicIndexer((slice(None),) * self.ndim)
211211
return self[key] # type: ignore [index]
212212

213+
def _oindex_get(self, key: indexing.OuterIndexer):
214+
raise NotImplementedError(
215+
f"{self.__class__.__name__}._oindex_get method should be overridden"
216+
)
217+
218+
def _vindex_get(self, key: indexing.VectorizedIndexer):
219+
raise NotImplementedError(
220+
f"{self.__class__.__name__}._vindex_get method should be overridden"
221+
)
222+
223+
@property
224+
def oindex(self) -> indexing.IndexCallable:
225+
return indexing.IndexCallable(self._oindex_get)
226+
227+
@property
228+
def vindex(self) -> indexing.IndexCallable:
229+
return indexing.IndexCallable(self._vindex_get)
230+
213231

214232
class AbstractDataStore:
215233
__slots__ = ()

xarray/backends/h5netcdf_.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,17 @@ def get_array(self, needs_lock=True):
4848
ds = self.datastore._acquire(needs_lock)
4949
return ds.variables[self.variable_name]
5050

51-
def __getitem__(self, key):
51+
def _oindex_get(self, key: indexing.OuterIndexer):
52+
return indexing.explicit_indexing_adapter(
53+
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
54+
)
55+
56+
def _vindex_get(self, key: indexing.VectorizedIndexer):
57+
return indexing.explicit_indexing_adapter(
58+
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
59+
)
60+
61+
def __getitem__(self, key: indexing.BasicIndexer):
5262
return indexing.explicit_indexing_adapter(
5363
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
5464
)

xarray/backends/netCDF4_.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,17 @@ def get_array(self, needs_lock=True):
9797
variable.set_auto_chartostring(False)
9898
return variable
9999

100-
def __getitem__(self, key):
100+
def _oindex_get(self, key: indexing.OuterIndexer):
101+
return indexing.explicit_indexing_adapter(
102+
key, self.shape, indexing.IndexingSupport.OUTER, self._getitem
103+
)
104+
105+
def _vindex_get(self, key: indexing.VectorizedIndexer):
106+
return indexing.explicit_indexing_adapter(
107+
key, self.shape, indexing.IndexingSupport.OUTER, self._getitem
108+
)
109+
110+
def __getitem__(self, key: indexing.BasicIndexer):
101111
return indexing.explicit_indexing_adapter(
102112
key, self.shape, indexing.IndexingSupport.OUTER, self._getitem
103113
)

xarray/backends/pydap_.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,17 @@ def shape(self) -> tuple[int, ...]:
4343
def dtype(self):
4444
return self.array.dtype
4545

46-
def __getitem__(self, key):
46+
def _oindex_get(self, key: indexing.OuterIndexer):
47+
return indexing.explicit_indexing_adapter(
48+
key, self.shape, indexing.IndexingSupport.BASIC, self._getitem
49+
)
50+
51+
def _vindex_get(self, key: indexing.VectorizedIndexer):
52+
return indexing.explicit_indexing_adapter(
53+
key, self.shape, indexing.IndexingSupport.BASIC, self._getitem
54+
)
55+
56+
def __getitem__(self, key: indexing.BasicIndexer):
4757
return indexing.explicit_indexing_adapter(
4858
key, self.shape, indexing.IndexingSupport.BASIC, self._getitem
4959
)

xarray/backends/scipy_.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,7 @@ def get_variable(self, needs_lock=True):
6767
ds = self.datastore._manager.acquire(needs_lock)
6868
return ds.variables[self.variable_name]
6969

70-
def _getitem(self, key):
71-
with self.datastore.lock:
72-
data = self.get_variable(needs_lock=False).data
73-
return data[key]
74-
75-
def __getitem__(self, key):
76-
data = indexing.explicit_indexing_adapter(
77-
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
78-
)
70+
def _finalize_result(self, data):
7971
# Copy data if the source file is mmapped. This makes things consistent
8072
# with the netCDF4 library by ensuring we can safely read arrays even
8173
# after closing associated files.
@@ -88,6 +80,29 @@ def __getitem__(self, key):
8880

8981
return np.array(data, dtype=self.dtype, copy=copy)
9082

83+
def _getitem(self, key):
84+
with self.datastore.lock:
85+
data = self.get_variable(needs_lock=False).data
86+
return data[key]
87+
88+
def _vindex_get(self, key: indexing.VectorizedIndexer):
89+
data = indexing.explicit_indexing_adapter(
90+
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
91+
)
92+
return self._finalize_result(data)
93+
94+
def _oindex_get(self, key: indexing.OuterIndexer):
95+
data = indexing.explicit_indexing_adapter(
96+
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
97+
)
98+
return self._finalize_result(data)
99+
100+
def __getitem__(self, key):
101+
data = indexing.explicit_indexing_adapter(
102+
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
103+
)
104+
return self._finalize_result(data)
105+
91106
def __setitem__(self, key, value):
92107
with self.datastore.lock:
93108
data = self.get_variable(needs_lock=False)

xarray/backends/zarr.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -85,25 +85,38 @@ def __init__(self, zarr_array):
8585
def get_array(self):
8686
return self._array
8787

88-
def _oindex(self, key):
89-
return self._array.oindex[key]
90-
91-
def _vindex(self, key):
92-
return self._array.vindex[key]
93-
94-
def _getitem(self, key):
95-
return self._array[key]
96-
97-
def __getitem__(self, key):
98-
array = self._array
99-
if isinstance(key, indexing.BasicIndexer):
100-
method = self._getitem
101-
elif isinstance(key, indexing.VectorizedIndexer):
102-
method = self._vindex
103-
elif isinstance(key, indexing.OuterIndexer):
104-
method = self._oindex
88+
def _oindex_get(self, key: indexing.OuterIndexer):
89+
def raw_indexing_method(key):
90+
return self._array.oindex[key]
91+
92+
return indexing.explicit_indexing_adapter(
93+
key,
94+
self._array.shape,
95+
indexing.IndexingSupport.VECTORIZED,
96+
raw_indexing_method,
97+
)
98+
99+
def _vindex_get(self, key: indexing.VectorizedIndexer):
100+
101+
def raw_indexing_method(key):
102+
return self._array.vindex[key]
103+
104+
return indexing.explicit_indexing_adapter(
105+
key,
106+
self._array.shape,
107+
indexing.IndexingSupport.VECTORIZED,
108+
raw_indexing_method,
109+
)
110+
111+
def __getitem__(self, key: indexing.BasicIndexer):
112+
def raw_indexing_method(key):
113+
return self._array[key]
114+
105115
return indexing.explicit_indexing_adapter(
106-
key, array.shape, indexing.IndexingSupport.VECTORIZED, method
116+
key,
117+
self._array.shape,
118+
indexing.IndexingSupport.VECTORIZED,
119+
raw_indexing_method,
107120
)
108121

109122
# if self.ndim == 0:

xarray/core/indexing.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import enum
44
import functools
55
import operator
6+
import warnings
67
from collections import Counter, defaultdict
78
from collections.abc import Hashable, Iterable, Mapping
89
from contextlib import suppress
@@ -588,6 +589,14 @@ def __getitem__(self, key: Any):
588589
return result
589590

590591

592+
BackendArray_fallback_warning_message = (
593+
"The array `{0}` does not support indexing using the .vindex and .oindex properties. "
594+
"The __getitem__ method is being used instead. This fallback behavior will be "
595+
"removed in a future version. Please ensure that the backend array `{1}` implements "
596+
"support for the .vindex and .oindex properties to avoid potential issues."
597+
)
598+
599+
591600
class LazilyIndexedArray(ExplicitlyIndexedNDArrayMixin):
592601
"""Wrap an array to make basic and outer indexing lazy."""
593602

@@ -639,11 +648,18 @@ def shape(self) -> _Shape:
639648
return tuple(shape)
640649

641650
def get_duck_array(self):
642-
if isinstance(self.array, ExplicitlyIndexedNDArrayMixin):
651+
try:
643652
array = apply_indexer(self.array, self.key)
644-
else:
653+
except NotImplementedError as _:
645654
# If the array is not an ExplicitlyIndexedNDArrayMixin,
646-
# it may wrap a BackendArray so use its __getitem__
655+
# it may wrap a BackendArray subclass that doesn't implement .oindex and .vindex. so use its __getitem__
656+
warnings.warn(
657+
BackendArray_fallback_warning_message.format(
658+
self.array.__class__.__name__, self.array.__class__.__name__
659+
),
660+
category=DeprecationWarning,
661+
stacklevel=2,
662+
)
647663
array = self.array[self.key]
648664

649665
# self.array[self.key] is now a numpy array when
@@ -715,12 +731,20 @@ def shape(self) -> _Shape:
715731
return np.broadcast(*self.key.tuple).shape
716732

717733
def get_duck_array(self):
718-
if isinstance(self.array, ExplicitlyIndexedNDArrayMixin):
734+
try:
719735
array = apply_indexer(self.array, self.key)
720-
else:
736+
except NotImplementedError as _:
721737
# If the array is not an ExplicitlyIndexedNDArrayMixin,
722-
# it may wrap a BackendArray so use its __getitem__
738+
# it may wrap a BackendArray subclass that doesn't implement .oindex and .vindex. so use its __getitem__
739+
warnings.warn(
740+
BackendArray_fallback_warning_message.format(
741+
self.array.__class__.__name__, self.array.__class__.__name__
742+
),
743+
category=PendingDeprecationWarning,
744+
stacklevel=2,
745+
)
723746
array = self.array[self.key]
747+
724748
# self.array[self.key] is now a numpy array when
725749
# self.array is a BackendArray subclass
726750
# and self.key is BasicIndexer((slice(None, None, None),))

xarray/tests/test_backends.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5787,3 +5787,49 @@ def test_zarr_region_chunk_partial_offset(tmp_path):
57875787
# This write is unsafe, and should raise an error, but does not.
57885788
# with pytest.raises(ValueError):
57895789
# da.isel(x=slice(5, 25)).chunk(x=(10, 10)).to_zarr(store, region="auto")
5790+
5791+
5792+
def test_backend_array_deprecation_warning(capsys):
5793+
class CustomBackendArray(xr.backends.common.BackendArray):
5794+
def __init__(self):
5795+
array = self.get_array()
5796+
self.shape = array.shape
5797+
self.dtype = array.dtype
5798+
5799+
def get_array(self):
5800+
return np.arange(10)
5801+
5802+
def __getitem__(self, key):
5803+
return xr.core.indexing.explicit_indexing_adapter(
5804+
key, self.shape, xr.core.indexing.IndexingSupport.BASIC, self._getitem
5805+
)
5806+
5807+
def _getitem(self, key):
5808+
array = self.get_array()
5809+
return array[key]
5810+
5811+
cba = CustomBackendArray()
5812+
indexer = xr.core.indexing.VectorizedIndexer(key=(np.array([0]),))
5813+
5814+
la = xr.core.indexing.LazilyIndexedArray(cba, indexer)
5815+
5816+
with warnings.catch_warnings(record=True) as w:
5817+
warnings.simplefilter("always")
5818+
la.vindex[indexer].get_duck_array()
5819+
5820+
captured = capsys.readouterr()
5821+
assert len(w) == 1
5822+
assert issubclass(w[-1].category, PendingDeprecationWarning)
5823+
assert (
5824+
"The array `CustomBackendArray` does not support indexing using the .vindex and .oindex properties."
5825+
in str(w[-1].message)
5826+
)
5827+
assert "The __getitem__ method is being used instead." in str(w[-1].message)
5828+
assert "This fallback behavior will be removed in a future version." in str(
5829+
w[-1].message
5830+
)
5831+
assert (
5832+
"Please ensure that the backend array `CustomBackendArray` implements support for the .vindex and .oindex properties to avoid potential issues."
5833+
in str(w[-1].message)
5834+
)
5835+
assert captured.out == ""

0 commit comments

Comments
 (0)