Skip to content

Commit 7837e90

Browse files
committed
Add support for copy kwarg in astype to match Array API
1 parent c7517b8 commit 7837e90

File tree

6 files changed

+94
-11
lines changed

6 files changed

+94
-11
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,13 @@ Remember to align the itemized text with the first line of an item within a list
4949
* Scalar arguments to {func}`jax.numpy.nonzero`, {func}`jax.numpy.where`, and
5050
related functions now raise an error, following a similar change in NumPy.
5151

52+
* Bug fixes
53+
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.
54+
Previously, no copy would be made when the output array would have the same
55+
dtype as the input array. This may result in some increased memory usage.
56+
To prevent copying when possible, set `copy=None`. To error when a copy is
57+
required, set `copy=False`.
58+
5259
## jaxlib 0.4.27
5360

5461
## jax 0.4.26 (April 3, 2024)

jax/_src/numpy/array_methods.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,13 @@
3131
import numpy as np
3232
import jax
3333
from jax import lax
34+
from jax.sharding import Sharding
3435
from jax._src import core
3536
from jax._src import dtypes
3637
from jax._src.api_util import _ensure_index_tuple
3738
from jax._src.array import ArrayImpl
3839
from jax._src.lax import lax as lax_internal
40+
from jax._src.lib import xla_client as xc
3941
from jax._src.numpy import lax_numpy
4042
from jax._src.numpy import reductions
4143
from jax._src.numpy import ufuncs
@@ -55,15 +57,15 @@
5557
# functions, which can themselves handle instances from any of these classes.
5658

5759

58-
def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array:
60+
def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = True, device: xc.Device | Sharding | None = None) -> Array:
5961
"""Copy the array and cast to a specified dtype.
6062
6163
This is implemented via :func:`jax.lax.convert_element_type`, which may
6264
have slightly different behavior than :meth:`numpy.ndarray.astype` in
6365
some cases. In particular, the details of float-to-int and int-to-float
6466
casts are implementation dependent.
6567
"""
66-
return lax_numpy.astype(arr, dtype)
68+
return lax_numpy.astype(arr, dtype, copy=copy, device=device)
6769

6870

6971
def _nbytes(arr: ArrayLike) -> int:

