Skip to content

Commit 8417f49

Browse files
keewisdcherian
andauthored
rely on numpy's version of nanprod and nansum (#6873)
Co-authored-by: dcherian <deepak@cherian.net>
1 parent 3c8ce0f commit 8417f49

File tree

2 files changed

+22
-18
lines changed

2 files changed

+22
-18
lines changed

xarray/core/nanops.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,6 @@
1717
dask_array_compat = None # type: ignore[assignment]
1818

1919

20-
def _replace_nan(a, val):
21-
"""
22-
replace nan in a by val, and returns the replaced array and the nan
23-
position
24-
"""
25-
mask = isnull(a)
26-
return where_method(val, mask, a), mask
27-
28-
2920
def _maybe_null_out(result, axis, mask, min_count=1):
3021
"""
3122
xarray version of pandas.core.nanops._maybe_null_out
@@ -105,8 +96,8 @@ def nanargmax(a, axis=None):
10596

10697

10798
def nansum(a, axis=None, dtype=None, out=None, min_count=None):
108-
a, mask = _replace_nan(a, 0)
109-
result = np.sum(a, axis=axis, dtype=dtype)
99+
mask = isnull(a)
100+
result = np.nansum(a, axis=axis, dtype=dtype)
110101
if min_count is not None:
111102
return _maybe_null_out(result, axis, mask, min_count)
112103
else:
@@ -173,7 +164,7 @@ def nanstd(a, axis=None, dtype=None, out=None, ddof=0):
173164

174165

175166
def nanprod(a, axis=None, dtype=None, out=None, min_count=None):
176-
a, mask = _replace_nan(a, 1)
167+
mask = isnull(a)
177168
result = nputils.nanprod(a, axis=axis, dtype=dtype, out=out)
178169
if min_count is not None:
179170
return _maybe_null_out(result, axis, mask, min_count)

xarray/tests/test_units.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import pandas as pd
88
import pytest
9+
from packaging import version
910

1011
import xarray as xr
1112
from xarray.core import dtypes, duck_array_ops
@@ -1530,8 +1531,12 @@ class TestVariable:
15301531
ids=repr,
15311532
)
15321533
def test_aggregation(self, func, dtype):
1533-
if func.name == "prod" and dtype.kind == "f":
1534-
pytest.xfail(reason="nanprod is not supported, yet")
1534+
if (
1535+
func.name == "prod"
1536+
and dtype.kind == "f"
1537+
and version.parse(pint.__version__) < version.parse("0.19")
1538+
):
1539+
pytest.xfail(reason="nanprod is not by older `pint` versions")
15351540

15361541
array = np.linspace(0, 1, 10).astype(dtype) * (
15371542
unit_registry.m if func.name != "cumprod" else unit_registry.dimensionless
@@ -2387,8 +2392,12 @@ def test_repr(self, func, variant, dtype):
23872392
ids=repr,
23882393
)
23892394
def test_aggregation(self, func, dtype):
2390-
if func.name == "prod" and dtype.kind == "f":
2391-
pytest.xfail(reason="nanprod is not supported, yet")
2395+
if (
2396+
func.name == "prod"
2397+
and dtype.kind == "f"
2398+
and version.parse(pint.__version__) < version.parse("0.19")
2399+
):
2400+
pytest.xfail(reason="nanprod is not by older `pint` versions")
23922401

23932402
array = np.arange(10).astype(dtype) * (
23942403
unit_registry.m if func.name != "cumprod" else unit_registry.dimensionless
@@ -4082,8 +4091,12 @@ def test_repr(self, func, variant, dtype):
40824091
ids=repr,
40834092
)
40844093
def test_aggregation(self, func, dtype):
4085-
if func.name == "prod" and dtype.kind == "f":
4086-
pytest.xfail(reason="nanprod is not supported, yet")
4094+
if (
4095+
func.name == "prod"
4096+
and dtype.kind == "f"
4097+
and version.parse(pint.__version__) < version.parse("0.19")
4098+
):
4099+
pytest.xfail(reason="nanprod is not by older `pint` versions")
40874100

40884101
unit_a, unit_b = (
40894102
(unit_registry.Pa, unit_registry.degK)

0 commit comments

Comments
 (0)