Skip to content

Commit 012de1e

Browse files
authored
FIX Fix array API tests for array_api_strict 2.0.1 (scikit-learn#29387)
1 parent ecdc957 commit 012de1e

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

sklearn/utils/_array_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,11 @@ def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None):
557557
if namespace.__name__ in {"cupy.array_api"}:
558558
namespace = _ArrayAPIWrapper(namespace)
559559

560+
if namespace.__name__ == "array_api_strict" and hasattr(
561+
namespace, "set_array_api_strict_flags"
562+
):
563+
namespace.set_array_api_strict_flags(api_version="2023.12")
564+
560565
return namespace, is_array_api_compliant
561566

562567

0 commit comments

Comments
 (0)