Skip to content

Commit 4174aa1

Browse files
authored
Preserve label ordering for multi-variable GroupBy (#10151)
* Preserve label ordering for multi-variable GroupBy * fix mypy
1 parent fd7c765 commit 4174aa1

File tree

3 files changed

+61
-15
lines changed

3 files changed

+61
-15
lines changed

xarray/core/groupby.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,11 @@ def factorize(self) -> EncodedGroups:
534534
list(grouper.full_index.values for grouper in groupers),
535535
names=tuple(grouper.name for grouper in groupers),
536536
)
537+
if not full_index.is_unique:
538+
raise ValueError(
539+
"The output index for the GroupBy is non-unique. "
540+
"This is a bug in the Grouper provided."
541+
)
537542
# This will be unused when grouping by dask arrays, so skip..
538543
if not is_chunked_array(_flatcodes):
539544
# Constructing an index from the product is wrong when there are missing groups
@@ -942,17 +947,29 @@ def _binary_op(self, other, f, reflexive=False):
942947
def _restore_dim_order(self, stacked):
943948
raise NotImplementedError
944949

945-
def _maybe_restore_empty_groups(self, combined):
946-
"""Our index contained empty groups (e.g., from a resampling or binning). If we
950+
def _maybe_reindex(self, combined):
951+
"""Reindexing is needed in two cases:
952+
1. Our index contained empty groups (e.g., from a resampling or binning). If we
947953
reduced on that dimension, we want to restore the full index.
954+
955+
2. We use a MultiIndex for multi-variable GroupBy.
956+
The MultiIndex stores each level's labels in sorted order
957+
which are then assigned on unstacking. So we need to restore
958+
the correct order here.
948959
"""
949960
has_missing_groups = (
950961
self.encoded.unique_coord.size != self.encoded.full_index.size
951962
)
952963
indexers = {}
953964
for grouper in self.groupers:
954-
if has_missing_groups and grouper.name in combined._indexes:
965+
index = combined._indexes.get(grouper.name, None)
966+
if has_missing_groups and index is not None:
955967
indexers[grouper.name] = grouper.full_index
968+
elif len(self.groupers) > 1:
969+
if not isinstance(
970+
grouper.full_index, pd.RangeIndex
971+
) and not index.index.equals(grouper.full_index):
972+
indexers[grouper.name] = grouper.full_index
956973
if indexers:
957974
combined = combined.reindex(**indexers)
958975
return combined
@@ -1595,7 +1612,7 @@ def _combine(self, applied, shortcut=False):
15951612
if dim not in applied_example.dims:
15961613
combined = combined.assign_coords(self.encoded.coords)
15971614
combined = self._maybe_unstack(combined)
1598-
combined = self._maybe_restore_empty_groups(combined)
1615+
combined = self._maybe_reindex(combined)
15991616
return combined
16001617

16011618
def reduce(
@@ -1751,7 +1768,7 @@ def _combine(self, applied):
17511768
if dim not in applied_example.dims:
17521769
combined = combined.assign_coords(self.encoded.coords)
17531770
combined = self._maybe_unstack(combined)
1754-
combined = self._maybe_restore_empty_groups(combined)
1771+
combined = self._maybe_reindex(combined)
17551772
return combined
17561773

17571774
def reduce(

xarray/groupers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]:
521521
counts = grouped.count()
522522
# This way we generate codes for the final output index: full_index.
523523
# So for _flox_reduce we avoid one reindex and copy by avoiding
524-
# _maybe_restore_empty_groups
524+
# _maybe_reindex
525525
codes = np.repeat(np.arange(len(first_items)), counts)
526526
return first_items, codes
527527

xarray/tests/test_groupby.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def test_multi_index_groupby_sum() -> None:
154154

155155

156156
@requires_pandas_ge_2_2
157-
def test_multi_index_propagation():
157+
def test_multi_index_propagation() -> None:
158158
# regression test for GH9648
159159
times = pd.date_range("2023-01-01", periods=4)
160160
locations = ["A", "B"]
@@ -2291,7 +2291,7 @@ def test_resample_origin(self) -> None:
22912291
times = pd.date_range("2000-01-01T02:03:01", freq="6h", periods=10)
22922292
array = DataArray(np.arange(10), [("time", times)])
22932293

2294-
origin = "start"
2294+
origin: Literal["start"] = "start"
22952295
actual = array.resample(time="24h", origin=origin).mean()
22962296
expected = DataArray(array.to_series().resample("24h", origin=origin).mean())
22972297
assert_identical(expected, actual)
@@ -2696,7 +2696,7 @@ def test_default_flox_method() -> None:
26962696

26972697
@requires_cftime
26982698
@pytest.mark.filterwarnings("ignore")
2699-
def test_cftime_resample_gh_9108():
2699+
def test_cftime_resample_gh_9108() -> None:
27002700
import cftime
27012701

27022702
ds = Dataset(
@@ -3046,7 +3046,7 @@ def test_gappy_resample_reductions(reduction):
30463046
assert_identical(expected, actual)
30473047

30483048

3049-
def test_groupby_transpose():
3049+
def test_groupby_transpose() -> None:
30503050
# GH5361
30513051
data = xr.DataArray(
30523052
np.random.randn(4, 2),
@@ -3106,7 +3106,7 @@ def test_lazy_grouping(grouper, expect_index):
31063106

31073107

31083108
@requires_dask
3109-
def test_lazy_grouping_errors():
3109+
def test_lazy_grouping_errors() -> None:
31103110
import dask.array
31113111

31123112
data = DataArray(
@@ -3132,15 +3132,15 @@ def test_lazy_grouping_errors():
31323132

31333133

31343134
@requires_dask
3135-
def test_lazy_int_bins_error():
3135+
def test_lazy_int_bins_error() -> None:
31363136
import dask.array
31373137

31383138
with pytest.raises(ValueError, match="Bin edges must be provided"):
31393139
with raise_if_dask_computes():
31403140
_ = BinGrouper(bins=4).factorize(DataArray(dask.array.arange(3)))
31413141

31423142

3143-
def test_time_grouping_seasons_specified():
3143+
def test_time_grouping_seasons_specified() -> None:
31443144
time = xr.date_range("2001-01-01", "2002-01-01", freq="D")
31453145
ds = xr.Dataset({"foo": np.arange(time.size)}, coords={"time": ("time", time)})
31463146
labels = ["DJF", "MAM", "JJA", "SON"]
@@ -3149,7 +3149,36 @@ def test_time_grouping_seasons_specified():
31493149
assert_identical(actual, expected.reindex(season=labels))
31503150

31513151

3152-
def test_groupby_multiple_bin_grouper_missing_groups():
3152+
def test_multiple_grouper_unsorted_order() -> None:
3153+
time = xr.date_range("2001-01-01", "2003-01-01", freq="MS")
3154+
ds = xr.Dataset({"foo": np.arange(time.size)}, coords={"time": ("time", time)})
3155+
labels = ["DJF", "MAM", "JJA", "SON"]
3156+
actual = ds.groupby(
3157+
{
3158+
"time.season": UniqueGrouper(labels=labels),
3159+
"time.year": UniqueGrouper(labels=[2002, 2001]),
3160+
}
3161+
).sum()
3162+
expected = (
3163+
ds.groupby({"time.season": UniqueGrouper(), "time.year": UniqueGrouper()})
3164+
.sum()
3165+
.reindex(season=labels, year=[2002, 2001])
3166+
)
3167+
assert_identical(actual, expected.reindex(season=labels))
3168+
3169+
b = xr.DataArray(
3170+
np.random.default_rng(0).random((2, 3, 4)),
3171+
coords={"x": [0, 1], "y": [0, 1, 2]},
3172+
dims=["x", "y", "z"],
3173+
)
3174+
actual2 = b.groupby(
3175+
x=UniqueGrouper(labels=[1, 0]), y=UniqueGrouper(labels=[2, 0, 1])
3176+
).sum()
3177+
expected2 = b.reindex(x=[1, 0], y=[2, 0, 1]).transpose("z", ...)
3178+
assert_identical(actual2, expected2)
3179+
3180+
3181+
def test_groupby_multiple_bin_grouper_missing_groups() -> None:
31533182
from numpy import nan
31543183

31553184
ds = xr.Dataset(
@@ -3226,7 +3255,7 @@ def test_shuffle_by(chunks, expected_chunks):
32263255

32273256

32283257
@requires_dask
3229-
def test_groupby_dask_eager_load_warnings():
3258+
def test_groupby_dask_eager_load_warnings() -> None:
32303259
ds = xr.Dataset(
32313260
{"foo": (("z"), np.arange(12))},
32323261
coords={"x": ("z", np.arange(12)), "y": ("z", np.arange(12))},

0 commit comments

Comments
 (0)