Skip to content

Commit ae41d82

Browse files
dcherianandersy005
andauthored
Enable numbagg for reductions (#8316)
Co-authored-by: Anderson Banihirwe <13301940+andersy005@users.noreply.github.com>
1 parent 087fe45 commit ae41d82

File tree

2 files changed

+49
-15
lines changed

2 files changed

+49
-15
lines changed

xarray/core/nputils.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import pandas as pd
77
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]
8+
from packaging.version import Version
89

910
# remove once numpy 2.0 is the oldest supported version
1011
try:
@@ -18,11 +19,20 @@
1819
try:
1920
import bottleneck as bn
2021

21-
_USE_BOTTLENECK = True
22+
_BOTTLENECK_AVAILABLE = True
2223
except ImportError:
2324
# use numpy methods instead
2425
bn = np
25-
_USE_BOTTLENECK = False
26+
_BOTTLENECK_AVAILABLE = False
27+
28+
try:
29+
import numbagg
30+
31+
_HAS_NUMBAGG = Version(numbagg.__version__) >= Version("0.5.0")
32+
except ImportError:
33+
# use numpy methods instead
34+
numbagg = np
35+
_HAS_NUMBAGG = False
2636

2737

2838
def _select_along_axis(values, idx, axis):
@@ -161,13 +171,30 @@ def __setitem__(self, key, value):
161171
self._array[key] = np.moveaxis(value, vindex_positions, mixed_positions)
162172

163173

164-
def _create_bottleneck_method(name, npmodule=np):
174+
def _create_method(name, npmodule=np):
165175
def f(values, axis=None, **kwargs):
166176
dtype = kwargs.get("dtype", None)
167177
bn_func = getattr(bn, name, None)
178+
nba_func = getattr(numbagg, name, None)
168179

169180
if (
170-
_USE_BOTTLENECK
181+
_HAS_NUMBAGG
182+
and OPTIONS["use_numbagg"]
183+
and isinstance(values, np.ndarray)
184+
and nba_func is not None
185+
# numbagg uses ddof=1 only, but numpy uses ddof=0 by default
186+
and (("var" in name or "std" in name) and kwargs.get("ddof", 0) == 1)
187+
# TODO: bool?
188+
and values.dtype.kind in "uifc"
189+
# and values.dtype.isnative
190+
and (dtype is None or np.dtype(dtype) == values.dtype)
191+
):
192+
# numbagg does not take care dtype, ddof
193+
kwargs.pop("dtype", None)
194+
kwargs.pop("ddof", None)
195+
result = nba_func(values, axis=axis, **kwargs)
196+
elif (
197+
_BOTTLENECK_AVAILABLE
171198
and OPTIONS["use_bottleneck"]
172199
and isinstance(values, np.ndarray)
173200
and bn_func is not None
@@ -233,14 +260,14 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
233260
return coeffs, residuals
234261

235262

236-
nanmin = _create_bottleneck_method("nanmin")
237-
nanmax = _create_bottleneck_method("nanmax")
238-
nanmean = _create_bottleneck_method("nanmean")
239-
nanmedian = _create_bottleneck_method("nanmedian")
240-
nanvar = _create_bottleneck_method("nanvar")
241-
nanstd = _create_bottleneck_method("nanstd")
242-
nanprod = _create_bottleneck_method("nanprod")
243-
nancumsum = _create_bottleneck_method("nancumsum")
244-
nancumprod = _create_bottleneck_method("nancumprod")
245-
nanargmin = _create_bottleneck_method("nanargmin")
246-
nanargmax = _create_bottleneck_method("nanargmax")
263+
nanmin = _create_method("nanmin")
264+
nanmax = _create_method("nanmax")
265+
nanmean = _create_method("nanmean")
266+
nanmedian = _create_method("nanmedian")
267+
nanvar = _create_method("nanvar")
268+
nanstd = _create_method("nanstd")
269+
nanprod = _create_method("nanprod")
270+
nancumsum = _create_method("nancumsum")
271+
nancumprod = _create_method("nancumprod")
272+
nanargmin = _create_method("nanargmin")
273+
nanargmax = _create_method("nanargmax")

xarray/core/options.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"keep_attrs",
2828
"warn_for_unclosed_files",
2929
"use_bottleneck",
30+
"use_numbagg",
3031
"use_flox",
3132
]
3233

@@ -50,6 +51,7 @@ class T_Options(TypedDict):
5051
warn_for_unclosed_files: bool
5152
use_bottleneck: bool
5253
use_flox: bool
54+
use_numbagg: bool
5355

5456

5557
OPTIONS: T_Options = {
@@ -72,6 +74,7 @@ class T_Options(TypedDict):
7274
"warn_for_unclosed_files": False,
7375
"use_bottleneck": True,
7476
"use_flox": True,
77+
"use_numbagg": True,
7578
}
7679

7780
_JOIN_OPTIONS = frozenset(["inner", "outer", "left", "right", "exact"])
@@ -98,6 +101,7 @@ def _positive_integer(value: int) -> bool:
98101
"file_cache_maxsize": _positive_integer,
99102
"keep_attrs": lambda choice: choice in [True, False, "default"],
100103
"use_bottleneck": lambda value: isinstance(value, bool),
104+
"use_numbagg": lambda value: isinstance(value, bool),
101105
"use_flox": lambda value: isinstance(value, bool),
102106
"warn_for_unclosed_files": lambda value: isinstance(value, bool),
103107
}
@@ -230,6 +234,9 @@ class set_options:
230234
use_flox : bool, default: True
231235
Whether to use ``numpy_groupies`` and `flox`` to
232236
accelerate groupby and resampling reductions.
237+
use_numbagg : bool, default: True
238+
Whether to use ``numbagg`` to accelerate reductions.
239+
Takes precedence over ``use_bottleneck`` when both are True.
233240
warn_for_unclosed_files : bool, default: False
234241
Whether or not to issue a warning when unclosed files are
235242
deallocated. This is mostly useful for debugging.

0 commit comments

Comments
 (0)