Skip to content

Commit 94cc065

Browse files
committed
Update numpy.array_api sum() and prod() to handle complex dtypes
Original NumPy Commit: e023bc611661bbed26292b098945170728e67d48
1 parent 3fde6a1 commit 94cc065

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

array_api_strict/_statistical_functions.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
)
77
from ._array_object import Array
88
from ._creation_functions import asarray
9-
from ._dtypes import float32, float64
9+
from ._dtypes import float32, float64, complex64, complex128
1010

1111
from typing import TYPE_CHECKING, Optional, Tuple, Union
1212

@@ -62,10 +62,14 @@ def prod(
6262
) -> Array:
6363
if x.dtype not in _numeric_dtypes:
6464
raise TypeError("Only numeric dtypes are allowed in prod")
65-
# Note: sum() and prod() always upcast float32 to float64 for dtype=None
66-
# We need to do so here before computing the product to avoid overflow
67-
if dtype is None and x.dtype == float32:
68-
dtype = float64
65+
# Note: sum() and prod() always upcast for dtype=None. `np.prod` does that
66+
# for integers, but not for float32 or complex64, so we need to
67+
# special-case it here
68+
if dtype is None:
69+
if x.dtype == float32:
70+
dtype = float64
71+
elif x.dtype == complex64:
72+
dtype = complex128
6973
return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims))
7074

7175

@@ -93,11 +97,14 @@ def sum(
9397
) -> Array:
9498
if x.dtype not in _numeric_dtypes:
9599
raise TypeError("Only numeric dtypes are allowed in sum")
96-
# Note: sum() and prod() always upcast integers to (u)int64 and float32 to
97-
# float64 for dtype=None. `np.sum` does that too for integers, but not for
98-
# float32, so we need to special-case it here
99-
if dtype is None and x.dtype == float32:
100-
dtype = float64
100+
# Note: sum() and prod() always upcast for dtype=None. `np.sum` does that
101+
# for integers, but not for float32 or complex64, so we need to
102+
# special-case it here
103+
if dtype is None:
104+
if x.dtype == float32:
105+
dtype = float64
106+
elif x.dtype == complex64:
107+
dtype = complex128
101108
return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims))
102109

103110

0 commit comments

Comments
 (0)