Skip to content

Commit befd28c

Browse files
committed
Set the api version flag in __array_namespace__
1 parent 632900f commit befd28c

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

array_api_strict/_array_object.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
_result_type,
3333
_dtype_categories,
3434
)
35-
from ._flags import get_array_api_strict_flags
35+
from ._flags import get_array_api_strict_flags, set_array_api_strict_flags
3636

3737
from typing import TYPE_CHECKING, Optional, Tuple, Union, Any, SupportsIndex
3838
import types
@@ -501,8 +501,7 @@ def __array_namespace__(
501501
{class}`array_api_strict.ArrayApiStrictFlags` context manager.
502502
503503
"""
504-
if api_version is not None and api_version not in ["2021.12", "2022.12"]:
505-
raise ValueError(f"Unrecognized array API version: {api_version!r}")
504+
set_array_api_strict_flags(standard_version=api_version)
506505
if api_version == "2021.12":
507506
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
508507
import array_api_strict

array_api_strict/_flags.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def set_array_api_strict_flags(
116116

117117
if standard_version is not None:
118118
if standard_version not in supported_versions:
119-
raise ValueError(f"Unsupported standard version {standard_version}")
119+
raise ValueError(f"Unsupported standard version {standard_version!r}")
120120
STANDARD_VERSION = standard_version
121121

122122
if data_dependent_shapes is not None:

0 commit comments

Comments
 (0)