21
21
from jax import numpy as jnp
22
22
from jax ._src import array
23
23
from jax ._src import xla_bridge
24
+ from jax ._src .lax .lax import _array_copy
24
25
from jax ._src .lib import xla_client
25
26
from jax ._src .lib import xla_extension_version
26
27
from jax ._src .typing import Array
28
+ from jax ._src .api import device_put
27
29
30
+ DLPACK_VERSION = (0 , 8 )
31
+ MIN_DLPACK_VERSION = (0 , 5 )
28
32
29
33
# A set of dtypes that dlpack supports.
30
34
# Note: Make sure to use a "type", not a dtype instance, when looking up this set
@@ -48,9 +52,34 @@ class DLDeviceType(enum.IntEnum):
48
52
kDLCUDA = 2
49
53
kDLROCM = 10
50
54
55
+ def _to_dlpack (x : Array , stream : int | Any | None ,
56
+ src_device : xla_client .Device | None = None ,
57
+ device : xla_client .Device | None = None ,
58
+ copy : bool | None = None ):
59
+
60
+ if src_device is None :
61
+ src_device , = x .devices ()
62
+ if device and (src_device is None or device != src_device ):
63
+ if copy is not None and not copy :
64
+ raise ValueError (
65
+ f"Specified { device = } which requires a copy since the source device "
66
+ f"is { repr (src_device )} , however copy=False. Set copy=True or "
67
+ "copy=None to perform the requested operation."
68
+ )
69
+ else :
70
+ arr = device_put (x , device )
71
+ else :
72
+ arr = _array_copy (x ) if copy else x
73
+ return xla_client ._xla .buffer_to_dlpack_managed_tensor (
74
+ arr .addressable_data (0 ), stream = stream
75
+ )
51
76
52
77
def to_dlpack (x : Array , take_ownership : bool = False ,
53
- stream : int | Any | None = None ):
78
+ stream : int | Any | None = None ,
79
+ src_device : xla_client .Device | None = None ,
80
+ dl_device : tuple [DLDeviceType , int ] | None = None ,
81
+ max_version : tuple [int , int ] | None = None ,
82
+ copy : bool | None = None ):
54
83
"""Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.
55
84
56
85
Args:
@@ -60,27 +89,82 @@ def to_dlpack(x: Array, take_ownership: bool = False,
60
89
stream: optional platform-dependent stream to wait on until the buffer is
61
90
ready. This corresponds to the `stream` argument to ``__dlpack__``
62
91
documented in https://dmlc.github.io/dlpack/latest/python_spec.html.
92
+ src_device: either a CPU or GPU :class:`~jax.Device`.
93
+ dl_device: a tuple of ``(dl_device_type, local_hardware_id)`` in DLPack
94
+ format e.g. as produced by ``__dlpack_device__``.
95
+ max_version: the maximum DLPack version that the consumer (i.e. caller of
96
+ ``__dlpack__``) supports in the form of a 2-tuple of ``(major, minor)``.
97
+ This function is not guaranteed to return a capsule of version
98
+ ``max_version``.
99
+ copy: a boolean indicating whether or not to copy the input. If
100
+ ``copy=True`` then the function must always copy. When
101
+ ``copy=False`` then the function must never copy, and must raise an error
102
+ when a copy is deemed necessary. If ``copy=None`` then the function must
103
+ avoid a copy if possible but may copy if needed.
63
104
64
105
Returns:
65
- A dlpack PyCapsule object.
106
+ A DLPack PyCapsule object.
66
107
67
108
Note:
68
- While JAX arrays are always immutable, dlpack buffers cannot be marked as
69
- immutable, and it is possible for processes external to JAX to mutate them
70
- in-place. If a dlpack buffer derived from a JAX array is mutated, it may
71
- lead to undefined behavior when using the associated JAX array.
109
+ While JAX arrays are always immutable, ``DLPackManagedTensor`` buffers
110
+ cannot be marked as immutable, and it is possible for processes external
111
+ to JAX to mutate them in-place. If a DLPack buffer derived from a JAX array
112
+ is mutated, it may lead to undefined behavior when using the associated JAX
113
+ array. When JAX eventually supports ``DLManagedTensorVersioned``
114
+ (DLPack 1.0), it will be possible to specify that a buffer is read-only.
72
115
"""
73
116
if not isinstance (x , array .ArrayImpl ):
74
117
raise TypeError ("Argument to to_dlpack must be a jax.Array, "
75
118
f"got { type (x )} " )
76
- assert len (x .devices ()) == 1
77
119
if take_ownership :
78
120
warnings .warn (
79
121
"take_ownership in to_dlpack is deprecated and it is a no-op."
80
122
)
81
- return xla_client ._xla .buffer_to_dlpack_managed_tensor (
82
- x .addressable_data (0 ), stream = stream
83
- ) # type: ignore
123
+
124
+ device = None
125
+ dl_device_type , local_hardware_id = dl_device if dl_device else (None , None )
126
+ if dl_device_type :
127
+ try :
128
+ dl_device_platform = {
129
+ DLDeviceType .kDLCPU : "cpu" ,
130
+ DLDeviceType .kDLCUDA : "cuda" ,
131
+ DLDeviceType .kDLROCM : "rocm" ,
132
+ }[dl_device_type ]
133
+ backend = xla_bridge .get_backend (dl_device_platform )
134
+ device = backend .device_from_local_hardware_id (local_hardware_id )
135
+ except TypeError :
136
+ # https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
137
+ # recommends using BufferError.
138
+ raise BufferError (
139
+ "The device specification passed to to_dlpack contains an unsupported "
140
+ f"device type (DLDeviceType: { dl_device_type } )" )
141
+
142
+ # As new versions are adopted over time, we can maintain some legacy paths
143
+ # for compatability mediated through the max_version parameter.
144
+ # TODO(micky774): Deprecate default usage of DLPackManagedTensor when XLA
145
+ # supports DLManagedTensorVersioned (DLPack version 1.0) and repurpose the
146
+ # current _to_dlpack as a legacy path for (0,5) <= max_version < (1,0).
147
+ if max_version is None or max_version >= DLPACK_VERSION :
148
+ # Latest
149
+ return _to_dlpack (
150
+ x , stream = stream ,
151
+ src_device = src_device ,
152
+ device = device ,
153
+ copy = copy
154
+ )
155
+ elif max_version >= MIN_DLPACK_VERSION :
156
+ # Oldest supported
157
+ return _to_dlpack (
158
+ x , stream = stream ,
159
+ src_device = src_device ,
160
+ device = device ,
161
+ copy = copy
162
+ )
163
+ else :
164
+ raise BufferError (
165
+ f"JAX does not support any version below { MIN_DLPACK_VERSION } but "
166
+ f"version ({ max_version } ) was requested."
167
+ )
84
168
85
169
86
170
def from_dlpack (external_array ):
@@ -110,12 +194,12 @@ def from_dlpack(external_array):
110
194
DLDeviceType .kDLCUDA : "cuda" ,
111
195
DLDeviceType .kDLROCM : "rocm" ,
112
196
}[dl_device_type ]
113
- except TypeError :
197
+ except TypeError as err :
114
198
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
115
199
# TypeError.
116
- raise TypeError (
200
+ raise BufferError (
117
201
"Array passed to from_dlpack is on unsupported device type "
118
- f"(DLDeviceType: { dl_device_type } , array: { external_array } " )
202
+ f"(DLDeviceType: { dl_device_type } , array: { external_array } " ) from err
119
203
120
204
backend = xla_bridge .get_backend (device_platform )
121
205
device = backend .device_from_local_hardware_id (device_id )
0 commit comments