26
26
27
27
from jax import lax
28
28
from jax ._src import api
29
- from jax ._src import core , config
29
+ from jax ._src import core
30
30
from jax ._src import dtypes
31
31
from jax ._src .numpy import ufuncs
32
32
from jax ._src .numpy .util import (
@@ -65,6 +65,20 @@ def _upcast_f16(dtype: DTypeLike) -> DType:
65
65
return np .dtype ('float32' )
66
66
return np .dtype (dtype )
67
67
68
+ def _promote_integer_dtype (dtype : DTypeLike ) -> DTypeLike :
69
+ # Note: NumPy always promotes to 64-bit; jax instead promotes to the
70
+ # default dtype as defined by dtypes.int_ or dtypes.uint.
71
+ if dtypes .issubdtype (dtype , np .bool_ ):
72
+ return dtypes .int_
73
+ elif dtypes .issubdtype (dtype , np .unsignedinteger ):
74
+ if np .iinfo (dtype ).bits < np .iinfo (dtypes .uint ).bits :
75
+ return dtypes .uint
76
+ elif dtypes .issubdtype (dtype , np .integer ):
77
+ if np .iinfo (dtype ).bits < np .iinfo (dtypes .int_ ).bits :
78
+ return dtypes .int_
79
+ return dtype
80
+
81
+
68
82
ReductionOp = Callable [[Any , Any ], Any ]
69
83
70
84
def _reduction (a : ArrayLike , name : str , np_fun : Any , op : ReductionOp , init_val : ArrayLike ,
@@ -103,16 +117,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val:
103
117
result_dtype = dtype or dtypes .dtype (a )
104
118
105
119
if dtype is None and promote_integers :
106
- # Note: NumPy always promotes to 64-bit; jax instead promotes to the
107
- # default dtype as defined by dtypes.int_ or dtypes.uint.
108
- if dtypes .issubdtype (result_dtype , np .bool_ ):
109
- result_dtype = dtypes .int_
110
- elif dtypes .issubdtype (result_dtype , np .unsignedinteger ):
111
- if np .iinfo (result_dtype ).bits < np .iinfo (dtypes .uint ).bits :
112
- result_dtype = dtypes .uint
113
- elif dtypes .issubdtype (result_dtype , np .integer ):
114
- if np .iinfo (result_dtype ).bits < np .iinfo (dtypes .int_ ).bits :
115
- result_dtype = dtypes .int_
120
+ result_dtype = _promote_integer_dtype (result_dtype )
116
121
117
122
result_dtype = dtypes .canonicalize_dtype (result_dtype )
118
123
@@ -653,7 +658,8 @@ def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out:
653
658
654
659
class CumulativeReduction (Protocol ):
655
660
def __call__ (self , a : ArrayLike , axis : Axis = None ,
656
- dtype : DTypeLike | None = None , out : None = None ) -> Array : ...
661
+ dtype : DTypeLike | None = None , out : None = None ,
662
+ promote_integers : bool = False ) -> Array : ...
657
663
658
664
659
665
# TODO(jakevdp): should we change these semantics to match those of numpy?
@@ -667,12 +673,17 @@ def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array
667
673
@implements (np_reduction , skip_params = ['out' ],
668
674
lax_description = CUML_REDUCTION_LAX_DESCRIPTION )
669
675
def cumulative_reduction (a : ArrayLike , axis : Axis = None ,
670
- dtype : DTypeLike | None = None , out : None = None ) -> Array :
671
- return _cumulative_reduction (a , _ensure_optional_axes (axis ), dtype , out )
676
+ dtype : DTypeLike | None = None , out : None = None ,
677
+ promote_integers : bool = False ) -> Array :
678
+ return _cumulative_reduction (
679
+ a , _ensure_optional_axes (axis ), dtype ,
680
+ out , promote_integers = promote_integers
681
+ )
672
682
673
- @partial (api .jit , static_argnames = ('axis' , 'dtype' ))
683
+ @partial (api .jit , static_argnames = ('axis' , 'dtype' , 'promote_integers' ))
674
684
def _cumulative_reduction (a : ArrayLike , axis : Axis = None ,
675
- dtype : DTypeLike | None = None , out : None = None ) -> Array :
685
+ dtype : DTypeLike | None = None , out : None = None ,
686
+ promote_integers : bool = False ) -> Array :
676
687
check_arraylike (np_reduction .__name__ , a )
677
688
if out is not None :
678
689
raise NotImplementedError (f"The 'out' argument to jnp.{ np_reduction .__name__ } "
@@ -691,11 +702,15 @@ def _cumulative_reduction(a: ArrayLike, axis: Axis = None,
691
702
if fill_nan :
692
703
a = _where (lax_internal ._isnan (a ), _lax_const (a , fill_value ), a )
693
704
694
- if not dtype and dtypes .dtype (a ) == np .bool_ :
705
+ result_type = dtype or dtypes .dtype (a )
706
+ if result_type == np .bool_ :
695
707
dtype = dtypes .canonicalize_dtype (dtypes .int_ )
708
+ elif dtype is None and promote_integers :
709
+ dtype = _promote_integer_dtype (result_type )
696
710
if dtype :
697
711
a = lax .convert_element_type (a , dtype )
698
712
713
+
699
714
return reduction (a , axis )
700
715
701
716
return cumulative_reduction
@@ -730,12 +745,7 @@ def cumulative_sum(
730
745
731
746
axis = _canonicalize_axis (axis , x .ndim )
732
747
dtypes .check_user_dtype_supported (dtype )
733
- kind = x .dtype .kind
734
- if (dtype is None and kind in {'i' , 'u' }
735
- and x .dtype .itemsize * 8 < int (config .default_dtype_bits .value )):
736
- dtype = dtypes .canonicalize_dtype (dtypes ._default_types [kind ])
737
- x = x .astype (dtype = dtype or x .dtype )
738
- out = cumsum (x , axis = axis )
748
+ out = cumsum (x , axis = axis , dtype = dtype , promote_integers = True )
739
749
if include_initial :
740
750
zeros_shape = list (x .shape )
741
751
zeros_shape [axis ] = 1
0 commit comments