Skip to content

Refactored common upcast for integral-type accumulators #20842

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3364,13 +3364,6 @@ def trace(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1,
dtypes.check_user_dtype_supported(dtype, "trace")

a_shape = shape(a)
if dtype is None:
dtype = _dtype(a)
if issubdtype(dtype, integer):
default_int = dtypes.canonicalize_dtype(int)
if iinfo(dtype).bits < iinfo(default_int).bits:
dtype = default_int

a = moveaxis(a, (axis1, axis2), (-2, -1))

# Mask out the diagonal and reduce.
Expand Down
56 changes: 33 additions & 23 deletions jax/_src/numpy/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from jax import lax
from jax._src import api
from jax._src import core, config
from jax._src import core
from jax._src import dtypes
from jax._src.numpy import ufuncs
from jax._src.numpy.util import (
Expand Down Expand Up @@ -65,6 +65,20 @@ def _upcast_f16(dtype: DTypeLike) -> DType:
return np.dtype('float32')
return np.dtype(dtype)

def _promote_integer_dtype(dtype: DTypeLike) -> DTypeLike:
# Note: NumPy always promotes to 64-bit; jax instead promotes to the
# default dtype as defined by dtypes.int_ or dtypes.uint.
if dtypes.issubdtype(dtype, np.bool_):
return dtypes.int_
elif dtypes.issubdtype(dtype, np.unsignedinteger):
if np.iinfo(dtype).bits < np.iinfo(dtypes.uint).bits:
return dtypes.uint
elif dtypes.issubdtype(dtype, np.integer):
if np.iinfo(dtype).bits < np.iinfo(dtypes.int_).bits:
return dtypes.int_
return dtype


ReductionOp = Callable[[Any, Any], Any]

def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: ArrayLike,
Expand Down Expand Up @@ -103,16 +117,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val:
result_dtype = dtype or dtypes.dtype(a)

if dtype is None and promote_integers:
# Note: NumPy always promotes to 64-bit; jax instead promotes to the
# default dtype as defined by dtypes.int_ or dtypes.uint.
if dtypes.issubdtype(result_dtype, np.bool_):
result_dtype = dtypes.int_
elif dtypes.issubdtype(result_dtype, np.unsignedinteger):
if np.iinfo(result_dtype).bits < np.iinfo(dtypes.uint).bits:
result_dtype = dtypes.uint
elif dtypes.issubdtype(result_dtype, np.integer):
if np.iinfo(result_dtype).bits < np.iinfo(dtypes.int_).bits:
result_dtype = dtypes.int_
result_dtype = _promote_integer_dtype(result_dtype)

result_dtype = dtypes.canonicalize_dtype(result_dtype)

Expand Down Expand Up @@ -663,7 +668,8 @@ def __call__(self, a: ArrayLike, axis: Axis = None,
"""

def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array],
fill_nan: bool = False, fill_value: ArrayLike = 0) -> CumulativeReduction:
fill_nan: bool = False, fill_value: ArrayLike = 0,
promote_integers: bool = False) -> CumulativeReduction:
@implements(np_reduction, skip_params=['out'],
lax_description=CUML_REDUCTION_LAX_DESCRIPTION)
def cumulative_reduction(a: ArrayLike, axis: Axis = None,
Expand Down Expand Up @@ -691,12 +697,18 @@ def _cumulative_reduction(a: ArrayLike, axis: Axis = None,
if fill_nan:
a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a)

if not dtype and dtypes.dtype(a) == np.bool_:
dtype = dtypes.canonicalize_dtype(dtypes.int_)
if dtype:
a = lax.convert_element_type(a, dtype)
result_type: DTypeLike = dtypes.dtype(dtype or a)
if dtype is None and promote_integers or dtypes.issubdtype(result_type, np.bool_):
result_type = _promote_integer_dtype(result_type)
result_type = dtypes.canonicalize_dtype(result_type)

a = lax.convert_element_type(a, result_type)
result = reduction(a, axis)

return reduction(a, axis)
# We downcast to boolean because we accumulate in integer types
if dtypes.issubdtype(dtype, np.bool_):
result = lax.convert_element_type(result, np.bool_)
return result

return cumulative_reduction

Expand All @@ -707,6 +719,9 @@ def _cumulative_reduction(a: ArrayLike, axis: Axis = None,
fill_nan=True, fill_value=0)
nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod,
fill_nan=True, fill_value=1)
_cumsum_with_promotion = _make_cumulative_reduction(
np.cumsum, lax.cumsum, fill_nan=False, promote_integers=True
)

@implements(getattr(np, 'cumulative_sum', None))
def cumulative_sum(
Expand All @@ -730,12 +745,7 @@ def cumulative_sum(

axis = _canonicalize_axis(axis, x.ndim)
dtypes.check_user_dtype_supported(dtype)
kind = x.dtype.kind
if (dtype is None and kind in {'i', 'u'}
and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)):
dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind])
x = x.astype(dtype=dtype or x.dtype)
out = cumsum(x, axis=axis)
out = _cumsum_with_promotion(x, axis=axis, dtype=dtype)
if include_initial:
zeros_shape = list(x.shape)
zeros_shape[axis] = 1
Expand Down
15 changes: 0 additions & 15 deletions jax/experimental/array_api/_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,3 @@ def finfo(type, /) -> FInfo:
smallest_normal=float(info.smallest_normal),
dtype=jnp.dtype(type)
)

# TODO(micky774): Update utility to only promote integral types
def _promote_to_default_dtype(x):
if x.dtype.kind == 'b':
return x
elif x.dtype.kind == 'i':
return x.astype(jnp.int_)
elif x.dtype.kind == 'u':
return x.astype(jnp.uint)
elif x.dtype.kind == 'f':
return x.astype(jnp.float_)
elif x.dtype.kind == 'c':
return x.astype(jnp.complex_)
else:
raise ValueError(f"Unrecognized {x.dtype=}")
7 changes: 1 addition & 6 deletions tests/lax_numpy_reducers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,13 +791,8 @@ def testCumulativeSum(self, shape, axis, dtype, out_dtype, include_initial):
rng = jtu.rand_some_zero(self.rng())

def np_mock_op(x, axis=None, dtype=None, include_initial=False):
kind = x.dtype.kind
if (dtype is None and kind in {'i', 'u'}
and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)):
dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind])
axis = axis or 0
x = x.astype(dtype=dtype or x.dtype)
out = jnp.cumsum(x, axis=axis)
out = np.cumsum(x, axis=axis, dtype=dtype or x.dtype)
if include_initial:
zeros_shape = list(x.shape)
zeros_shape[axis] = 1
Expand Down