Skip to content

Commit a437da3

Browse files
committed
Implement 2023.12 behavior for sum() and prod()
1 parent 8333107 commit a437da3

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

array_api_strict/_statistical_functions.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
)
88
from ._array_object import Array
99
from ._dtypes import float32, complex64
10-
from ._flags import requires_api_version
10+
from ._flags import requires_api_version, get_array_api_strict_flags
1111
from ._creation_functions import zeros
1212
from ._manipulation_functions import concat
1313

@@ -89,14 +89,16 @@ def prod(
8989
) -> Array:
9090
if x.dtype not in _numeric_dtypes:
9191
raise TypeError("Only numeric dtypes are allowed in prod")
92-
# Note: sum() and prod() always upcast for dtype=None. `np.prod` does that
93-
# for integers, but not for float32 or complex64, so we need to
94-
# special-case it here
92+
9593
if dtype is None:
96-
if x.dtype == float32:
97-
dtype = np.float64
98-
elif x.dtype == complex64:
99-
dtype = np.complex128
94+
# Note: In versions prior to 2023.12, sum() and prod() upcast for all
95+
# dtypes when dtype=None. For 2023.12, the behavior is the same as in
96+
# NumPy (only upcast for integral dtypes).
97+
if get_array_api_strict_flags()['api_version'] < '2023.12':
98+
if x.dtype == float32:
99+
dtype = np.float64
100+
elif x.dtype == complex64:
101+
dtype = np.complex128
100102
else:
101103
dtype = dtype._np_dtype
102104
return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims))
@@ -126,14 +128,16 @@ def sum(
126128
) -> Array:
127129
if x.dtype not in _numeric_dtypes:
128130
raise TypeError("Only numeric dtypes are allowed in sum")
129-
# Note: sum() and prod() always upcast for dtype=None. `np.sum` does that
130-
# for integers, but not for float32 or complex64, so we need to
131-
# special-case it here
131+
132132
if dtype is None:
133-
if x.dtype == float32:
134-
dtype = np.float64
135-
elif x.dtype == complex64:
136-
dtype = np.complex128
133+
# Note: In versions prior to 2023.12, sum() and prod() upcast for all
134+
# dtypes when dtype=None. For 2023.12, the behavior is the same as in
135+
# NumPy (only upcast for integral dtypes).
136+
if get_array_api_strict_flags()['api_version'] < '2023.12':
137+
if x.dtype == float32:
138+
dtype = np.float64
139+
elif x.dtype == complex64:
140+
dtype = np.complex128
137141
else:
138142
dtype = dtype._np_dtype
139143
return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims))

0 commit comments

Comments
 (0)