|
7 | 7 | )
|
8 | 8 | from ._array_object import Array
|
9 | 9 | from ._dtypes import float32, complex64
|
10 |
| -from ._flags import requires_api_version |
| 10 | +from ._flags import requires_api_version, get_array_api_strict_flags |
11 | 11 | from ._creation_functions import zeros
|
12 | 12 | from ._manipulation_functions import concat
|
13 | 13 |
|
@@ -89,14 +89,16 @@ def prod(
|
89 | 89 | ) -> Array:
|
90 | 90 | if x.dtype not in _numeric_dtypes:
|
91 | 91 | raise TypeError("Only numeric dtypes are allowed in prod")
|
92 |
| - # Note: sum() and prod() always upcast for dtype=None. `np.prod` does that |
93 |
| - # for integers, but not for float32 or complex64, so we need to |
94 |
| - # special-case it here |
| 92 | + |
95 | 93 | if dtype is None:
|
96 |
| - if x.dtype == float32: |
97 |
| - dtype = np.float64 |
98 |
| - elif x.dtype == complex64: |
99 |
| - dtype = np.complex128 |
| 94 | + # Note: In versions prior to 2023.12, sum() and prod() upcast for all |
| 95 | + # dtypes when dtype=None. For 2023.12, the behavior is the same as in |
| 96 | + # NumPy (only upcast for integral dtypes). |
| 97 | + if get_array_api_strict_flags()['api_version'] < '2023.12': |
| 98 | + if x.dtype == float32: |
| 99 | + dtype = np.float64 |
| 100 | + elif x.dtype == complex64: |
| 101 | + dtype = np.complex128 |
100 | 102 | else:
|
101 | 103 | dtype = dtype._np_dtype
|
102 | 104 | return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims))
|
@@ -126,14 +128,16 @@ def sum(
|
126 | 128 | ) -> Array:
|
127 | 129 | if x.dtype not in _numeric_dtypes:
|
128 | 130 | raise TypeError("Only numeric dtypes are allowed in sum")
|
129 |
| - # Note: sum() and prod() always upcast for dtype=None. `np.sum` does that |
130 |
| - # for integers, but not for float32 or complex64, so we need to |
131 |
| - # special-case it here |
| 131 | + |
132 | 132 | if dtype is None:
|
133 |
| - if x.dtype == float32: |
134 |
| - dtype = np.float64 |
135 |
| - elif x.dtype == complex64: |
136 |
| - dtype = np.complex128 |
| 133 | + # Note: In versions prior to 2023.12, sum() and prod() upcast for all |
| 134 | + # dtypes when dtype=None. For 2023.12, the behavior is the same as in |
| 135 | + # NumPy (only upcast for integral dtypes). |
| 136 | + if get_array_api_strict_flags()['api_version'] < '2023.12': |
| 137 | + if x.dtype == float32: |
| 138 | + dtype = np.float64 |
| 139 | + elif x.dtype == complex64: |
| 140 | + dtype = np.complex128 |
137 | 141 | else:
|
138 | 142 | dtype = dtype._np_dtype
|
139 | 143 | return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims))
|
|
0 commit comments