jax/_src/numpy/lax_numpy.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2272,17 +2272,53 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
22722272
In particular, the details of float-to-int and int-to-float casts are
22732273
implementation dependent.
22742274
""")
2275-
def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array:
2275+
def astype(x: ArrayLike, dtype: DTypeLike | None,
2276+
/, *, copy: bool | DeprecatedArg = DeprecatedArg(),
2277+
device: xc.Device | Sharding | None = None) -> Array:
22762278
util.check_arraylike("astype", x)
22772279
x_arr = asarray(x)
2278-
del copy # unused in JAX
2280+
2281+
# TODO(micky774): Deprecated 2024-4-19, remove after deprecation completed.
2282+
if isinstance(copy, DeprecatedArg):
2283+
warnings.warn(
2284+
"The copy keyword of astype was previously ignored but is now "
2285+
"implemented. The default of copy=True will lead to the expected "
2286+
"behavior of creating a copy, which may potentially lead to extra "
2287+
"memory usage. To preserve previous behavior, use copy=False. To "
2288+
"suppress this warning, please explicitly set copy.",
2289+
DeprecationWarning, stacklevel=2)
2290+
copy = False
2291+
22792292
if dtype is None:
22802293
dtype = dtypes.canonicalize_dtype(float_)
22812294
dtypes.check_user_dtype_supported(dtype, "astype")
2282-
# convert_element_type(complex, bool) has the wrong semantics.
2283-
if np.dtype(dtype) == bool and issubdtype(x_arr.dtype, complexfloating):
2284-
return (x_arr != _lax_const(x_arr, 0))
2285-
return lax.convert_element_type(x_arr, dtype)
2295+
if issubdtype(x_arr.dtype, complexfloating):
2296+
if dtypes.isdtype(dtype, ("integral", "real floating")):
2297+
warnings.warn(
2298+
"Casting from complex to real dtypes will soon raise a ValueError. "
2299+
"Please first use jnp.real or jnp.imag to take the real/imaginary "
2300+
"component of your input.",
2301+
DeprecationWarning, stacklevel=2
2302+
)
2303+
elif np.dtype(dtype) == bool:
2304+
# convert_element_type(complex, bool) has the wrong semantics.
2305+
x_arr = (x_arr != _lax_const(x_arr, 0))
2306+
2307+
# We offer a more specific warning than the usual ComplexWarning so we prefer
2308+
# to issue our warning.
2309+
with warnings.catch_warnings():
2310+
warnings.simplefilter("ignore", ComplexWarning)
2311+
return _place_array(
2312+
lax.convert_element_type(x_arr, dtype),
2313+
device=device, copy=copy,
2314+
)
2315+
2316+
def _place_array(x, device=None, copy=None):
2317+
# TODO(micky774): Implement in future PRs as we formalize device placement
2318+
# semantics
2319+
if copy:
2320+
return _array_copy(x)
2321+
return x
22862322

22872323

22882324
@util.implements(np.asarray, lax_description=_ARRAY_DOC)

jax/experimental/array_api/_data_type_functions.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import builtins
1618
import functools
1719
from typing import NamedTuple
1820
import jax
1921
import jax.numpy as jnp
2022

2123

24+
from jax._src.lib import xla_client as xc
25+
from jax._src.sharding import Sharding
26+
from jax._src import dtypes as _dtypes
2227
from jax.experimental.array_api._dtypes import (
2328
bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64,
2429
float32, float64, complex64, complex128
@@ -124,8 +129,19 @@ def _promote_types(t1, t2):
124129
raise ValueError("No promotion path for {t1} & {t2}")
125130

126131

127-
def astype(x, dtype, /, *, copy=True):
128-
return jnp.array(x, dtype=dtype, copy=copy)
132+
def astype(x, dtype, /, *, copy: builtins.bool = True, device: xc.Device | Sharding | None = None):
133+
src_dtype = x.dtype if hasattr(x, "dtype") else _dtypes.dtype(x)
134+
if (
135+
src_dtype is not None
136+
and _dtypes.isdtype(src_dtype, "complex floating")
137+
and _dtypes.isdtype(dtype, ("integral", "real floating"))
138+
):
139+
raise ValueError(
140+
"Casting from complex to non-complex dtypes is not permitted. Please "
141+
"first use jnp.real or jnp.imag to take the real/imaginary component of "
142+
"your input."
143+
)
144+
return jnp.astype(x, dtype, copy=copy, device=device)
129145

130146

131147
def can_cast(from_, to, /):

jax/numpy/__init__.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ from jax._src.typing import (
1313
Array, ArrayLike, DType, DTypeLike,
1414
DimSize, DuckTypedArray, Shape, DeprecatedArg
1515
)
16+
from jax._src.sharding import Sharding
17+
from jax._src.lib import xla_client as xc
1618
from jax.numpy import fft as fft, linalg as linalg
1719
from jax.sharding import Sharding as _Sharding
1820
import numpy as _np
@@ -115,7 +117,7 @@ def asarray(
115117
) -> Array: ...
116118
def asin(x: ArrayLike, /) -> Array: ...
117119
def asinh(x: ArrayLike, /) -> Array: ...
118-
def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: builtins.bool = ...) -> Array: ...
120+
def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: builtins.bool = ..., device: xc.Device | Sharding | None = ...) -> Array: ...
119121
def atan(x: ArrayLike, /) -> Array: ...
120122
def atan2(x: ArrayLike, y: ArrayLike, /) -> Array: ...
121123
def atanh(x: ArrayLike, /) -> Array: ...

tests/lax_numpy_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3870,6 +3870,26 @@ def testAstypeBool(self, from_dtype, use_method, to_dtype='bool'):
38703870
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
38713871
self._CompileAndCheck(jnp_op, args_maker)
38723872

3873+
@jtu.sample_product(
3874+
change_dtype=[True, False],
3875+
copy=[True, False],
3876+
)
3877+
def testAstypeCopy(self, change_dtype, copy):
3878+
dtype = 'float32' if change_dtype else 'int32'
3879+
expect_copy = change_dtype or copy
3880+
x = jnp.arange(5, dtype='int32')
3881+
y = x.astype(dtype, copy=copy)
3882+
3883+
assert y.dtype == dtype
3884+
y.delete()
3885+
assert x.is_deleted() != expect_copy
3886+
3887+
def testAstypeComplexDowncast(self):
3888+
x = jnp.array(2.0+1.5j, dtype='complex64')
3889+
msg = "Casting from complex to non-complex dtypes will soon raise "
3890+
with self.assertWarns(DeprecationWarning, msg=msg):
3891+
x.astype('float32')
3892+
38733893
def testAstypeInt4(self):
38743894
# Test converting from int4 to int8
38753895
x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4)

0 commit comments

Comments
 (0)