Skip to content

Commit 039ec69

Browse files
committed
Changes to allow nan functions to work with xarray
1 parent 5f75ba2 commit 039ec69

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

cubed/array_api/array_object.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,9 +367,9 @@ def __array_namespace__(self, /, *, api_version=None):
367367
"2023.12",
368368
):
369369
raise ValueError(f"Unrecognized array API version: {api_version!r}")
370-
import cubed.array_api as array_api
370+
import cubed
371371

372-
return array_api
372+
return cubed
373373

374374
def __bool__(self, /):
375375
if self.ndim != 0:

cubed/nan_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
# https://github.com/data-apis/array-api/issues/621
1919

2020

21-
def nanmean(x, /, *, axis=None, keepdims=False, split_every=None):
21+
def nanmean(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
2222
"""Compute the arithmetic mean along the specified axis, ignoring NaNs."""
23-
dtype = x.dtype
23+
dtype = dtype or x.dtype
2424
intermediate_dtype = [("n", nxp.int64), ("total", nxp.float64)]
2525
return reduction(
2626
x,

0 commit comments

Comments
 (0)