Skip to content

Commit 24f219a

Browse files
committed
Update
1 parent 71ec6e3 commit 24f219a

File tree

4 files changed

+104
-51
lines changed

4 files changed

+104
-51
lines changed

jax/_src/dlpack.py

Lines changed: 97 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Any
1919
import warnings
2020

21+
from jax._src.api import device_put
2122
from jax import numpy as jnp
2223
from jax._src import array
2324
from jax._src import xla_bridge
@@ -82,16 +83,107 @@ def to_dlpack(x: Array, take_ownership: bool = False,
8283
x.addressable_data(0), stream=stream
8384
) # type: ignore
8485

86+
def _place_array(_arr, device, dlpack_device, copy):
87+
if device and dlpack_device != device:
88+
if copy == False:
89+
raise ValueError(
90+
f"Specified {device=} which requires a copy since the source device "
91+
f"is {repr(dlpack_device)}, however copy=False. Set copy=True or "
92+
"copy=None to perform the requested operation."
93+
)
94+
else:
95+
return device_put(_arr, device)
96+
if copy:
97+
return jnp.array(_arr, copy=True)
98+
return _arr
99+
100+
def _legacy_from_dlpack(dlpack, device: xla_client.Device | None = None , copy: bool | None = None):
101+
preferred_platform = getattr(device, "platform", None)
102+
if device and preferred_platform == "gpu":
103+
preferred_platform = "cuda" if "cuda" in device.client.platform_version else "rocm"
104+
105+
cpu_backend = xla_bridge.get_backend("cpu")
106+
gpu_backend = None
107+
108+
if preferred_platform in {"cuda", "rocm"}:
109+
try:
110+
gpu_backend = xla_bridge.get_backend(preferred_platform)
111+
except RuntimeError:
112+
raise TypeError(
113+
f"A {str.upper(preferred_platform)} device was specified, however no "
114+
f"{str.upper(preferred_platform)} backend was found."
115+
)
85116

86-
def from_dlpack(external_array):
117+
if preferred_platform is None:
118+
try:
119+
gpu_backend = xla_bridge.get_backend("cuda")
120+
except RuntimeError:
121+
pass
122+
# Try ROCm if CUDA backend not found
123+
if gpu_backend is None:
124+
try:
125+
gpu_backend = xla_bridge.get_backend("rocm")
126+
except RuntimeError:
127+
pass
128+
129+
_arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
130+
dlpack, cpu_backend, gpu_backend))
131+
132+
return _place_array(_arr, device, _arr.devices().pop(), copy)
133+
134+
def _from_dlpack(external_array, device: xla_client.Device | None = None , copy: bool | None = None):
135+
dl_device_type, device_id = external_array.__dlpack_device__()
136+
try:
137+
dl_device_platform = {
138+
DLDeviceType.kDLCPU: "cpu",
139+
DLDeviceType.kDLCUDA: "cuda",
140+
DLDeviceType.kDLROCM: "rocm",
141+
}[dl_device_type]
142+
except TypeError:
143+
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
144+
# TypeError.
145+
raise TypeError(
146+
"Array passed to from_dlpack is on unsupported device type "
147+
f"(DLDeviceType: {dl_device_type}, array: {external_array}")
148+
149+
backend = xla_bridge.get_backend(dl_device_platform)
150+
dlpack_device = backend.device_from_local_hardware_id(device_id)
151+
try:
152+
stream = dlpack_device.get_stream_for_external_ready_events()
153+
except xla_client.XlaRuntimeError as err: # type: ignore
154+
if "UNIMPLEMENTED" in str(err):
155+
stream = None
156+
else:
157+
raise
158+
dlpack = external_array.__dlpack__(stream=stream)
159+
160+
_arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
161+
dlpack, dlpack_device, stream))
162+
163+
return _place_array(_arr, device, dlpack_device, copy)
164+
165+
def from_dlpack(external_array, device: xla_client.Device | None = None , copy: bool | None = None):
87166
"""Returns a :class:`~jax.Array` representation of a DLPack tensor.
88167
89168
The returned :class:`~jax.Array` shares memory with ``external_array``.
90169
91170
Args:
92-
external_array: an array object that has __dlpack__ and __dlpack_device__
171+
external_array: An array object that has __dlpack__ and __dlpack_device__
93172
methods, or a DLPack tensor on either CPU or GPU (legacy API).
94173
174+
device: The (optional) :py:class:`Device`, representing the device on which
175+
the returned array should be placed. If given, then the result is committed
176+
to the device. If unspecified, the resulting array will be unpacked onto the
177+
same device it originated from. Setting ``device`` to a device different from
178+
the source of ``external_array`` will require a copy, meaning ``copy`` must be
179+
set to either ``True`` or ``None``.
180+
181+
copy: An (optional) boolean, controlling whether or not to a copy is performed.
182+
If ``copy=True`` then a copy is always performed, even if unpacked onto the
183+
same device. If ``copy=False`` then the copy is never peformed and will raise
184+
an error if necessary. When ``copy=None`` then a copy may be performed if
185+
needed for a device transfer.
186+
95187
Returns:
96188
A jax.Array
97189
@@ -103,48 +195,7 @@ def from_dlpack(external_array):
103195
the associated JAX array.
104196
"""
105197
if hasattr(external_array, "__dlpack__"):
106-
dl_device_type, device_id = external_array.__dlpack_device__()
107-
try:
108-
device_platform = {
109-
DLDeviceType.kDLCPU: "cpu",
110-
DLDeviceType.kDLCUDA: "cuda",
111-
DLDeviceType.kDLROCM: "rocm",
112-
}[dl_device_type]
113-
except TypeError:
114-
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
115-
# TypeError.
116-
raise TypeError(
117-
"Array passed to from_dlpack is on unsupported device type "
118-
f"(DLDeviceType: {dl_device_type}, array: {external_array}")
119-
120-
backend = xla_bridge.get_backend(device_platform)
121-
device = backend.device_from_local_hardware_id(device_id)
122-
try:
123-
stream = device.get_stream_for_external_ready_events()
124-
except xla_client.XlaRuntimeError as err: # type: ignore
125-
if "UNIMPLEMENTED" in str(err):
126-
stream = None
127-
else:
128-
raise
129-
dlpack = external_array.__dlpack__(stream=stream)
130-
131-
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
132-
dlpack, device, stream))
133-
else:
134-
# Legacy path
135-
dlpack = external_array
136-
cpu_backend = xla_bridge.get_backend("cpu")
137-
try:
138-
gpu_backend = xla_bridge.get_backend("cuda")
139-
except RuntimeError:
140-
gpu_backend = None
141-
142-
# Try ROCm if CUDA backend not found
143-
if gpu_backend is None:
144-
try:
145-
gpu_backend = xla_bridge.get_backend("rocm")
146-
except RuntimeError:
147-
gpu_backend = None
198+
return _from_dlpack(external_array, device, copy)
148199

