Skip to content

Commit 281f5ac

Browse files
committed
Add support for device kwarg in astype to match Array API
1 parent 604f5ec commit 281f5ac

File tree

3 files changed

+16
-10
lines changed

3 files changed

+16
-10
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
import opt_einsum
4242

4343
import jax
44-
from jax import jit
44+
from jax import jit, device_put
4545
from jax import errors
4646
from jax import lax
4747
from jax.sharding import Sharding, SingleDeviceSharding
@@ -2209,19 +2209,22 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
22092209
else:
22102210
return x
22112211

2212-
22132212
@util.implements(getattr(np, "astype", None), lax_description="""
22142213
This is implemented via :func:`jax.lax.convert_element_type`, which may
22152214
have slightly different behavior than :func:`numpy.astype` in some cases.
22162215
In particular, the details of float-to-int and int-to-float casts are
22172216
implementation dependent.
22182217
""")
2219-
def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array:
2220-
del copy # unused in JAX
2218+
def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True, device: xc.Device | Sharding | None = None) -> Array:
2219+
src_devices = x.devices() if hasattr(x, "devices") and not isinstance(x, core.Tracer) else None
2220+
if device is not None and src_devices != {device}:
2221+
return device_put(x, device)
2222+
arr = _array_copy(x) if copy else x
22212223
if dtype is None:
22222224
dtype = dtypes.canonicalize_dtype(float_)
22232225
dtypes.check_user_dtype_supported(dtype, "astype")
2224-
return lax.convert_element_type(x, dtype)
2226+
return lax.convert_element_type(arr, dtype)
2227+
22252228

22262229

22272230
@util.implements(np.asarray, lax_description=_ARRAY_DOC)

jax/experimental/array_api/_data_type_functions.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
# limitations under the License.
1414

1515

16+
import builtins
1617
import functools
1718
from typing import NamedTuple
1819
import jax
20+
from jax._src.sharding import Sharding
1921
import jax.numpy as jnp
20-
21-
22+
from jax._src.lib import xla_client as xc
2223
from jax.experimental.array_api._dtypes import (
2324
bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64,
2425
float32, float64, complex64, complex128
@@ -124,8 +125,8 @@ def _promote_types(t1, t2):
124125
raise ValueError("No promotion path for {t1} & {t2}")
125126

126127

127-
def astype(x, dtype, /, *, copy=True):
128-
return jnp.array(x, dtype=dtype, copy=copy)
128+
def astype(x, dtype, /, *, copy: builtins.bool = True, device: xc.Device | Sharding | None = None):
129+
return jnp.astype(x, dtype, copy=copy, device=device)
129130

130131

131132
def can_cast(from_, to, /):

jax/numpy/__init__.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ from jax._src import dtypes as _dtypes
99
from jax._src.lax.lax import PrecisionLike
1010
from jax._src.lax.slicing import GatherScatterMode
1111
from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClass as _RClass
12+
from jax._src.sharding import Sharding
13+
from jax._src.lib import xla_client as xc
1214
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DimSize, DuckTypedArray, Shape
1315
from jax.numpy import fft as fft, linalg as linalg
1416
from jax.sharding import Sharding as _Sharding
@@ -112,7 +114,7 @@ def asarray(
112114
) -> Array: ...
113115
def asin(x: ArrayLike, /) -> Array: ...
114116
def asinh(x: ArrayLike, /) -> Array: ...
115-
def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: builtins.bool = ...) -> Array: ...
117+
def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: builtins.bool = ..., device: xc.Device | Sharding | None = ...) -> Array: ...
116118
def atan(x: ArrayLike, /) -> Array: ...
117119
def atan2(x: ArrayLike, y: ArrayLike, /) -> Array: ...
118120
def atanh(x: ArrayLike, /) -> Array: ...

0 commit comments

Comments
 (0)