Skip to content

Commit 36cb337

Browse files
max-sixtymathause
andauthored
Rolling exp sum (#5178)
* Small simplification to RollingExp * Add rolling_exp().sum() * lint * Update xarray/core/rolling_exp.py Co-authored-by: Mathias Hauser <mathause@users.noreply.github.com> * Check for 0.2.1 version * Add tests * whatsnew * Skip sum tests on older version of numbagg * . Co-authored-by: Mathias Hauser <mathause@users.noreply.github.com>
1 parent 9a7ab2b commit 36cb337

File tree

3 files changed

+91
-11
lines changed

3 files changed

+91
-11
lines changed

doc/whats-new.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ New Features
2525
indexing (:issue:`3015`, :pull:`5362`). By `Matthias Göbel <https://github.com/matzegoebel>`_.
2626
- Attempting to reduce a weighted object over missing dimensions now raises an error (:pull:`5362`).
2727
By `Mattia Almansi <https://github.com/malmans2>`_.
28+
- Add ``.sum`` to :py:meth:`~xarray.DataArray.rolling_exp` and
29+
:py:meth:`~xarray.Dataset.rolling_exp` for exponentially weighted rolling
30+
sums. These require numbagg 0.2.1;
31+
(:pull:`5178`).
32+
By `Maximilian Roos <https://github.com/max-sixty>`_.
2833
- :py:func:`xarray.cov` and :py:func:`xarray.corr` now lazily check for missing
2934
values if inputs are dask arrays (:issue:`4804`, :pull:`5284`).
3035
By `Andrew Williams <https://github.com/AndrewWilliams3142>`_.

xarray/core/rolling_exp.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import TYPE_CHECKING, Generic, Hashable, Mapping, Optional, TypeVar, Union
1+
from distutils.version import LooseVersion
2+
from typing import TYPE_CHECKING, Generic, Hashable, Mapping, TypeVar, Union
23

34
import numpy as np
45

@@ -26,12 +27,25 @@ def move_exp_nanmean(array, *, axis, alpha):
2627
raise TypeError("rolling_exp is not currently support for dask-like arrays")
2728
import numbagg
2829

30+
# No longer needed in numbag > 0.2.0; remove in time
2931
if axis == ():
3032
return array.astype(np.float64)
3133
else:
3234
return numbagg.move_exp_nanmean(array, axis=axis, alpha=alpha)
3335

3436

37+
def move_exp_nansum(array, *, axis, alpha):
38+
if is_duck_dask_array(array):
39+
raise TypeError("rolling_exp is not currently supported for dask-like arrays")
40+
import numbagg
41+
42+
# numbagg <= 0.2.0 did not have a __version__ attribute
43+
if LooseVersion(getattr(numbagg, "__version__", "0.1.0")) < LooseVersion("0.2.0"):
44+
raise ValueError("`rolling_exp(...).sum() requires numbagg>=0.2.1.")
45+
46+
return numbagg.move_exp_nansum(array, axis=axis, alpha=alpha)
47+
48+
3549
def _get_center_of_mass(comass, span, halflife, alpha):
3650
"""
3751
Vendored from pandas.core.window.common._get_center_of_mass
@@ -98,9 +112,9 @@ def __init__(
98112
self.dim = dim
99113
self.alpha = _get_alpha(**{window_type: window})
100114

101-
def mean(self, keep_attrs: Optional[bool] = None) -> T_DSorDA:
115+
def mean(self, keep_attrs: bool = None) -> T_DSorDA:
102116
"""
103-
Exponentially weighted moving average
117+
Exponentially weighted moving average.
104118
105119
Parameters
106120
----------
@@ -124,3 +138,30 @@ def mean(self, keep_attrs: Optional[bool] = None) -> T_DSorDA:
124138
return self.obj.reduce(
125139
move_exp_nanmean, dim=self.dim, alpha=self.alpha, keep_attrs=keep_attrs
126140
)
141+
142+
def sum(self, keep_attrs: bool = None) -> T_DSorDA:
143+
"""
144+
Exponentially weighted moving sum.
145+
146+
Parameters
147+
----------
148+
keep_attrs : bool, default: None
149+
If True, the attributes (``attrs``) will be copied from the original
150+
object to the new one. If False, the new object will be returned
151+
without attributes. If None uses the global default.
152+
153+
Examples
154+
--------
155+
>>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x")
156+
>>> da.rolling_exp(x=2, window_type="span").sum()
157+
<xarray.DataArray (x: 5)>
158+
array([1. , 1.33333333, 2.44444444, 2.81481481, 2.9382716 ])
159+
Dimensions without coordinates: x
160+
"""
161+
162+
if keep_attrs is None:
163+
keep_attrs = _get_keep_attrs(default=True)
164+
165+
return self.obj.reduce(
166+
move_exp_nansum, dim=self.dim, alpha=self.alpha, keep_attrs=keep_attrs
167+
)

xarray/tests/test_dataarray.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7268,10 +7268,32 @@ def test_fallback_to_iris_AuxCoord(self, coord_values):
72687268
"window_type, window", [["span", 5], ["alpha", 0.5], ["com", 0.5], ["halflife", 5]]
72697269
)
72707270
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
7271-
def test_rolling_exp(da, dim, window_type, window):
7272-
da = da.isel(a=0)
7271+
@pytest.mark.parametrize("func", ["mean", "sum"])
7272+
def test_rolling_exp_runs(da, dim, window_type, window, func):
7273+
import numbagg
7274+
7275+
if (
7276+
LooseVersion(getattr(numbagg, "__version__", "0.1.0")) < "0.2.1"
7277+
and func == "sum"
7278+
):
7279+
pytest.skip("rolling_exp.sum requires numbagg 0.2.1")
7280+
72737281
da = da.where(da > 0.2)
72747282

