Skip to content

Commit 75758b2

Browse files
committed
Refactored common upcast for integral-type accumulators
1 parent 9bf1148 commit 75758b2

File tree

6 files changed

+40
-54
lines changed

6 files changed

+40
-54
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3363,14 +3363,8 @@ def trace(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1,
33633363
raise NotImplementedError("The 'out' argument to jnp.trace is not supported.")
33643364
dtypes.check_user_dtype_supported(dtype, "trace")
33653365

3366+
a = asarray(a)
33663367
a_shape = shape(a)
3367-
if dtype is None:
3368-
dtype = _dtype(a)
3369-
if issubdtype(dtype, integer):
3370-
default_int = dtypes.canonicalize_dtype(int)
3371-
if iinfo(dtype).bits < iinfo(default_int).bits:
3372-
dtype = default_int
3373-
33743368
a = moveaxis(a, (axis1, axis2), (-2, -1))
33753369

33763370
# Mask out the diagonal and reduce.

jax/_src/numpy/reductions.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from jax import lax
2828
from jax._src import api
29-
from jax._src import core, config
29+
from jax._src import core
3030
from jax._src import dtypes
3131
from jax._src.numpy import ufuncs
3232
from jax._src.numpy.util import (
@@ -65,6 +65,20 @@ def _upcast_f16(dtype: DTypeLike) -> DType:
6565
return np.dtype('float32')
6666
return np.dtype(dtype)
6767

68+
def _promote_integer_dtype(dtype: DTypeLike) -> DTypeLike:
69+
# Note: NumPy always promotes to 64-bit; jax instead promotes to the
70+
# default dtype as defined by dtypes.int_ or dtypes.uint.
71+
if dtypes.issubdtype(dtype, np.bool_):
72+
return dtypes.int_
73+
elif dtypes.issubdtype(dtype, np.unsignedinteger):
74+
if np.iinfo(dtype).bits < np.iinfo(dtypes.uint).bits:
75+
return dtypes.uint
76+
elif dtypes.issubdtype(dtype, np.integer):
77+
if np.iinfo(dtype).bits < np.iinfo(dtypes.int_).bits:
78+
return dtypes.int_
79+
return dtype
80+
81+
6882
ReductionOp = Callable[[Any, Any], Any]
6983

7084
def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: ArrayLike,
@@ -103,16 +117,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val:
103117
result_dtype = dtype or dtypes.dtype(a)
104118

105119
if dtype is None and promote_integers:
106-
# Note: NumPy always promotes to 64-bit; jax instead promotes to the
107-
# default dtype as defined by dtypes.int_ or dtypes.uint.
108-
if dtypes.issubdtype(result_dtype, np.bool_):
109-
result_dtype = dtypes.int_
110-
elif dtypes.issubdtype(result_dtype, np.unsignedinteger):
111-
if np.iinfo(result_dtype).bits < np.iinfo(dtypes.uint).bits:
112-
result_dtype = dtypes.uint
113-
elif dtypes.issubdtype(result_dtype, np.integer):
114-
if np.iinfo(result_dtype).bits < np.iinfo(dtypes.int_).bits:
115-
result_dtype = dtypes.int_
120+
result_dtype = _promote_integer_dtype(result_dtype)
116121

117122
result_dtype = dtypes.canonicalize_dtype(result_dtype)
118123

@@ -653,7 +658,8 @@ def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out:
653658

654659
class CumulativeReduction(Protocol):
655660
def __call__(self, a: ArrayLike, axis: Axis = None,
656-
dtype: DTypeLike | None = None, out: None = None) -> Array: ...
661+
dtype: DTypeLike | None = None, out: None = None,
662+
promote_integers: bool = False) -> Array: ...
657663

658664

659665
# TODO(jakevdp): should we change these semantics to match those of numpy?
@@ -667,12 +673,17 @@ def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array
667673
@implements(np_reduction, skip_params=['out'],
668674
lax_description=CUML_REDUCTION_LAX_DESCRIPTION)
669675
def cumulative_reduction(a: ArrayLike, axis: Axis = None,
670-
dtype: DTypeLike | None = None, out: None = None) -> Array:
671-
return _cumulative_reduction(a, _ensure_optional_axes(axis), dtype, out)
676+
dtype: DTypeLike | None = None, out: None = None,
677+
promote_integers: bool = False) -> Array:
678+
return _cumulative_reduction(
679+
a, _ensure_optional_axes(axis), dtype,
680+
out, promote_integers=promote_integers
681+
)
672682

