Skip to content

Commit 231ef95

Browse files
committed
Add inspection namespace for dask
1 parent 774d175 commit 231ef95

File tree

2 files changed

+355
-6
lines changed

2 files changed

+355
-6
lines changed

array_api_compat/dask/array/_aliases.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from ..._internal import get_xp
77

8+
from ._info import __array_namespace_info__
9+
810
import numpy as np
911
from numpy import (
1012
# Constants
@@ -208,12 +210,14 @@ def _isscalar(a):
208210

209211
common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]
210212

211-
__all__ = common_aliases + ['asarray', 'bool', 'acos',
212-
'acosh', 'asin', 'asinh', 'atan', 'atan2',
213+
__all__ = common_aliases + ['__array_namespace_info__', 'asarray', 'bool',
214+
'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2',
213215
'atanh', 'bitwise_left_shift', 'bitwise_invert',
214-
'bitwise_right_shift', 'concat', 'pow',
215-
'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8',
216-
'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64',
217-
'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type']
216+
'bitwise_right_shift', 'concat', 'pow', 'e',
217+
'inf', 'nan', 'pi', 'newaxis', 'float32',
218+
'float64', 'int8', 'int16', 'int32', 'int64',
219+
'uint8', 'uint16', 'uint32', 'uint64',
220+
'complex64', 'complex128', 'iinfo', 'finfo',
221+
'can_cast', 'result_type']
218222

219223
_all_ignore = ['get_xp', 'da', 'partial', 'common_aliases', 'np']

