Skip to content

Commit a22c7ed

Browse files
committed
Support shuffling with multiple groupers
1 parent 5e2fdfb commit a22c7ed

File tree

2 files changed

+31
-31
lines changed

2 files changed

+31
-31
lines changed

xarray/core/groupby.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -608,19 +608,21 @@ def shuffle(self, chunks: T_Chunks = None):
608608
dask.dataframe.DataFrame.shuffle
609609
dask.array.shuffle
610610
"""
611-
(grouper,) = self.groupers
612-
return self._shuffle_obj(chunks).groupby(
611+
new_groupers = {
613612
# Using group.name handles the BinGrouper case
614613
# It does *not* handle the TimeResampler case,
615614
# so we just override this method in Resample
616-
{grouper.group.name: grouper.grouper.reset()},
615+
grouper.group.name: grouper.grouper.reset()
616+
for grouper in self.groupers
617+
}
618+
return self._shuffle_obj(chunks).groupby(
619+
new_groupers,
617620
restore_coord_dims=self._restore_coord_dims,
618621
)
619622

620623
def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray:
621624
from xarray.core.dataarray import DataArray
622625

623-
(grouper,) = self.groupers
624626
dim = self._group_dim
625627
size = self._obj.sizes[dim]
626628
was_array = isinstance(self._obj, DataArray)
@@ -629,9 +631,11 @@ def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray:
629631
list(range(*idx.indices(size))) if isinstance(idx, slice) else idx
630632
for idx in self.encoded.group_indices
631633
]
634+
no_slices = [idx for idx in no_slices if idx]
632635

633-
if grouper.name not in as_dataset._variables:
634-
as_dataset.coords[grouper.name] = grouper.group
636+
for grouper in self.groupers:
637+
if grouper.name not in as_dataset._variables:
638+
as_dataset.coords[grouper.name] = grouper.group
635639

636640
# Shuffling is only different from `isel` for chunked arrays.
637641
# Extract them out, and treat them specially. The rest, we route through isel.
@@ -644,10 +648,13 @@ def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray:
644648
subset = as_dataset[
645649
[name for name in as_dataset._variables if name not in is_chunked]
646650
]
651+
647652
shuffled = subset.isel({dim: np.concatenate(no_slices)})
648653
for name, var in is_chunked.items():
649654
shuffled[name] = var._shuffle(
650-
indices=list(self.encoded.group_indices), dim=dim, chunks=chunks
655+
indices=list(idx for idx in self.encoded.group_indices if idx),
656+
dim=dim,
657+
chunks=chunks,
651658
)
652659
shuffled = self._maybe_unstack(shuffled)
653660
new_obj = self._obj._from_temp_dataset(shuffled) if was_array else shuffled
@@ -861,7 +868,9 @@ def _maybe_unstack(self, obj):
861868
# and `inserted_dims`
862869
# if multiple groupers all share the same single dimension, then
863870
# we don't stack/unstack. Do that manually now.
864-
obj = obj.unstack(*self.encoded.unique_coord.dims)
871+
dims_to_unstack = self.encoded.unique_coord.dims
872+
if all(dim in obj.dims for dim in dims_to_unstack):
873+
obj = obj.unstack(*dims_to_unstack)
865874
to_drop = [
866875
grouper.name
867876
for grouper in self.groupers

xarray/tests/test_groupby.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2684,8 +2684,9 @@ def test_weather_data_resample(use_flox):
26842684
assert expected.location.attrs == ds.location.attrs
26852685

26862686

2687+
@pytest.mark.parametrize("shuffle", [True, False])
26872688
@pytest.mark.parametrize("use_flox", [True, False])
2688-
def test_multiple_groupers(use_flox) -> None:
2689+
def test_multiple_groupers(use_flox: bool, shuffle: bool) -> None:
26892690
da = DataArray(
26902691
np.array([1, 2, 3, 0, 2, np.nan]),
26912692
dims="d",
@@ -2697,6 +2698,8 @@ def test_multiple_groupers(use_flox) -> None:
26972698
)
26982699

26992700
gb = da.groupby(labels1=UniqueGrouper(), labels2=UniqueGrouper())
2701+
if shuffle:
2702+
gb = gb.shuffle()
27002703
repr(gb)
27012704

27022705
expected = DataArray(
@@ -2716,6 +2719,8 @@ def test_multiple_groupers(use_flox) -> None:
27162719
coords = {"a": ("x", [0, 0, 1, 1]), "b": ("y", [0, 0, 1, 1])}
27172720
square = DataArray(np.arange(16).reshape(4, 4), coords=coords, dims=["x", "y"])
27182721
gb = square.groupby(a=UniqueGrouper(), b=UniqueGrouper())
2722+
if shuffle:
2723+
gb = gb.shuffle()
27192724
repr(gb)
27202725
with xr.set_options(use_flox=use_flox):
27212726
actual = gb.mean()
@@ -2739,11 +2744,15 @@ def test_multiple_groupers(use_flox) -> None:
27392744
dims=["x", "y", "z"],
27402745
)
27412746
gb = b.groupby(x=UniqueGrouper(), y=UniqueGrouper())
2747+
if shuffle:
2748+
gb = gb.shuffle()
27422749
repr(gb)
27432750
with xr.set_options(use_flox=use_flox):
27442751
assert_identical(gb.mean("z"), b.mean("z"))
27452752

27462753
gb = b.groupby(x=UniqueGrouper(), xy=UniqueGrouper())
2754+
if shuffle:
2755+
gb = gb.shuffle()
27472756
repr(gb)
27482757
with xr.set_options(use_flox=use_flox):
27492758
actual = gb.mean()
@@ -2758,13 +2767,16 @@ def test_multiple_groupers(use_flox) -> None:
27582767

27592768

27602769
@pytest.mark.parametrize("use_flox", [True, False])
2761-
def test_multiple_groupers_mixed(use_flox) -> None:
2770+
@pytest.mark.parametrize("shuffle", [True, False])
2771+
def test_multiple_groupers_mixed(use_flox: bool, shuffle: bool) -> None:
27622772
# This groupby has missing groups
27632773
ds = xr.Dataset(
27642774
{"foo": (("x", "y"), np.arange(12).reshape((4, 3)))},
27652775
coords={"x": [10, 20, 30, 40], "letters": ("x", list("abba"))},
27662776
)
27672777
gb = ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper())
2778+
if shuffle:
2779+
gb = gb.shuffle()
27682780
expected_data = np.array(
27692781
[
27702782
[[0.0, np.nan], [np.nan, 3.0]],
@@ -2803,27 +2815,6 @@ def test_multiple_groupers_mixed(use_flox) -> None:
28032815
# ------
28042816

28052817

2806-
@requires_dask
2807-
def test_groupby_shuffle():
2808-
import dask
2809-
2810-
da = DataArray(
2811-
dask.array.from_array(np.array([1, 2, 3, 0, 2, np.nan]), chunks=2),
2812-
dims="d",
2813-
coords=dict(
2814-
labels1=("d", np.array(["a", "b", "c", "c", "b", "a"])),
2815-
labels2=("d", np.array(["x", "y", "z", "z", "y", "x"])),
2816-
),
2817-
name="foo",
2818-
)
2819-
2820-
gb = da.groupby("labels1")
2821-
shuffled = gb.shuffle()
2822-
shuffled_obj = shuffled._obj
2823-
with xr.set_options(use_flox=False):
2824-
xr.testing.assert_identical(gb.mean(), shuffled.mean())
2825-
2826-
28272818
# Possible property tests
28282819
# 1. lambda x: x
28292820
# 2. grouped-reduce on unique coords is identical to array

0 commit comments

Comments
 (0)