Skip to content

Commit 05fa0b5

Browse files
committed
Add tests that the new 2023.12 functions are properly decorated
1 parent 84d2aa5 commit 05fa0b5

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

array_api_strict/tests/test_flags.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from .._flags import (set_array_api_strict_flags, get_array_api_strict_flags,
22
reset_array_api_strict_flags)
3+
from .._info import (capabilities, default_device, default_dtypes, devices,
4+
dtypes)
35

46
from .. import (asarray, unique_all, unique_counts, unique_inverse,
57
unique_values, nonzero, repeat)
@@ -237,3 +239,39 @@ def test_fft(func_name):
237239

238240
set_array_api_strict_flags(enabled_extensions=('fft',))
239241
func()
242+
243+
api_version_2023_12_examples = {
244+
'__array_namespace_info__': lambda: xp.__array_namespace_info__(),
245+
# Test these functions directly to ensure they are properly decorated
246+
'capabilities': capabilities,
247+
'default_device': default_device,
248+
'default_dtypes': default_dtypes,
249+
'devices': devices,
250+
'dtypes': dtypes,
251+
'clip': lambda: xp.clip(xp.asarray([1, 2, 3]), 1, 2),
252+
'copysign': lambda: xp.copysign(xp.asarray([1., 2., 3.]), xp.asarray([-1., -1., -1.])),
253+
'cumulative_sum': lambda: xp.cumulative_sum(xp.asarray([1, 2, 3])),
254+
'hypot': lambda: xp.hypot(xp.asarray([3., 4.]), xp.asarray([4., 3.])),
255+
'maximum': lambda: xp.maximum(xp.asarray([1, 2, 3]), xp.asarray([2, 3, 4])),
256+
'minimum': lambda: xp.minimum(xp.asarray([1, 2, 3]), xp.asarray([2, 3, 4])),
257+
'moveaxis': lambda: xp.moveaxis(xp.ones((3, 3)), 0, 1),
258+
'repeat': lambda: xp.repeat(xp.asarray([1, 2, 3]), 3),
259+
'searchsorted': lambda: xp.searchsorted(xp.asarray([1, 2, 3]), xp.asarray([0, 1, 2, 3, 4])),
260+
'signbit': lambda: xp.signbit(xp.asarray([-1., 0., 1.])),
261+
'tile': lambda: xp.tile(xp.ones((3, 3)), (2, 3)),
262+
'unstack': lambda: xp.unstack(xp.ones((3, 3)), axis=0),
263+
}
264+
265+
@pytest.mark.parametrize('func_name', api_version_2023_12_examples.keys())
266+
def test_api_version_2023_12(func_name):
267+
func = api_version_2023_12_examples[func_name]
268+
269+
# By default, these functions should error
270+
pytest.raises(RuntimeError, func)
271+
272+
with pytest.warns(UserWarning):
273+
set_array_api_strict_flags(api_version='2023.12')
274+
func()
275+
276+
set_array_api_strict_flags(api_version='2022.12')
277+
pytest.raises(RuntimeError, func)

0 commit comments

Comments
 (0)