Skip to content

Commit a7da560

Browse files
committed
Make use_compat a public flag to array_namespace()
If None, it does the default behavior. If True or False it forces the use or non-use of the compat wrapper. This commit also changes it so that NumPy 2.0 does not return the wrapper library at all by default (when use_compat=None), since it is completely array API compatible on its own.
1 parent 311d0aa commit a7da560

File tree

3 files changed

+54
-20
lines changed

3 files changed

+54
-20
lines changed

array_api_compat/common/_helpers.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def _check_api_version(api_version):
178178
elif api_version is not None and api_version != '2022.12':
179179
raise ValueError("Only the 2022.12 version of the array API specification is currently supported")
180180

181-
def array_namespace(*xs, api_version=None, _use_compat=True):
181+
def array_namespace(*xs, api_version=None, use_compat=None):
182182
"""
183183
Get the array API compatible namespace for the arrays `xs`.
184184
@@ -191,6 +191,12 @@ def array_namespace(*xs, api_version=None, _use_compat=True):
191191
The newest version of the spec that you need support for (currently
192192
the compat library wrapped APIs support v2022.12).
193193
194+
use_compat: bool or None
195+
If None (the default), the native namespace will be returned if it is
196+
already array API compatible, otherwise a compat wrapper is used. If
197+
True, the compat library wrapped library will be returned. If False,
198+
the native library namespace is returned.
199+
194200
Returns
195201
-------
196202
@@ -234,16 +240,28 @@ def your_function(x, y):
234240
is_jax_array
235241
236242
"""
243+
if use_compat not in [None, True, False]:
244+
raise ValueError("use_compat must be None, True, or False")
245+
246+
_use_compat = use_compat in [None, True]
247+
237248
namespaces = set()
238249
for x in xs:
239250
if is_numpy_array(x):
240251
_check_api_version(api_version)
241-
if _use_compat:
242-
from .. import numpy as numpy_namespace
252+
from .. import numpy as numpy_namespace
253+
import numpy as np
254+
if use_compat is True:
243255
namespaces.add(numpy_namespace)
244-
else:
245-
import numpy as np
256+
elif use_compat is False:
246257
namespaces.add(np)
258+
else:
259+
# numpy 2.0 has __array_namespace__ and is fully array API
260+
# compatible.
261+
if hasattr(x, '__array_namespace__'):
262+
namespaces.add(x.__array_namespace__(api_version=api_version))
263+
else:
264+
namespaces.add(numpy_namespace)
247265
elif is_cupy_array(x):
248266
_check_api_version(api_version)
249267
if _use_compat:
@@ -266,14 +284,22 @@ def your_function(x, y):
266284
from ..dask import array as dask_namespace
267285
namespaces.add(dask_namespace)
268286
else:
269-
raise TypeError("_use_compat cannot be False if input array is a dask array!")
287+
import dask.array as da
288+
namespaces.add(da)
270289
elif is_jax_array(x):
271290
_check_api_version(api_version)
272-
# jax.experimental.array_api is already an array namespace. We do
273-
# not have a wrapper submodule for it.
274-
import jax.experimental.array_api as jnp
291+
if use_compat is True:
292+
raise ValueError("JAX does not have an array-api-compat wrapper")
293+
elif use_compat is False:
294+
import jax.numpy as jnp
295+
else:
296+
# jax.experimental.array_api is already an array namespace. We do
297+
# not have a wrapper submodule for it.
298+
import jax.experimental.array_api as jnp
275299
namespaces.add(jnp)
276300
elif hasattr(x, '__array_namespace__'):
301+
if use_compat is True:
302+
raise ValueError("The given array does not have an array-api-compat wrapper")
277303
namespaces.add(x.__array_namespace__(api_version=api_version))
278304
else:
279305
# TODO: Support Python scalars?

tests/_helpers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22

33
import pytest
44

5-
wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"]
6-
all_libraries = wrapped_libraries + ["jax.numpy"]
5+
wrapped_libraries = ["cupy", "torch", "dask.array"]
6+
all_libraries = wrapped_libraries + ["numpy", "jax.numpy"]
7+
import numpy as np
8+
if np.__version__[0] == '1':
9+
wrapped_libraries.append("numpy")
710

811
def import_(library, wrapper=False):
912
if library == 'cupy':

tests/test_array_namespace.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,29 @@
99
import array_api_compat
1010
from array_api_compat import array_namespace
1111

12-
from ._helpers import import_, all_libraries
12+
from ._helpers import import_, all_libraries, wrapped_libraries
1313

14-
@pytest.mark.parametrize("library", all_libraries)
14+
@pytest.mark.parametrize("use_compat", [True, False, None])
1515
@pytest.mark.parametrize("api_version", [None, "2021.12"])
16-
def test_array_namespace(library, api_version):
16+
@pytest.mark.parametrize("library", all_libraries + ['array_api_strict'])
17+
def test_array_namespace(library, api_version, use_compat):
1718
xp = import_(library)
1819

1920
array = xp.asarray([1.0, 2.0, 3.0])
20-
namespace = array_api_compat.array_namespace(array, api_version=api_version)
21+
if use_compat is True and library in ['array_api_strict', 'jax.numpy']:
22+
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
23+
return
24+
namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat)
2125

22-
if "array_api" in library:
23-
assert namespace == xp
26+
if use_compat is False or use_compat is None and library not in wrapped_libraries:
27+
if library == "jax.numpy" and use_compat is None:
28+
import jax.experimental.array_api
29+
assert namespace == jax.experimental.array_api
30+
else:
31+
assert namespace == xp
2432
else:
2533
if library == "dask.array":
2634
assert namespace == array_api_compat.dask.array
27-
elif library == "jax.numpy":
28-
import jax.experimental.array_api
29-
assert namespace == jax.experimental.array_api
3035
else:
3136
assert namespace == getattr(array_api_compat, library)
3237

0 commit comments

Comments
 (0)