Skip to content

Commit 5e2fdfb

Browse files
committed
Add tests
1 parent 7a99c8f commit 5e2fdfb

File tree

3 files changed

+103
-15
lines changed

3 files changed

+103
-15
lines changed

xarray/core/groupby.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ def sizes(self) -> Mapping[Hashable, int]:
566566
self._sizes = self._obj.isel({self._group_dim: index}).sizes
567567
return self._sizes
568568

569-
def shuffle(self, chunks: T_Chunks = None) -> DataArrayGroupBy | DatasetGroupBy:
569+
def shuffle(self, chunks: T_Chunks = None):
570570
"""
571571
Sort or "shuffle" the underlying object.
572572
@@ -610,7 +610,10 @@ def shuffle(self, chunks: T_Chunks = None) -> DataArrayGroupBy | DatasetGroupBy:
610610
"""
611611
(grouper,) = self.groupers
612612
return self._shuffle_obj(chunks).groupby(
613-
{grouper.name: grouper.grouper.reset()},
613+
# Using group.name handles the BinGrouper case
614+
# It does *not* handle the TimeResampler case,
615+
# so we just override this method in Resample
616+
{grouper.group.name: grouper.grouper.reset()},
614617
restore_coord_dims=self._restore_coord_dims,
615618
)
616619

@@ -624,11 +627,11 @@ def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray:
624627
as_dataset = self._obj._to_temp_dataset() if was_array else self._obj
625628
no_slices: list[list[int]] = [
626629
list(range(*idx.indices(size))) if isinstance(idx, slice) else idx
627-
for idx in self._group_indices
630+
for idx in self.encoded.group_indices
628631
]
629632

630633
if grouper.name not in as_dataset._variables:
631-
as_dataset.coords[grouper.name] = grouper.group1d
634+
as_dataset.coords[grouper.name] = grouper.group
632635

633636
# Shuffling is only different from `isel` for chunked arrays.
634637
# Extract them out, and treat them specially. The rest, we route through isel.
@@ -644,7 +647,7 @@ def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray:
644647
shuffled = subset.isel({dim: np.concatenate(no_slices)})
645648
for name, var in is_chunked.items():
646649
shuffled[name] = var._shuffle(
647-
indices=list(self._group_indices), dim=dim, chunks=chunks
650+
indices=list(self.encoded.group_indices), dim=dim, chunks=chunks
648651
)
649652
shuffled = self._maybe_unstack(shuffled)
650653
new_obj = self._obj._from_temp_dataset(shuffled) if was_array else shuffled

xarray/core/resample.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import warnings
44
from collections.abc import Callable, Hashable, Iterable, Sequence
5-
from typing import TYPE_CHECKING, Any
5+
from typing import TYPE_CHECKING, Any, cast
66

