Skip to content

Commit 4a85f4a

Browse files
committed
Add support for max_version, dl_device, copy kwargs in __dlpack__
1 parent d7e5dde commit 4a85f4a

File tree

3 files changed

+165
-27
lines changed

3 files changed

+165
-27
lines changed

jax/_src/array.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -404,11 +404,25 @@ def __array__(self, dtype=None, context=None, copy=None):
404404
kwds = {} if copy is None else {'copy': copy}
405405
return np.asarray(self._value, dtype=dtype, **kwds)
406406

407-
def __dlpack__(self, *, stream: int | Any | None = None):
408-
if len(self._arrays) != 1:
409-
raise BufferError("__dlpack__ only supported for unsharded arrays.")
407+
def __dlpack__(self, *, stream: int | Any | None = None,
408+
max_version: tuple[int, int] | None = None,
409+
dl_device: tuple[enum.Enum, int] | None = None,
410+
copy: bool | None = None):
410411
from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top
411-
return to_dlpack(self, stream=stream)
412+
413+
device_set = self.sharding.device_set
414+
if len(device_set) > 1:
415+
raise BufferError(
416+
"to_dlpack can only pack a dlpack tensor from an array on a singular "
417+
f"device, but an array with a Sharding over {len(device_set)} devices "
418+
"was provided."
419+
)
420+
device, = device_set
421+
return to_dlpack(self, stream=stream,
422+
max_version=max_version,
423+
src_device=device,
424+
dl_device=dl_device, # type: ignore
425+
copy=copy)
412426

413427
def __dlpack_device__(self) -> tuple[enum.Enum, int]:
414428
if len(self._arrays) != 1:

jax/_src/dlpack.py

Lines changed: 112 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,14 @@
2121
from jax import numpy as jnp
2222
from jax._src import array
2323
from jax._src import xla_bridge
24+
from jax._src.lax.lax import _array_copy
2425
from jax._src.lib import xla_client
2526
from jax._src.lib import xla_extension_version
2627
from jax._src.typing import Array
28+
from jax._src.api import device_put
2729

30+
DLPACK_VERSION = (0, 8)
31+
MIN_DLPACK_VERSION = (0, 5)
2832

2933
# A set of dtypes that dlpack supports.
3034
# Note: Make sure to use a "type", not a dtype instance, when looking up this set
@@ -48,9 +52,34 @@ class DLDeviceType(enum.IntEnum):
4852
kDLCUDA = 2
4953
kDLROCM = 10
5054

55+
def _to_dlpack(x: Array, stream: int | Any | None,
56+
src_device: xla_client.Device | None = None,
57+
device: xla_client.Device | None = None,
58+
copy: bool | None = None):
59+
60+
if src_device is None:
61+
src_device, = x.devices()
62+
if device and (src_device is None or device != src_device):
63+
if copy is not None and not copy:
64+
raise ValueError(
65+
f"Specified {device=} which requires a copy since the source device "
66+
f"is {repr(src_device)}, however copy=False. Set copy=True or "
67+
"copy=None to perform the requested operation."
68+
)
69+
else:
70+
arr = device_put(x, device)
71+
else:
72+
arr = _array_copy(x) if copy else x
73+
return xla_client._xla.buffer_to_dlpack_managed_tensor(
74+
arr.addressable_data(0), stream=stream
75+
)
5176

