@@ -305,29 +305,25 @@ def __exit__(self, exc_type, exc_value, traceback):
305
305
]
306
306
307
307
def set_flags_from_environment ():
308
+ kwargs = {}
308
309
if "ARRAY_API_STRICT_API_VERSION" in os .environ :
309
- set_array_api_strict_flags (
310
- api_version = os .environ ["ARRAY_API_STRICT_API_VERSION" ]
311
- )
310
+ kwargs ["api_version" ] = os .environ ["ARRAY_API_STRICT_API_VERSION" ]
312
311
313
312
if "ARRAY_API_STRICT_BOOLEAN_INDEXING" in os .environ :
314
- set_array_api_strict_flags (
315
- boolean_indexing = os .environ ["ARRAY_API_STRICT_BOOLEAN_INDEXING" ].lower () == "true"
316
- )
313
+ kwargs ["boolean_indexing" ] = os .environ ["ARRAY_API_STRICT_BOOLEAN_INDEXING" ].lower () == "true"
317
314
318
315
if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os .environ :
319
- set_array_api_strict_flags (
320
- data_dependent_shapes = os .environ ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" ].lower () == "true"
321
- )
316
+ kwargs ["data_dependent_shapes" ] = os .environ ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" ].lower () == "true"
322
317
323
318
if "ARRAY_API_STRICT_ENABLED_EXTENSIONS" in os .environ :
324
319
enabled_extensions = os .environ ["ARRAY_API_STRICT_ENABLED_EXTENSIONS" ].split ("," )
325
320
if enabled_extensions == ["" ]:
326
321
enabled_extensions = []
327
- set_array_api_strict_flags (enabled_extensions = enabled_extensions )
328
- else :
329
- # Needed at first import to add linalg and fft to __all__
330
- set_array_api_strict_flags (enabled_extensions = default_extensions )
322
+ kwargs ["enabled_extensions" ] = enabled_extensions
323
+
324
+ # Called unconditionally because it is needed at first import to add
325
+ # linalg and fft to __all__
326
+ set_array_api_strict_flags (** kwargs )
331
327
332
328
set_flags_from_environment ()
333
329
0 commit comments