26
26
27
27
from jax import lax
28
28
from jax ._src import api
29
- from jax ._src import core , config
29
+ from jax ._src import core
30
30
from jax ._src import dtypes
31
31
from jax ._src .numpy import ufuncs
32
32
from jax ._src .numpy .util import (
@@ -65,6 +65,18 @@ def _upcast_f16(dtype: DTypeLike) -> DType:
65
65
return np .dtype ('float32' )
66
66
return np .dtype (dtype )
67
67
68
+ def _promote_integer_dtype (dtype : DTypeLike ) -> DTypeLike :
69
+ # Note: NumPy always promotes to 64-bit; jax instead promotes to the
70
+ # default dtype as defined by dtypes.int_ or dtypes.uint.
71
+ if dtypes .issubdtype (dtype , np .unsignedinteger ):
72
+ if np .iinfo (dtype ).bits < np .iinfo (dtypes .uint ).bits :
73
+ return dtypes .uint
74
+ elif dtypes .issubdtype (dtype , np .integer ):
75
+ if np .iinfo (dtype ).bits < np .iinfo (dtypes .int_ ).bits :
76
+ return dtypes .int_
77
+ return dtype
78
+
79
+
68
80
ReductionOp = Callable [[Any , Any ], Any ]
69
81
70
82
def _reduction (a : ArrayLike , name : str , np_fun : Any , op : ReductionOp , init_val : ArrayLike ,
@@ -103,16 +115,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val:
103
115
result_dtype = dtype or dtypes .dtype (a )
104
116
105
117
if dtype is None and promote_integers :
106
- # Note: NumPy always promotes to 64-bit; jax instead promotes to the
107
- # default dtype as defined by dtypes.int_ or dtypes.uint.
108
- if dtypes .issubdtype (result_dtype , np .bool_ ):
109
- result_dtype = dtypes .int_
110
- elif dtypes .issubdtype (result_dtype , np .unsignedinteger ):
111
- if np .iinfo (result_dtype ).bits < np .iinfo (dtypes .uint ).bits :
112
- result_dtype = dtypes .uint
113
- elif dtypes .issubdtype (result_dtype , np .integer ):
114
- if np .iinfo (result_dtype ).bits < np .iinfo (dtypes .int_ ).bits :
115
- result_dtype = dtypes .int_
118
+ result_dtype = _promote_integer_dtype (result_dtype )
116
119
117
120
result_dtype = dtypes .canonicalize_dtype (result_dtype )
118
121
@@ -653,7 +656,8 @@ def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out:
653
656
654
657
class CumulativeReduction (Protocol ):
655
658
def __call__ (self , a : ArrayLike , axis : Axis = None ,
656
- dtype : DTypeLike | None = None , out : None = None ) -> Array : ...
659
+ dtype : DTypeLike | None = None , out : None = None ,
660
+ promote_integers : bool = False ) -> Array : ...
657
661
658
662
659
663
# TODO(jakevdp): should we change these semantics to match those of numpy?
@@ -667,12 +671,17 @@ def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array
667
671
@implements (np_reduction , skip_params = ['out' ],
668
672
lax_description = CUML_REDUCTION_LAX_DESCRIPTION )
669
673
def cumulative_reduction (a : ArrayLike , axis : Axis = None ,
670
- dtype : DTypeLike | None = None , out : None = None ) -> Array :
671
- return _cumulative_reduction (a , _ensure_optional_axes (axis ), dtype , out )
674
+ dtype : DTypeLike | None = None , out : None = None ,
675
+ promote_integers : bool = False ) -> Array :
676
+ return _cumulative_reduction (
677
+ a , _ensure_optional_axes (axis ), dtype ,
678
+ out , promote_integers = promote_integers
679
+ )
672
680
673
- @partial (api .jit , static_argnames = ('axis' , 'dtype' ))
681
+ @partial (api .jit , static_argnames = ('axis' , 'dtype' , 'promote_integers' ))
674
682
def _cumulative_reduction (a : ArrayLike , axis : Axis = None ,
675
- dtype : DTypeLike | None = None , out : None = None ) -> Array :
683
+ dtype : DTypeLike | None = None , out : None = None ,
684
+ promote_integers : bool = False ) -> Array :
676
685
check_arraylike (np_reduction .__name__ , a )
677
686
if out is not None :
678
687
raise NotImplementedError (f"The 'out' argument to jnp.{ np_reduction .__name__ } "
@@ -691,11 +700,13 @@ def _cumulative_reduction(a: ArrayLike, axis: Axis = None,
691
700
if fill_nan :
692
701
a = _where (lax_internal ._isnan (a ), _lax_const (a , fill_value ), a )
693
702
694
- if not dtype and dtypes .dtype (a ) == np .bool_ :
695
- dtype = dtypes .canonicalize_dtype (dtypes .int_ )
703
+ result_type = dtype or dtypes .dtype (a )
704
+ if dtype is None and promote_integers :
705
+ dtype = dtypes .canonicalize_dtype (_promote_integer_dtype (result_type ))
696
706
if dtype :
697
707
a = lax .convert_element_type (a , dtype )
698
708
709
+
699
710
return reduction (a , axis )
700
711
701
712
return cumulative_reduction
@@ -730,12 +741,12 @@ def cumulative_sum(
730
741
731
742
axis = _canonicalize_axis (axis , x .ndim )
732
743
dtypes .check_user_dtype_supported (dtype )
733
- kind = x . dtype . kind
734
- if ( dtype is None and kind in { 'i' , 'u' }
735
- and x . dtype . itemsize * 8 < int ( config . default_dtype_bits . value )):
736
- dtype = dtypes . canonicalize_dtype ( dtypes . _default_types [ kind ])
737
- x = x . astype ( dtype = dtype or x . dtype )
738
- out = cumsum (x , axis = axis )
744
+ if dtypes . issubdtype ( dtype , np . bool_ ):
745
+ raise ValueError (
746
+ "cumulative_sum does not support boolean dtype output. Please select a "
747
+ "numerical output instead."
748
+ )
749
+ out = cumsum (x , axis = axis , dtype = dtype , promote_integers = True )
739
750
if include_initial :
740
751
zeros_shape = list (x .shape )
741
752
zeros_shape [axis ] = 1
0 commit comments