|
14 | 14 |
|
15 | 15 | from __future__ import annotations
|
16 | 16 |
|
17 |
| -import enum |
18 | 17 | from typing import Any
|
19 |
| -import warnings |
20 | 18 |
|
21 | 19 | from jax._src.api import device_put
|
22 | 20 | from jax import numpy as jnp
|
23 | 21 | from jax._src import array
|
24 | 22 | from jax._src import xla_bridge
|
| 23 | +from jax._src.lax.lax import _array_copy |
25 | 24 | from jax._src.lib import xla_client
|
26 | 25 | from jax._src.lib import xla_extension_version
|
27 |
| -from jax._src.typing import Array |
| 26 | +from jax._src.typing import Array, DLDeviceType |
28 | 27 | from jax._src.sharding import Sharding
|
29 | 28 |
|
| 29 | +DLPACK_VERSION = (0, 8) |
| 30 | +MIN_DLPACK_VERSION = (0, 5) |
| 31 | + |
30 | 32 | # A set of dtypes that dlpack supports.
|
31 | 33 | # Note: Make sure to use a "type", not a dtype instance, when looking up this set
|
32 | 34 | # because their hashes are different.
|
|
43 | 45 | SUPPORTED_DTYPES = SUPPORTED_DTYPES | frozenset({jnp.bool_})
|
44 | 46 |
|
45 | 47 |
|
46 |
| -# Mirror of dlpack.h enum |
47 |
| -class DLDeviceType(enum.IntEnum): |
48 |
| - kDLCPU = 1 |
49 |
| - kDLCUDA = 2 |
50 |
| - kDLROCM = 10 |
| 48 | +def _to_dlpack(x: Array, stream: int | Any | None, |
| 49 | + src_device: xla_client.Device | None = None, |
| 50 | + device: xla_client.Device | None = None, |
| 51 | + copy: bool | None = None): |
51 | 52 |
|
| 53 | + if src_device is None: |
| 54 | + src_device, = x.devices() |
| 55 | + if device and (src_device is None or device != src_device): |
| 56 | + if copy is not None and not copy: |
| 57 | + raise ValueError( |
| 58 | + f"Specified {device=} which requires a copy since the source device " |
| 59 | + f"is {repr(src_device)}, however copy=False. Set copy=True or " |
| 60 | + "copy=None to perform the requested operation." |
| 61 | + ) |
| 62 | + else: |
| 63 | + arr = device_put(x, device) |
| 64 | + else: |
| 65 | + arr = _array_copy(x) if copy else x |
| 66 | + return xla_client._xla.buffer_to_dlpack_managed_tensor( |
| 67 | + arr.addressable_data(0), stream=stream |
| 68 | + ) |
52 | 69 |
|
53 |
| -def to_dlpack(x: Array, take_ownership: bool = False, |
54 |
| - stream: int | Any | None = None): |
| 70 | +def to_dlpack(x: Array, stream: int | Any | None = None, |
| 71 | + src_device: xla_client.Device | None = None, |
| 72 | + dl_device: tuple[DLDeviceType, int] | None = None, |
| 73 | + max_version: tuple[int, int] | None = None, |
| 74 | + copy : bool | None = None): |
55 | 75 | """Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.
|
56 | 76 |
|
57 | 77 | Args:
|
58 | 78 | x: a :class:`~jax.Array`, on either CPU or GPU.
|
59 |
| - take_ownership: Deprecated. It is a no-op to set take_ownership. Will be |
60 |
| - deleted in 01/2024. |
61 | 79 | stream: optional platform-dependent stream to wait on until the buffer is
|
62 | 80 | ready. This corresponds to the `stream` argument to ``__dlpack__``
|
63 | 81 | documented in https://dmlc.github.io/dlpack/latest/python_spec.html.
|
| 82 | + src_device: either a CPU or GPU :class:`~jax.Device`. |
| 83 | + dl_device: a tuple of ``(dl_device_type, local_hardware_id)`` in DLPack |
| 84 | + format e.g. as produced by ``__dlpack_device__``. |
| 85 | + max_version: the maximum DLPack version that the consumer (i.e. caller of |
| 86 | + ``__dlpack__``) supports in the form of a 2-tuple of ``(major, minor)``. |
| 87 | + This function is not guaranteed to return a capsule of version |
| 88 | + ``max_version``. |
| 89 | + copy: a boolean indicating whether or not to copy the input. If |
| 90 | + ``copy=True`` then the function must always copy. When |
| 91 | + ``copy=False`` then the function must never copy, and must raise an error |
| 92 | + when a copy is deemed necessary. If ``copy=None`` then the function must |
| 93 | + avoid a copy if possible but may copy if needed. |
64 | 94 |
|
65 | 95 | Returns:
|
66 |
| - A dlpack PyCapsule object. |
| 96 | + A DLPack PyCapsule object. |
67 | 97 |
|
68 | 98 | Note:
|
69 |
| - While JAX arrays are always immutable, dlpack buffers cannot be marked as |
70 |
| - immutable, and it is possible for processes external to JAX to mutate them |
71 |
| - in-place. If a dlpack buffer derived from a JAX array is mutated, it may |
72 |
| - lead to undefined behavior when using the associated JAX array. |
| 99 | + While JAX arrays are always immutable, ``DLPackManagedTensor`` buffers |
| 100 | + cannot be marked as immutable, and it is possible for processes external |
| 101 | + to JAX to mutate them in-place. If a DLPack buffer derived from a JAX array |
| 102 | + is mutated, it may lead to undefined behavior when using the associated JAX |
| 103 | + array. When JAX eventually supports ``DLManagedTensorVersioned`` |
| 104 | + (DLPack 1.0), it will be possible to specify that a buffer is read-only. |
73 | 105 | """
|
74 | 106 | if not isinstance(x, array.ArrayImpl):
|
75 | 107 | raise TypeError("Argument to to_dlpack must be a jax.Array, "
|
76 | 108 | f"got {type(x)}")
|
77 |
| - assert len(x.devices()) == 1 |
78 |
| - if take_ownership: |
79 |
| - warnings.warn( |
80 |
| - "take_ownership in to_dlpack is deprecated and it is a no-op." |
| 109 | + |
| 110 | + device = None |
| 111 | + dl_device_type, local_hardware_id = dl_device if dl_device else (None, None) |
| 112 | + if dl_device_type: |
| 113 | + try: |
| 114 | + dl_device_platform = { |
| 115 | + DLDeviceType.kDLCPU: "cpu", |
| 116 | + DLDeviceType.kDLCUDA: "cuda", |
| 117 | + DLDeviceType.kDLROCM: "rocm", |
| 118 | + }[dl_device_type] |
| 119 | + backend = xla_bridge.get_backend(dl_device_platform) |
| 120 | + device = backend.device_from_local_hardware_id(local_hardware_id) |
| 121 | + except TypeError: |
| 122 | + # https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html |
| 123 | + # recommends using BufferError. |
| 124 | + raise BufferError( |
| 125 | + "The device specification passed to to_dlpack contains an unsupported " |
| 126 | + f"device type (DLDeviceType: {dl_device_type})") |
| 127 | + |
| 128 | + # As new versions are adopted over time, we can maintain some legacy paths |
| 129 | + # for compatability mediated through the max_version parameter. |
| 130 | + # TODO(micky774): Deprecate default usage of DLPackManagedTensor when XLA |
| 131 | + # supports DLManagedTensorVersioned (DLPack version 1.0) and repurpose the |
| 132 | + # current _to_dlpack as a legacy path for (0,5) <= max_version < (1,0). |
| 133 | + if max_version is None or max_version >= DLPACK_VERSION: |
| 134 | + # Latest |
| 135 | + return _to_dlpack( |
| 136 | + x, stream=stream, |
| 137 | + src_device=src_device, |
| 138 | + device=device, |
| 139 | + copy=copy |
| 140 | + ) |
| 141 | + elif max_version >= MIN_DLPACK_VERSION: |
| 142 | + # Oldest supported |
| 143 | + return _to_dlpack( |
| 144 | + x, stream=stream, |
| 145 | + src_device=src_device, |
| 146 | + device=device, |
| 147 | + copy=copy |
| 148 | + ) |
| 149 | + else: |
| 150 | + raise BufferError( |
| 151 | + f"JAX does not support any version below {MIN_DLPACK_VERSION} but " |
| 152 | + f"version ({max_version}) was requested." |
81 | 153 | )
|
82 |
| - return xla_client._xla.buffer_to_dlpack_managed_tensor( |
83 |
| - x.addressable_data(0), stream=stream |
84 |
| - ) # type: ignore |
85 | 154 |
|
86 | 155 | def _place_array(_arr, device, dlpack_device, copy):
|
87 | 156 | if device and dlpack_device != device:
|
|
0 commit comments