Skip to content

Commit dfd35ab

Browse files
committed
Update
1 parent d0eae05 commit dfd35ab

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

jax/experimental/array_api/_data_type_functions.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,17 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
1415

1516

17+
import builtins
1618
import functools
17-
from typing import NamedTuple
19+
from typing import Any, NamedTuple
1820
import jax
1921
import jax.numpy as jnp
22+
from jax._src.lib import xla_client as xc
23+
from jax._src.sharding import Sharding
24+
from jax._src.api import device_put
2025

2126

2227
from jax.experimental.array_api._dtypes import (
@@ -124,9 +129,22 @@ def _promote_types(t1, t2):
124129
raise ValueError("No promotion path for {t1} & {t2}")
125130

126131

127-
def astype(x, dtype, /, *, copy=True):
128-
return jnp.array(x, dtype=dtype, copy=copy)
129-
132+
def astype(x, dtype, /, *, copy: builtins.bool | None = True, device: xc.Device | Sharding | Any | None = None):
133+
arr = jnp.array(x, dtype=dtype)
134+
src_device = arr.devices().pop()
135+
# TODO(micky774): refactor into a common utility with _place_array in gh-20175
136+
if device:
137+
if copy is not None and not copy:
138+
raise ValueError(
139+
f"Specified {device=} which requires a copy since the source device "
140+
f"is {repr(src_device)}, however copy=False. Set copy=True or "
141+
"copy=None to perform the requested operation."
142+
)
143+
else:
144+
return device_put(arr, device)
145+
if copy:
146+
return jnp.array(arr, copy=True)
147+
return arr
130148

131149
def can_cast(from_, to, /):
132150
if isinstance(from_, jax.Array):

0 commit comments

Comments
 (0)