Skip to content

Commit ffb25b8

Browse files
committed
Add support for copy kwarg in astype to match Array API
1 parent 2215021 commit ffb25b8

File tree

1 file changed

+14
-17
lines changed

1 file changed

+14
-17
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2270,28 +2270,25 @@ def astype(x: ArrayLike, dtype: DTypeLike | None,
22702270
if dtype is None:
22712271
dtype = dtypes.canonicalize_dtype(float_)
22722272
dtypes.check_user_dtype_supported(dtype, "astype")
2273-
if (
2274-
issubdtype(x_arr.dtype, complexfloating)
2275-
and dtypes.isdtype(dtype, ("integral", "real floating"))
2276-
):
2277-
warnings.warn(
2278-
"Casting from complex to real dtypes will soon raise a ValueError. "
2279-
"Please first use jnp.real or jnp.imag to take the real/imaginary "
2280-
"component of your input.",
2281-
DeprecationWarning, stacklevel=2
2282-
)
2283-
2284-
out = x_arr
2285-
# convert_element_type(complex, bool) has the wrong semantics.
2286-
if np.dtype(dtype) == bool and issubdtype(x_arr.dtype, complexfloating):
2287-
out = (x_arr != _lax_const(x_arr, 0))
2288-
out = _place_array(out, device=device, copy=copy)
2273+
if issubdtype(x_arr.dtype, complexfloating):
2274+
if dtypes.isdtype(dtype, ("integral", "real floating")):
2275+
warnings.warn(
2276+
"Casting from complex to real dtypes will soon raise a ValueError. "
2277+
"Please first use jnp.real or jnp.imag to take the real/imaginary "
2278+
"component of your input.",
2279+
DeprecationWarning, stacklevel=2
2280+
)
2281+
elif np.dtype(dtype) == bool:
2282+
# convert_element_type(complex, bool) has the wrong semantics.
2283+
x_arr = (x_arr != _lax_const(x_arr, 0))
2284+
2285+
x_arr = _place_array(x_arr, device=device, copy=copy)
22892286

22902287
# We offer a more specific warning than the usual ComplexWarning so we prefer
22912288
# to issue our warning.
22922289
with warnings.catch_warnings():
22932290
warnings.simplefilter("ignore", ComplexWarning)
2294-
return lax.convert_element_type(out, dtype)
2291+
return lax.convert_element_type(x_arr, dtype)
22952292

22962293
def _place_array(x, device=None, copy=None):
22972294
# TODO(micky774): Implement in future PRs as we formalize device placement

0 commit comments

Comments
 (0)