|
| 1 | +from .._flags import (set_array_api_strict_flags, get_array_api_strict_flags, |
| 2 | + reset_array_api_strict_flags) |
| 3 | + |
| 4 | +from .. import (asarray, unique_all, unique_counts, unique_inverse, |
| 5 | + unique_values, nonzero) |
| 6 | + |
| 7 | +import pytest |
| 8 | + |
| 9 | +@pytest.fixture(autouse=True) |
| 10 | +def reset_flags(): |
| 11 | + reset_array_api_strict_flags() |
| 12 | + yield |
| 13 | + reset_array_api_strict_flags() |
| 14 | + |
| 15 | +def test_flags(): |
| 16 | + # Test defaults |
| 17 | + flags = get_array_api_strict_flags() |
| 18 | + assert flags == { |
| 19 | + 'standard_version': '2022.12', |
| 20 | + 'data_dependent_shapes': True, |
| 21 | + 'enabled_extensions': ('linalg', 'fft'), |
| 22 | + } |
| 23 | + |
| 24 | + # Test setting flags |
| 25 | + set_array_api_strict_flags(data_dependent_shapes=False) |
| 26 | + flags = get_array_api_strict_flags() |
| 27 | + assert flags == { |
| 28 | + 'standard_version': '2022.12', |
| 29 | + 'data_dependent_shapes': False, |
| 30 | + 'enabled_extensions': ('linalg', 'fft'), |
| 31 | + } |
| 32 | + set_array_api_strict_flags(enabled_extensions=('fft',)) |
| 33 | + flags = get_array_api_strict_flags() |
| 34 | + assert flags == { |
| 35 | + 'standard_version': '2022.12', |
| 36 | + 'data_dependent_shapes': False, |
| 37 | + 'enabled_extensions': ('fft',), |
| 38 | + } |
| 39 | + # Make sure setting the version to 2021.12 disables fft |
| 40 | + set_array_api_strict_flags(standard_version='2021.12') |
| 41 | + flags = get_array_api_strict_flags() |
| 42 | + assert flags == { |
| 43 | + 'standard_version': '2021.12', |
| 44 | + 'data_dependent_shapes': False, |
| 45 | + 'enabled_extensions': ('linalg',), |
| 46 | + } |
| 47 | + |
| 48 | + # Test setting flags with invalid values |
| 49 | + pytest.raises(ValueError, lambda: |
| 50 | + set_array_api_strict_flags(standard_version='2020.12')) |
| 51 | + pytest.raises(ValueError, lambda: set_array_api_strict_flags( |
| 52 | + enabled_extensions=('linalg', 'fft', 'invalid'))) |
| 53 | + pytest.raises(ValueError, lambda: set_array_api_strict_flags( |
| 54 | + standard_version='2021.12', |
| 55 | + enabled_extensions=('linalg', 'fft'))) |
| 56 | + |
| 57 | + |
| 58 | +def test_data_dependent_shapes(): |
| 59 | + a = asarray([0, 0, 1, 2, 2]) |
| 60 | + mask = asarray([True, False, True, False, True]) |
| 61 | + |
| 62 | + # Should not error |
| 63 | + unique_all(a) |
| 64 | + unique_counts(a) |
| 65 | + unique_inverse(a) |
| 66 | + unique_values(a) |
| 67 | + nonzero(a) |
| 68 | + a[mask] |
| 69 | + # TODO: add repeat when it is implemented |
| 70 | + |
| 71 | + set_array_api_strict_flags(data_dependent_shapes=False) |
| 72 | + |
| 73 | + pytest.raises(RuntimeError, lambda: unique_all(a)) |
| 74 | + pytest.raises(RuntimeError, lambda: unique_counts(a)) |
| 75 | + pytest.raises(RuntimeError, lambda: unique_inverse(a)) |
| 76 | + pytest.raises(RuntimeError, lambda: unique_values(a)) |
| 77 | + pytest.raises(RuntimeError, lambda: nonzero(a)) |
| 78 | + pytest.raises(RuntimeError, lambda: a[mask]) |
0 commit comments