24
24
from jax ._src .lib import xla_client
25
25
from jax ._src .lib import xla_extension_version
26
26
from jax ._src .typing import Array
27
+ from jax ._src .api import device_put
27
28
29
+ DLPACK_VERSION = (0 , 1 )
30
+ MIN_DLPACK_VERSION = (0 , 1 )
28
31
29
32
# A set of dtypes that dlpack supports.
30
33
# Note: Make sure to use a "type", not a dtype instance, when looking up this set
@@ -48,9 +51,32 @@ class DLDeviceType(enum.IntEnum):
48
51
kDLCUDA = 2
49
52
kDLROCM = 10
50
53
54
+ def _to_dlpack (x : Array , stream : int | Any | None ,
55
+ device : xla_client .Device | None = None ,
56
+ dlpack_device : xla_client .Device | None = None ,
57
+ copy : bool | None = None ):
58
+ if dlpack_device and dlpack_device != device :
59
+ if copy is not None and not copy :
60
+ raise ValueError (
61
+ f"Specified { dlpack_device = } which requires a copy since the source device "
62
+ f"is { repr (device )} , however copy=False. Set copy=True or "
63
+ "copy=None to perform the requested operation."
64
+ )
65
+ else :
66
+ arr = device_put (x , dlpack_device )
67
+ else :
68
+ arr = x .copy () if copy else x
69
+
70
+ return xla_client ._xla .buffer_to_dlpack_managed_tensor (
71
+ arr .addressable_data (0 ), stream = stream
72
+ )
51
73
52
74
def to_dlpack (x : Array , take_ownership : bool = False ,
53
- stream : int | Any | None = None ):
75
+ stream : int | Any | None = None ,
76
+ device : xla_client .Device | None = None ,
77
+ dl_device : tuple [DLDeviceType , int ] | None = None ,
78
+ max_version : tuple [int , int ] | None = None ,
79
+ copy : bool | None = None ):
54
80
"""Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.
55
81
56
82
Args:
@@ -73,14 +99,40 @@ def to_dlpack(x: Array, take_ownership: bool = False,
73
99
if not isinstance (x , array .ArrayImpl ):
74
100
raise TypeError ("Argument to to_dlpack must be a jax.Array, "
75
101
f"got { type (x )} " )
76
- assert len (x .devices ()) == 1
77
102
if take_ownership :
78
103
warnings .warn (
79
104
"take_ownership in to_dlpack is deprecated and it is a no-op."
80
105
)
81
- return xla_client ._xla .buffer_to_dlpack_managed_tensor (
82
- x .addressable_data (0 ), stream = stream
83
- ) # type: ignore
106
+
107
+ dlpack_device = None
108
+ dl_device_type , local_hardware_id = dl_device if dl_device else (None , None )
109
+ if dl_device_type :
110
+ try :
111
+ dl_device_platform = {
112
+ DLDeviceType .kDLCPU : "cpu" ,
113
+ DLDeviceType .kDLCUDA : "cuda" ,
114
+ DLDeviceType .kDLROCM : "rocm" ,
115
+ }[dl_device_type ]
116
+ backend = xla_bridge .get_backend (dl_device_platform )
117
+ dlpack_device = backend .device_from_local_hardware_id (local_hardware_id )
118
+ except TypeError :
119
+ # https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
120
+ # recommends using BufferError.
121
+ raise BufferError (
122
+ "The device specification passed to to_dlpack contains an unsupported "
123
+ f"device type (DLDeviceType: { dl_device_type } )" )
124
+
125
+ if max_version is None or max_version [0 ] >= DLPACK_VERSION [0 ]:
126
+ return _to_dlpack (x , stream = stream , device = device , dlpack_device = dlpack_device , copy = copy )
127
+ elif max_version >= MIN_DLPACK_VERSION :
128
+ # Legacy path to be implemented when XLA adopts DLManagedTensorVersioned format
129
+ raise RuntimeError ("This branch should be unreachable. "
130
+ "Please open a bug if you see this." )
131
+ else :
132
+ raise BufferError (
133
+ f"JAX does not support any version below { MIN_DLPACK_VERSION } but "
134
+ f"version ({ max_version } ) was requested."
135
+ )
84
136
85
137
86
138
def from_dlpack (external_array ):
@@ -110,12 +162,12 @@ def from_dlpack(external_array):
110
162
DLDeviceType .kDLCUDA : "cuda" ,
111
163
DLDeviceType .kDLROCM : "rocm" ,
112
164
}[dl_device_type ]
113
- except TypeError :
165
+ except TypeError as err :
114
166
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
115
167
# TypeError.
116
- raise TypeError (
168
+ raise BufferError (
117
169
"Array passed to from_dlpack is on unsupported device type "
118
- f"(DLDeviceType: { dl_device_type } , array: { external_array } " )
170
+ f"(DLDeviceType: { dl_device_type } , array: { external_array } " ) from err
119
171
120
172
backend = xla_bridge .get_backend (device_platform )
121
173
device = backend .device_from_local_hardware_id (device_id )
0 commit comments