diff --git a/jax/_src/array.py b/jax/_src/array.py index 1c24142981ca..5a94549df60d 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -46,7 +46,7 @@ SingleDeviceSharding, XLACompatibleSharding, PmapSharding, device_replica_id_map, hashed_index) from jax._src.layout import DeviceLocalLayout, Layout -from jax._src.typing import ArrayLike +from jax._src.typing import ArrayLike, DLDeviceType from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method @@ -404,11 +404,25 @@ def __array__(self, dtype=None, context=None, copy=None): kwds = {} if copy is None else {'copy': copy} return np.asarray(self._value, dtype=dtype, **kwds) - def __dlpack__(self, *, stream: int | Any | None = None): - if len(self._arrays) != 1: - raise BufferError("__dlpack__ only supported for unsharded arrays.") + def __dlpack__(self, *, stream: int | Any | None = None, + max_version: tuple[int, int] | None = None, + dl_device: tuple[DLDeviceType, int] | None = None, + copy: bool | None = None): from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top - return to_dlpack(self, stream=stream) + + device_set = self.sharding.device_set + if len(device_set) > 1: + raise BufferError( + "to_dlpack can only pack a dlpack tensor from an array on a singular " + f"device, but an array with a Sharding over {len(device_set)} devices " + "was provided." + ) + device, = device_set + return to_dlpack(self, stream=stream, + max_version=max_version, + src_device=device, + dl_device=dl_device, + copy=copy) def __dlpack_device__(self) -> tuple[enum.Enum, int]: if len(self._arrays) != 1: diff --git a/jax/_src/dlpack.py b/jax/_src/dlpack.py index 72503fe18a2c..45577eaa03d3 100644 --- a/jax/_src/dlpack.py +++ b/jax/_src/dlpack.py @@ -14,19 +14,21 @@ from __future__ import annotations -import enum from typing import Any -import warnings from jax._src.api import device_put from jax import numpy as jnp from jax._src import array from jax._src import xla_bridge +from jax._src.lax.lax import _array_copy from jax._src.lib import xla_client from jax._src.lib import xla_extension_version -from jax._src.typing import Array +from jax._src.typing import Array, DLDeviceType from jax._src.sharding import Sharding +DLPACK_VERSION = (0, 8) +MIN_DLPACK_VERSION = (0, 5) + # A set of dtypes that dlpack supports. # Note: Make sure to use a "type", not a dtype instance, when looking up this set # because their hashes are different. @@ -43,45 +45,112 @@ SUPPORTED_DTYPES = SUPPORTED_DTYPES | frozenset({jnp.bool_}) -# Mirror of dlpack.h enum -class DLDeviceType(enum.IntEnum): - kDLCPU = 1 - kDLCUDA = 2 - kDLROCM = 10 +def _to_dlpack(x: Array, stream: int | Any | None, + src_device: xla_client.Device | None = None, + device: xla_client.Device | None = None, + copy: bool | None = None): + if src_device is None: + src_device, = x.devices() + if device and (src_device is None or device != src_device): + if copy is not None and not copy: + raise ValueError( + f"Specified {device=} which requires a copy since the source device " + f"is {repr(src_device)}, however copy=False. Set copy=True or " + "copy=None to perform the requested operation." + ) + else: + arr = device_put(x, device) + else: + arr = _array_copy(x) if copy else x + return xla_client._xla.buffer_to_dlpack_managed_tensor( + arr.addressable_data(0), stream=stream + ) -def to_dlpack(x: Array, take_ownership: bool = False, - stream: int | Any | None = None): +def to_dlpack(x: Array, stream: int | Any | None = None, + src_device: xla_client.Device | None = None, + dl_device: tuple[DLDeviceType, int] | None = None, + max_version: tuple[int, int] | None = None, + copy : bool | None = None): """Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``. Args: x: a :class:`~jax.Array`, on either CPU or GPU. - take_ownership: Deprecated. It is a no-op to set take_ownership. Will be - deleted in 01/2024. stream: optional platform-dependent stream to wait on until the buffer is ready. This corresponds to the `stream` argument to ``__dlpack__`` documented in https://dmlc.github.io/dlpack/latest/python_spec.html. + src_device: either a CPU or GPU :class:`~jax.Device`. + dl_device: a tuple of ``(dl_device_type, local_hardware_id)`` in DLPack + format e.g. as produced by ``__dlpack_device__``. + max_version: the maximum DLPack version that the consumer (i.e. caller of + ``__dlpack__``) supports in the form of a 2-tuple of ``(major, minor)``. + This function is not guaranteed to return a capsule of version + ``max_version``. + copy: a boolean indicating whether or not to copy the input. If + ``copy=True`` then the function must always copy. When + ``copy=False`` then the function must never copy, and must raise an error + when a copy is deemed necessary. If ``copy=None`` then the function must + avoid a copy if possible but may copy if needed. Returns: - A dlpack PyCapsule object. + A DLPack PyCapsule object. Note: - While JAX arrays are always immutable, dlpack buffers cannot be marked as - immutable, and it is possible for processes external to JAX to mutate them - in-place. If a dlpack buffer derived from a JAX array is mutated, it may - lead to undefined behavior when using the associated JAX array. + While JAX arrays are always immutable, ``DLPackManagedTensor`` buffers + cannot be marked as immutable, and it is possible for processes external + to JAX to mutate them in-place. If a DLPack buffer derived from a JAX array + is mutated, it may lead to undefined behavior when using the associated JAX + array. When JAX eventually supports ``DLManagedTensorVersioned`` + (DLPack 1.0), it will be possible to specify that a buffer is read-only. """ if not isinstance(x, array.ArrayImpl): raise TypeError("Argument to to_dlpack must be a jax.Array, " f"got {type(x)}") - assert len(x.devices()) == 1 - if take_ownership: - warnings.warn( - "take_ownership in to_dlpack is deprecated and it is a no-op." + + device = None + dl_device_type, local_hardware_id = dl_device if dl_device else (None, None) + if dl_device_type: + try: + dl_device_platform = { + DLDeviceType.kDLCPU: "cpu", + DLDeviceType.kDLCUDA: "cuda", + DLDeviceType.kDLROCM: "rocm", + }[dl_device_type] + backend = xla_bridge.get_backend(dl_device_platform) + device = backend.device_from_local_hardware_id(local_hardware_id) + except TypeError: + # https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html + # recommends using BufferError. + raise BufferError( + "The device specification passed to to_dlpack contains an unsupported " + f"device type (DLDeviceType: {dl_device_type})") + + # As new versions are adopted over time, we can maintain some legacy paths + # for compatability mediated through the max_version parameter. + # TODO(micky774): Deprecate default usage of DLPackManagedTensor when XLA + # supports DLManagedTensorVersioned (DLPack version 1.0) and repurpose the + # current _to_dlpack as a legacy path for (0,5) <= max_version < (1,0). + if max_version is None or max_version >= DLPACK_VERSION: + # Latest + return _to_dlpack( + x, stream=stream, + src_device=src_device, + device=device, + copy=copy + ) + elif max_version >= MIN_DLPACK_VERSION: + # Oldest supported + return _to_dlpack( + x, stream=stream, + src_device=src_device, + device=device, + copy=copy + ) + else: + raise BufferError( + f"JAX does not support any version below {MIN_DLPACK_VERSION} but " + f"version ({max_version}) was requested." ) - return xla_client._xla.buffer_to_dlpack_managed_tensor( - x.addressable_data(0), stream=stream - ) # type: ignore def _place_array(_arr, device, dlpack_device, copy): if device and dlpack_device != device: diff --git a/jax/_src/typing.py b/jax/_src/typing.py index 6cd466500f71..afbdedf5c936 100644 --- a/jax/_src/typing.py +++ b/jax/_src/typing.py @@ -29,6 +29,7 @@ from collections.abc import Sequence from typing import Any, Protocol, Union import numpy as np +import enum from jax._src.basearray import ( Array as Array, @@ -83,3 +84,9 @@ def shape(self) -> Shape: ... class DeprecatedArg: def __repr__(self): return "Deprecated" + +# Mirror of dlpack.h enum +class DLDeviceType(enum.IntEnum): + kDLCPU = 1 + kDLCUDA = 2 + kDLROCM = 10 diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 077ae796e3e2..65c95d9c2ea8 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -334,7 +334,7 @@ def _arg_jax_to_tf(arg_jax): if (isinstance(arg_jax, jax.Array) and list(arg_jax.devices())[0].platform in _DLPACK_PLATFORMS and arg_jax.dtype.type in dlpack.SUPPORTED_DTYPES): - arg_dlpack = jax.dlpack.to_dlpack(arg_jax, take_ownership=False) + arg_dlpack = jax.dlpack.to_dlpack(arg_jax) return tf.experimental.dlpack.from_dlpack(arg_dlpack) # The following avoids copies to the host on CPU, always for Array # and even for ndarray if they are sufficiently aligned. diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index 9935cd915530..6624ef723a3a 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -73,23 +73,48 @@ def setUp(self): @jtu.sample_product( shape=all_shapes, dtype=dlpack_dtypes, - gpu=[False, True], + copy=[False, True, None] ) - def testJaxRoundTrip(self, shape, dtype, gpu): + @jtu.run_on_devices("gpu") + def testJaxRoundTrip(self, shape, dtype, copy): + if xb.using_pjrt_c_api(): + self.skipTest("DLPack support is incomplete in the PJRT C API") # TODO(skyewm) rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) - if gpu and jtu.test_device_matches(["cpu"]): - raise unittest.SkipTest("Skipping GPU test case on CPU") - device = jax.devices("gpu" if gpu else "cpu")[0] - x = jax.device_put(np, device) - dlpack = jax.dlpack.to_dlpack(x) - y = jax.dlpack.from_dlpack(dlpack) - self.assertEqual(y.devices(), {device}) - self.assertAllClose(np.astype(x.dtype), y) + def _check_copy(x: jax.Array, y: jax.Array, expect_copy): + copied = x.unsafe_buffer_pointer() != y.unsafe_buffer_pointer() + assert copied == expect_copy, f"Expected {'a' if expect_copy else 'no'} copy" + + # Check if the source device is preserved + x = jax.device_put(np, jax.devices("cpu")[0]) + device = jax.devices("gpu")[0] + y = jax.device_put(x, device) + dl_device = y.__dlpack_device__() + dlpack = jax.dlpack.to_dlpack(y, copy=copy) + z = jax.dlpack.from_dlpack(dlpack) + + self.assertEqual(z.devices(), {device}) + self.assertAllClose(np.astype(x.dtype), z) self.assertRaisesRegex(RuntimeError, - "DLPack tensor may be consumed at most once", - lambda: jax.dlpack.from_dlpack(dlpack)) + "DLPack tensor may be consumed at most once", + lambda: jax.dlpack.from_dlpack(dlpack)) + + if shape in nonempty_array_shapes: + _check_copy(y, z, bool(copy)) + + # Check if the destination device can be specified + make_dlpack = lambda: x.__dlpack__(dl_device=dl_device, copy=copy) + if copy == False: + self.assertRaisesRegex(ValueError, "copy=False", make_dlpack) + return + + z = jax.dlpack.from_dlpack(make_dlpack()) + self.assertEqual(z.devices(), {device}) + self.assertAllClose(x, z) + + if shape in nonempty_array_shapes: + _check_copy(x, z, True) @jtu.sample_product( shape=all_shapes,