Skip to content

Commit bb27b92

Browse files
committed
Update
1 parent d448f66 commit bb27b92

File tree

2 files changed

+78
-12
lines changed

2 files changed

+78
-12
lines changed

jax/_src/array.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -404,11 +404,25 @@ def __array__(self, dtype=None, context=None, copy=None):
404404
kwds = {} if copy is None else {'copy': copy}
405405
return np.asarray(self._value, dtype=dtype, **kwds)
406406

407-
def __dlpack__(self, *, stream: int | Any | None = None):
408-
if len(self._arrays) != 1:
409-
raise BufferError("__dlpack__ only supported for unsharded arrays.")
407+
def __dlpack__(self, *, stream: int | Any | None = None,
408+
max_version: tuple[int, int] | None = None,
409+
dl_device: tuple[enum.Enum, int] | None = None,
410+
copy: bool | None = None):
410411
from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top
411-
return to_dlpack(self, stream=stream)
412+
413+
device_set = self.sharding.device_set
414+
if len(device_set) > 1:
415+
raise BufferError(
416+
"to_dlpack can only pack a dlpack tensor from an array on a singular "
417+
f"device, but an array with a Sharding over {len(device_set)} devices "
418+
"was provided."
419+
)
420+
device, = device_set
421+
return to_dlpack(self, stream=stream,
422+
max_version=max_version,
423+
device=device,
424+
dl_device=dl_device, # type: ignore
425+
copy=copy)
412426

413427
def __dlpack_device__(self) -> tuple[enum.Enum, int]:
414428
if len(self._arrays) != 1:

jax/_src/dlpack.py

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424
from jax._src.lib import xla_client
2525
from jax._src.lib import xla_extension_version
2626
from jax._src.typing import Array
27+
from jax._src.api import device_put
2728

29+
DLPACK_VERSION = (0, 1)
30+
MIN_DLPACK_VERSION = (0, 1)
2831

2932
# A set of dtypes that dlpack supports.
3033
# Note: Make sure to use a "type", not a dtype instance, when looking up this set
@@ -48,9 +51,32 @@ class DLDeviceType(enum.IntEnum):
4851
kDLCUDA = 2
4952
kDLROCM = 10
5053

54+
def _to_dlpack(x: Array, stream: int | Any | None,
55+
device: xla_client.Device | None = None,
56+
dlpack_device: xla_client.Device | None = None,
57+
copy: bool | None = None):
58+
if dlpack_device and dlpack_device != device:
59+
if copy is not None and not copy:
60+
raise ValueError(
61+
f"Specified {dlpack_device=} which requires a copy since the source device "
62+
f"is {repr(device)}, however copy=False. Set copy=True or "
63+
"copy=None to perform the requested operation."
64+
)
65+
else:
66+
arr = device_put(x, dlpack_device)
67+
else:
68+
arr = x.copy() if copy else x
69+
70+
return xla_client._xla.buffer_to_dlpack_managed_tensor(
71+
arr.addressable_data(0), stream=stream
72+
)
5173

5274
def to_dlpack(x: Array, take_ownership: bool = False,
53-
stream: int | Any | None = None):
75+
stream: int | Any | None = None,
76+
device: xla_client.Device | None = None,
77+
dl_device: tuple[DLDeviceType, int] | None = None,
78+
max_version: tuple[int, int] | None = None,
79+
copy : bool | None = None):
5480
"""Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.
5581
5682
Args:
@@ -73,14 +99,40 @@ def to_dlpack(x: Array, take_ownership: bool = False,
7399
if not isinstance(x, array.ArrayImpl):
74100
raise TypeError("Argument to to_dlpack must be a jax.Array, "
75101
f"got {type(x)}")
76-
assert len(x.devices()) == 1
77102
if take_ownership:
78103
warnings.warn(
79104
"take_ownership in to_dlpack is deprecated and it is a no-op."
80105
)
81-
return xla_client._xla.buffer_to_dlpack_managed_tensor(
82-
x.addressable_data(0), stream=stream
83-
) # type: ignore
106+
107+
dlpack_device = None
108+
dl_device_type, local_hardware_id = dl_device if dl_device else (None, None)
109+
if dl_device_type:
110+
try:
111+
dl_device_platform = {
112+
DLDeviceType.kDLCPU: "cpu",
113+
DLDeviceType.kDLCUDA: "cuda",
114+
DLDeviceType.kDLROCM: "rocm",
115+
}[dl_device_type]
116+
backend = xla_bridge.get_backend(dl_device_platform)
117+
dlpack_device = backend.device_from_local_hardware_id(local_hardware_id)
118+
except TypeError:
119+
# https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
120+
# recommends using BufferError.
121+
raise BufferError(
122+
"The device specification passed to to_dlpack contains an unsupported "
123+
f"device type (DLDeviceType: {dl_device_type})")
124+
125+
if max_version is None or max_version[0] >= DLPACK_VERSION[0]:
126+
return _to_dlpack(x, stream=stream, device=device, dlpack_device=dlpack_device, copy=copy)
127+
elif max_version >= MIN_DLPACK_VERSION:
128+
# Legacy path to be implemented when XLA adopts DLManagedTensorVersioned format
129+
raise RuntimeError("This branch should be unreachable. "
130+
"Please open a bug if you see this.")
131+
else:
132+
raise BufferError(
133+
f"JAX does not support any version below {MIN_DLPACK_VERSION} but "
134+
f"version ({max_version}) was requested."
135+
)
84136

85137

86138
def from_dlpack(external_array):
@@ -110,12 +162,12 @@ def from_dlpack(external_array):
110162
DLDeviceType.kDLCUDA: "cuda",
111163
DLDeviceType.kDLROCM: "rocm",
112164
}[dl_device_type]
113-
except TypeError:
165+
except TypeError as err:
114166
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
115167
# TypeError.
116-
raise TypeError(
168+
raise BufferError(
117169
"Array passed to from_dlpack is on unsupported device type "
118-
f"(DLDeviceType: {dl_device_type}, array: {external_array}")
170+
f"(DLDeviceType: {dl_device_type}, array: {external_array}") from err
119171

120172
backend = xla_bridge.get_backend(device_platform)
121173
device = backend.device_from_local_hardware_id(device_id)

0 commit comments

Comments
 (0)