Skip to content

Commit 8c6e896

Browse files
committed
replace is_duck_array with _arrayfunction_or_api instead
1 parent f26e259 commit 8c6e896

File tree

10 files changed

+46
-72
lines changed

10 files changed

+46
-72
lines changed

xarray/core/arithmetic.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,9 @@
1414
VariableOpsMixin,
1515
)
1616
from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
17-
from xarray.core.ops import (
18-
IncludeNumpySameMethods,
19-
IncludeReduceMethods,
20-
)
17+
from xarray.core.ops import IncludeNumpySameMethods, IncludeReduceMethods
2118
from xarray.core.options import OPTIONS, _get_keep_attrs
22-
from xarray.namedarray.pycompat import is_duck_array
19+
from xarray.namedarray._typing import _arrayfunction_or_api
2320

2421

2522
class SupportsArithmetic:
@@ -48,7 +45,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
4845
# See the docstring example for numpy.lib.mixins.NDArrayOperatorsMixin.
4946
out = kwargs.get("out", ())
5047
for x in inputs + out:
51-
if not is_duck_array(x) and not isinstance(
48+
if not isinstance(x, _arrayfunction_or_api) and not isinstance(
5249
x, self._HANDLED_TYPES + (SupportsArithmetic,)
5350
):
5451
return NotImplemented

xarray/core/dataset.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,17 +112,14 @@
112112
broadcast_variables,
113113
calculate_dimensions,
114114
)
115+
from xarray.namedarray._typing import _arrayfunction_or_api
115116
from xarray.namedarray.daskmanager import DaskManager
116117
from xarray.namedarray.parallelcompat import get_chunked_array_type, guess_chunkmanager
117-
from xarray.namedarray.pycompat import (
118-
array_type,
119-
is_chunked_array,
120-
)
118+
from xarray.namedarray.pycompat import array_type, is_chunked_array
121119
from xarray.namedarray.utils import (
122120
consolidate_dask_from_array_kwargs,
123121
either_dict_or_kwargs,
124122
is_dict_like,
125-
is_duck_array,
126123
is_duck_dask_array,
127124
)
128125
from xarray.plot.accessor import DatasetPlotAccessor
@@ -2719,7 +2716,7 @@ def _validate_indexers(
27192716
elif isinstance(v, Sequence) and len(v) == 0:
27202717
yield k, np.empty((0,), dtype="int64")
27212718
else:
2722-
if not is_duck_array(v):
2719+
if not isinstance(v, _arrayfunction_or_api):
27232720
v = np.asarray(v)
27242721

27252722
if v.dtype.kind in "US":

xarray/core/duck_array_ops.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,11 @@
1616
import pandas as pd
1717
from numpy import all as array_all # noqa
1818
from numpy import any as array_any # noqa
19-
from numpy import ( # noqa
19+
from numpy import (
2020
around, # noqa
21-
einsum,
22-
gradient,
2321
isclose,
24-
isin,
2522
isnat,
26-
take,
27-
tensordot,
28-
transpose,
29-
unravel_index,
23+
take, # noqa
3024
zeros_like, # noqa
3125
)
3226
from numpy import concatenate as _concatenate
@@ -35,9 +29,10 @@
3529

3630
from xarray.core import dask_array_ops, dtypes, nputils
3731
from xarray.core.utils import module_available
32+
from xarray.namedarray._typing import _arrayfunction_or_api
3833
from xarray.namedarray.parallelcompat import get_chunked_array_type, is_chunked_array
3934
from xarray.namedarray.pycompat import array_type
40-
from xarray.namedarray.utils import is_duck_array, is_duck_dask_array
35+
from xarray.namedarray.utils import is_duck_dask_array
4136

4237
dask_available = module_available("dask")
4338

@@ -193,7 +188,7 @@ def astype(data, dtype, **kwargs):
193188

194189

195190
def asarray(data, xp=np):
196-
return data if is_duck_array(data) else xp.asarray(data)
191+
return data if isinstance(data, _arrayfunction_or_api) else xp.asarray(data)
197192

198193

199194
def as_shared_dtype(scalars_or_arrays, xp=np):

xarray/core/formatting.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from xarray.core.duck_array_ops import array_equiv, astype
2020
from xarray.core.indexing import MemoryCachedArray
2121
from xarray.core.options import OPTIONS, _get_boolean_with_default
22+
from xarray.namedarray._typing import _arrayfunction_or_api
2223
from xarray.namedarray.pycompat import array_type, to_duck_array, to_numpy
23-
from xarray.namedarray.utils import is_duck_array
2424

2525
if TYPE_CHECKING:
2626
from xarray.core.coordinates import AbstractCoordinates
@@ -630,7 +630,7 @@ def short_data_repr(array):
630630
internal_data = getattr(array, "variable", array)._data
631631
if isinstance(array, np.ndarray):
632632
return short_array_repr(array)
633-
elif is_duck_array(internal_data):
633+
elif isinstance(internal_data, _arrayfunction_or_api):
634634
return limit_lines(repr(array.data), limit=40)
635635
elif getattr(array, "_in_memory", None):
636636
return short_array_repr(array)
@@ -789,7 +789,9 @@ def extra_items_repr(extra_keys, mapping, ab_side, kwargs):
789789
is_variable = True
790790
except AttributeError:
791791
# compare attribute value
792-
if is_duck_array(a_mapping[k]) or is_duck_array(b_mapping[k]):
792+
if isinstance(a_mapping[k], _arrayfunction_or_api) or isinstance(
793+
b_mapping[k], _arrayfunction_or_api
794+
):
793795
compatible = array_equiv(a_mapping[k], b_mapping[k])
794796
else:
795797
compatible = a_mapping[k] == b_mapping[k]

xarray/core/indexing.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,10 @@
2424
is_scalar,
2525
to_0d_array,
2626
)
27+
from xarray.namedarray._typing import _arrayfunction_or_api
2728
from xarray.namedarray.parallelcompat import get_chunked_array_type, is_chunked_array
2829
from xarray.namedarray.pycompat import array_type, integer_types
29-
from xarray.namedarray.utils import (
30-
either_dict_or_kwargs,
31-
is_duck_array,
32-
is_duck_dask_array,
33-
)
30+
from xarray.namedarray.utils import either_dict_or_kwargs, is_duck_dask_array
3431

3532
if TYPE_CHECKING:
3633
from numpy.typing import DTypeLike
@@ -377,7 +374,7 @@ def __init__(self, key):
377374
k = int(k)
378375
elif isinstance(k, slice):
379376
k = as_integer_slice(k)
380-
elif is_duck_array(k):
377+
elif isinstance(k, _arrayfunction_or_api):
381378
if not np.issubdtype(k.dtype, np.integer):
382379
raise TypeError(
383380
f"invalid indexer array, does not have integer dtype: {k!r}"
@@ -424,7 +421,7 @@ def __init__(self, key):
424421
"Please pass a numpy array by calling ``.compute``. "
425422
"See https://github.com/dask/dask/issues/8958."
426423
)
427-
elif is_duck_array(k):
424+
elif isinstance(k, _arrayfunction_or_api):
428425
if not np.issubdtype(k.dtype, np.integer):
429426
raise TypeError(
430427
f"invalid indexer array, does not have integer dtype: {k!r}"

xarray/core/nputils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from numpy import RankWarning
1515

1616
from xarray.core.options import OPTIONS
17-
from xarray.namedarray.pycompat import is_duck_array
17+
from xarray.namedarray._typing import _arrayfunction_or_api
1818

1919
try:
2020
import bottleneck as bn
@@ -143,7 +143,10 @@ def _advanced_indexer_subspaces(key):
143143

144144
non_slices = [k for k in key if not isinstance(k, slice)]
145145
broadcasted_shape = np.broadcast_shapes(
146-
*[item.shape if is_duck_array(item) else (0,) for item in non_slices]
146+
*[
147+
item.shape if isinstance(item, _arrayfunction_or_api) else (0,)
148+
for item in non_slices
149+
]
147150
)
148151
ndim = len(broadcasted_shape)
149152
mixed_positions = advanced_index_positions[0] + np.arange(ndim)

xarray/core/variable.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,13 @@
3636
infix_dims,
3737
maybe_coerce_to_str,
3838
)
39+
from xarray.namedarray._typing import _arrayfunction_or_api
3940
from xarray.namedarray.core import NamedArray
4041
from xarray.namedarray.parallelcompat import get_chunked_array_type
41-
from xarray.namedarray.pycompat import (
42-
integer_types,
43-
is_0d_dask_array,
44-
is_chunked_array,
45-
)
42+
from xarray.namedarray.pycompat import integer_types, is_0d_dask_array, is_chunked_array
4643
from xarray.namedarray.utils import (
4744
either_dict_or_kwargs,
4845
is_dict_like,
49-
is_duck_array,
5046
is_duck_dask_array,
5147
)
5248

@@ -411,7 +407,7 @@ def data(self):
411407
Variable.as_numpy
412408
Variable.values
413409
"""
414-
if is_duck_array(self._data):
410+
if isinstance(self._data, _arrayfunction_or_api):
415411
return self._data
416412
elif isinstance(self._data, indexing.ExplicitlyIndexed):
417413
return self._data.get_duck_array()
@@ -636,7 +632,7 @@ def _validate_indexers(self, key):
636632
for dim, k in zip(self.dims, key):
637633
if not isinstance(k, BASIC_INDEXING_TYPES):
638634
if not isinstance(k, Variable):
639-
if not is_duck_array(k):
635+
if not isinstance(k, _arrayfunction_or_api):
640636
k = np.asarray(k)
641637
if k.ndim > 1:
642638
raise IndexError(
@@ -681,7 +677,7 @@ def _broadcast_indexes_outer(self, key):
681677
if isinstance(k, Variable):
682678
k = k.data
683679
if not isinstance(k, BASIC_INDEXING_TYPES):
684-
if not is_duck_array(k):
680+
if not isinstance(k, _arrayfunction_or_api):
685681
k = np.asarray(k)
686682
if k.size == 0:
687683
# Slice by empty list; numpy could not infer the dtype
@@ -940,7 +936,7 @@ def load(self, **kwargs):
940936
self._data = as_compatible_data(loaded_data)
941937
elif isinstance(self._data, ExplicitlyIndexed):
942938
self._data = self._data.get_duck_array()
943-
elif not is_duck_array(self._data):
939+
elif not isinstance(self._data, _arrayfunction_or_api):
944940
self._data = np.asarray(self._data)
945941
return self
946942

xarray/namedarray/pycompat.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from packaging.version import Version
99

1010
from xarray.core.utils import is_scalar
11-
from xarray.namedarray.utils import is_duck_array, is_duck_dask_array
11+
from xarray.namedarray._typing import _arrayfunction_or_api
12+
from xarray.namedarray.utils import is_duck_dask_array
1213

1314
integer_types = (int, np.integer)
1415

@@ -85,7 +86,9 @@ def mod_version(mod: ModType) -> Version:
8586

8687

8788
def is_chunked_array(x) -> bool:
88-
return is_duck_dask_array(x) or (is_duck_array(x) and hasattr(x, "chunks"))
89+
return is_duck_dask_array(x) or (
90+
isinstance(x, _arrayfunction_or_api) and hasattr(x, "chunks")
91+
)
8992

9093

9194
def is_0d_dask_array(x):
@@ -120,7 +123,7 @@ def to_duck_array(data):
120123

121124
if isinstance(data, ExplicitlyIndexed):
122125
return data.get_duck_array()
123-
elif is_duck_array(data):
126+
elif isinstance(data, _arrayfunction_or_api):
124127
return data
125128
else:
126129
return np.asarray(data)

xarray/namedarray/utils.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import numpy as np
99

10+
from xarray.namedarray._typing import _arrayfunction_or_api
11+
1012
if TYPE_CHECKING:
1113
if sys.version_info >= (3, 10):
1214
from typing import TypeGuard
@@ -15,7 +17,7 @@
1517

1618
from numpy.typing import NDArray
1719

18-
from xarray.namedarray._typing import T_DuckArray, duckarray
20+
from xarray.namedarray._typing import duckarray
1921

2022
try:
2123
from dask.array.core import Array as DaskArray
@@ -26,10 +28,6 @@
2628

2729
T = TypeVar("T")
2830

29-
# https://stackoverflow.com/questions/74633074/how-to-type-hint-a-generic-numpy-array
30-
T_DType_co = TypeVar("T_DType_co", bound=np.dtype[np.generic], covariant=True)
31-
T_DType = TypeVar("T_DType", bound=np.dtype[np.generic])
32-
3331

3432
# Singleton type, as per https://github.com/python/typing/pull/240
3533
class Default(Enum):
@@ -68,21 +66,7 @@ def is_dask_collection(x: object) -> TypeGuard[DaskCollection]:
6866

6967

7068
def is_duck_dask_array(x: duckarray[Any, Any]) -> TypeGuard[DaskArray]:
71-
return is_dask_collection(x)
72-
73-
74-
def is_duck_array(value: Any) -> TypeGuard[T_DuckArray]:
75-
if isinstance(value, np.ndarray):
76-
return True
77-
return (
78-
hasattr(value, "ndim")
79-
and hasattr(value, "shape")
80-
and hasattr(value, "dtype")
81-
and (
82-
(hasattr(value, "__array_function__") and hasattr(value, "__array_ufunc__"))
83-
or hasattr(value, "__array_namespace__")
84-
)
85-
)
69+
return isinstance(x, _arrayfunction_or_api) and is_dask_collection(x)
8670

8771

8872
def to_0d_object_array(

xarray/testing.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from xarray.core.dataset import Dataset
1414
from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex, default_indexes
1515
from xarray.core.variable import IndexVariable, Variable
16-
from xarray.namedarray.utils import is_duck_array
16+
from xarray.namedarray._typing import _arrayfunction_or_api
1717

1818
__all__ = (
1919
"assert_allclose",
@@ -229,14 +229,14 @@ def assert_duckarray_equal(x, y, err_msg="", verbose=True):
229229
"""Like `np.testing.assert_array_equal`, but for duckarrays"""
230230
__tracebackhide__ = True
231231

232-
if not is_duck_array(x) and not utils.is_scalar(x):
232+
if not isinstance(x, _arrayfunction_or_api) and not utils.is_scalar(x):
233233
x = np.asarray(x)
234234

235-
if not is_duck_array(y) and not utils.is_scalar(y):
235+
if not isinstance(y, _arrayfunction_or_api) and not utils.is_scalar(y):
236236
y = np.asarray(y)
237237

238-
if (is_duck_array(x) and utils.is_scalar(y)) or (
239-
utils.is_scalar(x) and is_duck_array(y)
238+
if (isinstance(x, _arrayfunction_or_api) and utils.is_scalar(y)) or (
239+
utils.is_scalar(x) and isinstance(y, _arrayfunction_or_api)
240240
):
241241
equiv = (x == y).all()
242242
else:

0 commit comments

Comments
 (0)