7283+
rolling_exp = da.rolling_exp(window_type=window_type, **{dim: window})
7284+
result = getattr(rolling_exp, func)()
7285+
assert isinstance(result, DataArray)
7286+
7287+
7288+
@requires_numbagg
7289+
@pytest.mark.parametrize("dim", ["time", "x"])
7290+
@pytest.mark.parametrize(
7291+
"window_type, window", [["span", 5], ["alpha", 0.5], ["com", 0.5], ["halflife", 5]]
7292+
)
7293+
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
7294+
def test_rolling_exp_mean_pandas(da, dim, window_type, window):
7295+
da = da.isel(a=0).where(lambda x: x > 0.2)
7296+
72757297
result = da.rolling_exp(window_type=window_type, **{dim: window}).mean()
72767298
assert isinstance(result, DataArray)
72777299

@@ -7288,30 +7310,42 @@ def test_rolling_exp(da, dim, window_type, window):
72887310

72897311
@requires_numbagg
72907312
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
7291-
def test_rolling_exp_keep_attrs(da):
7313+
@pytest.mark.parametrize("func", ["mean", "sum"])
7314+
def test_rolling_exp_keep_attrs(da, func):
7315+
import numbagg
7316+
7317+
if (
7318+
LooseVersion(getattr(numbagg, "__version__", "0.1.0")) < "0.2.1"
7319+
and func == "sum"
7320+
):
7321+
pytest.skip("rolling_exp.sum requires numbagg 0.2.1")
7322+
72927323
attrs = {"attrs": "da"}
72937324
da.attrs = attrs
72947325

7326+
# Equivalent of `da.rolling_exp(time=10).mean`
7327+
rolling_exp_func = getattr(da.rolling_exp(time=10), func)
7328+
72957329
# attrs are kept per default
7296-
result = da.rolling_exp(time=10).mean()
7330+
result = rolling_exp_func()
72977331
assert result.attrs == attrs
72987332

72997333
# discard attrs
7300-
result = da.rolling_exp(time=10).mean(keep_attrs=False)
7334+
result = rolling_exp_func(keep_attrs=False)
73017335
assert result.attrs == {}
73027336

73037337
# test discard attrs using global option
73047338
with set_options(keep_attrs=False):
7305-
result = da.rolling_exp(time=10).mean()
7339+
result = rolling_exp_func()
73067340
assert result.attrs == {}
73077341

73087342
# keyword takes precedence over global option
73097343
with set_options(keep_attrs=False):
7310-
result = da.rolling_exp(time=10).mean(keep_attrs=True)
7344+
result = rolling_exp_func(keep_attrs=True)
73117345
assert result.attrs == attrs
73127346

73137347
with set_options(keep_attrs=True):
7314-
result = da.rolling_exp(time=10).mean(keep_attrs=False)
7348+
result = rolling_exp_func(keep_attrs=False)
73157349
assert result.attrs == {}
73167350

73177351
with pytest.warns(

0 commit comments

Comments
 (0)