Skip to content

Commit c246a97

Browse files
author
jax authors
committed
Merge pull request #20644 from jakevdp:complex-astype
PiperOrigin-RevId: 623323549
2 parents 763c6ff + e07325a commit c246a97

File tree

3 files changed

+26
-1
lines changed

3 files changed

+26
-1
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ Remember to align the itemized text with the first line of an item within a list
1313
now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover
1414
the old behavior by transforming the arguments via
1515
`jax.tree.map(np.asarray, args)` before passing them to the callback.
16+
* `complex_arr.astype(bool)` now follows the same semantics as NumPy, returning
17+
False where `complex_arr` is equal to `0 + 0j`, and True otherwise.
1618

1719
* Deprecations & Removals
1820
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old

jax/_src/numpy/lax_numpy.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2263,11 +2263,16 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
22632263
implementation dependent.
22642264
""")
22652265
def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array:
2266+
util.check_arraylike("astype", x)
2267+
x_arr = asarray(x)
22662268
del copy # unused in JAX
22672269
if dtype is None:
22682270
dtype = dtypes.canonicalize_dtype(float_)
22692271
dtypes.check_user_dtype_supported(dtype, "astype")
2270-
return lax.convert_element_type(x, dtype)
2272+
# convert_element_type(complex, bool) has the wrong semantics.
2273+
if np.dtype(dtype) == bool and issubdtype(x_arr.dtype, complexfloating):
2274+
return (x_arr != _lax_const(x_arr, 0))
2275+
return lax.convert_element_type(x_arr, dtype)
22712276

22722277

22732278
@util.implements(np.asarray, lax_description=_ARRAY_DOC)

tests/lax_numpy_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3822,6 +3822,24 @@ def testAstype(self, from_dtype, to_dtype, use_method):
38223822
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
38233823
self._CompileAndCheck(jnp_op, args_maker)
38243824

3825+
@jtu.sample_product(
3826+
from_dtype=['int32', 'float32', 'complex64'],
3827+
use_method=[True, False],
3828+
)
3829+
def testAstypeBool(self, from_dtype, use_method, to_dtype='bool'):
3830+
rng = jtu.rand_some_zero(self.rng())
3831+
args_maker = lambda: [rng((3, 4), from_dtype)]
3832+
if (not use_method) and hasattr(np, "astype"): # Added in numpy 2.0
3833+
np_op = lambda x: np.astype(x, to_dtype)
3834+
else:
3835+
np_op = lambda x: np.asarray(x).astype(to_dtype)
3836+
if use_method:
3837+
jnp_op = lambda x: jnp.asarray(x).astype(to_dtype)
3838+
else:
3839+
jnp_op = lambda x: jnp.astype(x, to_dtype)
3840+
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
3841+
self._CompileAndCheck(jnp_op, args_maker)
3842+
38253843
def testAstypeInt4(self):
38263844
# Test converting from int4 to int8
38273845
x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4)

0 commit comments

Comments
 (0)