From 79222c685551ad8eb96bad9ee516065282af5ffb Mon Sep 17 00:00:00 2001 From: DHRUVA KUMAR KAUSHAL Date: Sat, 21 Jun 2025 15:59:06 +0530 Subject: [PATCH 1/4] drop_incomplete support in SeasonGrouper --- xarray/groupers.py | 52 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 9 deletions(-) diff --git a/xarray/groupers.py b/xarray/groupers.py index 9ed948956a8..3e837dd3c1d 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -741,25 +741,30 @@ class SeasonGrouper(Grouper): seasons: sequence of str List of strings representing seasons. E.g. ``"JF"`` or ``"JJA"`` etc. Overlapping seasons are allowed (e.g. ``["DJFM", "MAMJ", "JJAS", "SOND"]``) + drop_incomplete: bool, default: False + Whether to drop seasons that are not completely included in the data. + For example, if a time series starts in Jan-2001, and seasons includes `"DJF"` + then observations from Jan-2001, and Feb-2001 are ignored in the grouping + since Dec-2000 isn't present. This check is performed for each year. Examples -------- >>> SeasonGrouper(["JF", "MAM", "JJAS", "OND"]) - SeasonGrouper(seasons=['JF', 'MAM', 'JJAS', 'OND']) + SeasonGrouper(seasons=['JF', 'MAM', 'JJAS', 'OND'], drop_incomplete=False) The ordering is preserved >>> SeasonGrouper(["MAM", "JJAS", "OND", "JF"]) - SeasonGrouper(seasons=['MAM', 'JJAS', 'OND', 'JF']) + SeasonGrouper(seasons=['MAM', 'JJAS', 'OND', 'JF'], drop_incomplete=False) Overlapping seasons are allowed >>> SeasonGrouper(["DJFM", "MAMJ", "JJAS", "SOND"]) - SeasonGrouper(seasons=['DJFM', 'MAMJ', 'JJAS', 'SOND']) + SeasonGrouper(seasons=['DJFM', 'MAMJ', 'JJAS', 'SOND'], drop_incomplete=False) """ seasons: Sequence[str] - # drop_incomplete: bool = field(default=True) # TODO + drop_incomplete: bool = field(default=False, kw_only=True) def factorize(self, group: T_Group) -> EncodedGroups: if TYPE_CHECKING: @@ -771,15 +776,44 @@ def factorize(self, group: T_Group) -> EncodedGroups: months = group.dt.month.data seasons_groups = find_independent_seasons(self.seasons) codes_ = np.full((len(seasons_groups),) + group.shape, -1, dtype=np.int8) - group_indices: list[list[int]] = [[]] * len(self.seasons) + group_indices: list[list[int]] = [[] for _ in range(len(self.seasons))] + + if self.drop_incomplete: + year = group.dt.year.data + for axis_index, seasgroup in enumerate(seasons_groups): for season_tuple, code in zip( seasgroup.inds, seasgroup.codes, strict=False ): mask = np.isin(months, season_tuple) - codes_[axis_index, mask] = code - (indices,) = mask.nonzero() - group_indices[code] = indices.tolist() + if not self.drop_incomplete: + codes_[axis_index, mask] = code + (indices,) = mask.nonzero() + group_indices[code] = indices.tolist() + else: + year_adjusted = year.copy() + # handle seasons like DJF + if 12 in season_tuple and 1 in season_tuple: + jan_or_later = [m for m in season_tuple if m < 12] + year_adjusted[np.isin(months, jan_or_later)] -= 1 + + # find unique years for this season + if not np.any(mask): + continue + unique_years = np.unique(year_adjusted[mask]) + + for yr in unique_years: + year_mask = year_adjusted == yr + + # elements for this season in this year + year_season_mask = mask & year_mask + + # check for completeness + present_months = np.unique(months[year_season_mask]) + if len(present_months) == len(season_tuple): + codes_[axis_index, year_season_mask] = code + (indices,) = year_season_mask.nonzero() + group_indices[code].extend(indices.tolist()) if np.all(codes_ == -1): raise ValueError( @@ -802,7 +836,7 @@ def factorize(self, group: T_Group) -> EncodedGroups: ) def reset(self) -> Self: - return type(self)(self.seasons) + return type(self)(seasons=self.seasons, drop_incomplete=self.drop_incomplete) @dataclass From 4c0801325d8b595b084a2a5b0e3042f876811336 Mon Sep 17 00:00:00 2001 From: DHRUVA KUMAR KAUSHAL Date: Fri, 27 Jun 2025 14:53:20 +0530 Subject: [PATCH 2/4] precommit fix --- xarray/groupers.py | 65 +++++++++++++---- xarray/tests/test_groupby.py | 134 +++++++++++++++++++++++++++++++++++ 2 files changed, 186 insertions(+), 13 deletions(-) diff --git a/xarray/groupers.py b/xarray/groupers.py index 3e837dd3c1d..63038ea4569 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -611,6 +611,37 @@ def unique_value_groups( return values, inverse +def _adjust_years_for_season( + years: np.ndarray, months: np.ndarray, season_tuple: tuple[int, ...] +) -> np.ndarray: + """ + Adjust years for seasons that span December and January (e.g., DJF). + + For seasons like DJF, January and February should be considered part of the + winter that started in the previous December. + + Parameters + ---------- + years : np.ndarray + Array of years corresponding to each timestamp + months : np.ndarray + Array of months corresponding to each timestamp + season_tuple : tuple of int + Tuple of month numbers that make up the season + + Returns + ------- + np.ndarray + Adjusted years array + """ + year_adjusted = years.copy() + # Handle seasons like DJF where December is in one year but Jan/Feb are in the next + if 12 in season_tuple and 1 in season_tuple: + jan_or_later = [m for m in season_tuple if m < 12] + year_adjusted[np.isin(months, jan_or_later)] -= 1 + return year_adjusted + + def season_to_month_tuple(seasons: Sequence[str]) -> tuple[tuple[int, ...], ...]: """ >>> season_to_month_tuple(["DJF", "MAM", "JJA", "SON"]) @@ -791,11 +822,7 @@ def factorize(self, group: T_Group) -> EncodedGroups: (indices,) = mask.nonzero() group_indices[code] = indices.tolist() else: - year_adjusted = year.copy() - # handle seasons like DJF - if 12 in season_tuple and 1 in season_tuple: - jan_or_later = [m for m in season_tuple if m < 12] - year_adjusted[np.isin(months, jan_or_later)] -= 1 + year_adjusted = _adjust_years_for_season(year, months, season_tuple) # find unique years for this season if not np.any(mask): @@ -826,8 +853,16 @@ def factorize(self, group: T_Group) -> EncodedGroups: attrs=group.attrs, name="season", ) - unique_coord = Variable("season", self.seasons, attrs=group.attrs) - full_index = pd.Index(self.seasons) + + # Always filter coordinates to match actual data present + # This avoids dimension mismatches regardless of drop_incomplete setting + present_codes = np.unique(codes.data.ravel()) + present_codes = present_codes[present_codes >= 0] # Remove -1 (missing data) + present_seasons = [self.seasons[code] for code in present_codes] + + unique_coord = Variable("season", present_seasons, attrs=group.attrs) + full_index = pd.Index(present_seasons) + return EncodedGroups( codes=codes, group_indices=tuple(group_indices), @@ -906,10 +941,14 @@ def factorize(self, group: T_Group) -> EncodedGroups: # offset years for seasons with December and January for season_str, season_ind in zip(seasons, season_inds, strict=True): season_label[month.isin(season_ind)] = season_str + + # Apply year adjustment for cross-year seasons + year_adjusted = year.copy() + for season_str, season_ind in zip(seasons, season_inds, strict=True): if "DJ" in season_str: - after_dec = season_ind[season_str.index("D") + 1 :] - # important: this is assuming non-overlapping seasons - year[month.isin(after_dec)] -= 1 + year_adjusted = _adjust_years_for_season( + year_adjusted, month.data, season_ind + ) # Allow users to skip one or more months? # present_seasons is a mask that is True for months that are requested in the output @@ -923,7 +962,7 @@ def factorize(self, group: T_Group) -> EncodedGroups: "month": month[present_seasons], }, index=pd.MultiIndex.from_arrays( - [year.data[present_seasons], season_label[present_seasons]], + [year_adjusted.data[present_seasons], season_label[present_seasons]], names=["year", "season"], ), ) @@ -962,7 +1001,7 @@ def factorize(self, group: T_Group) -> EncodedGroups: [ datetime_class(year=y, month=m, day=1) for y, m in itertools.product( - range(year[0].item(), year[-1].item() + 1), + range(year_adjusted[0].item(), year_adjusted[-1].item() + 1), [s[0] for s in season_inds], ) ] @@ -977,7 +1016,7 @@ def get_label(year, season): unique_codes = np.arange(len(unique_coord)) valid_season_mask = season_label != "" first_valid_season, last_valid_season = season_label[valid_season_mask][[0, -1]] - first_year, last_year = year.data[[0, -1]] + first_year, last_year = year_adjusted.data[[0, -1]] if self.drop_incomplete: if month.data[valid_season_mask][0] != season_tuples[first_valid_season][0]: if "DJ" in first_valid_season: diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index a64dfc97bb6..159ac655cb4 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3604,6 +3604,140 @@ def test_season_resampler_groupby_identical(self): gb = da.groupby(time=resampler).sum() assert_identical(rs, gb) + @pytest.mark.parametrize("calendar", ["standard"]) + def test_season_grouper_drop_incomplete_default_false(self, calendar): + """Test that drop_incomplete=False is the default and includes partial seasons.""" + # Create data that starts mid-winter (missing Dec 2000) + time = date_range("2001-01-01", "2001-12-31", freq="MS", calendar=calendar) + data = np.arange(len(time)) + da = DataArray(data, dims="time", coords={"time": time}) + + # Default behavior should include incomplete seasons + result_default = da.groupby( + time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"]) + ).mean() + result_explicit = da.groupby( + time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=False) + ).mean() + + assert_identical(result_default, result_explicit) + + # Should include 4 seasons (including incomplete DJF with just Jan-Feb) + assert len(result_default) == 4 + assert list(result_default.season.values) == ["DJF", "MAM", "JJA", "SON"] + + @pytest.mark.parametrize("calendar", ["standard"]) + def test_season_grouper_drop_incomplete_true(self, calendar): + """Test that drop_incomplete=True excludes partial seasons.""" + # Create data that starts mid-winter (missing Dec 2000) and ends mid-autumn (missing Nov-Dec 2002) + time = date_range("2001-01-01", "2002-10-31", freq="MS", calendar=calendar) + data = np.arange(len(time)) + da = DataArray(data, dims="time", coords={"time": time}) + + # With drop_incomplete=True, should exclude incomplete seasons + result_drop = da.groupby( + time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=True) + ).mean() + + # Should only include complete seasons + # 2001: DJF is incomplete (missing Dec 2000), MAM/JJA/SON are complete + # 2002: DJF/MAM/JJA are complete, SON is incomplete (missing Nov-Dec) + assert len(result_drop) <= 6 # At most 6 complete seasons + + # Compare with default behavior + result_default = da.groupby( + time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=False) + ).mean() + assert len(result_drop) <= len(result_default) + + @pytest.mark.parametrize("calendar", ["standard"]) + def test_season_grouper_drop_incomplete_cross_year_seasons(self, calendar): + """Test drop_incomplete with seasons that span calendar years like DJF.""" + # Create 2 complete years of data + time = date_range("2001-01-01", "2002-12-31", freq="MS", calendar=calendar) + data = np.arange(len(time)) + da = DataArray(data, dims="time", coords={"time": time}) + + # Test with DJF season (spans calendar year) + result_keep = da.groupby( + time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=False) + ).mean() + result_drop = da.groupby( + time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=True) + ).mean() + + # With complete data, both should give same number of seasons + assert len(result_keep) == len(result_drop) + + # Now test with incomplete data - start from Feb (missing Dec-Jan of first DJF) + time_incomplete = date_range( + "2001-02-01", "2002-12-31", freq="MS", calendar=calendar + ) + data_incomplete = np.arange(len(time_incomplete)) + da_incomplete = DataArray( + data_incomplete, dims="time", coords={"time": time_incomplete} + ) + + result_keep_inc = da_incomplete.groupby( + time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=False) + ).mean() + result_drop_inc = da_incomplete.groupby( + time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=True) + ).mean() + + # drop_incomplete should exclude the incomplete first DJF season + assert len(result_drop_inc) < len(result_keep_inc) + + @pytest.mark.parametrize("calendar", ["standard"]) + def test_season_grouper_drop_incomplete_all_incomplete(self, calendar): + """Test that drop_incomplete handles the case where all seasons are incomplete.""" + # Create data with only January (incomplete for any multi-month season) + time = date_range("2001-01-01", "2001-01-31", freq="D", calendar=calendar) + data = np.arange(len(time)) + da = DataArray(data, dims="time", coords={"time": time}) + + # Should raise error when all seasons are incomplete and drop_incomplete=True + with pytest.raises(ValueError, match="Failed to group data"): + da.groupby( + time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=True) + ).mean() + + def test_season_grouper_reset_preserves_drop_incomplete(self): + """Test that the reset method preserves the drop_incomplete setting.""" + grouper1 = SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=True) + grouper2 = grouper1.reset() + + assert grouper2.drop_incomplete == grouper1.drop_incomplete + assert grouper2.seasons == grouper1.seasons + + grouper3 = SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=False) + grouper4 = grouper3.reset() + + assert grouper4.drop_incomplete == grouper3.drop_incomplete + assert grouper4.seasons == grouper3.seasons + + def test_adjust_years_for_season_helper(self): + """Test the helper function _adjust_years_for_season.""" + from xarray.groupers import _adjust_years_for_season + + years = np.array([2001, 2001, 2001, 2002, 2002, 2002]) + months = np.array([12, 1, 2, 12, 1, 2]) + + # Test DJF season (December, January, February) + adjusted = _adjust_years_for_season(years, months, (12, 1, 2)) + expected = np.array( + [2001, 2000, 2000, 2002, 2001, 2001] + ) # Jan/Feb get previous year + np.testing.assert_array_equal(adjusted, expected) + + # Test MAM season (no cross-year adjustment needed) + adjusted_mam = _adjust_years_for_season(years, months, (3, 4, 5)) + np.testing.assert_array_equal(adjusted_mam, years) # Should be unchanged + + # Test single month season + adjusted_jan = _adjust_years_for_season(years, months, (1,)) + np.testing.assert_array_equal(adjusted_jan, years) # Should be unchanged + # TODO: Possible property tests to add to this module # 1. lambda x: x From dafab77d3dfbe1f59c41c34d583a601db6ba5842 Mon Sep 17 00:00:00 2001 From: DHRUVA KUMAR KAUSHAL Date: Fri, 27 Jun 2025 15:30:36 +0530 Subject: [PATCH 3/4] error resolve2 --- xarray/groupers.py | 36 ++++++++++++++++++++++++++---------- xarray/tests/test_groupby.py | 14 +++++++++----- 2 files changed, 35 insertions(+), 15 deletions(-) diff --git a/xarray/groupers.py b/xarray/groupers.py index 63038ea4569..57dcbf3b2a3 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -52,6 +52,8 @@ "EncodedGroups", "Grouper", "Resampler", + "SeasonGrouper", + "SeasonResampler", "TimeResampler", "UniqueGrouper", ] @@ -237,7 +239,7 @@ def _factorize_given_labels(self, group: T_Group) -> EncodedGroups: ) return EncodedGroups( codes=codes, - full_index=pd.Index(self.labels), # type: ignore[arg-type] + full_index=pd.Index(self.labels), unique_coord=Variable( dims=codes.name, data=self.labels, @@ -612,7 +614,10 @@ def unique_value_groups( def _adjust_years_for_season( - years: np.ndarray, months: np.ndarray, season_tuple: tuple[int, ...] + years: np.ndarray, + months: np.ndarray, + season_tuple: tuple[int, ...], + season_str: str, ) -> np.ndarray: """ Adjust years for seasons that span December and January (e.g., DJF). @@ -628,6 +633,8 @@ def _adjust_years_for_season( Array of months corresponding to each timestamp season_tuple : tuple of int Tuple of month numbers that make up the season + season_str : str + String representation of the season (e.g., "DJF") Returns ------- @@ -635,10 +642,15 @@ def _adjust_years_for_season( Adjusted years array """ year_adjusted = years.copy() - # Handle seasons like DJF where December is in one year but Jan/Feb are in the next - if 12 in season_tuple and 1 in season_tuple: - jan_or_later = [m for m in season_tuple if m < 12] - year_adjusted[np.isin(months, jan_or_later)] -= 1 + # Handle seasons that contain December followed by other months + if "D" in season_str and 12 in season_tuple: + # Find the position of "D" (December) in the season string + d_index = season_str.index("D") + # Get all months that come after December in the season + months_after_dec = season_tuple[d_index + 1 :] + # Reduce year by 1 for months that come after December + for month_num in months_after_dec: + year_adjusted[months == month_num] -= 1 return year_adjusted @@ -822,7 +834,10 @@ def factorize(self, group: T_Group) -> EncodedGroups: (indices,) = mask.nonzero() group_indices[code] = indices.tolist() else: - year_adjusted = _adjust_years_for_season(year, months, season_tuple) + season_str = self.seasons[code] + year_adjusted = _adjust_years_for_season( + year, months, season_tuple, season_str + ) # find unique years for this season if not np.any(mask): @@ -945,9 +960,10 @@ def factorize(self, group: T_Group) -> EncodedGroups: # Apply year adjustment for cross-year seasons year_adjusted = year.copy() for season_str, season_ind in zip(seasons, season_inds, strict=True): - if "DJ" in season_str: - year_adjusted = _adjust_years_for_season( - year_adjusted, month.data, season_ind + if "D" in season_str and 12 in season_ind: + # Use helper function for year adjustment + year_adjusted[:] = _adjust_years_for_season( + year_adjusted.values, month.values, tuple(season_ind), season_str ) # Allow users to skip one or more months? diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 159ac655cb4..37decdb6d15 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3669,9 +3669,12 @@ def test_season_grouper_drop_incomplete_cross_year_seasons(self, calendar): # With complete data, both should give same number of seasons assert len(result_keep) == len(result_drop) - # Now test with incomplete data - start from Feb (missing Dec-Jan of first DJF) + # Now test with incomplete data - start from Feb (missing Dec 2000 and Jan 2001) time_incomplete = date_range( - "2001-02-01", "2002-12-31", freq="MS", calendar=calendar + "2001-02-01", + "2001-12-31", + freq="MS", + calendar=calendar, # Stop before next year ) data_incomplete = np.arange(len(time_incomplete)) da_incomplete = DataArray( @@ -3686,6 +3689,7 @@ def test_season_grouper_drop_incomplete_cross_year_seasons(self, calendar): ).mean() # drop_incomplete should exclude the incomplete first DJF season + # Data starts in Feb 2001, so 2000-2001 DJF is incomplete (missing Dec 2000, Jan 2001) assert len(result_drop_inc) < len(result_keep_inc) @pytest.mark.parametrize("calendar", ["standard"]) @@ -3724,18 +3728,18 @@ def test_adjust_years_for_season_helper(self): months = np.array([12, 1, 2, 12, 1, 2]) # Test DJF season (December, January, February) - adjusted = _adjust_years_for_season(years, months, (12, 1, 2)) + adjusted = _adjust_years_for_season(years, months, (12, 1, 2), "DJF") expected = np.array( [2001, 2000, 2000, 2002, 2001, 2001] ) # Jan/Feb get previous year np.testing.assert_array_equal(adjusted, expected) # Test MAM season (no cross-year adjustment needed) - adjusted_mam = _adjust_years_for_season(years, months, (3, 4, 5)) + adjusted_mam = _adjust_years_for_season(years, months, (3, 4, 5), "MAM") np.testing.assert_array_equal(adjusted_mam, years) # Should be unchanged # Test single month season - adjusted_jan = _adjust_years_for_season(years, months, (1,)) + adjusted_jan = _adjust_years_for_season(years, months, (1,), "J") np.testing.assert_array_equal(adjusted_jan, years) # Should be unchanged From 4469bb6098f9b87e112e4d4112a46f9887d26ec9 Mon Sep 17 00:00:00 2001 From: DHRUVA KUMAR KAUSHAL Date: Fri, 4 Jul 2025 08:13:22 +0530 Subject: [PATCH 4/4] mypy resolve --- xarray/groupers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xarray/groupers.py b/xarray/groupers.py index 57dcbf3b2a3..a12d5b860a0 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -237,12 +237,14 @@ def _factorize_given_labels(self, group: T_Group) -> EncodedGroups: output_dtypes=[np.int64], keep_attrs=True, ) + # Convert labels to a sequence that pandas Index can handle + labels_array = np.asarray(self.labels) return EncodedGroups( codes=codes, - full_index=pd.Index(self.labels), + full_index=pd.Index(labels_array), unique_coord=Variable( dims=codes.name, - data=self.labels, + data=labels_array, attrs=self.group.attrs, ), )