Skip to content

Commit 0436b3b

Browse files
gcariaGiacomo Cariapre-commit-ci[bot]dcherian
authored
Fix infinite recursion when calling np.fix (#10248)
Co-authored-by: Giacomo Caria <giacomo@chloris.earth> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
1 parent e37bfbf commit 0436b3b

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

xarray/computation/apply_ufunc.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,7 @@
3030
from xarray.core.formatting import limit_lines
3131
from xarray.core.indexes import Index, filter_indexes_from_coords
3232
from xarray.core.options import _get_keep_attrs
33-
from xarray.core.utils import (
34-
is_dict_like,
35-
result_name,
36-
)
33+
from xarray.core.utils import is_dict_like, result_name
3734
from xarray.core.variable import Variable
3835
from xarray.namedarray.parallelcompat import get_chunked_array_type
3936
from xarray.namedarray.pycompat import is_chunked_array
@@ -1212,6 +1209,8 @@ def apply_ufunc(
12121209
dask_gufunc_kwargs.setdefault("output_sizes", output_sizes)
12131210

12141211
if kwargs:
1212+
if "where" in kwargs and isinstance(kwargs["where"], DataArray):
1213+
kwargs["where"] = kwargs["where"].data # type:ignore[index]
12151214
func = functools.partial(func, **kwargs)
12161215

12171216
if keep_attrs is None:

xarray/tests/test_computation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2627,3 +2627,14 @@ def test_complex_number_reduce(compute_backend):
26272627
# Check that xarray doesn't call into numbagg, which doesn't compile for complex
26282628
# numbers at the moment (but will when numba supports dynamic compilation)
26292629
da.min()
2630+
2631+
2632+
def test_fix() -> None:
2633+
val = 3.0
2634+
val_fixed = np.fix(val)
2635+
2636+
da = xr.DataArray([val])
2637+
expected = xr.DataArray([val_fixed])
2638+
2639+
actual = np.fix(da)
2640+
assert_identical(expected, actual)

0 commit comments

Comments
 (0)