Skip to content

Commit 75af56c

Browse files
authored
Enable .rolling_exp to work on dask arrays (#8284)
1 parent 46643bb commit 75af56c

File tree

3 files changed

+8
-3
lines changed

3 files changed

+8
-3
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ New Features
2929
- :py:meth:`DataArray.sortby` & :py:meth:`Dataset.sortby` accept a callable for
3030
the ``variables`` parameter, passing the object as the only argument.
3131
By `Maximilian Roos <https://github.com/max-sixty>`_.
32+
- ``.rolling_exp`` functions can now operate on dask-backed arrays, assuming the
33+
core dim has exactly one chunk. (:pull:`8284`).
34+
By `Maximilian Roos <https://github.com/max-sixty>`_.
3235

3336
Breaking changes
3437
~~~~~~~~~~~~~~~~

xarray/core/rolling_exp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,9 @@ def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
147147
input_core_dims=[[self.dim]],
148148
kwargs=dict(alpha=self.alpha, axis=-1),
149149
output_core_dims=[[self.dim]],
150-
exclude_dims={self.dim},
151150
keep_attrs=keep_attrs,
152151
on_missing_core_dim="copy",
152+
dask="parallelized",
153153
).transpose(*dim_order)
154154

155155
def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
@@ -183,7 +183,7 @@ def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
183183
input_core_dims=[[self.dim]],
184184
kwargs=dict(alpha=self.alpha, axis=-1),
185185
output_core_dims=[[self.dim]],
186-
exclude_dims={self.dim},
187186
keep_attrs=keep_attrs,
188187
on_missing_core_dim="copy",
188+
dask="parallelized",
189189
).transpose(*dim_order)

xarray/tests/test_rolling.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,9 @@ def test_raise_no_warning_dask_rolling_assert_close(self, ds, name) -> None:
788788

789789
@requires_numbagg
790790
class TestDatasetRollingExp:
791-
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
791+
@pytest.mark.parametrize(
792+
"backend", ["numpy", pytest.param("dask", marks=requires_dask)], indirect=True
793+
)
792794
def test_rolling_exp(self, ds) -> None:
793795
result = ds.rolling_exp(time=10, window_type="span").mean()
794796
assert isinstance(result, Dataset)

0 commit comments

Comments
 (0)