16
16
17
17
import enum
18
18
from typing import Any
19
- import warnings
20
19
21
20
from jax ._src .api import device_put
22
21
from jax import numpy as jnp
23
22
from jax ._src import array
24
23
from jax ._src import xla_bridge
24
+ from jax ._src .lax .lax import _array_copy
25
25
from jax ._src .lib import xla_client
26
26
from jax ._src .lib import xla_extension_version
27
27
from jax ._src .typing import Array
28
28
from jax ._src .sharding import Sharding
29
29
30
+ DLPACK_VERSION = (0 , 8 )
31
+ MIN_DLPACK_VERSION = (0 , 5 )
32
+
30
33
# A set of dtypes that dlpack supports.
31
34
# Note: Make sure to use a "type", not a dtype instance, when looking up this set
32
35
# because their hashes are different.
@@ -49,39 +52,112 @@ class DLDeviceType(enum.IntEnum):
49
52
kDLCUDA = 2
50
53
kDLROCM = 10
51
54
55
+ def _to_dlpack (x : Array , stream : int | Any | None ,
56
+ src_device : xla_client .Device | None = None ,
57
+ device : xla_client .Device | None = None ,
58
+ copy : bool | None = None ):
59
+
60
+ if src_device is None :
61
+ src_device , = x .devices ()
62
+ if device and (src_device is None or device != src_device ):
63
+ if copy is not None and not copy :
64
+ raise ValueError (
65
+ f"Specified { device = } which requires a copy since the source device "
66
+ f"is { repr (src_device )} , however copy=False. Set copy=True or "
67
+ "copy=None to perform the requested operation."
68
+ )
69
+ else :
70
+ arr = device_put (x , device )
71
+ else :
72
+ arr = _array_copy (x ) if copy else x
73
+ return xla_client ._xla .buffer_to_dlpack_managed_tensor (
74
+ arr .addressable_data (0 ), stream = stream
75
+ )
52
76
53
- def to_dlpack (x : Array , take_ownership : bool = False ,
54
- stream : int | Any | None = None ):
77
+ def to_dlpack (x : Array , stream : int | Any | None = None ,
78
+ src_device : xla_client .Device | None = None ,
79
+ dl_device : tuple [DLDeviceType , int ] | None = None ,
80
+ max_version : tuple [int , int ] | None = None ,
81
+ copy : bool | None = None ):
55
82
"""Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.
56
83
57
84
Args:
58
85
x: a :class:`~jax.Array`, on either CPU or GPU.
59
- take_ownership: Deprecated. It is a no-op to set take_ownership. Will be
60
- deleted in 01/2024.
61
86
stream: optional platform-dependent stream to wait on until the buffer is
62
87
ready. This corresponds to the `stream` argument to ``__dlpack__``
63
88
documented in https://dmlc.github.io/dlpack/latest/python_spec.html.
89
+ src_device: either a CPU or GPU :class:`~jax.Device`.
90
+ dl_device: a tuple of ``(dl_device_type, local_hardware_id)`` in DLPack
91
+ format e.g. as produced by ``__dlpack_device__``.
92
+ max_version: the maximum DLPack version that the consumer (i.e. caller of
93
+ ``__dlpack__``) supports in the form of a 2-tuple of ``(major, minor)``.
94
+ This function is not guaranteed to return a capsule of version
95
+ ``max_version``.
96
+ copy: a boolean indicating whether or not to copy the input. If
97
+ ``copy=True`` then the function must always copy. When
98
+ ``copy=False`` then the function must never copy, and must raise an error
99
+ when a copy is deemed necessary. If ``copy=None`` then the function must
100
+ avoid a copy if possible but may copy if needed.
64
101
65
102
Returns:
66
- A dlpack PyCapsule object.
103
+ A DLPack PyCapsule object.
67
104
68
105
Note:
69
- While JAX arrays are always immutable, dlpack buffers cannot be marked as
70
- immutable, and it is possible for processes external to JAX to mutate them
71
- in-place. If a dlpack buffer derived from a JAX array is mutated, it may
72
- lead to undefined behavior when using the associated JAX array.
106
+ While JAX arrays are always immutable, ``DLPackManagedTensor`` buffers
107
+ cannot be marked as immutable, and it is possible for processes external
108
+ to JAX to mutate them in-place. If a DLPack buffer derived from a JAX array
109
+ is mutated, it may lead to undefined behavior when using the associated JAX
110
+ array. When JAX eventually supports ``DLManagedTensorVersioned``
111
+ (DLPack 1.0), it will be possible to specify that a buffer is read-only.
73
112
"""
74
113
if not isinstance (x , array .ArrayImpl ):
75
114
raise TypeError ("Argument to to_dlpack must be a jax.Array, "
76
115
f"got { type (x )} " )
77
- assert len (x .devices ()) == 1
78
- if take_ownership :
79
- warnings .warn (
80
- "take_ownership in to_dlpack is deprecated and it is a no-op."
116
+
117
+ device = None
118
+ dl_device_type , local_hardware_id = dl_device if dl_device else (None , None )
119
+ if dl_device_type :
120
+ try :
121
+ dl_device_platform = {
122
+ DLDeviceType .kDLCPU : "cpu" ,
123
+ DLDeviceType .kDLCUDA : "cuda" ,
124
+ DLDeviceType .kDLROCM : "rocm" ,
125
+ }[dl_device_type ]
126
+ backend = xla_bridge .get_backend (dl_device_platform )
127
+ device = backend .device_from_local_hardware_id (local_hardware_id )
128
+ except TypeError :
129
+ # https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
130
+ # recommends using BufferError.
131
+ raise BufferError (
132
+ "The device specification passed to to_dlpack contains an unsupported "
133
+ f"device type (DLDeviceType: { dl_device_type } )" )
134
+
135
+ # As new versions are adopted over time, we can maintain some legacy paths
136
+ # for compatability mediated through the max_version parameter.
137
+ # TODO(micky774): Deprecate default usage of DLPackManagedTensor when XLA
138
+ # supports DLManagedTensorVersioned (DLPack version 1.0) and repurpose the
139
+ # current _to_dlpack as a legacy path for (0,5) <= max_version < (1,0).
140
+ if max_version is None or max_version >= DLPACK_VERSION :
141
+ # Latest
142
+ return _to_dlpack (
143
+ x , stream = stream ,
144
+ src_device = src_device ,
145
+ device = device ,
146
+ copy = copy
147
+ )
148
+ elif max_version >= MIN_DLPACK_VERSION :
149
+ # Oldest supported
150
+ return _to_dlpack (
151
+ x , stream = stream ,
152
+ src_device = src_device ,
153
+ device = device ,
154
+ copy = copy
155
+ )
156
+ else :
157
+ raise BufferError (
158
+ f"JAX does not support any version below { MIN_DLPACK_VERSION } but "
159
+ f"version ({ max_version } ) was requested."
81
160
)
82
- return xla_client ._xla .buffer_to_dlpack_managed_tensor (
83
- x .addressable_data (0 ), stream = stream
84
- ) # type: ignore
85
161
86
162
def _place_array (_arr , device , dlpack_device , copy ):
87
163
if device and dlpack_device != device :
0 commit comments