77
from xarray.core._aggregations import (
88
DataArrayResampleAggregations,
@@ -14,6 +14,8 @@
1414
if TYPE_CHECKING:
1515
from xarray.core.dataarray import DataArray
1616
from xarray.core.dataset import Dataset
17+
from xarray.core.types import T_Chunks
18+
from xarray.groupers import Resampler
1719

1820
from xarray.groupers import RESAMPLE_DIM
1921

@@ -58,6 +60,60 @@ def _flox_reduce(
5860
result = result.rename({RESAMPLE_DIM: self._group_dim})
5961
return result
6062

63+
def shuffle(self, chunks: T_Chunks = None):
64+
"""
65+
Sort or "shuffle" the underlying object.
66+
67+
"Shuffle" means the object is sorted so that all group members occur sequentially,
68+
in the same chunk. Multiple groups may occur in the same chunk.
69+
This method is particularly useful for chunked arrays (e.g. dask, cubed).
70+
particularly when you need to map a function that requires all members of a group
71+
to be present in a single chunk. For chunked array types, the order of appearance
72+
is not guaranteed, but will depend on the input chunking.
73+
74+
.. warning::
75+
76+
With resampling it is a lot better to use ``.chunk`` instead of ``.shuffle``,
77+
since one can only resample a sorted time coordinate.
78+
79+
Parameters
80+
----------
81+
chunks : int, tuple of int, "auto" or mapping of hashable to int or tuple of int, optional
82+
How to adjust chunks along dimensions not present in the array being grouped by.
83+
84+
Returns
85+
-------
86+
DataArrayGroupBy or DatasetGroupBy
87+
88+
Examples
89+
--------
90+
>>> import dask
91+
>>> da = xr.DataArray(
92+
... dims="x",
93+
... data=dask.array.arange(10, chunks=3),
94+
... coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]},
95+
... name="a",
96+
... )
97+
>>> shuffled = da.groupby("x").shuffle()
98+
>>> shuffled.quantile(q=0.5).compute()
99+
<xarray.DataArray 'a' (x: 4)> Size: 32B
100+
array([9., 3., 4., 5.])
101+
Coordinates:
102+
quantile float64 8B 0.5
103+
* x (x) int64 32B 0 1 2 3
104+
105+
See Also
106+
--------
107+
dask.dataframe.DataFrame.shuffle
108+
dask.array.shuffle
109+
"""
110+
(grouper,) = self.groupers
111+
shuffled = self._shuffle_obj(chunks).drop_vars(RESAMPLE_DIM)
112+
return shuffled.resample(
113+
{self._group_dim: cast("Resampler", grouper.grouper.reset())},
114+
restore_coord_dims=self._restore_coord_dims,
115+
)
116+
61117
def _drop_coords(self) -> T_Xarray:
62118
"""Drop non-dimension coordinates along the resampled dimension."""
63119
obj = self._obj

xarray/tests/test_groupby.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1659,13 +1659,14 @@ def test_groupby_bins(
16591659
)
16601660

16611661
with xr.set_options(use_flox=use_flox):
1662-
actual = array.groupby_bins("dim_0", bins=bins, **cut_kwargs).sum()
1662+
gb = array.groupby_bins("dim_0", bins=bins, **cut_kwargs)
1663+
actual = gb.sum()
16631664
assert_identical(expected, actual)
1665+
assert_identical(expected, gb.shuffle().sum())
16641666

1665-
actual = array.groupby_bins("dim_0", bins=bins, **cut_kwargs).map(
1666-
lambda x: x.sum()
1667-
)
1667+
actual = gb.map(lambda x: x.sum())
16681668
assert_identical(expected, actual)
1669+
assert_identical(expected, gb.shuffle().map(lambda x: x.sum()))
16691670

16701671
# make sure original array dims are unchanged
16711672
assert len(array.dim_0) == 4
@@ -1810,8 +1811,9 @@ def test_groupby_fastpath_for_monotonic(self, use_flox: bool) -> None:
18101811

18111812

18121813
class TestDataArrayResample:
1814+
@pytest.mark.parametrize("shuffle", [True, False])
18131815
@pytest.mark.parametrize("use_cftime", [True, False])
1814-
def test_resample(self, use_cftime: bool) -> None:
1816+
def test_resample(self, use_cftime: bool, shuffle: bool) -> None:
18151817
if use_cftime and not has_cftime:
18161818
pytest.skip()
18171819
times = xr.date_range(
@@ -1833,16 +1835,22 @@ def resample_as_pandas(array, *args, **kwargs):
18331835

18341836
array = DataArray(np.arange(10), [("time", times)])
18351837

1836-
actual = array.resample(time="24h").mean()
1838+
rs = array.resample(time="24h")
1839+
1840+
actual = rs.mean()
18371841
expected = resample_as_pandas(array, "24h")
18381842
assert_identical(expected, actual)
1843+
assert_identical(expected, rs.shuffle().mean())
18391844

1840-
actual = array.resample(time="24h").reduce(np.mean)
1841-
assert_identical(expected, actual)
1845+
assert_identical(expected, rs.reduce(np.mean))
1846+
assert_identical(expected, rs.shuffle().reduce(np.mean))
18421847

1843-
actual = array.resample(time="24h", closed="right").mean()
1848+
rs = array.resample(time="24h", closed="right")
1849+
actual = rs.mean()
1850+
shuffled = rs.shuffle().mean()
18441851
expected = resample_as_pandas(array, "24h", closed="right")
18451852
assert_identical(expected, actual)
1853+
assert_identical(expected, shuffled)
18461854

18471855
with pytest.raises(ValueError, match=r"Index must be monotonic"):
18481856
array[[2, 0, 1]].resample(time="1D")
@@ -2795,6 +2803,27 @@ def test_multiple_groupers_mixed(use_flox) -> None:
27952803
# ------
27962804

27972805

2806+
@requires_dask
2807+
def test_groupby_shuffle():
2808+
import dask
2809+
2810+
da = DataArray(
2811+
dask.array.from_array(np.array([1, 2, 3, 0, 2, np.nan]), chunks=2),
2812+
dims="d",
2813+
coords=dict(
2814+
labels1=("d", np.array(["a", "b", "c", "c", "b", "a"])),
2815+
labels2=("d", np.array(["x", "y", "z", "z", "y", "x"])),
2816+
),
2817+
name="foo",
2818+
)
2819+
2820+
gb = da.groupby("labels1")
2821+
shuffled = gb.shuffle()
2822+
shuffled_obj = shuffled._obj
2823+
with xr.set_options(use_flox=False):
2824+
xr.testing.assert_identical(gb.mean(), shuffled.mean())
2825+
2826+
27982827
# Possible property tests
27992828
# 1. lambda x: x
28002829
# 2. grouped-reduce on unique coords is identical to array

0 commit comments

Comments
 (0)