diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 78b5e92ed6..179516abe3 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -68,6 +68,8 @@ ) from zarr.core.config import config as zarr_config from zarr.core.indexing import ( + AsyncOIndex, + AsyncVIndex, BasicIndexer, BasicSelection, BlockIndex, @@ -1358,6 +1360,42 @@ async def getitem( ) return await self._get_selection(indexer, prototype=prototype) + async def get_orthogonal_selection( + self, + selection: OrthogonalSelection, + *, + out: NDBuffer | None = None, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, + ) -> NDArrayLike: + if prototype is None: + prototype = default_buffer_prototype() + indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid) + return await self._get_selection( + indexer=indexer, out=out, fields=fields, prototype=prototype + ) + + @_deprecate_positional_args + async def get_coordinate_selection( + self, + selection: CoordinateSelection, + *, + out: NDBuffer | None = None, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, + ) -> NDArrayLikeOrScalar: + if prototype is None: + prototype = default_buffer_prototype() + indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid) + out_array = await self._get_selection( + indexer=indexer, out=out, fields=fields, prototype=prototype + ) + + if hasattr(out_array, "shape"): + # restore shape + out_array = np.array(out_array).reshape(indexer.sel_shape) + return out_array + async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = False) -> None: """ Asynchronously save the array metadata. @@ -1488,6 +1526,16 @@ async def setitem( ) return await self._set_selection(indexer, value, prototype=prototype) + @property + def oindex(self) -> AsyncOIndex: + """Shortcut for orthogonal (outer) indexing, see :func:`get_orthogonal_selection` and + :func:`set_orthogonal_selection` for documentation and examples.""" + return AsyncOIndex(self) + + @property + def vindex(self) -> AsyncVIndex: + return AsyncVIndex(self) + async def resize(self, new_shape: ShapeLike, delete_outside_chunks: bool = True) -> None: """ Asynchronously resize the array to a new shape. diff --git a/src/zarr/core/indexing.py b/src/zarr/core/indexing.py index 998fe156a1..32c694a3cc 100644 --- a/src/zarr/core/indexing.py +++ b/src/zarr/core/indexing.py @@ -28,7 +28,7 @@ from zarr.core.common import product if TYPE_CHECKING: - from zarr.core.array import Array + from zarr.core.array import Array, AsyncArray from zarr.core.buffer import NDArrayLikeOrScalar from zarr.core.chunk_grids import ChunkGrid from zarr.core.common import ChunkCoords @@ -960,6 +960,25 @@ def __setitem__(self, selection: OrthogonalSelection, value: npt.ArrayLike) -> N ) +@dataclass(frozen=True) +class AsyncOIndex: + array: AsyncArray + + async def getitem(self, selection: OrthogonalSelection | Array) -> NDArrayLike: + from zarr.core.array import Array + + # if input is a Zarr array, we materialize it now. + if isinstance(selection, Array): + selection = _zarr_array_to_int_or_bool_array(selection) + + fields, new_selection = pop_fields(selection) + new_selection = ensure_tuple(new_selection) + new_selection = replace_lists(new_selection) + return await self.array.get_orthogonal_selection( + cast(OrthogonalSelection, new_selection), fields=fields + ) + + @dataclass(frozen=True) class BlockIndexer(Indexer): dim_indexers: list[SliceDimIndexer] @@ -1268,6 +1287,30 @@ def __setitem__( raise VindexInvalidSelectionError(new_selection) +@dataclass(frozen=True) +class AsyncVIndex: + array: AsyncArray + + # TODO: develop Array generic and move zarr.Array[np.intp] | zarr.Array[np.bool_] to ArrayOfIntOrBool + async def getitem( + self, selection: CoordinateSelection | MaskSelection | Array + ) -> NDArrayLikeOrScalar: + from zarr.core.array import Array + + # if input is a Zarr array, we materialize it now. + if isinstance(selection, Array): + selection = _zarr_array_to_int_or_bool_array(selection) + fields, new_selection = pop_fields(selection) + new_selection = ensure_tuple(new_selection) + new_selection = replace_lists(new_selection) + if is_coordinate_selection(new_selection, self.array.shape): + return await self.array.get_coordinate_selection(new_selection, fields=fields) + elif is_mask_selection(new_selection, self.array.shape): + return self.array.get_mask_selection(new_selection, fields=fields) + else: + raise VindexInvalidSelectionError(new_selection) + + def check_fields(fields: Fields | None, dtype: np.dtype[Any]) -> np.dtype[Any]: # early out if fields is None: