Skip to content

Commit a2feff2

Browse files
committed
Add support for max_version, dl_device, copy kwargs in __dlpack__
1 parent 477a44f commit a2feff2

File tree

5 files changed

+157
-42
lines changed

5 files changed

+157
-42
lines changed

jax/_src/array.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
SingleDeviceSharding, XLACompatibleSharding, PmapSharding,
4747
device_replica_id_map, hashed_index)
4848
from jax._src.layout import DeviceLocalLayout, Layout
49-
from jax._src.typing import ArrayLike
49+
from jax._src.typing import ArrayLike, DLDeviceType
5050
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method
5151

5252

@@ -404,11 +404,25 @@ def __array__(self, dtype=None, context=None, copy=None):
404404
kwds = {} if copy is None else {'copy': copy}
405405
return np.asarray(self._value, dtype=dtype, **kwds)
406406

407-
def __dlpack__(self, *, stream: int | Any | None = None):
408-
if len(self._arrays) != 1:
409-
raise BufferError("__dlpack__ only supported for unsharded arrays.")
407+
def __dlpack__(self, *, stream: int | Any | None = None,
408+
max_version: tuple[int, int] | None = None,
409+
dl_device: tuple[DLDeviceType, int] | None = None,
410+
copy: bool | None = None):
410411
from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top
411-
return to_dlpack(self, stream=stream)
412+
413+
device_set = self.sharding.device_set
414+
if len(device_set) > 1:
415+
raise BufferError(
416+
"to_dlpack can only pack a dlpack tensor from an array on a singular "
417+
f"device, but an array with a Sharding over {len(device_set)} devices "
418+
"was provided."
419+
)
420+
device, = device_set
421+
return to_dlpack(self, stream=stream,
422+
max_version=max_version,
423+
src_device=device,
424+
dl_device=dl_device,
425+
copy=copy)
412426

413427
def __dlpack_device__(self) -> tuple[enum.Enum, int]:
414428
if len(self._arrays) != 1:

jax/_src/dlpack.py

Lines changed: 93 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,21 @@
1414

1515
from __future__ import annotations
1616

17-
import enum
1817
from typing import Any
19-
import warnings
2018

2119
from jax._src.api import device_put
2220
from jax import numpy as jnp
2321
from jax._src import array
2422
from jax._src import xla_bridge
23+
from jax._src.lax.lax import _array_copy
2524
from jax._src.lib import xla_client
2625
from jax._src.lib import xla_extension_version
27-
from jax._src.typing import Array
26+
from jax._src.typing import Array, DLDeviceType
2827
from jax._src.sharding import Sharding
2928

29+
DLPACK_VERSION = (0, 8)
30+
MIN_DLPACK_VERSION = (0, 5)
31+
3032
# A set of dtypes that dlpack supports.
3133
# Note: Make sure to use a "type", not a dtype instance, when looking up this set
3234
# because their hashes are different.
@@ -43,45 +45,112 @@
4345
SUPPORTED_DTYPES = SUPPORTED_DTYPES | frozenset({jnp.bool_})
4446

4547

46-
# Mirror of dlpack.h enum
47-
class DLDeviceType(enum.IntEnum):
48-
kDLCPU = 1
49-
kDLCUDA = 2
50-
kDLROCM = 10
48+
def _to_dlpack(x: Array, stream: int | Any | None,
49+
src_device: xla_client.Device | None = None,
50+
device: xla_client.Device | None = None,
51+
copy: bool | None = None):
5152

53+
if src_device is None:
54+
src_device, = x.devices()
55+
if device and (src_device is None or device != src_device):
56+
if copy is not None and not copy:
57+
raise ValueError(
58+
f"Specified {device=} which requires a copy since the source device "
59+
f"is {repr(src_device)}, however copy=False. Set copy=True or "
60+
"copy=None to perform the requested operation."
61+
)
62+
else:
63+
arr = device_put(x, device)
64+
else:
65+
arr = _array_copy(x) if copy else x
66+
return xla_client._xla.buffer_to_dlpack_managed_tensor(
67+
arr.addressable_data(0), stream=stream
68+
)
5269