array_api_compat/dask/array/_info.py

Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
"""
2+
Array API Inspection namespace
3+
4+
This is the namespace for inspection functions as defined by the array API
5+
standard. See
6+
https://data-apis.org/array-api/latest/API_specification/inspection.html for
7+
more details.
8+
9+
"""
10+
from numpy import (
11+
dtype,
12+
bool_ as bool,
13+
intp,
14+
int8,
15+
int16,
16+
int32,
17+
int64,
18+
uint8,
19+
uint16,
20+
uint32,
21+
uint64,
22+
float32,
23+
float64,
24+
complex64,
25+
complex128,
26+
)
27+
28+
from ...common._helpers import _DASK_DEVICE
29+
30+
class __array_namespace_info__:
31+
"""
32+
Get the array API inspection namespace for Dask.
33+
34+
The array API inspection namespace defines the following functions:
35+
36+
- capabilities()
37+
- default_device()
38+
- default_dtypes()
39+
- dtypes()
40+
- devices()
41+
42+
See
43+
https://data-apis.org/array-api/latest/API_specification/inspection.html
44+
for more details.
45+
46+
Returns
47+
-------
48+
info : ModuleType
49+
The array API inspection namespace for Dask.
50+
51+
Examples
52+
--------
53+
>>> info = np.__array_namespace_info__()
54+
>>> info.default_dtypes()
55+
{'real floating': dask.float64,
56+
'complex floating': dask.complex128,
57+
'integral': dask.int64,
58+
'indexing': dask.int64}
59+
60+
"""
61+
62+
__module__ = 'dask.array'
63+
64+
def capabilities(self):
65+
"""
66+
Return a dictionary of array API library capabilities.
67+
68+
The resulting dictionary has the following keys:
69+
70+
- **"boolean indexing"**: boolean indicating whether an array library
71+
supports boolean indexing. Always ``False`` for Dask.
72+
73+
- **"data-dependent shapes"**: boolean indicating whether an array
74+
library supports data-dependent output shapes. Always ``False`` for
75+
Dask.
76+
77+
See
78+
https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html
79+
for more details.
80+
81+
See Also
82+
--------
83+
__array_namespace_info__.default_device,
84+
__array_namespace_info__.default_dtypes,
85+
__array_namespace_info__.dtypes,
86+
__array_namespace_info__.devices
87+
88+
Returns
89+
-------
90+
capabilities : dict
91+
A dictionary of array API library capabilities.
92+
93+
Examples
94+
--------
95+
>>> info = np.__array_namespace_info__()
96+
>>> info.capabilities()
97+
{'boolean indexing': True,
98+
'data-dependent shapes': True}
99+
100+
"""
101+
return {
102+
"boolean indexing": False,
103+
"data-dependent shapes": False,
104+
# 'max rank' will be part of the 2024.12 standard
105+
# "max rank": 64,
106+
}
107+
108+
def default_device(self):
109+
"""
110+
The default device used for new Dask arrays.
111+
112+
For Dask, this always returns ``'cpu'``.
113+
114+
See Also
115+
--------
116+
__array_namespace_info__.capabilities,
117+
__array_namespace_info__.default_dtypes,
118+
__array_namespace_info__.dtypes,
119+
__array_namespace_info__.devices
120+
121+
Returns
122+
-------
123+
device : str
124+
The default device used for new Dask arrays.
125+
126+
Examples
127+
--------
128+
>>> info = np.__array_namespace_info__()
129+
>>> info.default_device()
130+
'cpu'
131+
132+
"""
133+
return "cpu"
134+
135+
def default_dtypes(self, *, device=None):
136+
"""
137+
The default data types used for new Dask arrays.
138+
139+
For Dask, this always returns the following dictionary:
140+
141+
- **"real floating"**: ``numpy.float64``
142+
- **"complex floating"**: ``numpy.complex128``
143+
- **"integral"**: ``numpy.intp``
144+
- **"indexing"**: ``numpy.intp``
145+
146+
Parameters
147+
----------
148+
device : str, optional
149+
The device to get the default data types for.
150+
151+
Returns
152+
-------
153+
dtypes : dict
154+
A dictionary describing the default data types used for new Dask
155+
arrays.
156+
157+
See Also
158+
--------
159+
__array_namespace_info__.capabilities,
160+
__array_namespace_info__.default_device,
161+
__array_namespace_info__.dtypes,
162+
__array_namespace_info__.devices
163+
164+
Examples
165+
--------
166+
>>> info = np.__array_namespace_info__()
167+
>>> info.default_dtypes()
168+
{'real floating': dask.float64,
169+
'complex floating': dask.complex128,
170+
'integral': dask.int64,
171+
'indexing': dask.int64}
172+
173+
"""
174+
if device not in ["cpu", _DASK_DEVICE, None]:
175+
raise ValueError(
176+
'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:'
177+
f' {device}'
178+
)
179+
return {
180+
"real floating": dtype(float64),
181+
"complex floating": dtype(complex128),
182+
"integral": dtype(intp),
183+
"indexing": dtype(intp),
184+
}
185+
186+
def dtypes(self, *, device=None, kind=None):
187+
"""
188+
The array API data types supported by Dask.
189+
190+
Note that this function only returns data types that are defined by
191+
the array API.
192+
193+
Parameters
194+
----------
195+
device : str, optional
196+
The device to get the data types for.
197+
kind : str or tuple of str, optional
198+
The kind of data types to return. If ``None``, all data types are
199+
returned. If a string, only data types of that kind are returned.
200+
If a tuple, a dictionary containing the union of the given kinds
201+
is returned. The following kinds are supported:
202+
203+
- ``'bool'``: boolean data types (i.e., ``bool``).
204+
- ``'signed integer'``: signed integer data types (i.e., ``int8``,
205+
``int16``, ``int32``, ``int64``).
206+
- ``'unsigned integer'``: unsigned integer data types (i.e.,
207+
``uint8``, ``uint16``, ``uint32``, ``uint64``).
208+
- ``'integral'``: integer data types. Shorthand for ``('signed
209+
integer', 'unsigned integer')``.
210+
- ``'real floating'``: real-valued floating-point data types
211+
(i.e., ``float32``, ``float64``).
212+
- ``'complex floating'``: complex floating-point data types (i.e.,
213+
``complex64``, ``complex128``).
214+
- ``'numeric'``: numeric data types. Shorthand for ``('integral',
215+
'real floating', 'complex floating')``.
216+
217+
Returns
218+
-------
219+
dtypes : dict
220+
A dictionary mapping the names of data types to the corresponding
221+
Dask data types.
222+
223+
See Also
224+
--------
225+
__array_namespace_info__.capabilities,
226+
__array_namespace_info__.default_device,
227+
__array_namespace_info__.default_dtypes,
228+
__array_namespace_info__.devices
229+
230+
Examples
231+
--------
232+
>>> info = np.__array_namespace_info__()
233+
>>> info.dtypes(kind='signed integer')
234+
{'int8': dask.int8,
235+
'int16': dask.int16,
236+
'int32': dask.int32,
237+
'int64': dask.int64}
238+
239+
"""
240+
if device not in ["cpu", _DASK_DEVICE, None]:
241+
raise ValueError(
242+
'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:'
243+
f' {device}'
244+
)
245+
if kind is None:
246+
return {
247+
"bool": dtype(bool),
248+
"int8": dtype(int8),
249+
"int16": dtype(int16),
250+
"int32": dtype(int32),
251+
"int64": dtype(int64),
252+
"uint8": dtype(uint8),
253+
"uint16": dtype(uint16),
254+
"uint32": dtype(uint32),
255+
"uint64": dtype(uint64),
256+
"float32": dtype(float32),
257+
"float64": dtype(float64),
258+
"complex64": dtype(complex64),
259+
"complex128": dtype(complex128),
260+
}
261+
if kind == "bool":
262+
return {"bool": bool}
263+
if kind == "signed integer":
264+
return {
265+
"int8": dtype(int8),
266+
"int16": dtype(int16),
267+
"int32": dtype(int32),
268+
"int64": dtype(int64),
269+
}
270+
if kind == "unsigned integer":
271+
return {
272+
"uint8": dtype(uint8),
273+
"uint16": dtype(uint16),
274+
"uint32": dtype(uint32),
275+
"uint64": dtype(uint64),
276+
}
277+
if kind == "integral":
278+
return {
279+
"int8": dtype(int8),
280+
"int16": dtype(int16),
281+
"int32": dtype(int32),
282+
"int64": dtype(int64),
283+
"uint8": dtype(uint8),
284+
"uint16": dtype(uint16),
285+
"uint32": dtype(uint32),
286+
"uint64": dtype(uint64),
287+
}
288+
if kind == "real floating":
289+
return {
290+
"float32": dtype(float32),
291+
"float64": dtype(float64),
292+
}
293+
if kind == "complex floating":
294+
return {
295+
"complex64": dtype(complex64),
296+
"complex128": dtype(complex128),
297+
}
298+
if kind == "numeric":
299+
return {
300+
"int8": dtype(int8),
301+
"int16": dtype(int16),
302+
"int32": dtype(int32),
303+
"int64": dtype(int64),
304+
"uint8": dtype(uint8),
305+
"uint16": dtype(uint16),
306+
"uint32": dtype(uint32),
307+
"uint64": dtype(uint64),
308+
"float32": dtype(float32),
309+
"float64": dtype(float64),
310+
"complex64": dtype(complex64),
311+
"complex128": dtype(complex128),
312+
}
313+
if isinstance(kind, tuple):
314+
res = {}
315+
for k in kind:
316+
res.update(self.dtypes(kind=k))
317+
return res
318+
raise ValueError(f"unsupported kind: {kind!r}")
319+
320+
def devices(self):
321+
"""
322+
The devices supported by Dask.
323+
324+
For Dask, this always returns ``['cpu', DASK_DEVICE]``.
325+
326+
Returns
327+
-------
328+
devices : list of str
329+
The devices supported by Dask.
330+
331+
See Also
332+
--------
333+
__array_namespace_info__.capabilities,
334+
__array_namespace_info__.default_device,
335+
__array_namespace_info__.default_dtypes,
336+
__array_namespace_info__.dtypes
337+
338+
Examples
339+
--------
340+
>>> info = np.__array_namespace_info__()
341+
>>> info.devices()
342+
['cpu', DASK_DEVICE]
343+
344+
"""
345+
return ["cpu", _DASK_DEVICE]

0 commit comments

Comments
 (0)