Skip to content

Commit c770c9b

Browse files
committed
Add tests for environment variables
They're not pretty, but they get the job done.
1 parent 7bc29d6 commit c770c9b

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed

array_api_strict/tests/test_flags.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,3 +370,124 @@ def test_disabled_extensions():
370370
exec('from array_api_strict import *', ns)
371371
assert 'linalg' not in ns
372372
assert 'fft' not in ns
373+
374+
375+
def test_environment_variables():
376+
# Test that the environment variables work as expected
377+
subprocess_tests = [
378+
# ARRAY_API_STRICT_API_VERSION
379+
('''\
380+
import array_api_strict as xp
381+
assert xp.__array_api_version__ == '2022.12'
382+
383+
assert xp.get_array_api_strict_flags()['api_version'] == '2022.12'
384+
385+
''', {}),
386+
*[
387+
(f'''\
388+
import array_api_strict as xp
389+
assert xp.__array_api_version__ == '{version}'
390+
391+
assert xp.get_array_api_strict_flags()['api_version'] == '{version}'
392+
393+
if {version} == '2021.12':
394+
assert hasattr(xp, 'linalg')
395+
assert not hasattr(xp, 'fft')
396+
397+
''', {"ARRAY_API_STRICT_API_VERSION": version}) for version in ('2021.12', '2022.12', '2023.12')],
398+
399+
# ARRAY_API_STRICT_BOOLEAN_INDEXING
400+
('''\
401+
import array_api_strict as xp
402+
403+
a = xp.ones(3)
404+
mask = xp.asarray([True, False, True])
405+
406+
assert xp.all(a[mask] == xp.asarray([1., 1.]))
407+
assert xp.get_array_api_strict_flags()['boolean_indexing'] == True
408+
''', {}),
409+
*[(f'''\
410+
import array_api_strict as xp
411+
412+
a = xp.ones(3)
413+
mask = xp.asarray([True, False, True])
414+
415+
if {boolean_indexing}:
416+
assert xp.all(a[mask] == xp.asarray([1., 1.]))
417+
else:
418+
try:
419+
a[mask]
420+
except RuntimeError:
421+
pass
422+
else:
423+
assert False
424+
425+
assert xp.get_array_api_strict_flags()['boolean_indexing'] == {boolean_indexing}
426+
''', {"ARRAY_API_STRICT_BOOLEAN_INDEXING": boolean_indexing})
427+
for boolean_indexing in ('True', 'False')],
428+
429+
# ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES
430+
('''\
431+
import array_api_strict as xp
432+
433+
a = xp.ones(3)
434+
xp.unique_all(a)
435+
436+
assert xp.get_array_api_strict_flags()['data_dependent_shapes'] == True
437+
''', {}),
438+
*[(f'''\
439+
import array_api_strict as xp
440+
441+
a = xp.ones(3)
442+
if {data_dependent_shapes}:
443+
xp.unique_all(a)
444+
else:
445+
try:
446+
xp.unique_all(a)
447+
except RuntimeError:
448+
pass
449+
else:
450+
assert False
451+
452+
assert xp.get_array_api_strict_flags()['data_dependent_shapes'] == {data_dependent_shapes}
453+
''', {"ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES": data_dependent_shapes})
454+
for data_dependent_shapes in ('True', 'False')],
455+
456+
# ARRAY_API_STRICT_ENABLED_EXTENSIONS
457+
('''\
458+
import array_api_strict as xp
459+
assert hasattr(xp, 'linalg')
460+
assert hasattr(xp, 'fft')
461+
462+
assert xp.get_array_api_strict_flags()['enabled_extensions'] == ('linalg', 'fft')
463+
''', {}),
464+
*[(f'''\
465+
import array_api_strict as xp
466+
467+
assert hasattr(xp, 'linalg') == ('linalg' in {extensions.split(',')})
468+
assert hasattr(xp, 'fft') == ('fft' in {extensions.split(',')})
469+
470+
assert sorted(xp.get_array_api_strict_flags()['enabled_extensions']) == {sorted(set(extensions.split(','))-{''})}
471+
''', {"ARRAY_API_STRICT_ENABLED_EXTENSIONS": extensions})
472+
for extensions in ('', 'linalg', 'fft', 'linalg,fft')],
473+
]
474+
475+
for test, env in subprocess_tests:
476+
try:
477+
subprocess.run([sys.executable, '-c', test], check=True,
478+
capture_output=True, encoding='utf-8', env=env)
479+
except subprocess.CalledProcessError as e:
480+
print(e.stdout, end='')
481+
# Ensure the exception is shown in the output log
482+
raise AssertionError(f"""\
483+
STDOUT:
484+
{e.stderr}
485+
486+
STDERR:
487+
{e.stderr}
488+
489+
TEST:
490+
{test}
491+
492+
ENV:
493+
{env}""")

0 commit comments

Comments
 (0)