Skip to content

Commit 47b4ad9

Browse files
authored
Add corr, cov, std & var to .rolling_exp (#8307)
* Add `corr`, `cov`, `std` & `var` to `.rolling_exp` From the new routines in numbagg. Maybe needs better tests (though these are quite heavily tested in numbagg), docs, and potentially need to think about types (maybe existing binary ops can help here?)
1 parent 3bc33ee commit 47b4ad9

File tree

3 files changed

+140
-9
lines changed

3 files changed

+140
-9
lines changed

xarray/core/rolling_exp.py

Lines changed: 138 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, Generic
55

66
import numpy as np
7+
from packaging.version import Version
78

89
from xarray.core.computation import apply_ufunc
910
from xarray.core.options import _get_keep_attrs
@@ -14,9 +15,9 @@
1415
import numbagg
1516
from numbagg import move_exp_nanmean, move_exp_nansum
1617

17-
has_numbagg = numbagg.__version__
18+
has_numbagg: Version | None = Version(numbagg.__version__)
1819
except ImportError:
19-
has_numbagg = False
20+
has_numbagg = None
2021

2122

2223
def _get_alpha(
@@ -99,17 +100,17 @@ def __init__(
99100
window_type: str = "span",
100101
min_weight: float = 0.0,
101102
):
102-
if has_numbagg is False:
103+
if has_numbagg is None:
103104
raise ImportError(
104105
"numbagg >= 0.2.1 is required for rolling_exp but currently numbagg is not installed"
105106
)
106-
elif has_numbagg < "0.2.1":
107+
elif has_numbagg < Version("0.2.1"):
107108
raise ImportError(
108-
f"numbagg >= 0.2.1 is required for rolling_exp but currently version {has_numbagg} is installed"
109+
f"numbagg >= 0.2.1 is required for `rolling_exp` but currently version {has_numbagg} is installed"
109110
)
110-
elif has_numbagg < "0.3.1" and min_weight > 0:
111+
elif has_numbagg < Version("0.3.1") and min_weight > 0:
111112
raise ImportError(
112-
f"numbagg >= 0.3.1 is required for `min_weight > 0` but currently version {has_numbagg} is installed"
113+
f"numbagg >= 0.3.1 is required for `min_weight > 0` within `.rolling_exp` but currently version {has_numbagg} is installed"
113114
)
114115

115116
self.obj: T_DataWithCoords = obj
@@ -194,3 +195,133 @@ def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
194195
on_missing_core_dim="copy",
195196
dask="parallelized",
196197
).transpose(*dim_order)
198+
199+
def std(self) -> T_DataWithCoords:
200+
"""
201+
Exponentially weighted moving standard deviation.
202+
203+
`keep_attrs` is always True for this method. Drop attrs separately to remove attrs.
204+
205+
Examples
206+
--------
207+
>>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x")
208+
>>> da.rolling_exp(x=2, window_type="span").std()
209+
<xarray.DataArray (x: 5)>
210+
array([ nan, 0. , 0.67936622, 0.42966892, 0.25389527])
211+
Dimensions without coordinates: x
212+
"""
213+
214+
if has_numbagg is None or has_numbagg < Version("0.4.0"):
215+
raise ImportError(
216+
f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {has_numbagg} is installed"
217+
)
218+
dim_order = self.obj.dims
219+
220+
return apply_ufunc(
221+
numbagg.move_exp_nanstd,
222+
self.obj,
223+
input_core_dims=[[self.dim]],
224+
kwargs=self.kwargs,
225+
output_core_dims=[[self.dim]],
226+
keep_attrs=True,
227+
on_missing_core_dim="copy",
228+
dask="parallelized",
229+
).transpose(*dim_order)
230+
231+
def var(self) -> T_DataWithCoords:
232+
"""
233+
Exponentially weighted moving variance.
234+
235+
`keep_attrs` is always True for this method. Drop attrs separately to remove attrs.
236+
237+
Examples
238+
--------
239+
>>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x")
240+
>>> da.rolling_exp(x=2, window_type="span").var()
241+
<xarray.DataArray (x: 5)>
242+
array([ nan, 0. , 0.46153846, 0.18461538, 0.06446281])
243+
Dimensions without coordinates: x
244+
"""
245+
246+
if has_numbagg is None or has_numbagg < Version("0.4.0"):
247+
raise ImportError(
248+
f"numbagg >= 0.4.0 is required for rolling_exp().var(), currently {has_numbagg} is installed"
249+
)
250+
dim_order = self.obj.dims
251+
252+
return apply_ufunc(
253+
numbagg.move_exp_nanvar,
254+
self.obj,
255+
input_core_dims=[[self.dim]],
256+
kwargs=self.kwargs,
257+
output_core_dims=[[self.dim]],
258+
keep_attrs=True,
259+
on_missing_core_dim="copy",
260+
dask="parallelized",
261+
).transpose(*dim_order)
262+
263+
def cov(self, other: T_DataWithCoords) -> T_DataWithCoords:
264+
"""
265+
Exponentially weighted moving covariance.
266+
267+
`keep_attrs` is always True for this method. Drop attrs separately to remove attrs.
268+
269+
Examples
270+
--------
271+
>>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x")
272+
>>> da.rolling_exp(x=2, window_type="span").cov(da**2)
273+
<xarray.DataArray (x: 5)>
274+
array([ nan, 0. , 1.38461538, 0.55384615, 0.19338843])
275+
Dimensions without coordinates: x
276+
"""
277+
278+
if has_numbagg is None or has_numbagg < Version("0.4.0"):
279+
raise ImportError(
280+
f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {has_numbagg} is installed"
281+
)
282+
dim_order = self.obj.dims
283+
284+
return apply_ufunc(
285+
numbagg.move_exp_nancov,
286+
self.obj,
287+
other,
288+
input_core_dims=[[self.dim], [self.dim]],
289+
kwargs=self.kwargs,
290+
output_core_dims=[[self.dim]],
291+
keep_attrs=True,
292+
on_missing_core_dim="copy",
293+
dask="parallelized",
294+
).transpose(*dim_order)
295+
296+
def corr(self, other: T_DataWithCoords) -> T_DataWithCoords:
297+
"""
298+
Exponentially weighted moving correlation.
299+
300+
`keep_attrs` is always True for this method. Drop attrs separately to remove attrs.
301+
302+
Examples
303+
--------
304+
>>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x")
305+
>>> da.rolling_exp(x=2, window_type="span").corr(da.shift(x=1))
306+
<xarray.DataArray (x: 5)>
307+
array([ nan, nan, nan, 0.4330127 , 0.48038446])
308+
Dimensions without coordinates: x
309+
"""
310+
311+
if has_numbagg is None or has_numbagg < Version("0.4.0"):
312+
raise ImportError(
313+
f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {has_numbagg} is installed"
314+
)
315+
dim_order = self.obj.dims
316+
317+
return apply_ufunc(
318+
numbagg.move_exp_nancorr,
319+
self.obj,
320+
other,
321+
input_core_dims=[[self.dim], [self.dim]],
322+
kwargs=self.kwargs,
323+
output_core_dims=[[self.dim]],
324+
keep_attrs=True,
325+
on_missing_core_dim="copy",
326+
dask="parallelized",
327+
).transpose(*dim_order)

xarray/tests/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def _importorskip(
7575
has_zarr, requires_zarr = _importorskip("zarr")
7676
has_fsspec, requires_fsspec = _importorskip("fsspec")
7777
has_iris, requires_iris = _importorskip("iris")
78-
has_numbagg, requires_numbagg = _importorskip("numbagg")
78+
has_numbagg, requires_numbagg = _importorskip("numbagg", "0.4.0")
7979
has_seaborn, requires_seaborn = _importorskip("seaborn")
8080
has_sparse, requires_sparse = _importorskip("sparse")
8181
has_cupy, requires_cupy = _importorskip("cupy")

xarray/tests/test_rolling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ class TestDataArrayRollingExp:
394394
[["span", 5], ["alpha", 0.5], ["com", 0.5], ["halflife", 5]],
395395
)
396396
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
397-
@pytest.mark.parametrize("func", ["mean", "sum"])
397+
@pytest.mark.parametrize("func", ["mean", "sum", "var", "std"])
398398
def test_rolling_exp_runs(self, da, dim, window_type, window, func) -> None:
399399
da = da.where(da > 0.2)
400400

0 commit comments

Comments
 (0)