From 770e90d699a1a701bd760d582404c1eff65768f9 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 19 Nov 2024 15:55:42 -0700 Subject: [PATCH] Fix array_namespace on numpy scalars The check for Python scalars needed to be moved to the end since NumPy scalars subclass the Python scalar types. --- array_api_compat/common/_helpers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 452f4668..b011f08d 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -492,9 +492,7 @@ def your_function(x, y): namespaces = set() for x in xs: - if isinstance(x, (bool, int, float, complex, type(None))): - continue - elif is_numpy_array(x): + if is_numpy_array(x): from .. import numpy as numpy_namespace import numpy as np if use_compat is True: @@ -558,6 +556,8 @@ def your_function(x, y): if use_compat is True: raise ValueError("The given array does not have an array-api-compat wrapper") namespaces.add(x.__array_namespace__(api_version=api_version)) + elif isinstance(x, (bool, int, float, complex, type(None))): + continue else: # TODO: Support Python scalars? raise TypeError(f"{type(x).__name__} is not a supported array type")