Skip to content

Commit 4d3ff6c

Browse files
committed
Fix test failures
1 parent 161acaa commit 4d3ff6c

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

array_api_strict/tests/test_flags.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def test_flags():
5555
flags = get_array_api_strict_flags()
5656
assert flags == {
5757
'api_version': '2021.12',
58+
'boolean_indexing': True,
5859
'data_dependent_shapes': True,
5960
'enabled_extensions': ('linalg',),
6061
}
@@ -68,6 +69,7 @@ def test_flags():
6869
flags = get_array_api_strict_flags()
6970
assert flags == {
7071
'api_version': '2023.12',
72+
'boolean_indexing': True,
7173
'data_dependent_shapes': True,
7274
'enabled_extensions': ('linalg', 'fft'),
7375
}
@@ -132,6 +134,8 @@ def test_data_dependent_shapes():
132134
pytest.raises(RuntimeError, lambda: unique_inverse(a))
133135
pytest.raises(RuntimeError, lambda: unique_values(a))
134136
pytest.raises(RuntimeError, lambda: nonzero(a))
137+
pytest.raises(RuntimeError, lambda: repeat(a, repeats))
138+
repeat(a, 2) # Should never error
135139
a[mask] # No error (boolean indexing is a separate flag)
136140

137141
def test_boolean_indexing():
@@ -144,8 +148,6 @@ def test_boolean_indexing():
144148
set_array_api_strict_flags(boolean_indexing=False)
145149

146150
pytest.raises(RuntimeError, lambda: a[mask])
147-
pytest.raises(RuntimeError, lambda: repeat(a, repeats))
148-
repeat(a, 2) # Should never error
149151

150152
linalg_examples = {
151153
'cholesky': lambda: xp.linalg.cholesky(xp.eye(3)),

0 commit comments

Comments
 (0)