5277
def to_dlpack(x: Array, take_ownership: bool = False,
53-
stream: int | Any | None = None):
78+
stream: int | Any | None = None,
79+
src_device: xla_client.Device | None = None,
80+
dl_device: tuple[DLDeviceType, int] | None = None,
81+
max_version: tuple[int, int] | None = None,
82+
copy : bool | None = None):
5483
"""Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.
5584
5685
Args:
@@ -60,27 +89,97 @@ def to_dlpack(x: Array, take_ownership: bool = False,
6089
stream: optional platform-dependent stream to wait on until the buffer is
6190
ready. This corresponds to the `stream` argument to ``__dlpack__``
6291
documented in https://dmlc.github.io/dlpack/latest/python_spec.html.
92+
src_device: either a CPU or GPU :class:`~jax.Device`.
93+
dl_device: a tuple of ``(dl_device_type, local_hardware_id)`` in DLPack
94+
format e.g. as produced by ``__dlpack_device__``.
95+
max_version: the maximum DLPack version that the consumer (i.e. caller of
96+
``__dlpack__``) supports in the form of a 2-tuple of ``(major, minor)``.
97+
This function is not guaranteed to return a capsule of version
98+
``max_version``.
99+
copy: a boolean indicating whether or not to copy the input. If
100+
``copy=True`` then the function must always copy. When
101+
``copy=False`` then the function must never copy, and must raise an error
102+
when a copy is deemed necessary. If ``copy=None`` then the function must
103+
avoid a copy if possible but may copy if needed.
63104
64105
Returns:
65-
A dlpack PyCapsule object.
106+
A DLPack PyCapsule object.
66107
67108
Note:
68-
While JAX arrays are always immutable, dlpack buffers cannot be marked as
69-
immutable, and it is possible for processes external to JAX to mutate them
70-
in-place. If a dlpack buffer derived from a JAX array is mutated, it may
71-
lead to undefined behavior when using the associated JAX array.
109+
While JAX arrays are always immutable, ``DLPackManagedTensor`` buffers
110+
cannot be marked as immutable, and it is possible for processes external
111+
to JAX to mutate them in-place. If a DLPack buffer derived from a JAX array
112+
is mutated, it may lead to undefined behavior when using the associated JAX
113+
array. When JAX eventually supports ``DLManagedTensorVersioned``
114+
(DLPack 1.0), it will be possible to specify that a buffer is read-only.
72115
"""
73116
if not isinstance(x, array.ArrayImpl):
74117
raise TypeError("Argument to to_dlpack must be a jax.Array, "
75118
f"got {type(x)}")
76-
assert len(x.devices()) == 1
77119
if take_ownership:
78120
warnings.warn(
79121
"take_ownership in to_dlpack is deprecated and it is a no-op."
80122
)
81-
return xla_client._xla.buffer_to_dlpack_managed_tensor(
82-
x.addressable_data(0), stream=stream
83-
) # type: ignore
123+
124+
device = None
125+
dl_device_type, local_hardware_id = dl_device if dl_device else (None, None)
126+
if dl_device_type:
127+
try:
128+
dl_device_platform = {
129+
DLDeviceType.kDLCPU: "cpu",
130+
DLDeviceType.kDLCUDA: "cuda",
131+
DLDeviceType.kDLROCM: "rocm",
132+
}[dl_device_type]
133+
backend = xla_bridge.get_backend(dl_device_platform)
134+
device = backend.device_from_local_hardware_id(local_hardware_id)
135+
except TypeError:
136+
# https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
137+
# recommends using BufferError.
138+
raise BufferError(
139+
"The device specification passed to to_dlpack contains an unsupported "
140+
f"device type (DLDeviceType: {dl_device_type})")
141+
142+
# TODO(micky774): Deprecate default usage of DLPackManagedTensor when XLA
143+
# supports DLManagedTensorVersioned (DLPack version 1.0) and repurpose the
144+
# current _to_dlpack as a legacy path for (0,5) <= max_version < (1,0)
145+
if max_version is None:
146+
# Backwards compatible default
147+
return _to_dlpack(
148+
x, stream=stream,
149+
src_device=src_device,
150+
device=device,
151+
copy=copy
152+
)
153+
else:
154+
if max_version >= DLPACK_VERSION:
155+
# Latest
156+
return _to_dlpack(
157+
x, stream=stream,
158+
src_device=src_device,
159+
device=device,
160+
copy=copy
161+
)
162+
if max_version[0] == DLPACK_VERSION[0]:
163+
# ABI compatible
164+
return _to_dlpack(
165+
x, stream=stream,
166+
src_device=src_device,
167+
device=device,
168+
copy=copy
169+
)
170+
elif max_version >= MIN_DLPACK_VERSION:
171+
# Oldest supported
172+
return _to_dlpack(
173+
x, stream=stream,
174+
src_device=src_device,
175+
device=device,
176+
copy=copy
177+
)
178+
else:
179+
raise BufferError(
180+
f"JAX does not support any version below {MIN_DLPACK_VERSION} but "
181+
f"version ({max_version}) was requested."
182+
)
84183

