|
6 | 6 | )
|
7 | 7 | from ._array_object import Array
|
8 | 8 | from ._creation_functions import asarray
|
9 |
| -from ._dtypes import float32, float64 |
| 9 | +from ._dtypes import float32, float64, complex64, complex128 |
10 | 10 |
|
11 | 11 | from typing import TYPE_CHECKING, Optional, Tuple, Union
|
12 | 12 |
|
@@ -62,10 +62,14 @@ def prod(
|
62 | 62 | ) -> Array:
|
63 | 63 | if x.dtype not in _numeric_dtypes:
|
64 | 64 | raise TypeError("Only numeric dtypes are allowed in prod")
|
65 |
| - # Note: sum() and prod() always upcast float32 to float64 for dtype=None |
66 |
| - # We need to do so here before computing the product to avoid overflow |
67 |
| - if dtype is None and x.dtype == float32: |
68 |
| - dtype = float64 |
| 65 | + # Note: sum() and prod() always upcast for dtype=None. `np.prod` does that |
| 66 | + # for integers, but not for float32 or complex64, so we need to |
| 67 | + # special-case it here |
| 68 | + if dtype is None: |
| 69 | + if x.dtype == float32: |
| 70 | + dtype = float64 |
| 71 | + elif x.dtype == complex64: |
| 72 | + dtype = complex128 |
69 | 73 | return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims))
|
70 | 74 |
|
71 | 75 |
|
@@ -93,11 +97,14 @@ def sum(
|
93 | 97 | ) -> Array:
|
94 | 98 | if x.dtype not in _numeric_dtypes:
|
95 | 99 | raise TypeError("Only numeric dtypes are allowed in sum")
|
96 |
| - # Note: sum() and prod() always upcast integers to (u)int64 and float32 to |
97 |
| - # float64 for dtype=None. `np.sum` does that too for integers, but not for |
98 |
| - # float32, so we need to special-case it here |
99 |
| - if dtype is None and x.dtype == float32: |
100 |
| - dtype = float64 |
| 100 | + # Note: sum() and prod() always upcast for dtype=None. `np.sum` does that |
| 101 | + # for integers, but not for float32 or complex64, so we need to |
| 102 | + # special-case it here |
| 103 | + if dtype is None: |
| 104 | + if x.dtype == float32: |
| 105 | + dtype = float64 |
| 106 | + elif x.dtype == complex64: |
| 107 | + dtype = complex128 |
101 | 108 | return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims))
|
102 | 109 |
|
103 | 110 |
|
|
0 commit comments