Skip to content

Support rechunking to seasonal frequency with SeasonalResampler #10519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ New Features
~~~~~~~~~~~~
- Allow skipping the creation of default indexes when opening datasets (:pull:`8051`).
By `Benoit Bovy <https://github.com/benbovy>`_ and `Justus Magin <https://github.com/keewis>`_.
- Support chunking by :py:class:`~xarray.groupers.SeasonResampler` for seasonal data analysis (:issue:`10425`, :pull:`10519`).
By `Dhruva Kumar Kaushal <https://github.com/dhruvak001>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand All @@ -30,7 +32,7 @@ Bug fixes
creates extra variables that don't match the provided coordinate names, instead
of silently ignoring them. The error message suggests using the factory method
pattern with :py:meth:`xarray.Coordinates.from_xindex` and
:py:meth:`Dataset.assign_coords` for advanced use cases (:issue:`10499`).
:py:meth:`Dataset.assign_coords` for advanced use cases (:issue:`10499`, , :pull:`10503`).
By `Dhruva Kumar Kaushal <https://github.com/dhruvak001>`_.

Documentation
Expand Down
39 changes: 8 additions & 31 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2411,13 +2411,14 @@ def chunk(
sizes along that dimension will not be updated; non-dask arrays will be
converted into dask arrays with a single block.

Along datetime-like dimensions, a :py:class:`groupers.TimeResampler` object is also accepted.
Along datetime-like dimensions, a :py:class:`groupers.TimeResampler` or :py:class:`groupers.SeasonResampler` object is also accepted.

Parameters
----------
chunks : int, tuple of int, "auto" or mapping of hashable to int or a TimeResampler, optional
chunks : int, tuple of int, "auto" or mapping of hashable to int or a Resampler, optional
Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, or
``{"x": 5, "y": 5}`` or ``{"x": 5, "time": TimeResampler(freq="YE")}``.
``{"x": 5, "y": 5}`` or ``{"x": 5, "time": TimeResampler(freq="YE")}`` or
``{"time": SeasonResampler(["DJF", "MAM", "JJA", "SON"])}``.
name_prefix : str, default: "xarray-"
Prefix for the name of any new dask arrays.
token : str, optional
Expand Down Expand Up @@ -2452,8 +2453,7 @@ def chunk(
xarray.unify_chunks
dask.array.from_array
"""
from xarray.core.dataarray import DataArray
from xarray.groupers import TimeResampler
from xarray.groupers import Resampler

if chunks is None and not chunks_kwargs:
warnings.warn(
Expand Down Expand Up @@ -2481,41 +2481,18 @@ def chunk(
f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(self.sizes.keys())}"
)

def _resolve_frequency(
name: Hashable, resampler: TimeResampler
) -> tuple[int, ...]:
def _resolve_frequency(name: Hashable, resampler: Resampler) -> tuple[int, ...]:
variable = self._variables.get(name, None)
if variable is None:
raise ValueError(
f"Cannot chunk by resampler {resampler!r} for virtual variables."
)
elif not _contains_datetime_like_objects(variable):
raise ValueError(
f"chunks={resampler!r} only supported for datetime variables. "
f"Received variable {name!r} with dtype {variable.dtype!r} instead."
)

assert variable.ndim == 1
chunks = (
DataArray(
np.ones(variable.shape, dtype=int),
dims=(name,),
coords={name: variable},
)
.resample({name: resampler})
.sum()
)
# When bins (binning) or time periods are missing (resampling)
# we can end up with NaNs. Drop them.
if chunks.dtype.kind == "f":
chunks = chunks.dropna(name).astype(int)
chunks_tuple: tuple[int, ...] = tuple(chunks.data.tolist())
return chunks_tuple
return resampler.resolve_chunks(name, variable)

chunks_mapping_ints: Mapping[Any, T_ChunkDim] = {
name: (
_resolve_frequency(name, chunks)
if isinstance(chunks, TimeResampler)
if isinstance(chunks, Resampler)
else chunks
)
for name, chunks in chunks_mapping.items()
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from xarray.core.indexes import Index, Indexes
from xarray.core.utils import Frozen
from xarray.core.variable import IndexVariable, Variable
from xarray.groupers import Grouper, TimeResampler
from xarray.groupers import Grouper, SeasonResampler, TimeResampler
from xarray.structure.alignment import Aligner

GroupInput: TypeAlias = (
Expand Down Expand Up @@ -201,7 +201,7 @@ def copy(
# FYI in some cases we don't allow `None`, which this doesn't take account of.
# FYI the `str` is for a size string, e.g. "16MB", supported by dask.
T_ChunkDim: TypeAlias = str | int | Literal["auto"] | tuple[int, ...] | None # noqa: PYI051
T_ChunkDimFreq: TypeAlias = Union["TimeResampler", T_ChunkDim]
T_ChunkDimFreq: TypeAlias = Union["TimeResampler", "SeasonResampler", T_ChunkDim]
T_ChunksFreq: TypeAlias = T_ChunkDim | Mapping[Any, T_ChunkDimFreq]
# We allow the tuple form of this (though arguably we could transition to named dims only)
T_Chunks: TypeAlias = T_ChunkDim | Mapping[Any, T_ChunkDim]
Expand Down
67 changes: 65 additions & 2 deletions xarray/groupers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import operator
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Mapping, Sequence
from collections.abc import Hashable, Mapping, Sequence
from dataclasses import dataclass, field
from itertools import chain, pairwise
from typing import TYPE_CHECKING, Any, Literal, cast
Expand Down Expand Up @@ -52,6 +52,7 @@
"EncodedGroups",
"Grouper",
"Resampler",
"SeasonResampler",
"TimeResampler",
"UniqueGrouper",
]
Expand Down Expand Up @@ -169,7 +170,60 @@ class Resampler(Grouper):
Currently only used for TimeResampler, but could be used for SpaceResampler in the future.
"""

pass
def resolve_chunks(self, name: Hashable, variable: Variable) -> tuple[int, ...]:
"""
Resolve chunk sizes for this resampler.

This method is used during chunking operations to determine appropriate
chunk sizes for the given variable when using this resampler.

Parameters
----------
name : Hashable
The name of the dimension being chunked.
variable : Variable
The variable being chunked.

Returns
-------
tuple[int, ...]
A tuple of chunk sizes for the dimension.
"""
from xarray.core.dataarray import DataArray

if not _contains_datetime_like_objects(variable):
raise ValueError(
f"chunks={self!r} only supported for datetime variables. "
f"Received variable {name!r} with dtype {variable.dtype!r} instead."
)

if variable.ndim != 1:
raise ValueError(
f"chunks={self!r} only supported for 1D variables. "
f"Received variable {name!r} with {variable.ndim} dimensions instead."
)

# Create a temporary resampler that ignores drop_incomplete for chunking
# This prevents data from being silently dropped during chunking
resampler_for_chunking = (
self._for_chunking() if hasattr(self, "_for_chunking") else self
)

chunks = (
DataArray(
np.ones(variable.shape, dtype=int),
dims=(name,),
coords={name: variable},
)
.resample({name: resampler_for_chunking})
.sum()
)
# When bins (binning) or time periods are missing (resampling)
# we can end up with NaNs. Drop them.
if chunks.dtype.kind == "f":
chunks = chunks.dropna(name).astype(int)
chunks_tuple: tuple[int, ...] = tuple(chunks.data.tolist())
return chunks_tuple


@dataclass
Expand Down Expand Up @@ -968,5 +1022,14 @@ def get_label(year, season):

return EncodedGroups(codes=codes, full_index=full_index)

def _for_chunking(self) -> Self:
"""
Return a version of this resampler suitable for chunking.

For SeasonResampler, this returns a version with drop_incomplete=False
to prevent data from being silently dropped during chunking operations.
"""
return type(self)(seasons=self.seasons, drop_incomplete=False)

def reset(self) -> Self:
return type(self)(seasons=self.seasons, drop_incomplete=self.drop_incomplete)
103 changes: 102 additions & 1 deletion xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from xarray.core.indexes import Index, PandasIndex
from xarray.core.types import ArrayLike
from xarray.core.utils import is_scalar
from xarray.groupers import TimeResampler
from xarray.groupers import SeasonResampler, TimeResampler
from xarray.namedarray.pycompat import array_type, integer_types
from xarray.testing import _assert_internal_invariants
from xarray.tests import (
Expand Down Expand Up @@ -1137,6 +1137,107 @@ def test_chunks_does_not_load_data(self) -> None:

@requires_dask
def test_chunk(self) -> None:
data = create_test_data()
for chunks in [1, 2, 3, 4, 5]:
rechunked = data.chunk({"dim1": chunks})
assert rechunked.chunks["dim1"] == (chunks,) * (8 // chunks) + (
(8 % chunks,) if 8 % chunks else ()
)

rechunked = data.chunk({"dim2": chunks})
assert rechunked.chunks["dim2"] == (chunks,) * (9 // chunks) + (
(9 % chunks,) if 9 % chunks else ()
)

rechunked = data.chunk({"dim1": chunks, "dim2": chunks})
assert rechunked.chunks["dim1"] == (chunks,) * (8 // chunks) + (
(8 % chunks,) if 8 % chunks else ()
)
assert rechunked.chunks["dim2"] == (chunks,) * (9 // chunks) + (
(9 % chunks,) if 9 % chunks else ()
)

@requires_dask
def test_chunk_by_season_resampler(self) -> None:
"""Test chunking using SeasonResampler."""
import dask.array

N = 365 * 2 # 2 years
time = xr.date_range("2001-01-01", periods=N, freq="D")
ds = Dataset(
{
"pr": ("time", dask.array.random.random((N), chunks=(20))),
"pr2d": (("x", "time"), dask.array.random.random((10, N), chunks=(20))),
"ones": ("time", np.ones((N,))),
},
coords={"time": time},
)

# Test standard seasons
rechunked = ds.chunk(x=2, time=SeasonResampler(["DJF", "MAM", "JJA", "SON"]))
expected = tuple(
ds.ones.resample(
time=SeasonResampler(
["DJF", "MAM", "JJA", "SON"], drop_incomplete=False
)
)
.sum()
.dropna("time")
.astype(int)
.data.tolist()
)
assert rechunked.chunksizes["time"] == expected
assert rechunked.chunksizes["x"] == (2,) * 5

# Test custom seasons
rechunked = ds.chunk(
{"x": 2, "time": SeasonResampler(["DJFM", "AM", "JJA", "SON"])}
)
expected = tuple(
ds.ones.resample(
time=SeasonResampler(
["DJFM", "AM", "JJA", "SON"], drop_incomplete=False
)
)
.sum()
.dropna("time")
.astype(int)
.data.tolist()
)
assert rechunked.chunksizes["time"] == expected
assert rechunked.chunksizes["x"] == (2,) * 5

# Test that drop_incomplete doesn't affect chunking
rechunked_drop_true = ds.chunk(
time=SeasonResampler(["DJF", "MAM", "JJA", "SON"], drop_incomplete=True)
)
rechunked_drop_false = ds.chunk(
time=SeasonResampler(["DJF", "MAM", "JJA", "SON"], drop_incomplete=False)
)
assert (
rechunked_drop_true.chunksizes["time"]
== rechunked_drop_false.chunksizes["time"]
)

@requires_dask
def test_chunk_by_season_resampler_errors(self):
"""Test error handling for SeasonResampler chunking."""
ds = Dataset({"foo": ("x", [1, 2, 3])})

# Test error on virtual variable
with pytest.raises(ValueError, match="virtual variable"):
ds.chunk(x=SeasonResampler(["DJF", "MAM", "JJA", "SON"]))

# Test error on non-datetime variable
ds["x"] = ("x", [1, 2, 3])
with pytest.raises(ValueError, match="datetime variables"):
ds.chunk(x=SeasonResampler(["DJF", "MAM", "JJA", "SON"]))

# Test successful case with 1D datetime variable
ds["x"] = ("x", xr.date_range("2001-01-01", periods=3, freq="D"))
# This should work
result = ds.chunk(x=SeasonResampler(["DJF", "MAM", "JJA", "SON"]))
assert result.chunks is not None
data = create_test_data()
for v in data.variables.values():
assert isinstance(v.data, np.ndarray)
Expand Down
Loading