Skip to content

Commit 9fe816e

Browse files
authored
Optimize idxmin, idxmax with dask (#9800)
* Optimize idxmin, idxmax with dask Closes #9425 * use map_blocks instead * small edits * fix typing * try again * Migrate to DaskIndexingAdapter * cleanup * fix? * finish * fix types * review comments
1 parent 8afed74 commit 9fe816e

File tree

5 files changed

+115
-30
lines changed

5 files changed

+115
-30
lines changed

xarray/core/computation.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@
3232
from xarray.core.merge import merge_attrs, merge_coordinates_without_align
3333
from xarray.core.options import OPTIONS, _get_keep_attrs
3434
from xarray.core.types import Dims, T_DataArray
35-
from xarray.core.utils import is_dict_like, is_scalar, parse_dims_as_set, result_name
35+
from xarray.core.utils import (
36+
is_dict_like,
37+
is_scalar,
38+
parse_dims_as_set,
39+
result_name,
40+
)
3641
from xarray.core.variable import Variable
3742
from xarray.namedarray.parallelcompat import get_chunked_array_type
3843
from xarray.namedarray.pycompat import is_chunked_array
@@ -2166,19 +2171,17 @@ def _calc_idxminmax(
21662171
indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna)
21672172

21682173
# Handle chunked arrays (e.g. dask).
2174+
coord = array[dim]._variable.to_base_variable()
21692175
if is_chunked_array(array.data):
21702176
chunkmanager = get_chunked_array_type(array.data)
2171-
chunks = dict(zip(array.dims, array.chunks, strict=True))
2172-
dask_coord = chunkmanager.from_array(array[dim].data, chunks=chunks[dim])
2173-
data = dask_coord[duck_array_ops.ravel(indx.data)]
2177+
coord_array = chunkmanager.from_array(
2178+
array[dim].data, chunks=((array.sizes[dim],),)
2179+
)
2180+
coord = coord.copy(data=coord_array)
21742181
else:
2175-
arr_coord = to_like_array(array[dim].data, array.data)
2176-
data = arr_coord[duck_array_ops.ravel(indx.data)]
2182+
coord = coord.copy(data=to_like_array(array[dim].data, array.data))
21772183

2178-
# rebuild like the argmin/max output, and rename as the dim name
2179-
data = duck_array_ops.reshape(data, indx.shape)
2180-
res = indx.copy(data=data)
2181-
res.name = dim
2184+
res = indx._replace(coord[(indx.variable,)]).rename(dim)
21822185

21832186
if skipna or (skipna is None and array.dtype.kind in na_dtypes):
21842187
# Put the NaN values back in after removing them

xarray/core/indexing.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import enum
44
import functools
5+
import math
56
import operator
67
from collections import Counter, defaultdict
78
from collections.abc import Callable, Hashable, Iterable, Mapping
@@ -472,12 +473,6 @@ def __init__(self, key: tuple[slice | np.ndarray[Any, np.dtype[np.generic]], ...
472473
for k in key:
473474
if isinstance(k, slice):
474475
k = as_integer_slice(k)
475-
elif is_duck_dask_array(k):
476-
raise ValueError(
477-
"Vectorized indexing with Dask arrays is not supported. "
478-
"Please pass a numpy array by calling ``.compute``. "
479-
"See https://github.com/dask/dask/issues/8958."
480-
)
481476
elif is_duck_array(k):
482477
if not np.issubdtype(k.dtype, np.integer):
483478
raise TypeError(
@@ -1508,6 +1503,7 @@ def _oindex_get(self, indexer: OuterIndexer):
15081503
return self.array[key]
15091504

15101505
def _vindex_get(self, indexer: VectorizedIndexer):
1506+
_assert_not_chunked_indexer(indexer.tuple)
15111507
array = NumpyVIndexAdapter(self.array)
15121508
return array[indexer.tuple]
15131509

@@ -1607,6 +1603,28 @@ def transpose(self, order):
16071603
return xp.permute_dims(self.array, order)
16081604

16091605

1606+
def _apply_vectorized_indexer_dask_wrapper(indices, coord):
1607+
from xarray.core.indexing import (
1608+
VectorizedIndexer,
1609+
apply_indexer,
1610+
as_indexable,
1611+
)
1612+
1613+
return apply_indexer(
1614+
as_indexable(coord), VectorizedIndexer((indices.squeeze(axis=-1),))
1615+
)
1616+
1617+
1618+
def _assert_not_chunked_indexer(idxr: tuple[Any, ...]) -> None:
1619+
if any(is_chunked_array(i) for i in idxr):
1620+
raise ValueError(
1621+
"Cannot index with a chunked array indexer. "
1622+
"Please chunk the array you are indexing first, "
1623+
"and drop any indexed dimension coordinate variables. "
1624+
"Alternatively, call `.compute()` on any chunked arrays in the indexer."
1625+
)
1626+
1627+
16101628
class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
16111629
"""Wrap a dask array to support explicit indexing."""
16121630

@@ -1630,7 +1648,35 @@ def _oindex_get(self, indexer: OuterIndexer):
16301648
return value
16311649

16321650
def _vindex_get(self, indexer: VectorizedIndexer):
1633-
return self.array.vindex[indexer.tuple]
1651+
try:
1652+
return self.array.vindex[indexer.tuple]
1653+
except IndexError as e:
1654+
# TODO: upstream to dask
1655+
has_dask = any(is_duck_dask_array(i) for i in indexer.tuple)
1656+
# this only works for "small" 1d coordinate arrays with one chunk
1657+
# it is intended for idxmin, idxmax, and allows indexing with
1658+
# the nD array output of argmin, argmax
1659+
if (
1660+
not has_dask
1661+
or len(indexer.tuple) > 1
1662+
or math.prod(self.array.numblocks) > 1
1663+
or self.array.ndim > 1
1664+
):
1665+
raise e
1666+
(idxr,) = indexer.tuple
1667+
if idxr.ndim == 0:
1668+
return self.array[idxr.data]
1669+
else:
1670+
import dask.array
1671+
1672+
return dask.array.map_blocks(
1673+
_apply_vectorized_indexer_dask_wrapper,
1674+
idxr[..., np.newaxis],
1675+
self.array,
1676+
chunks=idxr.chunks,
1677+
drop_axis=-1,
1678+
dtype=self.array.dtype,
1679+
)
16341680

16351681
def __getitem__(self, indexer: ExplicitIndexer):
16361682
self._check_and_raise_if_non_basic_indexer(indexer)
@@ -1770,6 +1816,7 @@ def _vindex_get(
17701816
| np.datetime64
17711817
| np.timedelta64
17721818
):
1819+
_assert_not_chunked_indexer(indexer.tuple)
17731820
key = self._prepare_key(indexer.tuple)
17741821

17751822
if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional

xarray/tests/test_dask.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1815,3 +1815,16 @@ def test_minimize_graph_size():
18151815
# all the other dimensions.
18161816
# e.g. previously for 'x', actual == numchunks['y'] * numchunks['z']
18171817
assert actual == numchunks[var], (actual, numchunks[var])
1818+
1819+
1820+
def test_idxmin_chunking():
1821+
# GH9425
1822+
x, y, t = 100, 100, 10
1823+
rang = np.arange(t * x * y)
1824+
da = xr.DataArray(
1825+
rang.reshape(t, x, y), coords={"time": range(t), "x": range(x), "y": range(y)}
1826+
)
1827+
da = da.chunk(dict(time=-1, x=25, y=25))
1828+
actual = da.idxmin("time")
1829+
assert actual.chunksizes == {k: da.chunksizes[k] for k in ["x", "y"]}
1830+
assert_identical(actual, da.compute().idxmin("time"))

xarray/tests/test_dataarray.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4949,7 +4949,15 @@ def test_argmax(
49494949

49504950
assert_identical(result2, expected2)
49514951

4952-
@pytest.mark.parametrize("use_dask", [True, False])
4952+
@pytest.mark.parametrize(
4953+
"use_dask",
4954+
[
4955+
pytest.param(
4956+
True, marks=pytest.mark.skipif(not has_dask, reason="no dask")
4957+
),
4958+
False,
4959+
],
4960+
)
49534961
def test_idxmin(
49544962
self,
49554963
x: np.ndarray,
@@ -4958,16 +4966,11 @@ def test_idxmin(
49584966
nanindex: int | None,
49594967
use_dask: bool,
49604968
) -> None:
4961-
if use_dask and not has_dask:
4962-
pytest.skip("requires dask")
4963-
if use_dask and x.dtype.kind == "M":
4964-
pytest.xfail("dask operation 'argmin' breaks when dtype is datetime64 (M)")
49654969
ar0_raw = xr.DataArray(
49664970
x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs
49674971
)
4968-
49694972
if use_dask:
4970-
ar0 = ar0_raw.chunk({})
4973+
ar0 = ar0_raw.chunk()
49714974
else:
49724975
ar0 = ar0_raw
49734976

xarray/tests/test_indexing.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,7 @@ def test_indexing_1d_object_array() -> None:
974974

975975

976976
@requires_dask
977-
def test_indexing_dask_array():
977+
def test_indexing_dask_array() -> None:
978978
import dask.array
979979

980980
da = DataArray(
@@ -988,34 +988,53 @@ def test_indexing_dask_array():
988988

989989

990990
@requires_dask
991-
def test_indexing_dask_array_scalar():
991+
def test_indexing_dask_array_scalar() -> None:
992992
# GH4276
993993
import dask.array
994994

995995
a = dask.array.from_array(np.linspace(0.0, 1.0))
996996
da = DataArray(a, dims="x")
997997
x_selector = da.argmax(dim=...)
998+
assert not isinstance(x_selector, DataArray)
998999
with raise_if_dask_computes():
9991000
actual = da.isel(x_selector)
10001001
expected = da.isel(x=-1)
10011002
assert_identical(actual, expected)
10021003

10031004

10041005
@requires_dask
1005-
def test_vectorized_indexing_dask_array():
1006+
def test_vectorized_indexing_dask_array() -> None:
10061007
# https://github.com/pydata/xarray/issues/2511#issuecomment-563330352
10071008
darr = DataArray(data=[0.2, 0.4, 0.6], coords={"z": range(3)}, dims=("z",))
10081009
indexer = DataArray(
10091010
data=np.random.randint(0, 3, 8).reshape(4, 2).astype(int),
10101011
coords={"y": range(4), "x": range(2)},
10111012
dims=("y", "x"),
10121013
)
1013-
with pytest.raises(ValueError, match="Vectorized indexing with Dask arrays"):
1014-
darr[indexer.chunk({"y": 2})]
1014+
expected = darr[indexer]
1015+
1016+
# fails because we can't index pd.Index lazily (yet).
1017+
# We could make this succeed by auto-chunking the values
1018+
# and constructing a lazy index variable, and not automatically
1019+
# create an index for it.
1020+
with pytest.raises(ValueError, match="Cannot index with"):
1021+
with raise_if_dask_computes():
1022+
darr.chunk()[indexer.chunk({"y": 2})]
1023+
with pytest.raises(ValueError, match="Cannot index with"):
1024+
with raise_if_dask_computes():
1025+
actual = darr[indexer.chunk({"y": 2})]
1026+
1027+
with raise_if_dask_computes():
1028+
actual = darr.drop_vars("z").chunk()[indexer.chunk({"y": 2})]
1029+
assert_identical(actual, expected.drop_vars("z"))
1030+
1031+
with raise_if_dask_computes():
1032+
actual_variable = darr.variable.chunk()[indexer.variable.chunk({"y": 2})]
1033+
assert_identical(actual_variable, expected.variable)
10151034

10161035

10171036
@requires_dask
1018-
def test_advanced_indexing_dask_array():
1037+
def test_advanced_indexing_dask_array() -> None:
10191038
# GH4663
10201039
import dask.array as da
10211040

0 commit comments

Comments
 (0)