26
26
27
27
from jax import lax
28
28
from jax ._src import api
29
- from jax ._src import core , config
29
+ from jax ._src import core
30
30
from jax ._src import dtypes
31
31
from jax ._src .numpy import ufuncs
32
32
from jax ._src .numpy .util import (
@@ -65,6 +65,18 @@ def _upcast_f16(dtype: DTypeLike) -> DType:
65
65
return np .dtype ('float32' )
66
66
return np .dtype (dtype )
67
67
68
+ def _promote_integer_dtype (dtype : DTypeLike ) -> DTypeLike :
69
+ # Note: NumPy always promotes to 64-bit; jax instead promotes to the
70
+ # default dtype as defined by dtypes.int_ or dtypes.uint.
71
+ if dtypes .issubdtype (dtype , np .unsignedinteger ):
72
+ if np .iinfo (dtype ).bits < np .iinfo (dtypes .uint ).bits :
73
+ return dtypes .uint
74
+ elif dtypes .issubdtype (dtype , np .integer ):
75
+ if np .iinfo (dtype ).bits < np .iinfo (dtypes .int_ ).bits :
76
+ return dtypes .int_
77
+ return dtype
78
+
79
+
68
80
ReductionOp = Callable [[Any , Any ], Any ]
69
81
70
82
def _reduction (a : ArrayLike , name : str , np_fun : Any , op : ReductionOp , init_val : ArrayLike ,
@@ -103,16 +115,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val:
103
115
result_dtype = dtype or dtypes .dtype (a )
104
116
105
117
if dtype is None and promote_integers :
106
- # Note: NumPy always promotes to 64-bit; jax instead promotes to the
107
- # default dtype as defined by dtypes.int_ or dtypes.uint.
108
- if dtypes .issubdtype (result_dtype , np .bool_ ):
109
- result_dtype = dtypes .int_
110
- elif dtypes .issubdtype (result_dtype , np .unsignedinteger ):
111
- if np .iinfo (result_dtype ).bits < np .iinfo (dtypes .uint ).bits :
112
- result_dtype = dtypes .uint
113
- elif dtypes .issubdtype (result_dtype , np .integer ):
114
- if np .iinfo (result_dtype ).bits < np .iinfo (dtypes .int_ ).bits :
115
- result_dtype = dtypes .int_
118
+ result_dtype = _promote_integer_dtype (result_dtype )
116
119
117
120
result_dtype = dtypes .canonicalize_dtype (result_dtype )
118
121
@@ -653,7 +656,8 @@ def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out:
653
656
654
657
class CumulativeReduction (Protocol ):
655
658
def __call__ (self , a : ArrayLike , axis : Axis = None ,
656
- dtype : DTypeLike | None = None , out : None = None ) -> Array : ...
659
+ dtype : DTypeLike | None = None , out : None = None ,
660
+ promote_integers : bool = False ) -> Array : ...
657
661
658
662
659
663
# TODO(jakevdp): should we change these semantics to match those of numpy?
@@ -667,12 +671,17 @@ def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array
667
671
@implements (np_reduction , skip_params = ['out' ],
668
672
lax_description = CUML_REDUCTION_LAX_DESCRIPTION )
669
673
def cumulative_reduction (a : ArrayLike , axis : Axis = None ,
670
- dtype : DTypeLike | None = None , out : None = None ) -> Array :
671
- return _cumulative_reduction (a , _ensure_optional_axes (axis ), dtype , out )
674
+ dtype : DTypeLike | None = None , out : None = None ,
675
+ promote_integers : bool = False ) -> Array :
676
+ return _cumulative_reduction (
677
+ a , _ensure_optional_axes (axis ), dtype ,
678
+ out , promote_integers = promote_integers
679
+ )
672
680
673
- @partial (api .jit , static_argnames = ('axis' , 'dtype' ))
681
+ @partial (api .jit , static_argnames = ('axis' , 'dtype' , 'promote_integers' ))
674
682
def _cumulative_reduction (a : ArrayLike , axis : Axis = None ,
675
- dtype : DTypeLike | None = None , out : None = None ) -> Array :
683
+ dtype : DTypeLike | None = None , out : None = None ,
684
+ promote_integers : bool = False ) -> Array :
676
685
check_arraylike (np_reduction .__name__ , a )
677
686
if out is not None :
678
687
raise NotImplementedError (f"The 'out' argument to jnp.{ np_reduction .__name__ } "
@@ -691,11 +700,15 @@ def _cumulative_reduction(a: ArrayLike, axis: Axis = None,
691
700
if fill_nan :
692
701
a = _where (lax_internal ._isnan (a ), _lax_const (a , fill_value ), a )
693
702
694
- if not dtype and dtypes .dtype (a ) == np .bool_ :
703
+ result_type = dtype or dtypes .dtype (a )
704
+ if result_type == np .bool_ :
695
705
dtype = dtypes .canonicalize_dtype (dtypes .int_ )
706
+ elif dtype is None and promote_integers :
707
+ dtype = _promote_integer_dtype (result_type )
696
708
if dtype :
697
709
a = lax .convert_element_type (a , dtype )
698
710
711
+
699
712
return reduction (a , axis )
700
713
701
714
return cumulative_reduction
@@ -730,12 +743,7 @@ def cumulative_sum(
730
743
731
744
axis = _canonicalize_axis (axis , x .ndim )
732
745
dtypes .check_user_dtype_supported (dtype )
733
- kind = x .dtype .kind
734
- if (dtype is None and kind in {'i' , 'u' }
735
- and x .dtype .itemsize * 8 < int (config .default_dtype_bits .value )):
736
- dtype = dtypes .canonicalize_dtype (dtypes ._default_types [kind ])
737
- x = x .astype (dtype = dtype or x .dtype )
738
- out = cumsum (x , axis = axis )
746
+ out = cumsum (x , axis = axis , dtype = dtype , promote_integers = True )
739
747
if include_initial :
740
748
zeros_shape = list (x .shape )
741
749
zeros_shape [axis ] = 1
0 commit comments