Skip to content

Commit 9f954e6

Browse files
committed
Implement 2023.12 behavior for trace
1 parent a437da3 commit 9f954e6

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

array_api_strict/linalg.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ._manipulation_functions import reshape
1212
from ._elementwise_functions import conj
1313
from ._array_object import Array
14-
from ._flags import requires_extension
14+
from ._flags import requires_extension, get_array_api_strict_flags
1515

1616
try:
1717
from numpy._core.numeric import normalize_axis_tuple
@@ -377,10 +377,11 @@ def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Arr
377377
# Note: trace() works the same as sum() and prod() (see
378378
# _statistical_functions.py)
379379
if dtype is None:
380-
if x.dtype == float32:
381-
dtype = np.float64
382-
elif x.dtype == complex64:
383-
dtype = np.complex128
380+
if get_array_api_strict_flags()['api_version'] < '2023.12':
381+
if x.dtype == float32:
382+
dtype = np.float64
383+
elif x.dtype == complex64:
384+
dtype = np.complex128
384385
else:
385386
dtype = dtype._np_dtype
386387
# Note: trace always operates on the last two axes, whereas np.trace

0 commit comments

Comments
 (0)