Skip to content

Commit 1b3aea8

Browse files
Jake VanderPlasjax authors
authored andcommitted
Finalize the deprecation of the arr.device() method
The method has been emitting an DeprecationWarning since JAX v0.4.21, released December 2023. Existing uses can be replaced with `arr.devices()` or `arr.sharding`, depending on the context. PiperOrigin-RevId: 623015500
1 parent af3dcd2 commit 1b3aea8

File tree

6 files changed

+6
-31
lines changed

6 files changed

+6
-31
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ Remember to align the itemized text with the first line of an item within a list
2121
* {func}`jax.numpy.clip` has a new argument signature: `a`, `a_min`, and
2222
`a_max` are deprecated in favor of `x` (positonal only), `min`, and
2323
`max` ({jax-issue}`20550`).
24+
* The `device()` method of JAX arrays has been removed, after being deprecated
25+
since JAX v0.4.21. Use `arr.devices()` instead.
2426

2527

2628
## jaxlib 0.4.27

jax/_src/array.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from jax._src import basearray
3131
from jax._src import config
3232
from jax._src import core
33-
from jax._src import deprecations
3433
from jax._src import dispatch
3534
from jax._src import dtypes
3635
from jax._src import errors
@@ -50,7 +49,6 @@
5049
from jax._src.typing import ArrayLike
5150
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method
5251

53-
deprecations.register(__name__, "device-method")
5452

5553
Shape = tuple[int, ...]
5654
Device = xc.Device
@@ -471,21 +469,6 @@ def on_device_size_in_bytes(self):
471469
per_shard_size = arr.on_device_size_in_bytes() # type: ignore
472470
return per_shard_size * len(self.sharding.device_set)
473471

474-
# TODO(yashkatariya): Remove this method when everyone is using devices().
475-
def device(self) -> Device:
476-
if deprecations.is_accelerated(__name__, "device-method"):
477-
raise NotImplementedError("arr.device() is deprecated. Use arr.devices() instead.")
478-
else:
479-
warnings.warn("arr.device() is deprecated. Use arr.devices() instead.",
480-
DeprecationWarning, stacklevel=2)
481-
self._check_if_deleted()
482-
device_set = self.sharding.device_set
483-
if len(device_set) == 1:
484-
single_device, = device_set
485-
return single_device
486-
raise ValueError('Length of devices is greater than 1. '
487-
'Please use `.devices()`.')
488-
489472
def devices(self) -> set[Device]:
490473
self._check_if_deleted()
491474
return self.sharding.device_set

jax/_src/basearray.pyi

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,6 @@ class Array(abc.ABC):
196196
def block_until_ready(self) -> Array: ...
197197
def copy_to_host_async(self) -> None: ...
198198
def delete(self) -> None: ...
199-
def device(self) -> Device: ...
200199
def devices(self) -> set[Device]: ...
201200
@property
202201
def sharding(self) -> Sharding: ...

jax/_src/core.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -849,11 +849,6 @@ def delete(self):
849849
f"The delete() method was called on {self._error_repr()}."
850850
f"{self._origin_msg()}")
851851

852-
def device(self):
853-
raise ConcretizationTypeError(self,
854-
f"The device() method was called on {self._error_repr()}."
855-
f"{self._origin_msg()}")
856-
857852
def devices(self):
858853
raise ConcretizationTypeError(self,
859854
f"The devices() method was called on {self._error_repr()}."

jax/experimental/array_api/_array_methods.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import jax
2020
from jax._src.array import ArrayImpl
2121
from jax.experimental.array_api._version import __array_api_version__
22+
from jax.sharding import Sharding
2223

2324
from jax._src.lib import xla_extension as xe
2425

@@ -30,16 +31,15 @@ def _array_namespace(self, /, *, api_version: None | str = None):
3031
return jax.experimental.array_api
3132

3233

33-
def _to_device(self, device: xe.Device | Callable[[], xe.Device], /, *,
34+
def _to_device(self, device: xe.Device | Sharding | None, *,
3435
stream: int | Any | None = None):
3536
if stream is not None:
3637
raise NotImplementedError("stream argument of array.to_device()")
37-
# The type of device is defined by Array.device. In JAX, this is a callable that
38-
# returns a device, so we must handle this case to satisfy the API spec.
39-
return jax.device_put(self, device() if callable(device) else device)
38+
return jax.device_put(self, device)
4039

4140

4241
def add_array_object_methods():
4342
# TODO(jakevdp): set on tracers as well?
4443
setattr(ArrayImpl, "__array_namespace__", _array_namespace)
4544
setattr(ArrayImpl, "to_device", _to_device)
45+
setattr(ArrayImpl, "device", property(lambda self: self.sharding))

tests/random_test.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from jax import random
3434
from jax._src import config
3535
from jax._src import core
36-
from jax._src import deprecations
3736
from jax._src import dtypes
3837
from jax._src import test_util as jtu
3938
from jax import vmap
@@ -1019,9 +1018,6 @@ def test_array_impl_attributes(self):
10191018

10201019
self.assertEqual(key.is_fully_addressable, key._base_array.is_fully_addressable)
10211020
self.assertEqual(key.is_fully_replicated, key._base_array.is_fully_replicated)
1022-
if not deprecations.is_accelerated('jax._src.array', 'device-method'):
1023-
with jtu.ignore_warning(category=DeprecationWarning, message="arr.device"):
1024-
self.assertEqual(key.device(), key._base_array.device())
10251021
self.assertEqual(key.devices(), key._base_array.devices())
10261022
self.assertEqual(key.on_device_size_in_bytes(),
10271023
key._base_array.on_device_size_in_bytes())

0 commit comments

Comments
 (0)