53-
def to_dlpack(x: Array, take_ownership: bool = False,
54-
stream: int | Any | None = None):
70+
def to_dlpack(x: Array, stream: int | Any | None = None,
71+
src_device: xla_client.Device | None = None,
72+
dl_device: tuple[DLDeviceType, int] | None = None,
73+
max_version: tuple[int, int] | None = None,
74+
copy : bool | None = None):
5575
"""Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.
5676
5777
Args:
5878
x: a :class:`~jax.Array`, on either CPU or GPU.
59-
take_ownership: Deprecated. It is a no-op to set take_ownership. Will be
60-
deleted in 01/2024.
6179
stream: optional platform-dependent stream to wait on until the buffer is
6280
ready. This corresponds to the `stream` argument to ``__dlpack__``
6381
documented in https://dmlc.github.io/dlpack/latest/python_spec.html.
82+
src_device: either a CPU or GPU :class:`~jax.Device`.
83+
dl_device: a tuple of ``(dl_device_type, local_hardware_id)`` in DLPack
84+
format e.g. as produced by ``__dlpack_device__``.
85+
max_version: the maximum DLPack version that the consumer (i.e. caller of
86+
``__dlpack__``) supports in the form of a 2-tuple of ``(major, minor)``.
87+
This function is not guaranteed to return a capsule of version
88+
``max_version``.
89+
copy: a boolean indicating whether or not to copy the input. If
90+
``copy=True`` then the function must always copy. When
91+
``copy=False`` then the function must never copy, and must raise an error
92+
when a copy is deemed necessary. If ``copy=None`` then the function must
93+
avoid a copy if possible but may copy if needed.
6494
6595
Returns:
66-
A dlpack PyCapsule object.
96+
A DLPack PyCapsule object.
6797
6898
Note:
69-
While JAX arrays are always immutable, dlpack buffers cannot be marked as
70-
immutable, and it is possible for processes external to JAX to mutate them
71-
in-place. If a dlpack buffer derived from a JAX array is mutated, it may
72-
lead to undefined behavior when using the associated JAX array.
99+
While JAX arrays are always immutable, ``DLPackManagedTensor`` buffers
100+
cannot be marked as immutable, and it is possible for processes external
101+
to JAX to mutate them in-place. If a DLPack buffer derived from a JAX array
102+
is mutated, it may lead to undefined behavior when using the associated JAX
103+
array. When JAX eventually supports ``DLManagedTensorVersioned``
104+
(DLPack 1.0), it will be possible to specify that a buffer is read-only.
73105
"""
74106
if not isinstance(x, array.ArrayImpl):
75107
raise TypeError("Argument to to_dlpack must be a jax.Array, "
76108
f"got {type(x)}")
77-
assert len(x.devices()) == 1
78-
if take_ownership:
79-
warnings.warn(
80-
"take_ownership in to_dlpack is deprecated and it is a no-op."
109+
110+
device = None
111+
dl_device_type, local_hardware_id = dl_device if dl_device else (None, None)
112+
if dl_device_type:
113+
try:
114+
dl_device_platform = {
115+
DLDeviceType.kDLCPU: "cpu",
116+
DLDeviceType.kDLCUDA: "cuda",
117+
DLDeviceType.kDLROCM: "rocm",
118+
}[dl_device_type]
119+
backend = xla_bridge.get_backend(dl_device_platform)
120+
device = backend.device_from_local_hardware_id(local_hardware_id)
121+
except TypeError:
122+
# https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
123+
# recommends using BufferError.
124+
raise BufferError(
125+
"The device specification passed to to_dlpack contains an unsupported "
126+
f"device type (DLDeviceType: {dl_device_type})")
127+
128+
# As new versions are adopted over time, we can maintain some legacy paths
129+
# for compatability mediated through the max_version parameter.
130+
# TODO(micky774): Deprecate default usage of DLPackManagedTensor when XLA
131+
# supports DLManagedTensorVersioned (DLPack version 1.0) and repurpose the
132+
# current _to_dlpack as a legacy path for (0,5) <= max_version < (1,0).
133+
if max_version is None or max_version >= DLPACK_VERSION:
134+
# Latest
135+
return _to_dlpack(
136+
x, stream=stream,
137+
src_device=src_device,
138+
device=device,
139+
copy=copy
140+
)
141+
elif max_version >= MIN_DLPACK_VERSION:
142+
# Oldest supported
143+
return _to_dlpack(
144+
x, stream=stream,
145+
src_device=src_device,
146+
device=device,
147+
copy=copy
148+
)
149+
else:
150+
raise BufferError(
151+
f"JAX does not support any version below {MIN_DLPACK_VERSION} but "
152+
f"version ({max_version}) was requested."
81153
)
82-
return xla_client._xla.buffer_to_dlpack_managed_tensor(
83-
x.addressable_data(0), stream=stream
84-
) # type: ignore
85154

86155
def _place_array(_arr, device, dlpack_device, copy):
87156
if device and dlpack_device != device:

jax/_src/typing.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from collections.abc import Sequence
3030
from typing import Any, Protocol, Union
3131
import numpy as np
32+
import enum
3233

3334
from jax._src.basearray import (
3435
Array as Array,
@@ -83,3 +84,9 @@ def shape(self) -> Shape: ...
8384
class DeprecatedArg:
8485
def __repr__(self):
8586
return "Deprecated"
87+
88+
# Mirror of dlpack.h enum
89+
class DLDeviceType(enum.IntEnum):
90+
kDLCPU = 1
91+
kDLCUDA = 2
92+
kDLROCM = 10

jax/experimental/jax2tf/call_tf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def _arg_jax_to_tf(arg_jax):
334334
if (isinstance(arg_jax, jax.Array) and
335335
list(arg_jax.devices())[0].platform in _DLPACK_PLATFORMS and
336336
arg_jax.dtype.type in dlpack.SUPPORTED_DTYPES):
337-
arg_dlpack = jax.dlpack.to_dlpack(arg_jax, take_ownership=False)
337+
arg_dlpack = jax.dlpack.to_dlpack(arg_jax)
338338
return tf.experimental.dlpack.from_dlpack(arg_dlpack)
339339
# The following avoids copies to the host on CPU, always for Array
340340
# and even for ndarray if they are sufficiently aligned.

tests/array_interoperability_test.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,23 +73,48 @@ def setUp(self):
7373
@jtu.sample_product(
7474
shape=all_shapes,
7575
dtype=dlpack_dtypes,
76-
gpu=[False, True],
76+
copy=[False, True, None]
7777
)
78-
def testJaxRoundTrip(self, shape, dtype, gpu):
78+
@jtu.run_on_devices("gpu")
79+
def testJaxRoundTrip(self, shape, dtype, copy):
80+
if xb.using_pjrt_c_api():
81+
self.skipTest("DLPack support is incomplete in the PJRT C API") # TODO(skyewm)
7982
rng = jtu.rand_default(self.rng())
8083
np = rng(shape, dtype)
81-
if gpu and jtu.test_device_matches(["cpu"]):
82-
raise unittest.SkipTest("Skipping GPU test case on CPU")
83-
device = jax.devices("gpu" if gpu else "cpu")[0]
84-
x = jax.device_put(np, device)
85-
dlpack = jax.dlpack.to_dlpack(x)
86-
y = jax.dlpack.from_dlpack(dlpack)
87-
self.assertEqual(y.devices(), {device})
88-
self.assertAllClose(np.astype(x.dtype), y)
8984

85+
def _check_copy(x: jax.Array, y: jax.Array, expect_copy):
86+
copied = x.unsafe_buffer_pointer() != y.unsafe_buffer_pointer()
87+
assert copied == expect_copy, f"Expected {'a' if expect_copy else 'no'} copy"
88+
89+
# Check if the source device is preserved
90+
x = jax.device_put(np, jax.devices("cpu")[0])
91+
device = jax.devices("gpu")[0]
92+
y = jax.device_put(x, device)
93+
dl_device = y.__dlpack_device__()
94+
dlpack = jax.dlpack.to_dlpack(y, copy=copy)
95+
z = jax.dlpack.from_dlpack(dlpack)
96+
97+
self.assertEqual(z.devices(), {device})
98+
self.assertAllClose(np.astype(x.dtype), z)
9099
self.assertRaisesRegex(RuntimeError,
91-
"DLPack tensor may be consumed at most once",
92-
lambda: jax.dlpack.from_dlpack(dlpack))
100+
"DLPack tensor may be consumed at most once",
101+
lambda: jax.dlpack.from_dlpack(dlpack))
102+
103+
if shape in nonempty_array_shapes:
104+
_check_copy(y, z, bool(copy))
105+
106+
# Check if the destination device can be specified
107+
make_dlpack = lambda: x.__dlpack__(dl_device=dl_device, copy=copy)
108+
if copy == False:
109+
self.assertRaisesRegex(ValueError, "copy=False", make_dlpack)
110+
return
111+
112+
z = jax.dlpack.from_dlpack(make_dlpack())
113+
self.assertEqual(z.devices(), {device})
114+
self.assertAllClose(x, z)
115+
116+
if shape in nonempty_array_shapes:
117+
_check_copy(x, z, True)
93118

94119
@jtu.sample_product(
95120
shape=all_shapes,

0 commit comments

Comments
 (0)