|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
| 14 | +from __future__ import annotations |
14 | 15 |
|
15 | 16 |
|
| 17 | +import builtins |
16 | 18 | import functools
|
17 |
| -from typing import NamedTuple |
| 19 | +from typing import Any, NamedTuple |
18 | 20 | import jax
|
19 | 21 | import jax.numpy as jnp
|
| 22 | +from jax._src.lib import xla_client as xc |
| 23 | +from jax._src.sharding import Sharding |
| 24 | +from jax._src.api import device_put |
20 | 25 |
|
21 | 26 |
|
22 | 27 | from jax.experimental.array_api._dtypes import (
|
@@ -124,9 +129,24 @@ def _promote_types(t1, t2):
|
124 | 129 | raise ValueError("No promotion path for {t1} & {t2}")
|
125 | 130 |
|
126 | 131 |
|
127 |
| -def astype(x, dtype, /, *, copy=True): |
128 |
| - return jnp.array(x, dtype=dtype, copy=copy) |
129 |
| - |
| 132 | +def astype(x, dtype, /, *, copy: builtins.bool | None = True, device: xc.Device | Sharding | None = None): |
| 133 | + arr = jnp.array(x, dtype=dtype) |
| 134 | + src_device = arr.devices().pop() |
| 135 | + # TODO(micky774): refactor into a common utility with _place_array in gh-20175 |
| 136 | + if device is not None: |
| 137 | + if copy is not None and not copy: |
| 138 | + raise ValueError( |
| 139 | + f"Specified {device=} which requires a copy since the source device " |
| 140 | + f"is {repr(src_device)}, however copy=False. Set copy=True or " |
| 141 | + "copy=None to perform the requested operation." |
| 142 | + ) |
| 143 | + else: |
| 144 | + return device_put(arr, device) |
| 145 | + if copy: |
| 146 | + # TODO(micky774): Remove if clause and replace with jnp.array(arr, copy=copy) |
| 147 | + # when we support Numpy 2.0 copy semantics |
| 148 | + return jnp.array(arr, copy=True) |
| 149 | + return arr |
130 | 150 |
|
131 | 151 | def can_cast(from_, to, /):
|
132 | 152 | if isinstance(from_, jax.Array):
|
|
0 commit comments