Skip to content

Commit 5a79c9d

Browse files
authored
Add array API inspection utilities (#592)
* Add array API inspection utilities * JAX default_device() returns None
1 parent 7ca5ae9 commit 5a79c9d

File tree

5 files changed

+67
-7
lines changed

5 files changed

+67
-7
lines changed

api_status.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
4646
| | Multi-axis | :white_check_mark: | | |
4747
| | Boolean array | :x: | | Shape is data dependent, [#73](https://github.com/cubed-dev/cubed/issues/73) |
4848
| Indexing Functions | `take` | :white_check_mark: | 2022.12 | |
49-
| Inspection | `capabilities` | :x: | 2023.12 | |
50-
| | `default_device` | :x: | 2023.12 | |
51-
| | `default_dtypes` | :x: | 2023.12 | |
52-
| | `devices` | :x: | 2023.12 | |
53-
| | `dtypes` | :x: | 2023.12 | |
49+
| Inspection | `capabilities` | :white_check_mark: | 2023.12 | |
50+
| | `default_device` | :white_check_mark: | 2023.12 | |
51+
| | `default_dtypes` | :white_check_mark: | 2023.12 | |
52+
| | `devices` | :white_check_mark: | 2023.12 | |
53+
| | `dtypes` | :white_check_mark: | 2023.12 | |
5454
| Linear Algebra Functions | `matmul` | :white_check_mark: | | |
5555
| | `matrix_transpose` | :white_check_mark: | | |
5656
| | `tensordot` | :white_check_mark: | | |

cubed/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@
4747

4848
__array_api_version__ = "2022.12"
4949

50-
__all__ += ["__array_api_version__"]
50+
from .array_api.inspection import __array_namespace_info__
51+
52+
__all__ += ["__array_api_version__", "__array_namespace_info__"]
53+
5154

5255
from .array_api.array_object import Array
5356

cubed/array_api/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
__array_api_version__ = "2022.12"
44

5-
__all__ += ["__array_api_version__"]
5+
from .inspection import __array_namespace_info__
6+
7+
__all__ += ["__array_api_version__", "__array_namespace_info__"]
68

79
from .array_object import Array
810

cubed/array_api/inspection.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from cubed.backend_array_api import namespace as nxp
2+
3+
4+
class __array_namespace_info__:
5+
# capabilities are determined by Cubed, not the backend array API
6+
def capabilities(self):
7+
return {
8+
"boolean indexing": False,
9+
"data-dependent shapes": False,
10+
}
11+
12+
# devices and dtypes are determined by the backend array API
13+
14+
def default_device(self):
15+
return nxp.__array_namespace_info__().default_device()
16+
17+
def default_dtypes(self, *, device=None):
18+
return nxp.__array_namespace_info__().default_dtypes(device=device)
19+
20+
def devices(self):
21+
return nxp.__array_namespace_info__().devices()
22+
23+
def dtypes(self, *, device=None, kind=None):
24+
return nxp.__array_namespace_info__().dtypes(device=device, kind=kind)

cubed/tests/test_inspection.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import cubed.array_api as xp
2+
3+
info = xp.__array_namespace_info__()
4+
5+
6+
def test_capabilities():
7+
capabilities = info.capabilities()
8+
assert capabilities["boolean indexing"] is False
9+
assert capabilities["data-dependent shapes"] is False
10+
11+
12+
def test_default_device():
13+
assert (
14+
info.default_device() is None or info.default_device() == xp.asarray(0).device
15+
)
16+
17+
18+
def test_default_dtypes():
19+
dtypes = info.default_dtypes()
20+
assert dtypes["real floating"] == xp.asarray(0.0).dtype
21+
assert dtypes["complex floating"] == xp.asarray(0.0j).dtype
22+
assert dtypes["integral"] == xp.asarray(0).dtype
23+
assert dtypes["indexing"] == xp.argmax(xp.zeros(10)).dtype
24+
25+
26+
def test_devices():
27+
assert len(info.devices()) > 0
28+
29+
30+
def test_dtypes():
31+
assert len(info.dtypes()) > 0

0 commit comments

Comments
 (0)