From 0f10d9198151eb8167d190b0fc995b199705825b Mon Sep 17 00:00:00 2001 From: DHRUVA KUMAR KAUSHAL Date: Thu, 10 Jul 2025 00:01:09 +0530 Subject: [PATCH 1/3] Support chunking --- doc/whats-new.rst | 4 +- xarray/core/dataset.py | 43 +++------------ xarray/groupers.py | 68 ++++++++++++++++++++++- xarray/tests/test_dataset.py | 102 ++++++++++++++++++++++++++++++++++- 4 files changed, 179 insertions(+), 38 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f97f12ffc9f..bc54e03db0c 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -14,6 +14,8 @@ New Features ~~~~~~~~~~~~ - Allow skipping the creation of default indexes when opening datasets (:pull:`8051`). By `Benoit Bovy `_ and `Justus Magin `_. +- Support chunking by :py:class:`~xarray.groupers.SeasonResampler` for seasonal data analysis (:issue:`10425`, :pull:`10517`). + By `Dhruva Kumar Kaushal `_. Breaking changes ~~~~~~~~~~~~~~~~ @@ -30,7 +32,7 @@ Bug fixes creates extra variables that don't match the provided coordinate names, instead of silently ignoring them. The error message suggests using the factory method pattern with :py:meth:`xarray.Coordinates.from_xindex` and - :py:meth:`Dataset.assign_coords` for advanced use cases (:issue:`10499`). + :py:meth:`Dataset.assign_coords` for advanced use cases (:issue:`10499`, , :pull:`10503`). By `Dhruva Kumar Kaushal `_. Documentation diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index ac4bfc32df5..bdb311de231 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2411,13 +2411,14 @@ def chunk( sizes along that dimension will not be updated; non-dask arrays will be converted into dask arrays with a single block. - Along datetime-like dimensions, a :py:class:`groupers.TimeResampler` object is also accepted. + Along datetime-like dimensions, a :py:class:`groupers.TimeResampler` or :py:class:`groupers.SeasonResampler` object is also accepted. Parameters ---------- - chunks : int, tuple of int, "auto" or mapping of hashable to int or a TimeResampler, optional + chunks : int, tuple of int, "auto" or mapping of hashable to int or a Resampler, optional Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, or - ``{"x": 5, "y": 5}`` or ``{"x": 5, "time": TimeResampler(freq="YE")}``. + ``{"x": 5, "y": 5}`` or ``{"x": 5, "time": TimeResampler(freq="YE")}`` or + ``{"time": SeasonResampler(["DJF", "MAM", "JJA", "SON"])}``. name_prefix : str, default: "xarray-" Prefix for the name of any new dask arrays. token : str, optional @@ -2452,8 +2453,7 @@ def chunk( xarray.unify_chunks dask.array.from_array """ - from xarray.core.dataarray import DataArray - from xarray.groupers import TimeResampler + from xarray.groupers import Resampler if chunks is None and not chunks_kwargs: warnings.warn( @@ -2481,41 +2481,14 @@ def chunk( f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(self.sizes.keys())}" ) - def _resolve_frequency( - name: Hashable, resampler: TimeResampler - ) -> tuple[int, ...]: + def _resolve_frequency(name: Hashable, resampler: Resampler) -> tuple[int, ...]: variable = self._variables.get(name, None) - if variable is None: - raise ValueError( - f"Cannot chunk by resampler {resampler!r} for virtual variables." - ) - elif not _contains_datetime_like_objects(variable): - raise ValueError( - f"chunks={resampler!r} only supported for datetime variables. " - f"Received variable {name!r} with dtype {variable.dtype!r} instead." - ) - - assert variable.ndim == 1 - chunks = ( - DataArray( - np.ones(variable.shape, dtype=int), - dims=(name,), - coords={name: variable}, - ) - .resample({name: resampler}) - .sum() - ) - # When bins (binning) or time periods are missing (resampling) - # we can end up with NaNs. Drop them. - if chunks.dtype.kind == "f": - chunks = chunks.dropna(name).astype(int) - chunks_tuple: tuple[int, ...] = tuple(chunks.data.tolist()) - return chunks_tuple + return resampler.resolve_chunks(name, variable) chunks_mapping_ints: Mapping[Any, T_ChunkDim] = { name: ( _resolve_frequency(name, chunks) - if isinstance(chunks, TimeResampler) + if isinstance(chunks, Resampler) else chunks ) for name, chunks in chunks_mapping.items() diff --git a/xarray/groupers.py b/xarray/groupers.py index 4424c65a94b..df271d4356c 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -52,6 +52,7 @@ "EncodedGroups", "Grouper", "Resampler", + "SeasonResampler", "TimeResampler", "UniqueGrouper", ] @@ -169,7 +170,63 @@ class Resampler(Grouper): Currently only used for TimeResampler, but could be used for SpaceResampler in the future. """ - pass + def resolve_chunks(self, name: str, variable: Variable) -> tuple[int, ...]: + """ + Resolve chunk sizes for this resampler. + + This method is used during chunking operations to determine appropriate + chunk sizes for the given variable when using this resampler. + + Parameters + ---------- + name : str + The name of the dimension being chunked. + variable : Variable + The variable being chunked. + + Returns + ------- + tuple[int, ...] + A tuple of chunk sizes for the dimension. + """ + + if variable is None: + raise ValueError( + f"Cannot chunk by resampler {self!r} for virtual variables." + ) + elif not _contains_datetime_like_objects(variable): + raise ValueError( + f"chunks={self!r} only supported for datetime variables. " + f"Received variable {name!r} with dtype {variable.dtype!r} instead." + ) + + if variable.ndim != 1: + raise ValueError( + f"chunks={self!r} only supported for 1D variables. " + f"Received variable {name!r} with {variable.ndim} dimensions instead." + ) + + # Create a temporary resampler that ignores drop_incomplete for chunking + # This prevents data from being silently dropped during chunking + resampler_for_chunking = ( + self._for_chunking() if hasattr(self, "_for_chunking") else self + ) + + chunks = ( + DataArray( + np.ones(variable.shape, dtype=int), + dims=(name,), + coords={name: variable}, + ) + .resample({name: resampler_for_chunking}) + .sum() + ) + # When bins (binning) or time periods are missing (resampling) + # we can end up with NaNs. Drop them. + if chunks.dtype.kind == "f": + chunks = chunks.dropna(name).astype(int) + chunks_tuple: tuple[int, ...] = tuple(chunks.data.tolist()) + return chunks_tuple @dataclass @@ -968,5 +1025,14 @@ def get_label(year, season): return EncodedGroups(codes=codes, full_index=full_index) + def _for_chunking(self) -> Self: + """ + Return a version of this resampler suitable for chunking. + + For SeasonResampler, this returns a version with drop_incomplete=False + to prevent data from being silently dropped during chunking operations. + """ + return type(self)(seasons=self.seasons, drop_incomplete=False) + def reset(self) -> Self: return type(self)(seasons=self.seasons, drop_incomplete=self.drop_incomplete) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 3e0734c8a1a..85c30e9fa0e 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -46,7 +46,7 @@ from xarray.core.indexes import Index, PandasIndex from xarray.core.types import ArrayLike from xarray.core.utils import is_scalar -from xarray.groupers import TimeResampler +from xarray.groupers import SeasonResampler, TimeResampler from xarray.namedarray.pycompat import array_type, integer_types from xarray.testing import _assert_internal_invariants from xarray.tests import ( @@ -1137,6 +1137,106 @@ def test_chunks_does_not_load_data(self) -> None: @requires_dask def test_chunk(self) -> None: + data = create_test_data() + for chunks in [1, 2, 3, 4, 5]: + rechunked = data.chunk({"dim1": chunks}) + assert rechunked.chunks["dim1"] == (chunks,) * (8 // chunks) + ( + (8 % chunks,) if 8 % chunks else () + ) + + rechunked = data.chunk({"dim2": chunks}) + assert rechunked.chunks["dim2"] == (chunks,) * (9 // chunks) + ( + (9 % chunks,) if 9 % chunks else () + ) + + rechunked = data.chunk({"dim1": chunks, "dim2": chunks}) + assert rechunked.chunks["dim1"] == (chunks,) * (8 // chunks) + ( + (8 % chunks,) if 8 % chunks else () + ) + assert rechunked.chunks["dim2"] == (chunks,) * (9 // chunks) + ( + (9 % chunks,) if 9 % chunks else () + ) + + @requires_dask + def test_chunk_by_season_resampler(self) -> None: + """Test chunking using SeasonResampler.""" + import dask.array + + N = 365 * 2 # 2 years + time = xr.date_range("2001-01-01", periods=N, freq="D") + ds = Dataset( + { + "pr": ("time", dask.array.random.random((N), chunks=(20))), + "pr2d": (("x", "time"), dask.array.random.random((10, N), chunks=(20))), + "ones": ("time", np.ones((N,))), + }, + coords={"time": time}, + ) + + # Test standard seasons + rechunked = ds.chunk(x=2, time=SeasonResampler(["DJF", "MAM", "JJA", "SON"])) + expected = tuple( + ds.ones.resample( + time=SeasonResampler( + ["DJF", "MAM", "JJA", "SON"], drop_incomplete=False + ) + ) + .sum() + .dropna("time") + .astype(int) + .data.tolist() + ) + assert rechunked.chunksizes["time"] == expected + assert rechunked.chunksizes["x"] == (2,) * 5 + + # Test custom seasons + rechunked = ds.chunk( + {"x": 2, "time": SeasonResampler(["DJFM", "AM", "JJA", "SON"])} + ) + expected = tuple( + ds.ones.resample( + time=SeasonResampler( + ["DJFM", "AM", "JJA", "SON"], drop_incomplete=False + ) + ) + .sum() + .dropna("time") + .astype(int) + .data.tolist() + ) + assert rechunked.chunksizes["time"] == expected + assert rechunked.chunksizes["x"] == (2,) * 5 + + # Test that drop_incomplete doesn't affect chunking + rechunked_drop_true = ds.chunk( + time=SeasonResampler(["DJF", "MAM", "JJA", "SON"], drop_incomplete=True) + ) + rechunked_drop_false = ds.chunk( + time=SeasonResampler(["DJF", "MAM", "JJA", "SON"], drop_incomplete=False) + ) + assert ( + rechunked_drop_true.chunksizes["time"] + == rechunked_drop_false.chunksizes["time"] + ) + + def test_chunk_by_season_resampler_errors(self): + """Test error handling for SeasonResampler chunking.""" + ds = Dataset({"foo": ("x", [1, 2, 3])}) + + # Test error on virtual variable + with pytest.raises(ValueError, match="virtual variable"): + ds.chunk(x=SeasonResampler(["DJF", "MAM", "JJA", "SON"])) + + # Test error on non-datetime variable + ds["x"] = ("x", [1, 2, 3]) + with pytest.raises(ValueError, match="datetime variables"): + ds.chunk(x=SeasonResampler(["DJF", "MAM", "JJA", "SON"])) + + # Test successful case with 1D datetime variable + ds["x"] = ("x", xr.date_range("2001-01-01", periods=3, freq="D")) + # This should work + result = ds.chunk(x=SeasonResampler(["DJF", "MAM", "JJA", "SON"])) + assert result.chunks is not None data = create_test_data() for v in data.variables.values(): assert isinstance(v.data, np.ndarray) From 3d4b62e623a389f995af6c638d0ee0cf6a326578 Mon Sep 17 00:00:00 2001 From: DHRUVA KUMAR KAUSHAL Date: Thu, 10 Jul 2025 00:06:24 +0530 Subject: [PATCH 2/3] whats new --- doc/whats-new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index bc54e03db0c..b31d20c9298 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -14,7 +14,7 @@ New Features ~~~~~~~~~~~~ - Allow skipping the creation of default indexes when opening datasets (:pull:`8051`). By `Benoit Bovy `_ and `Justus Magin `_. -- Support chunking by :py:class:`~xarray.groupers.SeasonResampler` for seasonal data analysis (:issue:`10425`, :pull:`10517`). +- Support chunking by :py:class:`~xarray.groupers.SeasonResampler` for seasonal data analysis (:issue:`10425`, :pull:`10519`). By `Dhruva Kumar Kaushal `_. Breaking changes From 527c48ccb09537812b668bd8cba03b7bcf5b4d9b Mon Sep 17 00:00:00 2001 From: DHRUVA KUMAR KAUSHAL Date: Thu, 10 Jul 2025 00:46:58 +0530 Subject: [PATCH 3/3] error resolving --- xarray/core/dataset.py | 4 ++++ xarray/core/types.py | 4 ++-- xarray/groupers.py | 13 +++++-------- xarray/tests/test_dataset.py | 1 + 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index bdb311de231..0a38e2307a1 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2483,6 +2483,10 @@ def chunk( def _resolve_frequency(name: Hashable, resampler: Resampler) -> tuple[int, ...]: variable = self._variables.get(name, None) + if variable is None: + raise ValueError( + f"Cannot chunk by resampler {resampler!r} for virtual variables." + ) return resampler.resolve_chunks(name, variable) chunks_mapping_ints: Mapping[Any, T_ChunkDim] = { diff --git a/xarray/core/types.py b/xarray/core/types.py index 736a11f5f17..4b511b18387 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -32,7 +32,7 @@ from xarray.core.indexes import Index, Indexes from xarray.core.utils import Frozen from xarray.core.variable import IndexVariable, Variable - from xarray.groupers import Grouper, TimeResampler + from xarray.groupers import Grouper, SeasonResampler, TimeResampler from xarray.structure.alignment import Aligner GroupInput: TypeAlias = ( @@ -201,7 +201,7 @@ def copy( # FYI in some cases we don't allow `None`, which this doesn't take account of. # FYI the `str` is for a size string, e.g. "16MB", supported by dask. T_ChunkDim: TypeAlias = str | int | Literal["auto"] | tuple[int, ...] | None # noqa: PYI051 -T_ChunkDimFreq: TypeAlias = Union["TimeResampler", T_ChunkDim] +T_ChunkDimFreq: TypeAlias = Union["TimeResampler", "SeasonResampler", T_ChunkDim] T_ChunksFreq: TypeAlias = T_ChunkDim | Mapping[Any, T_ChunkDimFreq] # We allow the tuple form of this (though arguably we could transition to named dims only) T_Chunks: TypeAlias = T_ChunkDim | Mapping[Any, T_ChunkDim] diff --git a/xarray/groupers.py b/xarray/groupers.py index df271d4356c..b4454ace173 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -12,7 +12,7 @@ import operator from abc import ABC, abstractmethod from collections import defaultdict -from collections.abc import Mapping, Sequence +from collections.abc import Hashable, Mapping, Sequence from dataclasses import dataclass, field from itertools import chain, pairwise from typing import TYPE_CHECKING, Any, Literal, cast @@ -170,7 +170,7 @@ class Resampler(Grouper): Currently only used for TimeResampler, but could be used for SpaceResampler in the future. """ - def resolve_chunks(self, name: str, variable: Variable) -> tuple[int, ...]: + def resolve_chunks(self, name: Hashable, variable: Variable) -> tuple[int, ...]: """ Resolve chunk sizes for this resampler. @@ -179,7 +179,7 @@ def resolve_chunks(self, name: str, variable: Variable) -> tuple[int, ...]: Parameters ---------- - name : str + name : Hashable The name of the dimension being chunked. variable : Variable The variable being chunked. @@ -189,12 +189,9 @@ def resolve_chunks(self, name: str, variable: Variable) -> tuple[int, ...]: tuple[int, ...] A tuple of chunk sizes for the dimension. """ + from xarray.core.dataarray import DataArray - if variable is None: - raise ValueError( - f"Cannot chunk by resampler {self!r} for virtual variables." - ) - elif not _contains_datetime_like_objects(variable): + if not _contains_datetime_like_objects(variable): raise ValueError( f"chunks={self!r} only supported for datetime variables. " f"Received variable {name!r} with dtype {variable.dtype!r} instead." diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 85c30e9fa0e..88f1d4c92d6 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1219,6 +1219,7 @@ def test_chunk_by_season_resampler(self) -> None: == rechunked_drop_false.chunksizes["time"] ) + @requires_dask def test_chunk_by_season_resampler_errors(self): """Test error handling for SeasonResampler chunking.""" ds = Dataset({"foo": ("x", [1, 2, 3])})