Skip to content

Commit e115797

Browse files
committed
Refactored common upcast for integral-type accumulators
1 parent 7e20e53 commit e115797

File tree

4 files changed

+48
-54
lines changed

4 files changed

+48
-54
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3364,13 +3364,6 @@ def trace(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1,
33643364
dtypes.check_user_dtype_supported(dtype, "trace")
33653365

33663366
a_shape = shape(a)
3367-
if dtype is None:
3368-
dtype = _dtype(a)
3369-
if issubdtype(dtype, integer):
3370-
default_int = dtypes.canonicalize_dtype(int)
3371-
if iinfo(dtype).bits < iinfo(default_int).bits:
3372-
dtype = default_int
3373-
33743367
a = moveaxis(a, (axis1, axis2), (-2, -1))
33753368

33763369
# Mask out the diagonal and reduce.

jax/_src/numpy/reductions.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from jax import lax
2828
from jax._src import api
29-
from jax._src import core, config
29+
from jax._src import core
3030
from jax._src import dtypes
3131
from jax._src.numpy import ufuncs
3232
from jax._src.numpy.util import (
@@ -65,6 +65,18 @@ def _upcast_f16(dtype: DTypeLike) -> DType:
6565
return np.dtype('float32')
6666
return np.dtype(dtype)
6767

68+
def _promote_integer_dtype(dtype: DTypeLike) -> DTypeLike:
69+
# Note: NumPy always promotes to 64-bit; jax instead promotes to the
70+
# default dtype as defined by dtypes.int_ or dtypes.uint.
71+
if dtypes.issubdtype(dtype, np.unsignedinteger):
72+
if np.iinfo(dtype).bits < np.iinfo(dtypes.uint).bits:
73+
return dtypes.uint
74+
elif dtypes.issubdtype(dtype, np.integer):
75+
if np.iinfo(dtype).bits < np.iinfo(dtypes.int_).bits:
76+
return dtypes.int_
77+
return dtype
78+
79+
6880
ReductionOp = Callable[[Any, Any], Any]
6981

7082
def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: ArrayLike,
@@ -103,16 +115,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val:
103115
result_dtype = dtype or dtypes.dtype(a)
104116

105117
if dtype is None and promote_integers:
106-
# Note: NumPy always promotes to 64-bit; jax instead promotes to the
107-
# default dtype as defined by dtypes.int_ or dtypes.uint.
108-
if dtypes.issubdtype(result_dtype, np.bool_):
109-
result_dtype = dtypes.int_
110-
elif dtypes.issubdtype(result_dtype, np.unsignedinteger):
111-
if np.iinfo(result_dtype).bits < np.iinfo(dtypes.uint).bits:
112-
result_dtype = dtypes.uint
113-
elif dtypes.issubdtype(result_dtype, np.integer):
114-
if np.iinfo(result_dtype).bits < np.iinfo(dtypes.int_).bits:
115-
result_dtype = dtypes.int_
118+
result_dtype = _promote_integer_dtype(result_dtype)
116119

117120
result_dtype = dtypes.canonicalize_dtype(result_dtype)
118121

@@ -653,7 +656,8 @@ def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out:
653656

654657
class CumulativeReduction(Protocol):
655658
def __call__(self, a: ArrayLike, axis: Axis = None,
656-
dtype: DTypeLike | None = None, out: None = None) -> Array: ...
659+
dtype: DTypeLike | None = None, out: None = None,
660+
promote_integers: bool = False) -> Array: ...
657661

658662

659663
# TODO(jakevdp): should we change these semantics to match those of numpy?
@@ -667,12 +671,17 @@ def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array
667671
@implements(np_reduction, skip_params=['out'],
668672
lax_description=CUML_REDUCTION_LAX_DESCRIPTION)
669673
def cumulative_reduction(a: ArrayLike, axis: Axis = None,
670-
dtype: DTypeLike | None = None, out: None = None) -> Array:
671-
return _cumulative_reduction(a, _ensure_optional_axes(axis), dtype, out)
674+
dtype: DTypeLike | None = None, out: None = None,
675+
promote_integers: bool = False) -> Array:
676+
return _cumulative_reduction(
677+
a, _ensure_optional_axes(axis), dtype,
678+
out, promote_integers=promote_integers
679+
)
672680

