Skip to content

Commit a5d296e

Browse files
authored
Fix reduction by subset of grouper dimensions (pydata#10258)
1 parent 729c4fa commit a5d296e

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

xarray/core/groupby.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ def check_reduce_dims(reduce_dims, dimensions):
7878
if any(dim not in dimensions for dim in reduce_dims):
7979
raise ValueError(
8080
f"cannot reduce over dimensions {reduce_dims!r}. expected either '...' "
81-
f"to reduce over all dimensions or one or more of {dimensions!r}."
81+
f"to reduce over all dimensions or one or more of {dimensions!r}. "
82+
f"Alternatively, install the `flox` package. "
8283
)
8384

8485

@@ -1135,7 +1136,7 @@ def _flox_reduce(
11351136
group_dims = set(grouper.group.dims)
11361137
new_coords = []
11371138
to_drop = []
1138-
if group_dims.issubset(set(parsed_dim)):
1139+
if group_dims & set(parsed_dim):
11391140
for grouper in self.groupers:
11401141
output_index = grouper.full_index
11411142
if isinstance(output_index, pd.RangeIndex):

xarray/tests/test_groupby.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1646,6 +1646,19 @@ def test_groupby_multidim(self) -> None:
16461646
actual_sum = array.groupby(dim).sum(...)
16471647
assert_identical(expected_sum, actual_sum)
16481648

1649+
if has_flox:
1650+
# GH9803
1651+
# reduce over one dim of a nD grouper
1652+
array.coords["labels"] = (("ny", "nx"), np.array([["a", "b"], ["b", "a"]]))
1653+
actual = array.groupby("labels").sum("nx")
1654+
expected_np = np.array([[[0, 1], [3, 2]], [[5, 10], [20, 15]]])
1655+
expected = xr.DataArray(
1656+
expected_np,
1657+
dims=("time", "ny", "labels"),
1658+
coords={"labels": ["a", "b"]},
1659+
)
1660+
assert_identical(expected, actual)
1661+
16491662
def test_groupby_multidim_map(self) -> None:
16501663
array = self.make_groupby_multidim_example_array()
16511664
actual = array.groupby("lon").map(lambda x: x - x.mean())

0 commit comments

Comments
 (0)