@@ -370,3 +370,124 @@ def test_disabled_extensions():
370
370
exec ('from array_api_strict import *' , ns )
371
371
assert 'linalg' not in ns
372
372
assert 'fft' not in ns
373
+
374
+
375
+ def test_environment_variables ():
376
+ # Test that the environment variables work as expected
377
+ subprocess_tests = [
378
+ # ARRAY_API_STRICT_API_VERSION
379
+ ('''\
380
+ import array_api_strict as xp
381
+ assert xp.__array_api_version__ == '2022.12'
382
+
383
+ assert xp.get_array_api_strict_flags()['api_version'] == '2022.12'
384
+
385
+ ''' , {}),
386
+ * [
387
+ (f'''\
388
+ import array_api_strict as xp
389
+ assert xp.__array_api_version__ == '{ version } '
390
+
391
+ assert xp.get_array_api_strict_flags()['api_version'] == '{ version } '
392
+
393
+ if { version } == '2021.12':
394
+ assert hasattr(xp, 'linalg')
395
+ assert not hasattr(xp, 'fft')
396
+
397
+ ''' , {"ARRAY_API_STRICT_API_VERSION" : version }) for version in ('2021.12' , '2022.12' , '2023.12' )],
398
+
399
+ # ARRAY_API_STRICT_BOOLEAN_INDEXING
400
+ ('''\
401
+ import array_api_strict as xp
402
+
403
+ a = xp.ones(3)
404
+ mask = xp.asarray([True, False, True])
405
+
406
+ assert xp.all(a[mask] == xp.asarray([1., 1.]))
407
+ assert xp.get_array_api_strict_flags()['boolean_indexing'] == True
408
+ ''' , {}),
409
+ * [(f'''\
410
+ import array_api_strict as xp
411
+
412
+ a = xp.ones(3)
413
+ mask = xp.asarray([True, False, True])
414
+
415
+ if { boolean_indexing } :
416
+ assert xp.all(a[mask] == xp.asarray([1., 1.]))
417
+ else:
418
+ try:
419
+ a[mask]
420
+ except RuntimeError:
421
+ pass
422
+ else:
423
+ assert False
424
+
425
+ assert xp.get_array_api_strict_flags()['boolean_indexing'] == { boolean_indexing }
426
+ ''' , {"ARRAY_API_STRICT_BOOLEAN_INDEXING" : boolean_indexing })
427
+ for boolean_indexing in ('True' , 'False' )],
428
+
429
+ # ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES
430
+ ('''\
431
+ import array_api_strict as xp
432
+
433
+ a = xp.ones(3)
434
+ xp.unique_all(a)
435
+
436
+ assert xp.get_array_api_strict_flags()['data_dependent_shapes'] == True
437
+ ''' , {}),
438
+ * [(f'''\
439
+ import array_api_strict as xp
440
+
441
+ a = xp.ones(3)
442
+ if { data_dependent_shapes } :
443
+ xp.unique_all(a)
444
+ else:
445
+ try:
446
+ xp.unique_all(a)
447
+ except RuntimeError:
448
+ pass
449
+ else:
450
+ assert False
451
+
452
+ assert xp.get_array_api_strict_flags()['data_dependent_shapes'] == { data_dependent_shapes }
453
+ ''' , {"ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" : data_dependent_shapes })
454
+ for data_dependent_shapes in ('True' , 'False' )],
455
+
456
+ # ARRAY_API_STRICT_ENABLED_EXTENSIONS
457
+ ('''\
458
+ import array_api_strict as xp
459
+ assert hasattr(xp, 'linalg')
460
+ assert hasattr(xp, 'fft')
461
+
462
+ assert xp.get_array_api_strict_flags()['enabled_extensions'] == ('linalg', 'fft')
463
+ ''' , {}),
464
+ * [(f'''\
465
+ import array_api_strict as xp
466
+
467
+ assert hasattr(xp, 'linalg') == ('linalg' in { extensions .split (',' )} )
468
+ assert hasattr(xp, 'fft') == ('fft' in { extensions .split (',' )} )
469
+
470
+ assert sorted(xp.get_array_api_strict_flags()['enabled_extensions']) == { sorted (set (extensions .split (',' ))- {'' })}
471
+ ''' , {"ARRAY_API_STRICT_ENABLED_EXTENSIONS" : extensions })
472
+ for extensions in ('' , 'linalg' , 'fft' , 'linalg,fft' )],
473
+ ]
474
+
475
+ for test , env in subprocess_tests :
476
+ try :
477
+ subprocess .run ([sys .executable , '-c' , test ], check = True ,
478
+ capture_output = True , encoding = 'utf-8' , env = env )
479
+ except subprocess .CalledProcessError as e :
480
+ print (e .stdout , end = '' )
481
+ # Ensure the exception is shown in the output log
482
+ raise AssertionError (f"""\
483
+ STDOUT:
484
+ { e .stderr }
485
+
486
+ STDERR:
487
+ { e .stderr }
488
+
489
+ TEST:
490
+ { test }
491
+
492
+ ENV:
493
+ { env } """ )
0 commit comments