21
21
from jax import numpy as jnp
22
22
from jax ._src import array
23
23
from jax ._src import xla_bridge
24
+ from jax ._src .lax .lax import _array_copy
24
25
from jax ._src .lib import xla_client
25
26
from jax ._src .lib import xla_extension_version
26
27
from jax ._src .typing import Array
28
+ from jax ._src .api import device_put
27
29
30
+ DLPACK_VERSION = (0 , 8 )
31
+ MIN_DLPACK_VERSION = (0 , 5 )
28
32
29
33
# A set of dtypes that dlpack supports.
30
34
# Note: Make sure to use a "type", not a dtype instance, when looking up this set
@@ -48,9 +52,34 @@ class DLDeviceType(enum.IntEnum):
48
52
kDLCUDA = 2
49
53
kDLROCM = 10
50
54
55
+ def _to_dlpack (x : Array , stream : int | Any | None ,
56
+ src_device : xla_client .Device | None = None ,
57
+ device : xla_client .Device | None = None ,
58
+ copy : bool | None = None ):
59
+
60
+ if src_device is None :
61
+ src_device , = x .devices ()
62
+ if device and (src_device is None or device != src_device ):
63
+ if copy is not None and not copy :
64
+ raise ValueError (
65
+ f"Specified { device = } which requires a copy since the source device "
66
+ f"is { repr (src_device )} , however copy=False. Set copy=True or "
67
+ "copy=None to perform the requested operation."
68
+ )
69
+ else :
70
+ arr = device_put (x , device )
71
+ else :
72
+ arr = _array_copy (x ) if copy else x
73
+ return xla_client ._xla .buffer_to_dlpack_managed_tensor (
74
+ arr .addressable_data (0 ), stream = stream
75
+ )
51
76
52
77
def to_dlpack (x : Array , take_ownership : bool = False ,
53
- stream : int | Any | None = None ):
78
+ stream : int | Any | None = None ,
79
+ src_device : xla_client .Device | None = None ,
80
+ dl_device : tuple [DLDeviceType , int ] | None = None ,
81
+ max_version : tuple [int , int ] | None = None ,
82
+ copy : bool | None = None ):
54
83
"""Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.
55
84
56
85
Args:
@@ -60,27 +89,98 @@ def to_dlpack(x: Array, take_ownership: bool = False,
60
89
stream: optional platform-dependent stream to wait on until the buffer is
61
90
ready. This corresponds to the `stream` argument to ``__dlpack__``
62
91
documented in https://dmlc.github.io/dlpack/latest/python_spec.html.
92
+ src_device: either a CPU or GPU :class:`~jax.Device`.
93
+ dl_device: a tuple of ``(dl_device_type, local_hardware_id)`` in DLPack
94
+ format e.g. as produced by ``__dlpack_device__``.
95
+ max_version: the maximum DLPack version that the consumer (i.e. caller of
96
+ ``__dlpack__``) supports in the form of a 2-tuple of ``(major, minor)``.
97
+ This function is not guaranteed to return a capsule of version
98
+ ``max_version``.
99
+ copy: a boolean indicating whether or not to copy the input.
100
+ - If ``True`` then the function must always copy
101
+ - If ``False`` then the function must never copy, and must raise an error
102
+ if a copy is deemed necessary.
103
+ - If ``None`` then the function must avoid a copy if possible but may
104
+ copy if needed.
63
105
64
106
Returns:
65
- A dlpack PyCapsule object.
107
+ A DLPack PyCapsule object.
66
108
67
109
Note:
68
- While JAX arrays are always immutable, dlpack buffers cannot be marked as
69
- immutable, and it is possible for processes external to JAX to mutate them
70
- in-place. If a dlpack buffer derived from a JAX array is mutated, it may
71
- lead to undefined behavior when using the associated JAX array.
110
+ While JAX arrays are always immutable, ``DLPackManagedTensor`` buffers
111
+ cannot be marked as immutable, and it is possible for processes external
112
+ to JAX to mutate them in-place. If a DLPack buffer derived from a JAX array
113
+ is mutated, it may lead to undefined behavior when using the associated JAX
114
+ array. When JAX eventually supports ``DLManagedTensorVersioned``
115
+ (DLPack 1.0), it will be possible to specify that a buffer is read-only.
72
116
"""
73
117
if not isinstance (x , array .ArrayImpl ):
74
118
raise TypeError ("Argument to to_dlpack must be a jax.Array, "
75
119
f"got { type (x )} " )
76
- assert len (x .devices ()) == 1
77
120
if take_ownership :
78
121
warnings .warn (
79
122
"take_ownership in to_dlpack is deprecated and it is a no-op."
80
123
)
81
- return xla_client ._xla .buffer_to_dlpack_managed_tensor (
82
- x .addressable_data (0 ), stream = stream
83
- ) # type: ignore
124
+
125
+ device = None
126
+ dl_device_type , local_hardware_id = dl_device if dl_device else (None , None )
127
+ if dl_device_type :
128
+ try :
129
+ dl_device_platform = {
130
+ DLDeviceType .kDLCPU : "cpu" ,
131
+ DLDeviceType .kDLCUDA : "cuda" ,
132
+ DLDeviceType .kDLROCM : "rocm" ,
133
+ }[dl_device_type ]
134
+ backend = xla_bridge .get_backend (dl_device_platform )
135
+ device = backend .device_from_local_hardware_id (local_hardware_id )
136
+ except TypeError :
137
+ # https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
138
+ # recommends using BufferError.
139
+ raise BufferError (
140
+ "The device specification passed to to_dlpack contains an unsupported "
141
+ f"device type (DLDeviceType: { dl_device_type } )" )
142
+
143
+ # TODO(micky774): Deprecate default usage of DLPackManagedTensor when XLA
144
+ # supports DLManagedTensorVersioned (DLPack version 1.0) and repurpose the
145
+ # current _to_dlpack as a legacy path for (0,5) <= max_version < (1,0)
146
+ if max_version is None :
147
+ # Backwards compatible default
148
+ return _to_dlpack (
149
+ x , stream = stream ,
150
+ src_device = src_device ,
151
+ device = device ,
152
+ copy = copy
153
+ )
154
+ else :
155
+ if max_version >= DLPACK_VERSION :
156
+ # Latest
157
+ return _to_dlpack (
158
+ x , stream = stream ,
159
+ src_device = src_device ,
160
+ device = device ,
161
+ copy = copy
162
+ )
163
+ if max_version [0 ] == DLPACK_VERSION [0 ]:
164
+ # ABI compatible
165
+ return _to_dlpack (
166
+ x , stream = stream ,
167
+ src_device = src_device ,
168
+ device = device ,
169
+ copy = copy
170
+ )
171
+ elif max_version >= MIN_DLPACK_VERSION :
172
+ # Oldest supported
173
+ return _to_dlpack (
174
+ x , stream = stream ,
175
+ src_device = src_device ,
176
+ device = device ,
177
+ copy = copy
178
+ )
179
+ else :
180
+ raise BufferError (
181
+ f"JAX does not support any version below { MIN_DLPACK_VERSION } but "
182
+ f"version ({ max_version } ) was requested."
183
+ )
84
184
85
185
86
186
def from_dlpack (external_array ):
@@ -110,12 +210,12 @@ def from_dlpack(external_array):
110
210
DLDeviceType .kDLCUDA : "cuda" ,
111
211
DLDeviceType .kDLROCM : "rocm" ,
112
212
}[dl_device_type ]
113
- except TypeError :
213
+ except TypeError as err :
114
214
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
115
215
# TypeError.
116
- raise TypeError (
216
+ raise BufferError (
117
217
"Array passed to from_dlpack is on unsupported device type "
118
- f"(DLDeviceType: { dl_device_type } , array: { external_array } " )
218
+ f"(DLDeviceType: { dl_device_type } , array: { external_array } " ) from err
119
219
120
220
backend = xla_bridge .get_backend (device_platform )
121
221
device = backend .device_from_local_hardware_id (device_id )
0 commit comments