85184

86185
def from_dlpack(external_array):
@@ -110,12 +209,12 @@ def from_dlpack(external_array):
110209
DLDeviceType.kDLCUDA: "cuda",
111210
DLDeviceType.kDLROCM: "rocm",
112211
}[dl_device_type]
113-
except TypeError:
212+
except TypeError as err:
114213
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
115214
# TypeError.
116-
raise TypeError(
215+
raise BufferError(
117216
"Array passed to from_dlpack is on unsupported device type "
118-
f"(DLDeviceType: {dl_device_type}, array: {external_array}")
217+
f"(DLDeviceType: {dl_device_type}, array: {external_array}") from err
119218

120219
backend = xla_bridge.get_backend(device_platform)
121220
device = backend.device_from_local_hardware_id(device_id)

tests/array_interoperability_test.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,24 +74,49 @@ def setUp(self):
7474
shape=all_shapes,
7575
dtype=dlpack_dtypes,
7676
gpu=[False, True],
77+
copy=[False, True, None]
7778
)
78-
def testJaxRoundTrip(self, shape, dtype, gpu):
79+
def testJaxRoundTrip(self, shape, dtype, gpu, copy):
7980
if xb.using_pjrt_c_api():
8081
self.skipTest("DLPack support is incomplete in the PJRT C API") # TODO(skyewm)
8182
rng = jtu.rand_default(self.rng())
8283
np = rng(shape, dtype)
83-
if gpu and jtu.test_device_matches(["cpu"]):
84-
raise unittest.SkipTest("Skipping GPU test case on CPU")
84+
if gpu and not jtu.test_device_matches(["gpu"]):
85+
raise unittest.SkipTest("Skipping GPU test case on CPU/TPU")
86+
87+
def _check_copy(x, y, expect_copy):
88+
is_copy = x.unsafe_buffer_pointer() != y.unsafe_buffer_pointer()
89+
assert is_copy == expect_copy
90+
91+
# Check if the source device is preserved
92+
x = jax.device_put(np, jax.devices("cpu")[0])
8593
device = jax.devices("gpu" if gpu else "cpu")[0]
86-
x = jax.device_put(np, device)
87-
dlpack = jax.dlpack.to_dlpack(x)
88-
y = jax.dlpack.from_dlpack(dlpack)
89-
self.assertEqual(y.devices(), {device})
90-
self.assertAllClose(np.astype(x.dtype), y)
94+
y = jax.device_put(x, device)
95+
dlpack = jax.dlpack.to_dlpack(y, copy=copy)
96+
z = jax.dlpack.from_dlpack(dlpack)
9197

98+
self.assertEqual(z.devices(), {device})
99+
self.assertAllClose(np.astype(x.dtype), z)
92100
self.assertRaisesRegex(RuntimeError,
93-
"DLPack tensor may be consumed at most once",
94-
lambda: jax.dlpack.from_dlpack(dlpack))
101+
"DLPack tensor may be consumed at most once",
102+
lambda: jax.dlpack.from_dlpack(dlpack))
103+
104+
if shape in nonempty_array_shapes:
105+
_check_copy(y, z, bool(copy))
106+
107+
# Check if the destination device can be specified
108+
dl_device = y.__dlpack_device__()
109+
make_dlpack = lambda: x.__dlpack__(dl_device=dl_device, copy=copy)
110+
if gpu and copy == False:
111+
self.assertRaisesRegex(ValueError, "copy=False", make_dlpack)
112+
return
113+
114+
z = jax.dlpack.from_dlpack(make_dlpack())
115+
self.assertEqual(z.devices(), {device})
116+
self.assertAllClose(np.astype(x.dtype), z)
117+
118+
if shape in nonempty_array_shapes:
119+
_check_copy(x, z, bool(copy) or gpu)
95120

96121
@jtu.sample_product(
97122
shape=all_shapes,

0 commit comments

Comments
 (0)