Skip to content

Commit 161acaa

Browse files
committed
Add the inspection APIs
1 parent f247130 commit 161acaa

File tree

4 files changed

+193
-0
lines changed

4 files changed

+193
-0
lines changed

array_api_strict/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,12 @@
260260

261261
__all__ += ["take"]
262262

263+
from ._info import __array_namespace_info__
264+
265+
__all__ += [
266+
"__array_namespace_info__",
267+
]
268+
263269
# linalg is an extension in the array API spec, which is a sub-namespace. Only
264270
# a subset of functions in it are imported into the top-level namespace.
265271
from . import linalg

array_api_strict/_flags.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,14 @@ def get_array_api_strict_flags():
178178
This function is **not** part of the array API standard. It only exists
179179
in array-api-strict.
180180
181+
.. note::
182+
183+
The `inspection API
184+
<https://data-apis.org/array-api/latest/API_specification/inspection.html>`__
185+
provides a portable way to access most of this information. However, it
186+
is only present in standard versions starting with 2023.12. The array
187+
API version can be accessed portably using `xp.__array_api_version__`.
188+
181189
Returns
182190
-------
183191
dict

array_api_strict/_info.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
from __future__ import annotations
2+
3+
__all__ = [
4+
"__array_namespace_info__",
5+
"capabilities",
6+
"default_device",
7+
"default_dtypes",
8+
"devices",
9+
"dtypes",
10+
]
11+
12+
from typing import TYPE_CHECKING
13+
14+
if TYPE_CHECKING:
15+
from typing import Optional, Union, Tuple, List
16+
from ._typing import device, DefaultDataTypes, DataTypes, Capabilities, Info
17+
18+
from ._array_object import CPU_DEVICE
19+
from ._flags import get_array_api_strict_flags, requires_api_version
20+
from ._dtypes import bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128
21+
22+
@requires_api_version('2023.12')
23+
def __array_namespace_info__() -> Info:
24+
import array_api_strict._info
25+
return array_api_strict._info
26+
27+
@requires_api_version('2023.12')
28+
def capabilities() -> Capabilities:
29+
flags = get_array_api_strict_flags()
30+
return {"boolean indexing": flags['boolean_indexing'],
31+
"data-dependent shapes": flags['data_dependent_shapes'],
32+
}
33+
34+
@requires_api_version('2023.12')
35+
def default_device() -> device:
36+
return CPU_DEVICE
37+
38+
@requires_api_version('2023.12')
39+
def default_dtypes(
40+
*,
41+
device: Optional[device] = None,
42+
) -> DefaultDataTypes:
43+
return {
44+
"real floating": float64,
45+
"complex floating": complex128,
46+
"integral": int64,
47+
"indexing": int64,
48+
}
49+
50+
@requires_api_version('2023.12')
51+
def dtypes(
52+
*,
53+
device: Optional[device] = None,
54+
kind: Optional[Union[str, Tuple[str, ...]]] = None,
55+
) -> DataTypes:
56+
if kind is None:
57+
return {
58+
"bool": bool,
59+
"int8": int8,
60+
"int16": int16,
61+
"int32": int32,
62+
"int64": int64,
63+
"uint8": uint8,
64+
"uint16": uint16,
65+
"uint32": uint32,
66+
"uint64": uint64,
67+
"float32": float32,
68+
"float64": float64,
69+
"complex64": complex64,
70+
"complex128": complex128,
71+
}
72+
if kind == "bool":
73+
return {"bool": bool}
74+
if kind == "signed integer":
75+
return {
76+
"int8": int8,
77+
"int16": int16,
78+
"int32": int32,
79+
"int64": int64,
80+
}
81+
if kind == "unsigned integer":
82+
return {
83+
"uint8": uint8,
84+
"uint16": uint16,
85+
"uint32": uint32,
86+
"uint64": uint64,
87+
}
88+
if kind == "integral":
89+
return {
90+
"int8": int8,
91+
"int16": int16,
92+
"int32": int32,
93+
"int64": int64,
94+
"uint8": uint8,
95+
"uint16": uint16,
96+
"uint32": uint32,
97+
"uint64": uint64,
98+
}
99+
if kind == "real floating":
100+
return {
101+
"float32": float32,
102+
"float64": float64,
103+
}
104+
if kind == "complex floating":
105+
return {
106+
"complex64": complex64,
107+
"complex128": complex128,
108+
}
109+
if kind == "numeric":
110+
return {
111+
"int8": int8,
112+
"int16": int16,
113+
"int32": int32,
114+
"int64": int64,
115+
"uint8": uint8,
116+
"uint16": uint16,
117+
"uint32": uint32,
118+
"uint64": uint64,
119+
"float32": float32,
120+
"float64": float64,
121+
"complex64": complex64,
122+
"complex128": complex128,
123+
}
124+
if isinstance(kind, tuple):
125+
res = {}
126+
for k in kind:
127+
res.update(dtypes(kind=k))
128+
return res
129+
raise ValueError(f"unsupported kind: {kind!r}")
130+
131+
@requires_api_version('2023.12')
132+
def devices() -> List[device]:
133+
return [CPU_DEVICE]
134+
135+
__all__ = [
136+
"capabilities",
137+
"default_device",
138+
"default_dtypes",
139+
"devices",
140+
"dtypes",
141+
]

array_api_strict/_typing.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
from typing import (
2323
Any,
24+
ModuleType,
25+
TypedDict,
2426
TypeVar,
2527
Protocol,
2628
)
@@ -39,6 +41,8 @@ def __len__(self, /) -> int: ...
3941

4042
Dtype = _DType
4143

44+
Info = ModuleType
45+
4246
if sys.version_info >= (3, 12):
4347
from collections.abc import Buffer as SupportsBufferProtocol
4448
else:
@@ -48,3 +52,37 @@ def __len__(self, /) -> int: ...
4852

4953
class SupportsDLPack(Protocol):
5054
def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: ...
55+
56+
Capabilities = TypedDict(
57+
"Capabilities", {"boolean indexing": bool, "data-dependent shapes": bool}
58+
)
59+
60+
DefaultDataTypes = TypedDict(
61+
"DefaultDataTypes",
62+
{
63+
"real floating": Dtype,
64+
"complex floating": Dtype,
65+
"integral": Dtype,
66+
"indexing": Dtype,
67+
},
68+
)
69+
70+
DataTypes = TypedDict(
71+
"DataTypes",
72+
{
73+
"bool": Dtype,
74+
"float32": Dtype,
75+
"float64": Dtype,
76+
"complex64": Dtype,
77+
"complex128": Dtype,
78+
"int8": Dtype,
79+
"int16": Dtype,
80+
"int32": Dtype,
81+
"int64": Dtype,
82+
"uint8": Dtype,
83+
"uint16": Dtype,
84+
"uint32": Dtype,
85+
"uint64": Dtype,
86+
},
87+
total=False,
88+
)

0 commit comments

Comments
 (0)