|
| 1 | +import pytest |
| 2 | + |
| 3 | +import array_api_strict as xp |
| 4 | + |
| 5 | +@pytest.mark.parametrize('func_name', ['sum', 'prod', 'trace']) |
| 6 | +def test_sum_prod_trace_2023_12(func_name): |
| 7 | + # sum, prod, and trace were changed in 2023.12 to not upcast floating-point dtypes |
| 8 | + # with dtype=None |
| 9 | + if func_name == 'trace': |
| 10 | + func = getattr(xp.linalg, func_name) |
| 11 | + else: |
| 12 | + func = getattr(xp, func_name) |
| 13 | + |
| 14 | + a_real = xp.asarray([[1., 2.], [3., 4.]], dtype=xp.float32) |
| 15 | + a_complex = xp.asarray([[1., 2.], [3., 4.]], dtype=xp.complex64) |
| 16 | + a_int = xp.asarray([[1, 2], [3, 4]], dtype=xp.int32) |
| 17 | + |
| 18 | + assert func(a_real).dtype == xp.float64 |
| 19 | + assert func(a_complex).dtype == xp.complex128 |
| 20 | + assert func(a_int).dtype == xp.int64 |
| 21 | + |
| 22 | + with pytest.warns(UserWarning): |
| 23 | + xp.set_array_api_strict_flags(api_version='2023.12') |
| 24 | + |
| 25 | + assert func(a_real).dtype == xp.float32 |
| 26 | + assert func(a_complex).dtype == xp.complex64 |
| 27 | + assert func(a_int).dtype == xp.int64 |
0 commit comments