26
26
27
27
from jax import lax
28
28
from jax ._src import api
29
- from jax ._src import core , config
29
+ from jax ._src import core
30
30
from jax ._src import dtypes
31
31
from jax ._src .numpy import ufuncs
32
32
from jax ._src .numpy .util import (
@@ -65,6 +65,20 @@ def _upcast_f16(dtype: DTypeLike) -> DType:
65
65
return np .dtype ('float32' )
66
66
return np .dtype (dtype )
67
67
68
+ def _promote_integer_dtype (dtype : DTypeLike ) -> DTypeLike :
69
+ # Note: NumPy always promotes to 64-bit; jax instead promotes to the
70
+ # default dtype as defined by dtypes.int_ or dtypes.uint.
71
+ if dtypes .issubdtype (dtype , np .bool_ ):
72
+ return dtypes .int_
73
+ elif dtypes .issubdtype (dtype , np .unsignedinteger ):
74
+ if np .iinfo (dtype ).bits < np .iinfo (dtypes .uint ).bits :
75
+ return dtypes .uint
76
+ elif dtypes .issubdtype (dtype , np .integer ):
77
+ if np .iinfo (dtype ).bits < np .iinfo (dtypes .int_ ).bits :
78
+ return dtypes .int_
79
+ return dtype
80
+
81
+
68
82
ReductionOp = Callable [[Any , Any ], Any ]
69
83
70
84
def _reduction (a : ArrayLike , name : str , np_fun : Any , op : ReductionOp , init_val : ArrayLike ,
@@ -103,16 +117,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val:
103
117
result_dtype = dtype or dtypes .dtype (a )
104
118
105
119
if dtype is None and promote_integers :
106
- # Note: NumPy always promotes to 64-bit; jax instead promotes to the
107
- # default dtype as defined by dtypes.int_ or dtypes.uint.
108
- if dtypes .issubdtype (result_dtype , np .bool_ ):
109
- result_dtype = dtypes .int_
110
- elif dtypes .issubdtype (result_dtype , np .unsignedinteger ):
111
- if np .iinfo (result_dtype ).bits < np .iinfo (dtypes .uint ).bits :
112
- result_dtype = dtypes .uint
113
- elif dtypes .issubdtype (result_dtype , np .integer ):
114
- if np .iinfo (result_dtype ).bits < np .iinfo (dtypes .int_ ).bits :
115
- result_dtype = dtypes .int_
120
+ result_dtype = _promote_integer_dtype (result_dtype )
116
121
117
122
result_dtype = dtypes .canonicalize_dtype (result_dtype )
118
123
@@ -663,7 +668,8 @@ def __call__(self, a: ArrayLike, axis: Axis = None,
663
668
"""
664
669
665
670
def _make_cumulative_reduction (np_reduction : Any , reduction : Callable [..., Array ],
666
- fill_nan : bool = False , fill_value : ArrayLike = 0 ) -> CumulativeReduction :
671
+ fill_nan : bool = False , fill_value : ArrayLike = 0 ,
672
+ promote_integers : bool = False ) -> CumulativeReduction :
667
673
@implements (np_reduction , skip_params = ['out' ],
668
674
lax_description = CUML_REDUCTION_LAX_DESCRIPTION )
669
675
def cumulative_reduction (a : ArrayLike , axis : Axis = None ,
@@ -691,12 +697,18 @@ def _cumulative_reduction(a: ArrayLike, axis: Axis = None,
691
697
if fill_nan :
692
698
a = _where (lax_internal ._isnan (a ), _lax_const (a , fill_value ), a )
693
699
694
- if not dtype and dtypes .dtype (a ) == np .bool_ :
695
- dtype = dtypes .canonicalize_dtype (dtypes .int_ )
696
- if dtype :
697
- a = lax .convert_element_type (a , dtype )
700
+ result_type : DTypeLike = dtypes .dtype (dtype or a )
701
+ if dtype is None and promote_integers or dtypes .issubdtype (result_type , np .bool_ ):
702
+ result_type = _promote_integer_dtype (result_type )
703
+ result_type = dtypes .canonicalize_dtype (result_type )
704
+
705
+ a = lax .convert_element_type (a , result_type )
706
+ result = reduction (a , axis )
698
707
699
- return reduction (a , axis )
708
+ # We downcast to boolean because we accumulate in integer types
709
+ if dtypes .issubdtype (dtype , np .bool_ ):
710
+ result = lax .convert_element_type (result , np .bool_ )
711
+ return result
700
712
701
713
return cumulative_reduction
702
714
@@ -707,6 +719,9 @@ def _cumulative_reduction(a: ArrayLike, axis: Axis = None,
707
719
fill_nan = True , fill_value = 0 )
708
720
nancumprod = _make_cumulative_reduction (np .nancumprod , lax .cumprod ,
709
721
fill_nan = True , fill_value = 1 )
722
+ _cumsum_with_promotion = _make_cumulative_reduction (
723
+ np .cumsum , lax .cumsum , fill_nan = False , promote_integers = True
724
+ )
710
725
711
726
@implements (getattr (np , 'cumulative_sum' , None ))
712
727
def cumulative_sum (
@@ -730,12 +745,7 @@ def cumulative_sum(
730
745
731
746
axis = _canonicalize_axis (axis , x .ndim )
732
747
dtypes .check_user_dtype_supported (dtype )
733
- kind = x .dtype .kind
734
- if (dtype is None and kind in {'i' , 'u' }
735
- and x .dtype .itemsize * 8 < int (config .default_dtype_bits .value )):
736
- dtype = dtypes .canonicalize_dtype (dtypes ._default_types [kind ])
737
- x = x .astype (dtype = dtype or x .dtype )
738
- out = cumsum (x , axis = axis )
748
+ out = _cumsum_with_promotion (x , axis = axis , dtype = dtype )
739
749
if include_initial :
740
750
zeros_shape = list (x .shape )
741
751
zeros_shape [axis ] = 1
0 commit comments