Skip to content

Commit 6aa6b80

Browse files
committed
fix imports
1 parent ebf0252 commit 6aa6b80

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

xarray/core/duck_array_ops.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,24 @@
1818
from numpy import any as array_any # noqa
1919
from numpy import (
2020
around, # noqa
21+
einsum, # noqa
22+
gradient, # noqa
2123
isclose,
24+
isin, # noqa
2225
isnat,
2326
take, # noqa
24-
zeros_like, # noqa
27+
tensordot, # noqa
28+
transpose, # noqa
29+
unravel_index, # noqa
30+
zeros_like,
2531
)
2632
from numpy import concatenate as _concatenate
2733
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]
2834
from numpy.lib.stride_tricks import sliding_window_view # noqa
2935

3036
from xarray.core import dask_array_ops, dtypes, nputils
3137
from xarray.core.utils import module_available
38+
from xarray.namedarray._array_api import _get_data_namespace
3239
from xarray.namedarray._typing import _arrayfunction_or_api
3340
from xarray.namedarray.parallelcompat import get_chunked_array_type, is_chunked_array
3441
from xarray.namedarray.pycompat import array_type
@@ -37,13 +44,6 @@
3744
dask_available = module_available("dask")
3845

3946

40-
def get_array_namespace(x):
41-
if hasattr(x, "__array_namespace__"):
42-
return x.__array_namespace__()
43-
else:
44-
return np
45-
46-
4747
def _dask_or_eager_func(
4848
name,
4949
eager_module=np,
@@ -121,7 +121,7 @@ def isnull(data):
121121
return isnat(data)
122122
elif issubclass(scalar_type, np.inexact):
123123
# float types use NaN for null
124-
xp = get_array_namespace(data)
124+
xp = _get_data_namespace(data)
125125
return xp.isnan(data)
126126
elif issubclass(scalar_type, (np.bool_, np.integer, np.character, np.void)):
127127
# these types cannot represent missing values
@@ -179,7 +179,7 @@ def cumulative_trapezoid(y, x, axis):
179179

180180
def astype(data, dtype, **kwargs):
181181
if hasattr(data, "__array_namespace__"):
182-
xp = get_array_namespace(data)
182+
xp = _get_data_namespace(data)
183183
if xp == np:
184184
# numpy currently doesn't have a astype:
185185
return data.astype(dtype, **kwargs)
@@ -211,7 +211,7 @@ def as_shared_dtype(scalars_or_arrays, xp=np):
211211

212212

213213
def broadcast_to(array, shape):
214-
xp = get_array_namespace(array)
214+
xp = _get_data_namespace(array)
215215
return xp.broadcast_to(array, shape)
216216

217217

@@ -289,7 +289,7 @@ def count(data, axis=None):
289289

290290

291291
def sum_where(data, axis=None, dtype=None, where=None):
292-
xp = get_array_namespace(data)
292+
xp = _get_data_namespace(data)
293293
if where is not None:
294294
a = where_method(xp.zeros_like(data), where, data)
295295
else:
@@ -300,7 +300,7 @@ def sum_where(data, axis=None, dtype=None, where=None):
300300

301301
def where(condition, x, y):
302302
"""Three argument where() with better dtype promotion rules."""
303-
xp = get_array_namespace(condition)
303+
xp = _get_data_namespace(condition)
304304
return xp.where(condition, *as_shared_dtype([x, y], xp=xp))
305305

306306

@@ -320,19 +320,19 @@ def fillna(data, other):
320320
def concatenate(arrays, axis=0):
321321
"""concatenate() with better dtype promotion rules."""
322322
if hasattr(arrays[0], "__array_namespace__"):
323-
xp = get_array_namespace(arrays[0])
323+
xp = _get_data_namespace(arrays[0])
324324
return xp.concat(as_shared_dtype(arrays, xp=xp), axis=axis)
325325
return _concatenate(as_shared_dtype(arrays), axis=axis)
326326

327327

328328
def stack(arrays, axis=0):
329329
"""stack() with better dtype promotion rules."""
330-
xp = get_array_namespace(arrays[0])
330+
xp = _get_data_namespace(arrays[0])
331331
return xp.stack(as_shared_dtype(arrays, xp=xp), axis=axis)
332332

333333

334334
def reshape(array, shape):
335-
xp = get_array_namespace(array)
335+
xp = _get_data_namespace(array)
336336
return xp.reshape(array, shape)
337337

338338

@@ -376,7 +376,7 @@ def f(values, axis=None, skipna=None, **kwargs):
376376
if name in ["sum", "prod"]:
377377
kwargs.pop("min_count", None)
378378

379-
xp = get_array_namespace(values)
379+
xp = _get_data_namespace(values)
380380
func = getattr(xp, name)
381381

382382
try:

0 commit comments

Comments
 (0)