Skip to content

Commit 93cda81

Browse files
committed
Refactored common upcast for integral-type accumulators
1 parent b1cb90c commit 93cda81

File tree

7 files changed

+34
-37
lines changed

7 files changed

+34
-37
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3024,14 +3024,10 @@ def trace(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1,
30243024
raise NotImplementedError("The 'out' argument to jnp.trace is not supported.")
30253025
dtypes.check_user_dtype_supported(dtype, "trace")
30263026

3027+
a = asarray(a)
30273028
a_shape = shape(a)
30283029
if dtype is None:
3029-
dtype = _dtype(a)
3030-
if issubdtype(dtype, integer):
3031-
default_int = dtypes.canonicalize_dtype(int)
3032-
if iinfo(dtype).bits < iinfo(default_int).bits:
3033-
dtype = default_int
3034-
3030+
a = util.promote_dtypes_integral_default(a)[0]
30353031
a = moveaxis(a, (axis1, axis2), (-2, -1))
30363032

30373033
# Mask out the diagonal and reduce.

jax/_src/numpy/reductions.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@
2626

2727
from jax import lax
2828
from jax._src import api
29-
from jax._src import core, config
29+
from jax._src import core
3030
from jax._src import dtypes
3131
from jax._src.numpy import ufuncs
3232
from jax._src.numpy.util import (
3333
_broadcast_to, check_arraylike, _complex_elem_type,
34-
promote_dtypes_inexact, promote_dtypes_numeric, _where, implements)
34+
promote_dtypes_inexact, promote_dtypes_numeric, _where, implements,
35+
promote_dtypes_integral_default, )
3536
from jax._src.lax import lax as lax_internal
3637
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
3738
from jax._src.util import (
@@ -730,11 +731,10 @@ def cumulative_sum(
730731

731732
axis = _canonicalize_axis(axis, x.ndim)
732733
dtypes.check_user_dtype_supported(dtype)
733-
kind = x.dtype.kind
734-
if (dtype is None and kind in {'i', 'u'}
735-
and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)):
736-
dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind])
737-
x = x.astype(dtype=dtype or x.dtype)
734+
if dtype is None:
735+
x = promote_dtypes_integral_default(x)[0]
736+
else:
737+
x = x.astype(dtype=dtype)
738738
out = cumsum(x, axis=axis)
739739
if include_initial:
740740
zeros_shape = list(x.shape)

jax/_src/numpy/util.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,22 @@ def promote_dtypes_complex(*args: ArrayLike) -> list[Array]:
311311
for x in args]
312312

313313

314+
def promote_dtypes_integral_default(*args: ArrayLike) -> list[Array]:
315+
"""Convenience function to apply default promotion to integral accumulators.
316+
317+
Promotes arguments to their corresponding default integral type, or returns
318+
the arguments unchanged."""
319+
def _promote(x: ArrayLike) -> Array:
320+
x = lax.asarray(x)
321+
kind = x.dtype.kind
322+
if kind in {'i', 'u'}:
323+
default_dtype = dtypes.dtype(dtypes._default_types[kind])
324+
if x.dtype.itemsize*8 < default_dtype.itemsize:
325+
return x.astype(default_dtype)
326+
return x
327+
return [_promote(x) for x in args]
328+
329+
314330
def _complex_elem_type(dtype: DTypeLike) -> DType:
315331
"""Returns the float type of the real/imaginary parts of a complex dtype."""
316332
return np.abs(np.zeros((), dtype)).dtype

jax/experimental/array_api/_data_type_functions.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -212,18 +212,3 @@ def result_type(*arrays_and_dtypes):
212212
if len(dtypes) == 1:
213213
return dtypes[0]
214214
return functools.reduce(_promote_types, dtypes)
215-
216-
217-
def _promote_to_default_dtype(x):
218-
if x.dtype.kind == 'b':
219-
return x
220-
elif x.dtype.kind == 'i':
221-
return x.astype(jnp.int_)
222-
elif x.dtype.kind == 'u':
223-
return x.astype(jnp.uint)
224-
elif x.dtype.kind == 'f':
225-
return x.astype(jnp.float_)
226-
elif x.dtype.kind == 'c':
227-
return x.astype(jnp.complex_)
228-
else:
229-
raise ValueError(f"Unrecognized {x.dtype=}")

jax/experimental/array_api/_linear_algebra_functions.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
# limitations under the License.
1414

1515
import jax
16-
from jax.experimental.array_api._data_type_functions import (
17-
_promote_to_default_dtype,
18-
)
16+
from jax._src.numpy.util import promote_dtypes_integral_default
1917

2018
def cholesky(x, /, *, upper=False):
2119
"""
@@ -140,7 +138,8 @@ def trace(x, /, *, offset=0, dtype=None):
140138
"""
141139
Returns the sum along the specified diagonals of a matrix (or a stack of matrices) x.
142140
"""
143-
x = _promote_to_default_dtype(x)
141+
if dtype is None:
142+
x = promote_dtypes_integral_default(x)[0]
144143
return jax.numpy.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1)
145144

146145
def vecdot(x1, x2, /, *, axis=-1):

jax/experimental/array_api/_statistical_functions.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
# limitations under the License.
1414

1515
import jax
16-
from jax.experimental.array_api._data_type_functions import (
17-
_promote_to_default_dtype,
18-
)
1916

2017

2118
def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False):
@@ -39,7 +36,6 @@ def min(x, /, *, axis=None, keepdims=False):
3936

4037
def prod(x, /, *, axis=None, dtype=None, keepdims=False):
4138
"""Calculates the product of input array x elements."""
42-
x = _promote_to_default_dtype(x)
4339
return jax.numpy.prod(x, axis=axis, dtype=dtype, keepdims=keepdims)
4440

4541

@@ -50,7 +46,6 @@ def std(x, /, *, axis=None, correction=0.0, keepdims=False):
5046

5147
def sum(x, /, *, axis=None, dtype=None, keepdims=False):
5248
"""Calculates the sum of the input array x."""
53-
x = _promote_to_default_dtype(x)
5449
return jax.numpy.sum(x, axis=axis, dtype=dtype, keepdims=keepdims)
5550

5651

jax/experimental/array_api/skips.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,9 @@ array_api_tests/test_special_cases.py::test_unary
1111

1212
# fft test suite is buggy as of 83f0bcdc
1313
array_api_tests/test_fft.py
14+
15+
# Pending implementation update for proper dtype promotion behavior,
16+
# see https://github.com/data-apis/array-api-tests/issues/234
17+
array_api_tests/test_statistical_functions.py::test_sum
18+
array_api_tests/test_statistical_functions.py::test_prod
19+
array_api_tests/test_linalg.py::test_trace

0 commit comments

Comments
 (0)