Skip to content

Commit 6b249f0

Browse files
committed
Refactored common upcast for integral-type accumulators
1 parent 9bf1148 commit 6b249f0

File tree

7 files changed

+17
-37
lines changed

7 files changed

+17
-37
lines changed

jax/_src/dtypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ def dtype(x: Any, *, canonicalize: bool = False) -> DType:
662662
elif type(x) in python_scalar_dtypes:
663663
dt = python_scalar_dtypes[type(x)]
664664
elif is_type and _issubclass(x, np.generic):
665-
return np.dtype(x)
665+
dt = np.dtype(x)
666666
elif issubdtype(getattr(x, 'dtype', None), extended):
667667
dt = x.dtype
668668
else:

jax/_src/numpy/lax_numpy.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3363,14 +3363,8 @@ def trace(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1,
33633363
raise NotImplementedError("The 'out' argument to jnp.trace is not supported.")
33643364
dtypes.check_user_dtype_supported(dtype, "trace")
33653365

3366+
a = asarray(a)
33663367
a_shape = shape(a)
3367-
if dtype is None:
3368-
dtype = _dtype(a)
3369-
if issubdtype(dtype, integer):
3370-
default_int = dtypes.canonicalize_dtype(int)
3371-
if iinfo(dtype).bits < iinfo(default_int).bits:
3372-
dtype = default_int
3373-
33743368
a = moveaxis(a, (axis1, axis2), (-2, -1))
33753369

33763370
# Mask out the diagonal and reduce.

jax/_src/numpy/reductions.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -730,11 +730,15 @@ def cumulative_sum(
730730

731731
axis = _canonicalize_axis(axis, x.ndim)
732732
dtypes.check_user_dtype_supported(dtype)
733-
kind = x.dtype.kind
734-
if (dtype is None and kind in {'i', 'u'}
735-
and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)):
736-
dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind])
737-
x = x.astype(dtype=dtype or x.dtype)
733+
if dtype is None:
734+
kind = x.dtype.kind
735+
if kind in {'i', 'u'}:
736+
default_dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind])
737+
if x.dtype.itemsize*8 < int(config.default_dtype_bits.value):
738+
dtype = default_dtype
739+
740+
if dtype:
741+
x = x.astype(dtype=dtype)
738742
out = cumsum(x, axis=axis)
739743
if include_initial:
740744
zeros_shape = list(x.shape)

jax/experimental/array_api/_data_type_functions.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -212,18 +212,3 @@ def result_type(*arrays_and_dtypes):
212212
if len(dtypes) == 1:
213213
return dtypes[0]
214214
return functools.reduce(_promote_types, dtypes)
215-
216-
217-
def _promote_to_default_dtype(x):
218-
if x.dtype.kind == 'b':
219-
return x
220-
elif x.dtype.kind == 'i':
221-
return x.astype(jnp.int_)
222-
elif x.dtype.kind == 'u':
223-
return x.astype(jnp.uint)
224-
elif x.dtype.kind == 'f':
225-
return x.astype(jnp.float_)
226-
elif x.dtype.kind == 'c':
227-
return x.astype(jnp.complex_)
228-
else:
229-
raise ValueError(f"Unrecognized {x.dtype=}")

jax/experimental/array_api/_linear_algebra_functions.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
# limitations under the License.
1414

1515
import jax
16-
from jax.experimental.array_api._data_type_functions import (
17-
_promote_to_default_dtype,
18-
)
1916

2017
def cholesky(x, /, *, upper=False):
2118
"""
@@ -140,7 +137,6 @@ def trace(x, /, *, offset=0, dtype=None):
140137
"""
141138
Returns the sum along the specified diagonals of a matrix (or a stack of matrices) x.
142139
"""
143-
x = _promote_to_default_dtype(x)
144140
return jax.numpy.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1)
145141

146142
def vecdot(x1, x2, /, *, axis=-1):

jax/experimental/array_api/_statistical_functions.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
# limitations under the License.
1414

1515
import jax
16-
from jax.experimental.array_api._data_type_functions import (
17-
_promote_to_default_dtype,
18-
)
1916

2017

2118
def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False):
@@ -39,7 +36,6 @@ def min(x, /, *, axis=None, keepdims=False):
3936

4037
def prod(x, /, *, axis=None, dtype=None, keepdims=False):
4138
"""Calculates the product of input array x elements."""
42-
x = _promote_to_default_dtype(x)
4339
return jax.numpy.prod(x, axis=axis, dtype=dtype, keepdims=keepdims)
4440

4541

@@ -50,7 +46,6 @@ def std(x, /, *, axis=None, correction=0.0, keepdims=False):
5046

5147
def sum(x, /, *, axis=None, dtype=None, keepdims=False):
5248
"""Calculates the sum of the input array x."""
53-
x = _promote_to_default_dtype(x)
5449
return jax.numpy.sum(x, axis=axis, dtype=dtype, keepdims=keepdims)
5550

5651

jax/experimental/array_api/skips.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,9 @@ array_api_tests/test_special_cases.py::test_unary
1111

1212
# fft test suite is buggy as of 83f0bcdc
1313
array_api_tests/test_fft.py
14+
15+
# Pending implementation update for proper dtype promotion behavior,
16+
# see https://github.com/data-apis/array-api-tests/issues/234
17+
array_api_tests/test_statistical_functions.py::test_sum
18+
array_api_tests/test_statistical_functions.py::test_prod
19+
array_api_tests/test_linalg.py::test_trace

0 commit comments

Comments
 (0)