20
20
from xarray .core .alignment import align , broadcast
21
21
from xarray .core .arithmetic import DataArrayGroupbyArithmetic , DatasetGroupbyArithmetic
22
22
from xarray .core .common import ImplementsArrayReduce , ImplementsDatasetReduce
23
+ from xarray .core .computation import apply_ufunc
23
24
from xarray .core .concat import concat
24
25
from xarray .core .coordinates import Coordinates , _coordinates_from_variable
25
26
from xarray .core .duck_array_ops import where
@@ -1357,7 +1358,9 @@ def where(self, cond, other=dtypes.NA) -> T_Xarray:
1357
1358
"""
1358
1359
return ops .where_method (self , cond , other )
1359
1360
1360
- def _first_or_last (self , op , skipna , keep_attrs ):
1361
+ def _first_or_last (self , op : str , skipna : bool | None , keep_attrs : bool | None ):
1362
+ from xarray .core .dataarray import DataArray
1363
+
1361
1364
if all (
1362
1365
isinstance (maybe_slice , slice )
1363
1366
and (maybe_slice .stop == maybe_slice .start + 1 )
@@ -1368,17 +1371,95 @@ def _first_or_last(self, op, skipna, keep_attrs):
1368
1371
return self ._obj
1369
1372
if keep_attrs is None :
1370
1373
keep_attrs = _get_keep_attrs (default = True )
1371
- return self .reduce (
1372
- op , dim = [self ._group_dim ], skipna = skipna , keep_attrs = keep_attrs
1374
+
1375
+ def _groupby_first_last_wrapper (
1376
+ values ,
1377
+ by ,
1378
+ * ,
1379
+ op : Literal ["first" , "last" ],
1380
+ skipna : bool | None ,
1381
+ group_indices ,
1382
+ ):
1383
+ no_nans = dtypes .isdtype (
1384
+ values .dtype , "signed integer"
1385
+ ) or dtypes .is_string (values .dtype )
1386
+ if (skipna or skipna is None ) and not no_nans :
1387
+ skipna = True
1388
+ else :
1389
+ skipna = False
1390
+
1391
+ if TYPE_CHECKING :
1392
+ assert isinstance (skipna , bool )
1393
+
1394
+ if skipna is False or (skipna and no_nans ):
1395
+ # this is an optimization: when skipna=False, we can simply index
1396
+ # the whole object after picking the first/last member of each group
1397
+ # in self.encoded.group_indices
1398
+ if op == "first" :
1399
+ indices = [
1400
+ (idx .start if isinstance (idx , slice ) else idx [0 ])
1401
+ for idx in group_indices
1402
+ if idx
1403
+ ]
1404
+ else :
1405
+ indices = [
1406
+ (idx .stop - 1 if isinstance (idx , slice ) else idx [- 1 ])
1407
+ for idx in self .encoded .group_indices
1408
+ if idx
1409
+ ]
1410
+ return self ._obj .isel ({self ._group_dim : indices })
1411
+
1412
+ elif (
1413
+ skipna
1414
+ and module_available ("flox" , minversion = "0.9.14" )
1415
+ and OPTIONS ["use_flox" ]
1416
+ and contains_only_chunked_or_numpy (self ._obj )
1417
+ ):
1418
+ import flox
1419
+
1420
+ result , * _ = flox .groupby_reduce (
1421
+ values , self .group1d .data , axis = - 1 , func = f"nan{ op } "
1422
+ )
1423
+ return result
1424
+
1425
+ else :
1426
+ return self .reduce (
1427
+ getattr (duck_array_ops , op ),
1428
+ dim = [self ._group_dim ],
1429
+ skipna = skipna ,
1430
+ keep_attrs = keep_attrs ,
1431
+ )
1432
+
1433
+ result = apply_ufunc (
1434
+ _groupby_first_last_wrapper ,
1435
+ self ._obj ,
1436
+ self .group1d ,
1437
+ input_core_dims = [[self ._group_dim ], [self ._group_dim ]],
1438
+ output_core_dims = [[self .group1d .name ]],
1439
+ dask = "allowed" ,
1440
+ output_sizes = {self .group1d .name : len (self )},
1441
+ exclude_dims = {self ._group_dim },
1442
+ keep_attrs = keep_attrs ,
1443
+ kwargs = {
1444
+ "op" : op ,
1445
+ "skipna" : skipna ,
1446
+ "group_indices" : self .encoded .group_indices ,
1447
+ },
1373
1448
)
1449
+ result = result .assign_coords (self .encoded .coords )
1450
+ result = self ._maybe_unstack (result )
1451
+ result = self ._maybe_restore_empty_groups (result )
1452
+ if isinstance (result , DataArray ):
1453
+ result = self ._restore_dim_order (result )
1454
+ return result
1374
1455
1375
1456
def first (self , skipna : bool | None = None , keep_attrs : bool | None = None ):
1376
1457
"""Return the first element of each group along the group dimension"""
1377
- return self ._first_or_last (duck_array_ops . first , skipna , keep_attrs )
1458
+ return self ._first_or_last (" first" , skipna , keep_attrs )
1378
1459
1379
1460
def last (self , skipna : bool | None = None , keep_attrs : bool | None = None ):
1380
1461
"""Return the last element of each group along the group dimension"""
1381
- return self ._first_or_last (duck_array_ops . last , skipna , keep_attrs )
1462
+ return self ._first_or_last (" last" , skipna , keep_attrs )
1382
1463
1383
1464
def assign_coords (self , coords = None , ** coords_kwargs ):
1384
1465
"""Assign coordinates by group.
0 commit comments