Skip to content

Commit 9b1a90b

Browse files
committed
Optimize grouped first, last.
1. Use flox where possible. 2. Use simple indexing where possible. Closes pydata#9647
1 parent 1c7ee65 commit 9b1a90b

File tree

3 files changed

+104
-5
lines changed

3 files changed

+104
-5
lines changed

xarray/core/groupby.py

Lines changed: 86 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from xarray.core.alignment import align, broadcast
2121
from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic
2222
from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
23+
from xarray.core.computation import apply_ufunc
2324
from xarray.core.concat import concat
2425
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
2526
from xarray.core.duck_array_ops import where
@@ -1357,7 +1358,9 @@ def where(self, cond, other=dtypes.NA) -> T_Xarray:
13571358
"""
13581359
return ops.where_method(self, cond, other)
13591360

1360-
def _first_or_last(self, op, skipna, keep_attrs):
1361+
def _first_or_last(self, op: str, skipna: bool | None, keep_attrs: bool | None):
1362+
from xarray.core.dataarray import DataArray
1363+
13611364
if all(
13621365
isinstance(maybe_slice, slice)
13631366
and (maybe_slice.stop == maybe_slice.start + 1)
@@ -1368,17 +1371,95 @@ def _first_or_last(self, op, skipna, keep_attrs):
13681371
return self._obj
13691372
if keep_attrs is None:
13701373
keep_attrs = _get_keep_attrs(default=True)
1371-
return self.reduce(
1372-
op, dim=[self._group_dim], skipna=skipna, keep_attrs=keep_attrs
1374+
1375+
def _groupby_first_last_wrapper(
1376+
values,
1377+
by,
1378+
*,
1379+
op: Literal["first", "last"],
1380+
skipna: bool | None,
1381+
group_indices,
1382+
):
1383+
no_nans = dtypes.isdtype(
1384+
values.dtype, "signed integer"
1385+
) or dtypes.is_string(values.dtype)
1386+
if (skipna or skipna is None) and not no_nans:
1387+
skipna = True
1388+
else:
1389+
skipna = False
1390+
1391+
if TYPE_CHECKING:
1392+
assert isinstance(skipna, bool)
1393+
1394+
if skipna is False or (skipna and no_nans):
1395+
# this is an optimization: when skipna=False, we can simply index
1396+
# the whole object after picking the first/last member of each group
1397+
# in self.encoded.group_indices
1398+
if op == "first":
1399+
indices = [
1400+
(idx.start if isinstance(idx, slice) else idx[0])
1401+
for idx in group_indices
1402+
if idx
1403+
]
1404+
else:
1405+
indices = [
1406+
(idx.stop - 1 if isinstance(idx, slice) else idx[-1])
1407+
for idx in self.encoded.group_indices
1408+
if idx
1409+
]
1410+
return self._obj.isel({self._group_dim: indices})
1411+
1412+
elif (
1413+
skipna
1414+
and module_available("flox", minversion="0.9.14")
1415+
and OPTIONS["use_flox"]
1416+
and contains_only_chunked_or_numpy(self._obj)
1417+
):
1418+
import flox
1419+
1420+
result, *_ = flox.groupby_reduce(
1421+
values, self.group1d.data, axis=-1, func=f"nan{op}"
1422+
)
1423+
return result
1424+
1425+
else:
1426+
return self.reduce(
1427+
getattr(duck_array_ops, op),
1428+
dim=[self._group_dim],
1429+
skipna=skipna,
1430+
keep_attrs=keep_attrs,
1431+
)
1432+
1433+
result = apply_ufunc(
1434+
_groupby_first_last_wrapper,
1435+
self._obj,
1436+
self.group1d,
1437+
input_core_dims=[[self._group_dim], [self._group_dim]],
1438+
output_core_dims=[[self.group1d.name]],
1439+
dask="allowed",
1440+
output_sizes={self.group1d.name: len(self)},
1441+
exclude_dims={self._group_dim},
1442+
keep_attrs=keep_attrs,
1443+
kwargs={
1444+
"op": op,
1445+
"skipna": skipna,
1446+
"group_indices": self.encoded.group_indices,
1447+
},
13731448
)
1449+
result = result.assign_coords(self.encoded.coords)
1450+
result = self._maybe_unstack(result)
1451+
result = self._maybe_restore_empty_groups(result)
1452+
if isinstance(result, DataArray):
1453+
result = self._restore_dim_order(result)
1454+
return result
13741455

13751456
def first(self, skipna: bool | None = None, keep_attrs: bool | None = None):
13761457
"""Return the first element of each group along the group dimension"""
1377-
return self._first_or_last(duck_array_ops.first, skipna, keep_attrs)
1458+
return self._first_or_last("first", skipna, keep_attrs)
13781459

13791460
def last(self, skipna: bool | None = None, keep_attrs: bool | None = None):
13801461
"""Return the last element of each group along the group dimension"""
1381-
return self._first_or_last(duck_array_ops.last, skipna, keep_attrs)
1462+
return self._first_or_last("last", skipna, keep_attrs)
13821463

13831464
def assign_coords(self, coords=None, **coords_kwargs):
13841465
"""Assign coordinates by group.

xarray/core/resample.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,22 @@ def shuffle_to_chunks(self, chunks: T_Chunks = None):
103103
(grouper,) = self.groupers
104104
return self._shuffle_obj(chunks).drop_vars(RESAMPLE_DIM)
105105

106+
def _first_or_last(
107+
self, op: str, skipna: bool | None, keep_attrs: bool | None
108+
) -> T_Xarray:
109+
from xarray.core.dataset import Dataset
110+
111+
result = super()._first_or_last(op=op, skipna=skipna, keep_attrs=keep_attrs)
112+
result = result.rename({RESAMPLE_DIM: self._group_dim})
113+
if isinstance(result, Dataset):
114+
# Can't do this in the base class because group_dim is RESAMPLE_DIM
115+
# which is not present in the original object
116+
for var in result.data_vars:
117+
result._variables[var] = result._variables[var].transpose(
118+
*self._obj._variables[var].dims
119+
)
120+
return result
121+
106122
def _drop_coords(self) -> T_Xarray:
107123
"""Drop non-dimension coordinates along the resampled dimension."""
108124
obj = self._obj

xarray/tests/test_groupby.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1618,6 +1618,8 @@ def test_groupby_first_and_last(self) -> None:
16181618
expected = array # should be a no-op
16191619
assert_identical(expected, actual)
16201620

1621+
# TODO: groupby_bins too
1622+
16211623
def make_groupby_multidim_example_array(self) -> DataArray:
16221624
return DataArray(
16231625
[[[0, 1], [2, 3]], [[5, 10], [15, 20]]],

0 commit comments

Comments
 (0)