18
18
from numpy import any as array_any # noqa
19
19
from numpy import (
20
20
around , # noqa
21
+ einsum , # noqa
22
+ gradient , # noqa
21
23
isclose ,
24
+ isin , # noqa
22
25
isnat ,
23
26
take , # noqa
24
- zeros_like , # noqa
27
+ tensordot , # noqa
28
+ transpose , # noqa
29
+ unravel_index , # noqa
30
+ zeros_like ,
25
31
)
26
32
from numpy import concatenate as _concatenate
27
33
from numpy .core .multiarray import normalize_axis_index # type: ignore[attr-defined]
28
34
from numpy .lib .stride_tricks import sliding_window_view # noqa
29
35
30
36
from xarray .core import dask_array_ops , dtypes , nputils
31
37
from xarray .core .utils import module_available
38
+ from xarray .namedarray ._array_api import _get_data_namespace
32
39
from xarray .namedarray ._typing import _arrayfunction_or_api
33
40
from xarray .namedarray .parallelcompat import get_chunked_array_type , is_chunked_array
34
41
from xarray .namedarray .pycompat import array_type
37
44
dask_available = module_available ("dask" )
38
45
39
46
40
- def get_array_namespace (x ):
41
- if hasattr (x , "__array_namespace__" ):
42
- return x .__array_namespace__ ()
43
- else :
44
- return np
45
-
46
-
47
47
def _dask_or_eager_func (
48
48
name ,
49
49
eager_module = np ,
@@ -121,7 +121,7 @@ def isnull(data):
121
121
return isnat (data )
122
122
elif issubclass (scalar_type , np .inexact ):
123
123
# float types use NaN for null
124
- xp = get_array_namespace (data )
124
+ xp = _get_data_namespace (data )
125
125
return xp .isnan (data )
126
126
elif issubclass (scalar_type , (np .bool_ , np .integer , np .character , np .void )):
127
127
# these types cannot represent missing values
@@ -179,7 +179,7 @@ def cumulative_trapezoid(y, x, axis):
179
179
180
180
def astype (data , dtype , ** kwargs ):
181
181
if hasattr (data , "__array_namespace__" ):
182
- xp = get_array_namespace (data )
182
+ xp = _get_data_namespace (data )
183
183
if xp == np :
184
184
# numpy currently doesn't have a astype:
185
185
return data .astype (dtype , ** kwargs )
@@ -211,7 +211,7 @@ def as_shared_dtype(scalars_or_arrays, xp=np):
211
211
212
212
213
213
def broadcast_to (array , shape ):
214
- xp = get_array_namespace (array )
214
+ xp = _get_data_namespace (array )
215
215
return xp .broadcast_to (array , shape )
216
216
217
217
@@ -289,7 +289,7 @@ def count(data, axis=None):
289
289
290
290
291
291
def sum_where (data , axis = None , dtype = None , where = None ):
292
- xp = get_array_namespace (data )
292
+ xp = _get_data_namespace (data )
293
293
if where is not None :
294
294
a = where_method (xp .zeros_like (data ), where , data )
295
295
else :
@@ -300,7 +300,7 @@ def sum_where(data, axis=None, dtype=None, where=None):
300
300
301
301
def where (condition , x , y ):
302
302
"""Three argument where() with better dtype promotion rules."""
303
- xp = get_array_namespace (condition )
303
+ xp = _get_data_namespace (condition )
304
304
return xp .where (condition , * as_shared_dtype ([x , y ], xp = xp ))
305
305
306
306
@@ -320,19 +320,19 @@ def fillna(data, other):
320
320
def concatenate (arrays , axis = 0 ):
321
321
"""concatenate() with better dtype promotion rules."""
322
322
if hasattr (arrays [0 ], "__array_namespace__" ):
323
- xp = get_array_namespace (arrays [0 ])
323
+ xp = _get_data_namespace (arrays [0 ])
324
324
return xp .concat (as_shared_dtype (arrays , xp = xp ), axis = axis )
325
325
return _concatenate (as_shared_dtype (arrays ), axis = axis )
326
326
327
327
328
328
def stack (arrays , axis = 0 ):
329
329
"""stack() with better dtype promotion rules."""
330
- xp = get_array_namespace (arrays [0 ])
330
+ xp = _get_data_namespace (arrays [0 ])
331
331
return xp .stack (as_shared_dtype (arrays , xp = xp ), axis = axis )
332
332
333
333
334
334
def reshape (array , shape ):
335
- xp = get_array_namespace (array )
335
+ xp = _get_data_namespace (array )
336
336
return xp .reshape (array , shape )
337
337
338
338
@@ -376,7 +376,7 @@ def f(values, axis=None, skipna=None, **kwargs):
376
376
if name in ["sum" , "prod" ]:
377
377
kwargs .pop ("min_count" , None )
378
378
379
- xp = get_array_namespace (values )
379
+ xp = _get_data_namespace (values )
380
380
func = getattr (xp , name )
381
381
382
382
try :
0 commit comments