Skip to content

Commit 7d143da

Browse files
committed
Changes to allow nan functions to work with xarray
1 parent 4c29b50 commit 7d143da

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

cubed/array/nan_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
# https://github.com/data-apis/array-api/issues/621
99

1010

11-
def nanmean(x, /, *, axis=None, keepdims=False, split_every=None):
11+
def nanmean(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
1212
"""Compute the arithmetic mean along the specified axis, ignoring NaNs."""
13-
dtype = x.dtype
13+
dtype = dtype or x.dtype
1414
intermediate_dtype = [("n", nxp.int64), ("total", nxp.float64)]
1515
return reduction(
1616
x,

cubed/array_api/array_object.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,9 +367,9 @@ def __array_namespace__(self, /, *, api_version=None):
367367
"2023.12",
368368
):
369369
raise ValueError(f"Unrecognized array API version: {api_version!r}")
370-
import cubed.array_api as array_api
370+
import cubed
371371

372-
return array_api
372+
return cubed
373373

374374
def __bool__(self, /):
375375
if self.ndim != 0:

0 commit comments

Comments
 (0)