Skip to content

Commit 2395e76

Browse files
committed
Add support for device kwarg in astype to match Array API
1 parent 281f5ac commit 2395e76

File tree

4 files changed

+68
-6
lines changed

4 files changed

+68
-6
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list
88

99
## jax 0.4.26
1010

11+
* New Features
12+
* {func}`jax.numpy.astype` supports new `device` keyword argument.
13+
1114
* Deprecations & Removals
1215
* {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward
1316
compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`.
@@ -28,6 +31,12 @@ Remember to align the itemized text with the first line of an item within a list
2831
`jax.interpreters.ad.source_info_util` have now been removed. Use `jax.config`
2932
and `jax.extend.source_info_util` instead.
3033

34+
* Bug fixes
35+
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.
36+
Previously, no copy would be made when the output array would have the same
37+
dtype as the input array. This may result in some increased memory usage.
38+
To prevent copying when possible, set `copy=False`.
39+
3140
## jaxlib 0.4.26
3241

3342
## jax 0.4.25 (Feb 26, 2024)

jax/_src/numpy/array_methods.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,13 @@
3131
import numpy as np
3232
import jax
3333
from jax import lax
34+
from jax.sharding import Sharding
3435
from jax._src import core
3536
from jax._src import dtypes
3637
from jax._src.api_util import _ensure_index_tuple
3738
from jax._src.array import ArrayImpl
3839
from jax._src.lax import lax as lax_internal
40+
from jax._src.lib import xla_client as xc
3941
from jax._src.numpy import lax_numpy
4042
from jax._src.numpy import reductions
4143
from jax._src.numpy import ufuncs
@@ -55,15 +57,15 @@
5557
# functions, which can themselves handle instances from any of these classes.
5658

5759

58-
def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array:
60+
def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = True, device: xc.Device | Sharding | None = None) -> Array:
5961
"""Copy the array and cast to a specified dtype.
6062
6163
This is implemented via :func:`jax.lax.convert_element_type`, which may
6264
have slightly different behavior than :meth:`numpy.ndarray.astype` in
6365
some cases. In particular, the details of float-to-int and int-to-float
6466
casts are implementation dependent.
6567
"""
66-
return lax_numpy.astype(arr, dtype)
68+
return lax_numpy.astype(arr, dtype, copy=copy, device=device)
6769

6870

6971
def _nbytes(arr: ArrayLike) -> int:

jax/_src/numpy/lax_numpy.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2216,13 +2216,29 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
22162216
implementation dependent.
22172217
""")
22182218
def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True, device: xc.Device | Sharding | None = None) -> Array:
2219-
src_devices = x.devices() if hasattr(x, "devices") and not isinstance(x, core.Tracer) else None
2220-
if device is not None and src_devices != {device}:
2221-
return device_put(x, device)
2222-
arr = _array_copy(x) if copy else x
22232219
if dtype is None:
22242220
dtype = dtypes.canonicalize_dtype(float_)
22252221
dtypes.check_user_dtype_supported(dtype, "astype")
2222+
src_dtype = x.dtype if hasattr(x, "dtype") else dtypes.dtype(x)
2223+
if (
2224+
src_dtype is not None
2225+
and dtypes.isdtype(src_dtype, "complex floating")
2226+
and dtypes.isdtype(dtype, ("integral", "real floating"))
2227+
):
2228+
raise ValueError(
2229+
"Casting from complex to non-complex dtypes is not permitted. Please "
2230+
"first use jnp.real or jnp.imag to take the real/imaginary component of "
2231+
"your input."
2232+
)
2233+
src_devices = (
2234+
x.devices() if hasattr(x, "devices")
2235+
and not isinstance(x, core.Tracer) else None
2236+
)
2237+
arr = x
2238+
if device is not None and src_devices != {device}:
2239+
arr = device_put(x, device)
2240+
elif copy:
2241+
arr = _array_copy(x)
22262242
return lax.convert_element_type(arr, dtype)
22272243

22282244

tests/lax_numpy_test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3789,6 +3789,41 @@ def testAstype(self, from_dtype, to_dtype, use_method):
37893789
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
37903790
self._CompileAndCheck(jnp_op, args_maker)
37913791

3792+
@jtu.sample_product(
3793+
change_dtype=[True, False],
3794+
copy=[True, False],
3795+
change_device=[True, False],
3796+
)
3797+
def testAstypeCopy(self, change_dtype, copy, change_device):
3798+
if jax.device_count() == 1 and change_device:
3799+
raise unittest.SkipTest(
3800+
"Testing device transfer requires at least two available devices."
3801+
)
3802+
3803+
dtype = 'float32' if change_dtype else 'int32'
3804+
device = jax.devices()[-1] if change_device else None
3805+
expect_copy = change_dtype or copy or change_device
3806+
x = jnp.arange(5, dtype='int32')
3807+
y = x.astype(dtype, copy=copy, device=device)
3808+
3809+
assert y.dtype == dtype
3810+
if change_device:
3811+
assert y.devices() == {device}
3812+
else:
3813+
y.delete()
3814+
get_val = lambda: np.array(x)
3815+
err_msg = "Array has been deleted"
3816+
if expect_copy:
3817+
get_val()
3818+
else:
3819+
jtu.check_raises(get_val, RuntimeError, err_msg)
3820+
3821+
def testAstypeComplexDowncast(self):
3822+
x = jnp.array(2.0+1.5j, dtype='complex64')
3823+
complex_downcast = lambda: x.astype('float32')
3824+
err_msg = "Casting from complex to non-complex "
3825+
jtu.check_raises(complex_downcast, ValueError, err_msg)
3826+
37923827
def testAstypeInt4(self):
37933828
# Test converting from int4 to int8
37943829
x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4)

0 commit comments

Comments
 (0)