673-
@partial(api.jit, static_argnames=('axis', 'dtype'))
683+
@partial(api.jit, static_argnames=('axis', 'dtype', 'promote_integers'))
674684
def _cumulative_reduction(a: ArrayLike, axis: Axis = None,
675-
dtype: DTypeLike | None = None, out: None = None) -> Array:
685+
dtype: DTypeLike | None = None, out: None = None,
686+
promote_integers: bool = False) -> Array:
676687
check_arraylike(np_reduction.__name__, a)
677688
if out is not None:
678689
raise NotImplementedError(f"The 'out' argument to jnp.{np_reduction.__name__} "
@@ -691,11 +702,15 @@ def _cumulative_reduction(a: ArrayLike, axis: Axis = None,
691702
if fill_nan:
692703
a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a)
693704

694-
if not dtype and dtypes.dtype(a) == np.bool_:
705+
result_type = dtype or dtypes.dtype(a)
706+
if result_type == np.bool_:
695707
dtype = dtypes.canonicalize_dtype(dtypes.int_)
708+
elif dtype is None and promote_integers:
709+
dtype = _promote_integer_dtype(result_type)
696710
if dtype:
697711
a = lax.convert_element_type(a, dtype)
698712

713+
699714
return reduction(a, axis)
700715

701716
return cumulative_reduction
@@ -730,12 +745,7 @@ def cumulative_sum(
730745

731746
axis = _canonicalize_axis(axis, x.ndim)
732747
dtypes.check_user_dtype_supported(dtype)
733-
kind = x.dtype.kind
734-
if (dtype is None and kind in {'i', 'u'}
735-
and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)):
736-
dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind])
737-
x = x.astype(dtype=dtype or x.dtype)
738-
out = cumsum(x, axis=axis)
748+
out = cumsum(x, axis=axis, dtype=dtype, promote_integers=True)
739749
if include_initial:
740750
zeros_shape = list(x.shape)
741751
zeros_shape[axis] = 1

jax/experimental/array_api/_data_type_functions.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -212,18 +212,3 @@ def result_type(*arrays_and_dtypes):
212212
if len(dtypes) == 1:
213213
return dtypes[0]
214214
return functools.reduce(_promote_types, dtypes)
215-
216-
217-
def _promote_to_default_dtype(x):
218-
if x.dtype.kind == 'b':
219-
return x
220-
elif x.dtype.kind == 'i':
221-
return x.astype(jnp.int_)
222-
elif x.dtype.kind == 'u':
223-
return x.astype(jnp.uint)
224-
elif x.dtype.kind == 'f':
225-
return x.astype(jnp.float_)
226-
elif x.dtype.kind == 'c':
227-
return x.astype(jnp.complex_)
228-
else:
229-
raise ValueError(f"Unrecognized {x.dtype=}")

jax/experimental/array_api/_linear_algebra_functions.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
# limitations under the License.
1414

1515
import jax
16-
from jax.experimental.array_api._data_type_functions import (
17-
_promote_to_default_dtype,
18-
)
1916

2017
def cholesky(x, /, *, upper=False):
2118
"""
@@ -140,7 +137,6 @@ def trace(x, /, *, offset=0, dtype=None):
140137
"""
141138
Returns the sum along the specified diagonals of a matrix (or a stack of matrices) x.
142139
"""
143-
x = _promote_to_default_dtype(x)
144140
return jax.numpy.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1)
145141

146142
def vecdot(x1, x2, /, *, axis=-1):

jax/experimental/array_api/_statistical_functions.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
# limitations under the License.
1414

1515
import jax
16-
from jax.experimental.array_api._data_type_functions import (
17-
_promote_to_default_dtype,
18-
)
1916

2017

2118
def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False):
@@ -39,7 +36,6 @@ def min(x, /, *, axis=None, keepdims=False):
3936

4037
def prod(x, /, *, axis=None, dtype=None, keepdims=False):
4138
"""Calculates the product of input array x elements."""
42-
x = _promote_to_default_dtype(x)
4339
return jax.numpy.prod(x, axis=axis, dtype=dtype, keepdims=keepdims)
4440

4541

@@ -50,7 +46,6 @@ def std(x, /, *, axis=None, correction=0.0, keepdims=False):
5046

5147
def sum(x, /, *, axis=None, dtype=None, keepdims=False):
5248
"""Calculates the sum of the input array x."""
53-
x = _promote_to_default_dtype(x)
5449
return jax.numpy.sum(x, axis=axis, dtype=dtype, keepdims=keepdims)
5550

5651

jax/experimental/array_api/skips.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,9 @@ array_api_tests/test_special_cases.py::test_unary
1111

1212
# fft test suite is buggy as of 83f0bcdc
1313
array_api_tests/test_fft.py
14+
15+
# Pending implementation update for proper dtype promotion behavior,
16+
# see https://github.com/data-apis/array-api-tests/issues/234
17+
array_api_tests/test_statistical_functions.py::test_sum
18+
array_api_tests/test_statistical_functions.py::test_prod
19+
array_api_tests/test_linalg.py::test_trace

0 commit comments

Comments
 (0)