Skip to content

Commit d86bec1

Browse files
committed
simplify
1 parent 9b1a90b commit d86bec1

File tree

2 files changed

+15
-83
lines changed

2 files changed

+15
-83
lines changed

xarray/core/groupby.py

Lines changed: 15 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from xarray.core.alignment import align, broadcast
2121
from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic
2222
from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
23-
from xarray.core.computation import apply_ufunc
2423
from xarray.core.concat import concat
2524
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
2625
from xarray.core.duck_array_ops import where
@@ -1359,8 +1358,6 @@ def where(self, cond, other=dtypes.NA) -> T_Xarray:
13591358
return ops.where_method(self, cond, other)
13601359

13611360
def _first_or_last(self, op: str, skipna: bool | None, keep_attrs: bool | None):
1362-
from xarray.core.dataarray import DataArray
1363-
13641361
if all(
13651362
isinstance(maybe_slice, slice)
13661363
and (maybe_slice.stop == maybe_slice.start + 1)
@@ -1371,86 +1368,22 @@ def _first_or_last(self, op: str, skipna: bool | None, keep_attrs: bool | None):
13711368
return self._obj
13721369
if keep_attrs is None:
13731370
keep_attrs = _get_keep_attrs(default=True)
1374-
1375-
def _groupby_first_last_wrapper(
1376-
values,
1377-
by,
1378-
*,
1379-
op: Literal["first", "last"],
1380-
skipna: bool | None,
1381-
group_indices,
1371+
if (
1372+
skipna
1373+
and module_available("flox", minversion="0.9.16")
1374+
and OPTIONS["use_flox"]
1375+
and contains_only_chunked_or_numpy(self._obj)
13821376
):
1383-
no_nans = dtypes.isdtype(
1384-
values.dtype, "signed integer"
1385-
) or dtypes.is_string(values.dtype)
1386-
if (skipna or skipna is None) and not no_nans:
1387-
skipna = True
1388-
else:
1389-
skipna = False
1390-
1391-
if TYPE_CHECKING:
1392-
assert isinstance(skipna, bool)
1393-
1394-
if skipna is False or (skipna and no_nans):
1395-
# this is an optimization: when skipna=False, we can simply index
1396-
# the whole object after picking the first/last member of each group
1397-
# in self.encoded.group_indices
1398-
if op == "first":
1399-
indices = [
1400-
(idx.start if isinstance(idx, slice) else idx[0])
1401-
for idx in group_indices
1402-
if idx
1403-
]
1404-
else:
1405-
indices = [
1406-
(idx.stop - 1 if isinstance(idx, slice) else idx[-1])
1407-
for idx in self.encoded.group_indices
1408-
if idx
1409-
]
1410-
return self._obj.isel({self._group_dim: indices})
1411-
1412-
elif (
1413-
skipna
1414-
and module_available("flox", minversion="0.9.14")
1415-
and OPTIONS["use_flox"]
1416-
and contains_only_chunked_or_numpy(self._obj)
1417-
):
1418-
import flox
1419-
1420-
result, *_ = flox.groupby_reduce(
1421-
values, self.group1d.data, axis=-1, func=f"nan{op}"
1422-
)
1423-
return result
1424-
1425-
else:
1426-
return self.reduce(
1427-
getattr(duck_array_ops, op),
1428-
dim=[self._group_dim],
1429-
skipna=skipna,
1430-
keep_attrs=keep_attrs,
1431-
)
1432-
1433-
result = apply_ufunc(
1434-
_groupby_first_last_wrapper,
1435-
self._obj,
1436-
self.group1d,
1437-
input_core_dims=[[self._group_dim], [self._group_dim]],
1438-
output_core_dims=[[self.group1d.name]],
1439-
dask="allowed",
1440-
output_sizes={self.group1d.name: len(self)},
1441-
exclude_dims={self._group_dim},
1442-
keep_attrs=keep_attrs,
1443-
kwargs={
1444-
"op": op,
1445-
"skipna": skipna,
1446-
"group_indices": self.encoded.group_indices,
1447-
},
1448-
)
1449-
result = result.assign_coords(self.encoded.coords)
1450-
result = self._maybe_unstack(result)
1451-
result = self._maybe_restore_empty_groups(result)
1452-
if isinstance(result, DataArray):
1453-
result = self._restore_dim_order(result)
1377+
result, *_ = self._flox_reduce(
1378+
dim=None, func=f"nan{op}" if skipna else op, keep_attrs=keep_attrs
1379+
)
1380+
else:
1381+
result = self.reduce(
1382+
getattr(duck_array_ops, op),
1383+
dim=[self._group_dim],
1384+
skipna=skipna,
1385+
keep_attrs=keep_attrs,
1386+
)
14541387
return result
14551388

14561389
def first(self, skipna: bool | None = None, keep_attrs: bool | None = None):

xarray/core/resample.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ def _first_or_last(
109109
from xarray.core.dataset import Dataset
110110

111111
result = super()._first_or_last(op=op, skipna=skipna, keep_attrs=keep_attrs)
112-
result = result.rename({RESAMPLE_DIM: self._group_dim})
113112
if isinstance(result, Dataset):
114113
# Can't do this in the base class because group_dim is RESAMPLE_DIM
115114
# which is not present in the original object

0 commit comments

Comments
 (0)