20
20
from xarray .core .alignment import align , broadcast
21
21
from xarray .core .arithmetic import DataArrayGroupbyArithmetic , DatasetGroupbyArithmetic
22
22
from xarray .core .common import ImplementsArrayReduce , ImplementsDatasetReduce
23
- from xarray .core .computation import apply_ufunc
24
23
from xarray .core .concat import concat
25
24
from xarray .core .coordinates import Coordinates , _coordinates_from_variable
26
25
from xarray .core .duck_array_ops import where
@@ -1359,8 +1358,6 @@ def where(self, cond, other=dtypes.NA) -> T_Xarray:
1359
1358
return ops .where_method (self , cond , other )
1360
1359
1361
1360
def _first_or_last (self , op : str , skipna : bool | None , keep_attrs : bool | None ):
1362
- from xarray .core .dataarray import DataArray
1363
-
1364
1361
if all (
1365
1362
isinstance (maybe_slice , slice )
1366
1363
and (maybe_slice .stop == maybe_slice .start + 1 )
@@ -1371,86 +1368,22 @@ def _first_or_last(self, op: str, skipna: bool | None, keep_attrs: bool | None):
1371
1368
return self ._obj
1372
1369
if keep_attrs is None :
1373
1370
keep_attrs = _get_keep_attrs (default = True )
1374
-
1375
- def _groupby_first_last_wrapper (
1376
- values ,
1377
- by ,
1378
- * ,
1379
- op : Literal ["first" , "last" ],
1380
- skipna : bool | None ,
1381
- group_indices ,
1371
+ if (
1372
+ skipna
1373
+ and module_available ("flox" , minversion = "0.9.16" )
1374
+ and OPTIONS ["use_flox" ]
1375
+ and contains_only_chunked_or_numpy (self ._obj )
1382
1376
):
1383
- no_nans = dtypes .isdtype (
1384
- values .dtype , "signed integer"
1385
- ) or dtypes .is_string (values .dtype )
1386
- if (skipna or skipna is None ) and not no_nans :
1387
- skipna = True
1388
- else :
1389
- skipna = False
1390
-
1391
- if TYPE_CHECKING :
1392
- assert isinstance (skipna , bool )
1393
-
1394
- if skipna is False or (skipna and no_nans ):
1395
- # this is an optimization: when skipna=False, we can simply index
1396
- # the whole object after picking the first/last member of each group
1397
- # in self.encoded.group_indices
1398
- if op == "first" :
1399
- indices = [
1400
- (idx .start if isinstance (idx , slice ) else idx [0 ])
1401
- for idx in group_indices
1402
- if idx
1403
- ]
1404
- else :
1405
- indices = [
1406
- (idx .stop - 1 if isinstance (idx , slice ) else idx [- 1 ])
1407
- for idx in self .encoded .group_indices
1408
- if idx
1409
- ]
1410
- return self ._obj .isel ({self ._group_dim : indices })
1411
-
1412
- elif (
1413
- skipna
1414
- and module_available ("flox" , minversion = "0.9.14" )
1415
- and OPTIONS ["use_flox" ]
1416
- and contains_only_chunked_or_numpy (self ._obj )
1417
- ):
1418
- import flox
1419
-
1420
- result , * _ = flox .groupby_reduce (
1421
- values , self .group1d .data , axis = - 1 , func = f"nan{ op } "
1422
- )
1423
- return result
1424
-
1425
- else :
1426
- return self .reduce (
1427
- getattr (duck_array_ops , op ),
1428
- dim = [self ._group_dim ],
1429
- skipna = skipna ,
1430
- keep_attrs = keep_attrs ,
1431
- )
1432
-
1433
- result = apply_ufunc (
1434
- _groupby_first_last_wrapper ,
1435
- self ._obj ,
1436
- self .group1d ,
1437
- input_core_dims = [[self ._group_dim ], [self ._group_dim ]],
1438
- output_core_dims = [[self .group1d .name ]],
1439
- dask = "allowed" ,
1440
- output_sizes = {self .group1d .name : len (self )},
1441
- exclude_dims = {self ._group_dim },
1442
- keep_attrs = keep_attrs ,
1443
- kwargs = {
1444
- "op" : op ,
1445
- "skipna" : skipna ,
1446
- "group_indices" : self .encoded .group_indices ,
1447
- },
1448
- )
1449
- result = result .assign_coords (self .encoded .coords )
1450
- result = self ._maybe_unstack (result )
1451
- result = self ._maybe_restore_empty_groups (result )
1452
- if isinstance (result , DataArray ):
1453
- result = self ._restore_dim_order (result )
1377
+ result , * _ = self ._flox_reduce (
1378
+ dim = None , func = f"nan{ op } " if skipna else op , keep_attrs = keep_attrs
1379
+ )
1380
+ else :
1381
+ result = self .reduce (
1382
+ getattr (duck_array_ops , op ),
1383
+ dim = [self ._group_dim ],
1384
+ skipna = skipna ,
1385
+ keep_attrs = keep_attrs ,
1386
+ )
1454
1387
return result
1455
1388
1456
1389
def first (self , skipna : bool | None = None , keep_attrs : bool | None = None ):
0 commit comments