149-
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
150-
dlpack, cpu_backend, gpu_backend))
200+
# Legacy path
201+
return _legacy_from_dlpack(external_array, device, copy)

jax/_src/numpy/lax_numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2436,9 +2436,9 @@ def fromiter(*args, **kwargs):
24362436
is later modified in-place, it may lead to undefined behavior when using
24372437
the associated JAX array.
24382438
""")
2439-
def from_dlpack(x: Any) -> Array:
2439+
def from_dlpack(x: Any, /, *, device: xc.Device | None = None, copy: bool | None = None) -> Array:
24402440
from jax.dlpack import from_dlpack # pylint: disable=g-import-not-at-top
2441-
return from_dlpack(x)
2441+
return from_dlpack(x, device=device, copy=copy)
24422442

24432443
@util.implements(np.fromfunction)
24442444
def fromfunction(function: Callable[..., Array], shape: Any,

jax/experimental/array_api/_creation_functions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import jax
1616
import jax.numpy as jnp
17+
from jax._src.lib import xla_client as xc
1718

1819

1920
def arange(start, /, stop=None, step=1, *, dtype=None, device=None):
@@ -31,8 +32,8 @@ def empty_like(x, /, *, dtype=None, device=None):
3132
def eye(n_rows, n_cols=None, /, *, k=0, dtype=None, device=None):
3233
return jax.device_put(jnp.eye(n_rows, n_cols, k=k, dtype=dtype), device=device)
3334

34-
def from_dlpack(x, /):
35-
return jnp.from_dlpack(x)
35+
def from_dlpack(x, /, *, device: xc.Device | None = None, copy: bool | None = None):
36+
return jnp.from_dlpack(x.__dlpack__(), device=device, copy=copy)
3637

3738
def full(shape, fill_value, *, dtype=None, device=None):
3839
return jax.device_put(jnp.full(shape, fill_value, dtype=dtype), device=device)

jax/numpy/__init__.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ from typing import Any, Callable, Literal, NamedTuple, Optional, Sequence, TypeV
66

77
from jax._src import core as _core
88
from jax._src import dtypes as _dtypes
9+
from jax._src.lib import xla_client as xc
910
from jax._src.lax.lax import PrecisionLike
1011
from jax._src.lax.slicing import GatherScatterMode
1112
from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClass as _RClass
@@ -353,7 +354,7 @@ def fmax(x: ArrayLike, y: ArrayLike, /) -> Array: ...
353354
def fmin(x: ArrayLike, y: ArrayLike, /) -> Array: ...
354355
def fmod(x: ArrayLike, y: ArrayLike, /) -> Array: ...
355356
def frexp(x: ArrayLike, /) -> tuple[Array, Array]: ...
356-
def from_dlpack(x: Any) -> Array: ...
357+
def from_dlpack(x: Any, /, *, device: xc.Device | None = None, copy: bool | None = None) -> Array: ...
357358
def frombuffer(buffer: Union[bytes, Any], dtype: DTypeLike = ...,
358359
count: int = ..., offset: int = ...) -> Array: ...
359360
def fromfile(*args, **kwargs): ...

0 commit comments

Comments
 (0)