Skip to content

Commit 8572df3

Browse files
committed
Add a test for sum/trace/prod 2023.12 upcasting behavior
1 parent 9f954e6 commit 8572df3

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import pytest
2+
3+
import array_api_strict as xp
4+
5+
@pytest.mark.parametrize('func_name', ['sum', 'prod', 'trace'])
6+
def test_sum_prod_trace_2023_12(func_name):
7+
# sum, prod, and trace were changed in 2023.12 to not upcast floating-point dtypes
8+
# with dtype=None
9+
if func_name == 'trace':
10+
func = getattr(xp.linalg, func_name)
11+
else:
12+
func = getattr(xp, func_name)
13+
14+
a_real = xp.asarray([[1., 2.], [3., 4.]], dtype=xp.float32)
15+
a_complex = xp.asarray([[1., 2.], [3., 4.]], dtype=xp.complex64)
16+
a_int = xp.asarray([[1, 2], [3, 4]], dtype=xp.int32)
17+
18+
assert func(a_real).dtype == xp.float64
19+
assert func(a_complex).dtype == xp.complex128
20+
assert func(a_int).dtype == xp.int64
21+
22+
with pytest.warns(UserWarning):
23+
xp.set_array_api_strict_flags(api_version='2023.12')
24+
25+
assert func(a_real).dtype == xp.float32
26+
assert func(a_complex).dtype == xp.complex64
27+
assert func(a_int).dtype == xp.int64

0 commit comments

Comments
 (0)