|
1 | 1 | from .._flags import (set_array_api_strict_flags, get_array_api_strict_flags,
|
2 | 2 | reset_array_api_strict_flags)
|
| 3 | +from .._info import (capabilities, default_device, default_dtypes, devices, |
| 4 | + dtypes) |
3 | 5 |
|
4 | 6 | from .. import (asarray, unique_all, unique_counts, unique_inverse,
|
5 | 7 | unique_values, nonzero, repeat)
|
@@ -237,3 +239,39 @@ def test_fft(func_name):
|
237 | 239 |
|
238 | 240 | set_array_api_strict_flags(enabled_extensions=('fft',))
|
239 | 241 | func()
|
| 242 | + |
| 243 | +api_version_2023_12_examples = { |
| 244 | + '__array_namespace_info__': lambda: xp.__array_namespace_info__(), |
| 245 | + # Test these functions directly to ensure they are properly decorated |
| 246 | + 'capabilities': capabilities, |
| 247 | + 'default_device': default_device, |
| 248 | + 'default_dtypes': default_dtypes, |
| 249 | + 'devices': devices, |
| 250 | + 'dtypes': dtypes, |
| 251 | + 'clip': lambda: xp.clip(xp.asarray([1, 2, 3]), 1, 2), |
| 252 | + 'copysign': lambda: xp.copysign(xp.asarray([1., 2., 3.]), xp.asarray([-1., -1., -1.])), |
| 253 | + 'cumulative_sum': lambda: xp.cumulative_sum(xp.asarray([1, 2, 3])), |
| 254 | + 'hypot': lambda: xp.hypot(xp.asarray([3., 4.]), xp.asarray([4., 3.])), |
| 255 | + 'maximum': lambda: xp.maximum(xp.asarray([1, 2, 3]), xp.asarray([2, 3, 4])), |
| 256 | + 'minimum': lambda: xp.minimum(xp.asarray([1, 2, 3]), xp.asarray([2, 3, 4])), |
| 257 | + 'moveaxis': lambda: xp.moveaxis(xp.ones((3, 3)), 0, 1), |
| 258 | + 'repeat': lambda: xp.repeat(xp.asarray([1, 2, 3]), 3), |
| 259 | + 'searchsorted': lambda: xp.searchsorted(xp.asarray([1, 2, 3]), xp.asarray([0, 1, 2, 3, 4])), |
| 260 | + 'signbit': lambda: xp.signbit(xp.asarray([-1., 0., 1.])), |
| 261 | + 'tile': lambda: xp.tile(xp.ones((3, 3)), (2, 3)), |
| 262 | + 'unstack': lambda: xp.unstack(xp.ones((3, 3)), axis=0), |
| 263 | +} |
| 264 | + |
| 265 | +@pytest.mark.parametrize('func_name', api_version_2023_12_examples.keys()) |
| 266 | +def test_api_version_2023_12(func_name): |
| 267 | + func = api_version_2023_12_examples[func_name] |
| 268 | + |
| 269 | + # By default, these functions should error |
| 270 | + pytest.raises(RuntimeError, func) |
| 271 | + |
| 272 | + with pytest.warns(UserWarning): |
| 273 | + set_array_api_strict_flags(api_version='2023.12') |
| 274 | + func() |
| 275 | + |
| 276 | + set_array_api_strict_flags(api_version='2022.12') |
| 277 | + pytest.raises(RuntimeError, func) |
0 commit comments