Skip to content

Commit 8bc7bea

Browse files
committed
implement async vectorized indexing
1 parent b9e8e06 commit 8bc7bea

File tree

4 files changed

+44
-8
lines changed

4 files changed

+44
-8
lines changed

xarray/backends/zarr.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,10 @@ async def _async_oindex(self, key):
222222
async_array = self._array._async_array
223223
return await async_array.oindex.getitem(key)
224224

225+
async def _async_vindex(self, key):
226+
async_array = self._array._async_array
227+
return await async_array.vindex.getitem(key)
228+
225229
def __getitem__(self, key):
226230
array = self._array
227231
if isinstance(key, indexing.BasicIndexer):
@@ -242,14 +246,11 @@ async def async_getitem(self, key):
242246
if isinstance(key, indexing.BasicIndexer):
243247
method = self._async_getitem
244248
elif isinstance(key, indexing.VectorizedIndexer):
245-
# method = self._vindex
246-
raise NotImplementedError("async lazy vectorized indexing is not supported")
249+
method = self._async_vindex
247250
elif isinstance(key, indexing.OuterIndexer):
248251
method = self._async_oindex
249-
250-
print("did an async get")
251252
return await indexing.async_explicit_indexing_adapter(
252-
key, array.shape, indexing.IndexingSupport.OUTER, method
253+
key, array.shape, indexing.IndexingSupport.VECTORIZED, method
253254
)
254255

255256

xarray/core/indexing.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ def get_duck_array(self):
525525

526526
async def async_get_duck_array(self):
527527
key = BasicIndexer((slice(None),) * self.ndim)
528-
return self[key]
528+
return await self.async_getitem(key)
529529

530530
def _oindex_get(self, indexer: OuterIndexer):
531531
raise NotImplementedError(
@@ -756,6 +756,22 @@ def get_duck_array(self):
756756
array = array.get_duck_array()
757757
return _wrap_numpy_scalars(array)
758758

759+
async def async_get_duck_array(self):
760+
print("inside LazilyVectorizedIndexedArray.async_get_duck_array")
761+
if isinstance(self.array, ExplicitlyIndexedNDArrayMixin):
762+
array = apply_indexer(self.array, self.key)
763+
else:
764+
# If the array is not an ExplicitlyIndexedNDArrayMixin,
765+
# it may wrap a BackendArray so use its __getitem__
766+
array = await self.array.async_getitem(self.key)
767+
# self.array[self.key] is now a numpy array when
768+
# self.array is a BackendArray subclass
769+
# and self.key is BasicIndexer((slice(None, None, None),))
770+
# so we need the explicit check for ExplicitlyIndexed
771+
if isinstance(array, ExplicitlyIndexed):
772+
array = await array.async_get_duck_array()
773+
return _wrap_numpy_scalars(array)
774+
759775
def _updated_key(self, new_key: ExplicitIndexer):
760776
return _combine_indexers(self.key, self.shape, new_key)
761777

@@ -1608,6 +1624,16 @@ def __getitem__(self, indexer: ExplicitIndexer):
16081624
key = indexer.tuple + (Ellipsis,)
16091625
return array[key]
16101626

1627+
async def async_getitem(self, indexer: ExplicitIndexer):
1628+
self._check_and_raise_if_non_basic_indexer(indexer)
1629+
1630+
array = self.array
1631+
# We want 0d slices rather than scalars. This is achieved by
1632+
# appending an ellipsis (see
1633+
# https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes).
1634+
key = indexer.tuple + (Ellipsis,)
1635+
return array[key]
1636+
16111637
def _safe_setitem(self, array, key: tuple[Any, ...], value: Any) -> None:
16121638
try:
16131639
array[key] = value
@@ -1855,6 +1881,15 @@ def get_duck_array(self) -> np.ndarray | PandasExtensionArray:
18551881
return PandasExtensionArray(self.array.array)
18561882
return np.asarray(self)
18571883

1884+
async def async_get_duck_array(self) -> np.ndarray | PandasExtensionArray:
1885+
# TODO this must surely be wrong - it's not async yet
1886+
print("in PandasIndexingAdapter")
1887+
if pd.api.types.is_extension_array_dtype(self.array):
1888+
from xarray.core.extension_array import PandasExtensionArray
1889+
1890+
return PandasExtensionArray(self.array.array)
1891+
return np.asarray(self)
1892+
18581893
@property
18591894
def shape(self) -> _Shape:
18601895
return (len(self.array),)

xarray/core/variable.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -959,7 +959,6 @@ def load(self, **kwargs):
959959
return self
960960

961961
async def load_async(self, **kwargs):
962-
print("async inside Variable")
963962
self._data = await async_to_duck_array(self._data, **kwargs)
964963
return self
965964

xarray/tests/test_async.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ async def test_concurrent_load_multiple_objects(self, xr_obj) -> None:
184184
("sel", {"x": 2}),
185185
("sel", {"x": [2, 3]}),
186186
(
187-
"isel",
187+
"sel",
188188
{
189189
"x": xr.DataArray([2, 3], dims="points"),
190190
"y": xr.DataArray([2, 3], dims="points"),
@@ -198,6 +198,7 @@ async def test_indexing(self, memorystore, method, indexer) -> None:
198198
latencystore = LatencyStore(memorystore, latency=0.0)
199199
ds = xr.open_zarr(latencystore, zarr_format=3, consolidated=False, chunks=None)
200200

201+
# TODO we're not actually testing that these indexing methods are not blocking...
201202
result = await getattr(ds, method)(**indexer).load_async()
202203
expected = getattr(ds, method)(**indexer).load()
203204
xrt.assert_identical(result, expected)

0 commit comments

Comments
 (0)