Skip to content

Commit 69d62cb

Browse files
committed
Review comments
1 parent 3bbcab2 commit 69d62cb

File tree

1 file changed

+29
-20
lines changed

1 file changed

+29
-20
lines changed

xarray/tests/test_groupby.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,14 @@ def test_groupby_dataset_errors() -> None:
873873
data.groupby(data.coords["dim1"].to_index())
874874

875875

876-
def test_groupby_dataset_reduce() -> None:
876+
@pytest.mark.parametrize(
877+
"by_func",
878+
[
879+
pytest.param(lambda x: x, id="group-by-string"),
880+
pytest.param(lambda x: {x: UniqueGrouper()}, id="group-by-unique-grouper"),
881+
],
882+
)
883+
def test_groupby_dataset_reduce_ellipsis(by_func) -> None:
877884
data = Dataset(
878885
{
879886
"xy": (["x", "y"], np.random.randn(3, 4)),
@@ -885,12 +892,12 @@ def test_groupby_dataset_reduce() -> None:
885892

886893
expected = data.mean("y")
887894
expected["yonly"] = expected["yonly"].variable.set_dims({"x": 3})
888-
for gb in [data.groupby("x"), data.groupby(x=UniqueGrouper())]:
889-
actual = gb.mean(...)
890-
assert_allclose(expected, actual)
895+
gb = data.groupby(by_func("x"))
896+
actual = gb.mean(...)
897+
assert_allclose(expected, actual)
891898

892-
actual = gb.mean("y")
893-
assert_allclose(expected, actual)
899+
actual = gb.mean("y")
900+
assert_allclose(expected, actual)
894901

895902
letters = data["letters"]
896903
expected = Dataset(
@@ -900,9 +907,9 @@ def test_groupby_dataset_reduce() -> None:
900907
"yonly": data["yonly"].groupby(letters).mean(),
901908
}
902909
)
903-
for gb in [data.groupby("letters"), data.groupby(letters=UniqueGrouper())]:
904-
actual = gb.mean(...)
905-
assert_allclose(expected, actual)
910+
gb = data.groupby(by_func("letters"))
911+
actual = gb.mean(...)
912+
assert_allclose(expected, actual)
906913

907914

908915
@pytest.mark.parametrize("squeeze", [True, False])
@@ -1040,23 +1047,25 @@ def test_groupby_bins_cut_kwargs(use_flox: bool) -> None:
10401047

10411048

10421049
@pytest.mark.parametrize("indexed_coord", [True, False])
1043-
def test_groupby_bins_math(indexed_coord) -> None:
1050+
@pytest.mark.parametrize(
1051+
["groupby_method", "args"],
1052+
(
1053+
("groupby_bins", ("x", np.arange(0, 8, 3))),
1054+
("groupby", ({"x": BinGrouper(bins=np.arange(0, 8, 3))},)),
1055+
),
1056+
)
1057+
def test_groupby_bins_math(groupby_method, args, indexed_coord) -> None:
10441058
N = 7
10451059
da = DataArray(np.random.random((N, N)), dims=("x", "y"))
10461060
if indexed_coord:
10471061
da["x"] = np.arange(N)
10481062
da["y"] = np.arange(N)
10491063

1050-
for g in [
1051-
da.groupby_bins("x", np.arange(0, N + 1, 3)),
1052-
da.groupby(x=BinGrouper(bins=np.arange(0, N + 1, 3))),
1053-
]:
1054-
mean = g.mean()
1055-
expected = da.isel(x=slice(1, None)) - mean.isel(
1056-
x_bins=("x", [0, 0, 0, 1, 1, 1])
1057-
)
1058-
actual = g - mean
1059-
assert_identical(expected, actual)
1064+
g = getattr(da, groupby_method)(*args)
1065+
mean = g.mean()
1066+
expected = da.isel(x=slice(1, None)) - mean.isel(x_bins=("x", [0, 0, 0, 1, 1, 1]))
1067+
actual = g - mean
1068+
assert_identical(expected, actual)
10601069

10611070

10621071
def test_groupby_math_nD_group() -> None:

0 commit comments

Comments
 (0)