diff --git a/xarray/groupers.py b/xarray/groupers.py index 3a27d725116..337938b830b 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -52,6 +52,8 @@ "EncodedGroups", "Grouper", "Resampler", + "SeasonGrouper", + "SeasonResampler", "TimeResampler", "UniqueGrouper", ] @@ -235,12 +237,14 @@ def _factorize_given_labels(self, group: T_Group) -> EncodedGroups: output_dtypes=[np.int64], keep_attrs=True, ) + # Convert labels to a sequence that pandas Index can handle + labels_array = np.asarray(self.labels) return EncodedGroups( codes=codes, - full_index=pd.Index(self.labels), # type: ignore[arg-type] + full_index=pd.Index(labels_array), unique_coord=Variable( dims=codes.name, - data=self.labels, + data=labels_array, attrs=self.group.attrs, ), ) @@ -607,6 +611,47 @@ def unique_value_groups( return values, inverse +def _adjust_years_for_season( + years: np.ndarray, + months: np.ndarray, + season_tuple: tuple[int, ...], + season_str: str, +) -> np.ndarray: + """ + Adjust years for seasons that span December and January (e.g., DJF). + + For seasons like DJF, January and February should be considered part of the + winter that started in the previous December. + + Parameters + ---------- + years : np.ndarray + Array of years corresponding to each timestamp + months : np.ndarray + Array of months corresponding to each timestamp + season_tuple : tuple of int + Tuple of month numbers that make up the season + season_str : str + String representation of the season (e.g., "DJF") + + Returns + ------- + np.ndarray + Adjusted years array + """ + year_adjusted = years.copy() + # Handle seasons that contain December followed by other months + if "D" in season_str and 12 in season_tuple: + # Find the position of "D" (December) in the season string + d_index = season_str.index("D") + # Get all months that come after December in the season + months_after_dec = season_tuple[d_index + 1 :] + # Reduce year by 1 for months that come after December + for month_num in months_after_dec: + year_adjusted[months == month_num] -= 1 + return year_adjusted + + def season_to_month_tuple(seasons: Sequence[str]) -> tuple[tuple[int, ...], ...]: """ >>> season_to_month_tuple(["DJF", "MAM", "JJA", "SON"]) @@ -737,25 +782,30 @@ class SeasonGrouper(Grouper): seasons: sequence of str List of strings representing seasons. E.g. ``"JF"`` or ``"JJA"`` etc. Overlapping seasons are allowed (e.g. ``["DJFM", "MAMJ", "JJAS", "SOND"]``) + drop_incomplete: bool, default: False + Whether to drop seasons that are not completely included in the data. + For example, if a time series starts in Jan-2001, and seasons includes `"DJF"` + then observations from Jan-2001, and Feb-2001 are ignored in the grouping + since Dec-2000 isn't present. This check is performed for each year. Examples -------- >>> SeasonGrouper(["JF", "MAM", "JJAS", "OND"]) - SeasonGrouper(seasons=['JF', 'MAM', 'JJAS', 'OND']) + SeasonGrouper(seasons=['JF', 'MAM', 'JJAS', 'OND'], drop_incomplete=False) The ordering is preserved >>> SeasonGrouper(["MAM", "JJAS", "OND", "JF"]) - SeasonGrouper(seasons=['MAM', 'JJAS', 'OND', 'JF']) + SeasonGrouper(seasons=['MAM', 'JJAS', 'OND', 'JF'], drop_incomplete=False) Overlapping seasons are allowed >>> SeasonGrouper(["DJFM", "MAMJ", "JJAS", "SOND"]) - SeasonGrouper(seasons=['DJFM', 'MAMJ', 'JJAS', 'SOND']) + SeasonGrouper(seasons=['DJFM', 'MAMJ', 'JJAS', 'SOND'], drop_incomplete=False) """ seasons: Sequence[str] - # drop_incomplete: bool = field(default=True) # TODO + drop_incomplete: bool = field(default=False, kw_only=True) def factorize(self, group: T_Group) -> EncodedGroups: if TYPE_CHECKING: @@ -767,15 +817,43 @@ def factorize(self, group: T_Group) -> EncodedGroups: months = group.dt.month.data seasons_groups = find_independent_seasons(self.seasons) codes_ = np.full((len(seasons_groups),) + group.shape, -1, dtype=np.int8) - group_indices: list[list[int]] = [[]] * len(self.seasons) + group_indices: list[list[int]] = [[] for _ in range(len(self.seasons))] + + if self.drop_incomplete: + year = group.dt.year.data + for axis_index, seasgroup in enumerate(seasons_groups): for season_tuple, code in zip( seasgroup.inds, seasgroup.codes, strict=False ): mask = np.isin(months, season_tuple) - codes_[axis_index, mask] = code - (indices,) = mask.nonzero() - group_indices[code] = indices.tolist() + if not self.drop_incomplete: + codes_[axis_index, mask] = code + (indices,) = mask.nonzero() + group_indices[code] = indices.tolist() + else: + season_str = self.seasons[code] + year_adjusted = _adjust_years_for_season( + year, months, season_tuple, season_str + ) + + # find unique years for this season + if not np.any(mask): + continue + unique_years = np.unique(year_adjusted[mask]) + + for yr in unique_years: + year_mask = year_adjusted == yr + + # elements for this season in this year + year_season_mask = mask & year_mask + + # check for completeness + present_months = np.unique(months[year_season_mask]) + if len(present_months) == len(season_tuple): + codes_[axis_index, year_season_mask] = code + (indices,) = year_season_mask.nonzero() + group_indices[code].extend(indices.tolist()) if np.all(codes_ == -1): raise ValueError( @@ -788,8 +866,16 @@ def factorize(self, group: T_Group) -> EncodedGroups: attrs=group.attrs, name="season", ) - unique_coord = Variable("season", self.seasons, attrs=group.attrs) - full_index = pd.Index(self.seasons) + + # Always filter coordinates to match actual data present + # This avoids dimension mismatches regardless of drop_incomplete setting + present_codes = np.unique(codes.data.ravel()) + present_codes = present_codes[present_codes >= 0] # Remove -1 (missing data) + present_seasons = [self.seasons[code] for code in present_codes] + + unique_coord = Variable("season", present_seasons, attrs=group.attrs) + full_index = pd.Index(present_seasons) + return EncodedGroups( codes=codes, group_indices=tuple(group_indices), @@ -798,7 +884,7 @@ def factorize(self, group: T_Group) -> EncodedGroups: ) def reset(self) -> Self: - return type(self)(self.seasons) + return type(self)(seasons=self.seasons, drop_incomplete=self.drop_incomplete) @dataclass @@ -868,10 +954,15 @@ def factorize(self, group: T_Group) -> EncodedGroups: # offset years for seasons with December and January for season_str, season_ind in zip(seasons, season_inds, strict=True): season_label[month.isin(season_ind)] = season_str - if "DJ" in season_str: - after_dec = season_ind[season_str.index("D") + 1 :] - # important: this is assuming non-overlapping seasons - year[month.isin(after_dec)] -= 1 + + # Apply year adjustment for cross-year seasons + year_adjusted = year.copy() + for season_str, season_ind in zip(seasons, season_inds, strict=True): + if "D" in season_str and 12 in season_ind: + # Use helper function for year adjustment + year_adjusted[:] = _adjust_years_for_season( + year_adjusted.values, month.values, tuple(season_ind), season_str + ) # Allow users to skip one or more months? # present_seasons is a mask that is True for months that are requested in the output @@ -885,7 +976,7 @@ def factorize(self, group: T_Group) -> EncodedGroups: "month": month[present_seasons], }, index=pd.MultiIndex.from_arrays( - [year.data[present_seasons], season_label[present_seasons]], + [year_adjusted.data[present_seasons], season_label[present_seasons]], names=["year", "season"], ), ) @@ -924,7 +1015,7 @@ def factorize(self, group: T_Group) -> EncodedGroups: [ datetime_class(year=y, month=m, day=1) for y, m in itertools.product( - range(year[0].item(), year[-1].item() + 1), + range(year_adjusted[0].item(), year_adjusted[-1].item() + 1), [s[0] for s in season_inds], ) ] @@ -939,7 +1030,7 @@ def get_label(year, season): unique_codes = np.arange(len(unique_coord)) valid_season_mask = season_label != "" first_valid_season, last_valid_season = season_label[valid_season_mask][[0, -1]] - first_year, last_year = year.data[[0, -1]] + first_year, last_year = year_adjusted.data[[0, -1]] if self.drop_incomplete: if month.data[valid_season_mask][0] != season_tuples[first_valid_season][0]: if "DJ" in first_valid_season: diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 54cc21b5d2c..c3bf6c1408d 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3601,6 +3601,144 @@ def test_season_resampler_groupby_identical(self): gb = da.groupby(time=resampler).sum() assert_identical(rs, gb) + @pytest.mark.parametrize("calendar", ["standard"]) + def test_season_grouper_drop_incomplete_default_false(self, calendar): + """Test that drop_incomplete=False is the default and includes partial seasons.""" + # Create data that starts mid-winter (missing Dec 2000) + time = date_range("2001-01-01", "2001-12-31", freq="MS", calendar=calendar) + data = np.arange(len(time)) + da = DataArray(data, dims="time", coords={"time": time}) + + # Default behavior should include incomplete seasons + result_default = da.groupby( + time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"]) + ).mean() + result_explicit = da.groupby( + time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=False) + ).mean() + + assert_identical(result_default, result_explicit) + + # Should include 4 seasons (including incomplete DJF with just Jan-Feb) + assert len(result_default) == 4 + assert list(result_default.season.values) == ["DJF", "MAM", "JJA", "SON"] + + @pytest.mark.parametrize("calendar", ["standard"]) + def test_season_grouper_drop_incomplete_true(self, calendar): + """Test that drop_incomplete=True excludes partial seasons.""" + # Create data that starts mid-winter (missing Dec 2000) and ends mid-autumn (missing Nov-Dec 2002) + time = date_range("2001-01-01", "2002-10-31", freq="MS", calendar=calendar) + data = np.arange(len(time)) + da = DataArray(data, dims="time", coords={"time": time}) + + # With drop_incomplete=True, should exclude incomplete seasons + result_drop = da.groupby( + time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=True) + ).mean() + + # Should only include complete seasons + # 2001: DJF is incomplete (missing Dec 2000), MAM/JJA/SON are complete + # 2002: DJF/MAM/JJA are complete, SON is incomplete (missing Nov-Dec) + assert len(result_drop) <= 6 # At most 6 complete seasons + + # Compare with default behavior + result_default = da.groupby( + time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=False) + ).mean() + assert len(result_drop) <= len(result_default) + + @pytest.mark.parametrize("calendar", ["standard"]) + def test_season_grouper_drop_incomplete_cross_year_seasons(self, calendar): + """Test drop_incomplete with seasons that span calendar years like DJF.""" + # Create 2 complete years of data + time = date_range("2001-01-01", "2002-12-31", freq="MS", calendar=calendar) + data = np.arange(len(time)) + da = DataArray(data, dims="time", coords={"time": time}) + + # Test with DJF season (spans calendar year) + result_keep = da.groupby( + time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=False) + ).mean() + result_drop = da.groupby( + time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=True) + ).mean() + + # With complete data, both should give same number of seasons + assert len(result_keep) == len(result_drop) + + # Now test with incomplete data - start from Feb (missing Dec 2000 and Jan 2001) + time_incomplete = date_range( + "2001-02-01", + "2001-12-31", + freq="MS", + calendar=calendar, # Stop before next year + ) + data_incomplete = np.arange(len(time_incomplete)) + da_incomplete = DataArray( + data_incomplete, dims="time", coords={"time": time_incomplete} + ) + + result_keep_inc = da_incomplete.groupby( + time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=False) + ).mean() + result_drop_inc = da_incomplete.groupby( + time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=True) + ).mean() + + # drop_incomplete should exclude the incomplete first DJF season + # Data starts in Feb 2001, so 2000-2001 DJF is incomplete (missing Dec 2000, Jan 2001) + assert len(result_drop_inc) < len(result_keep_inc) + + @pytest.mark.parametrize("calendar", ["standard"]) + def test_season_grouper_drop_incomplete_all_incomplete(self, calendar): + """Test that drop_incomplete handles the case where all seasons are incomplete.""" + # Create data with only January (incomplete for any multi-month season) + time = date_range("2001-01-01", "2001-01-31", freq="D", calendar=calendar) + data = np.arange(len(time)) + da = DataArray(data, dims="time", coords={"time": time}) + + # Should raise error when all seasons are incomplete and drop_incomplete=True + with pytest.raises(ValueError, match="Failed to group data"): + da.groupby( + time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=True) + ).mean() + + def test_season_grouper_reset_preserves_drop_incomplete(self): + """Test that the reset method preserves the drop_incomplete setting.""" + grouper1 = SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=True) + grouper2 = grouper1.reset() + + assert grouper2.drop_incomplete == grouper1.drop_incomplete + assert grouper2.seasons == grouper1.seasons + + grouper3 = SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=False) + grouper4 = grouper3.reset() + + assert grouper4.drop_incomplete == grouper3.drop_incomplete + assert grouper4.seasons == grouper3.seasons + + def test_adjust_years_for_season_helper(self): + """Test the helper function _adjust_years_for_season.""" + from xarray.groupers import _adjust_years_for_season + + years = np.array([2001, 2001, 2001, 2002, 2002, 2002]) + months = np.array([12, 1, 2, 12, 1, 2]) + + # Test DJF season (December, January, February) + adjusted = _adjust_years_for_season(years, months, (12, 1, 2), "DJF") + expected = np.array( + [2001, 2000, 2000, 2002, 2001, 2001] + ) # Jan/Feb get previous year + np.testing.assert_array_equal(adjusted, expected) + + # Test MAM season (no cross-year adjustment needed) + adjusted_mam = _adjust_years_for_season(years, months, (3, 4, 5), "MAM") + np.testing.assert_array_equal(adjusted_mam, years) # Should be unchanged + + # Test single month season + adjusted_jan = _adjust_years_for_season(years, months, (1,), "J") + np.testing.assert_array_equal(adjusted_jan, years) # Should be unchanged + # TODO: Possible property tests to add to this module # 1. lambda x: x