673-
@partial(api.jit, static_argnames=('axis', 'dtype'))
681+
@partial(api.jit, static_argnames=('axis', 'dtype', 'promote_integers'))
674682
def _cumulative_reduction(a: ArrayLike, axis: Axis = None,
675-
dtype: DTypeLike | None = None, out: None = None) -> Array:
683+
dtype: DTypeLike | None = None, out: None = None,
684+
promote_integers: bool = False) -> Array:
676685
check_arraylike(np_reduction.__name__, a)
677686
if out is not None:
678687
raise NotImplementedError(f"The 'out' argument to jnp.{np_reduction.__name__} "
@@ -691,11 +700,13 @@ def _cumulative_reduction(a: ArrayLike, axis: Axis = None,
691700
if fill_nan:
692701
a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a)
693702

694-
if not dtype and dtypes.dtype(a) == np.bool_:
695-
dtype = dtypes.canonicalize_dtype(dtypes.int_)
703+
result_type = dtype or dtypes.dtype(a)
704+
if dtype is None and promote_integers:
705+
dtype = dtypes.canonicalize_dtype(_promote_integer_dtype(result_type))
696706
if dtype:
697707
a = lax.convert_element_type(a, dtype)
698708

709+
699710
return reduction(a, axis)
700711

701712
return cumulative_reduction
@@ -730,12 +741,12 @@ def cumulative_sum(
730741

731742
axis = _canonicalize_axis(axis, x.ndim)
732743
dtypes.check_user_dtype_supported(dtype)
733-
kind = x.dtype.kind
734-
if (dtype is None and kind in {'i', 'u'}
735-
and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)):
736-
dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind])
737-
x = x.astype(dtype=dtype or x.dtype)
738-
out = cumsum(x, axis=axis)
744+
if dtypes.issubdtype(dtype, np.bool_):
745+
raise ValueError(
746+
"cumulative_sum does not support boolean dtype output. Please select a "
747+
"numerical output instead."
748+
)
749+
out = cumsum(x, axis=axis, dtype=dtype, promote_integers=True)
739750
if include_initial:
740751
zeros_shape = list(x.shape)
741752
zeros_shape[axis] = 1

jax/experimental/array_api/_data_type_functions.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,3 @@ def finfo(type, /) -> FInfo:
7676
smallest_normal=float(info.smallest_normal),
7777
dtype=jnp.dtype(type)
7878
)
79-
80-
# TODO(micky774): Update utility to only promote integral types
81-
def _promote_to_default_dtype(x):
82-
if x.dtype.kind == 'b':
83-
return x
84-
elif x.dtype.kind == 'i':
85-
return x.astype(jnp.int_)
86-
elif x.dtype.kind == 'u':
87-
return x.astype(jnp.uint)
88-
elif x.dtype.kind == 'f':
89-
return x.astype(jnp.float_)
90-
elif x.dtype.kind == 'c':
91-
return x.astype(jnp.complex_)
92-
else:
93-
raise ValueError(f"Unrecognized {x.dtype=}")

tests/lax_numpy_reducers_test.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,7 @@ def test_f16_mean(self, dtype):
780780
for dtype in (all_dtypes+[None])
781781
for out_dtype in (
782782
complex_dtypes if np.issubdtype(dtype, np.complexfloating)
783-
else all_dtypes
783+
else number_dtypes
784784
)
785785
],
786786
include_initial=[False, True],
@@ -791,13 +791,8 @@ def testCumulativeSum(self, shape, axis, dtype, out_dtype, include_initial):
791791
rng = jtu.rand_some_zero(self.rng())
792792

793793
def np_mock_op(x, axis=None, dtype=None, include_initial=False):
794-
kind = x.dtype.kind
795-
if (dtype is None and kind in {'i', 'u'}
796-
and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)):
797-
dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind])
798794
axis = axis or 0
799-
x = x.astype(dtype=dtype or x.dtype)
800-
out = jnp.cumsum(x, axis=axis)
795+
out = np.cumsum(x, axis=axis, dtype=dtype or x.dtype)
801796
if include_initial:
802797
zeros_shape = list(x.shape)
803798
zeros_shape[axis] = 1
@@ -818,7 +813,7 @@ def np_mock_op(x, axis=None, dtype=None, include_initial=False):
818813

819814

820815
@jtu.sample_product(
821-
shape=filter(lambda x: len(x) != 1, all_shapes), dtype=all_dtypes,
816+
shape=filter(lambda x: len(x) != 1, all_shapes), dtype=number_dtypes,
822817
include_initial=[False, True])
823818
def testCumulativeSumErrors(self, shape, dtype, include_initial):
824819
rng = jtu.rand_some_zero(self.rng())
@@ -834,5 +829,15 @@ def testCumulativeSumErrors(self, shape, dtype, include_initial):
834829
jnp.cumulative_sum(x, include_initial=include_initial)
835830

836831

832+
@jtu.sample_product(
833+
shape=nonempty_nonscalar_array_shapes, dtype=all_dtypes)
834+
def testCumulativeSumBoolError(self, shape, dtype):
835+
rng = jtu.rand_some_zero(self.rng())
836+
x = rng(shape, dtype)
837+
msg = "cumulative_sum does not support boolean dtype output"
838+
with self.assertRaisesRegex(ValueError, msg):
839+
jnp.cumulative_sum(x, axis=0, dtype=jnp.bool_)
840+
841+
837842
if __name__ == "__main__":
838843
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)