Skip to content

Commit 604aea5

Browse files
committed
Add array-api copy semantics to dlpack MakePjrtBuffer
1 parent 36c4581 commit 604aea5

File tree

4 files changed

+45
-29
lines changed

4 files changed

+45
-29
lines changed

jax/_src/dlpack.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,12 @@ def _from_dlpack(external_array, device: xla_client.Device | None = None,
240240
dlpack = external_array.__dlpack__(stream=stream)
241241

242242
_arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
243-
dlpack, dlpack_device, stream))
243+
dlpack, dlpack_device, stream, copy))
244+
if copy:
245+
# dlpack_managed_tensor_to_buffer already handles the array-api
246+
# copy semantics, so resetting copy to False to avoid a second
247+
# copy
248+
copy = False
244249
return _place_array(_arr, device, dlpack_device, copy)
245250

246251
def from_dlpack(external_array,

jaxlib/dlpack.cc

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -211,39 +211,49 @@ absl::StatusOr<std::vector<int64_t>> GetByteStrides(const DLTensor& dl_tensor) {
211211
absl::StatusOr<std::unique_ptr<PjRtBuffer>> MakePjrtBuffer(
212212
PjRtDevice& device, ::DLManagedTensor* dlmt, const Shape& shape,
213213
PrimitiveType element_type, absl::Span<int64_t const> dimensions,
214+
std::optional<bool> copy = std::nullopt,
214215
std::optional<std::intptr_t> stream = std::nullopt) {
215216
std::function<void()> on_delete_callback;
216217
if (dlmt->deleter) {
217218
on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); };
218219
}
219220

220-
// First try to create a view.
221221
void* data =
222222
static_cast<char*>(dlmt->dl_tensor.data) + dlmt->dl_tensor.byte_offset;
223-
auto result = device.client()->CreateViewOfDeviceBuffer(
224-
data, shape, *device.default_memory_space(), on_delete_callback, stream);
225-
226-
// If that fails with invalid argument, it's possibly because of the incorrect
227-
// alignment. If we're on CPU, we can create a copy of buffer.
228-
if (result.status().code() == absl::StatusCode::kInvalidArgument &&
229-
dlmt->dl_tensor.device.device_type == kDLCPU) {
230-
LOG(WARNING) << "DLPack buffer is not aligned (data at: " << data
231-
<< "). Creating a copy.";
232-
233-
// Convert tensor strides (expressed in number of elements) to byte strides.
234-
std::optional<std::vector<int64_t>> byte_strides;
235-
if (dlmt->dl_tensor.strides) {
236-
TF_ASSIGN_OR_RETURN(byte_strides, GetByteStrides(dlmt->dl_tensor));
237-
}
238223

239-
TF_ASSIGN_OR_RETURN(auto* memory_space, device.default_memory_space());
224+
// On CPU, creating a view may fail because of unaligned data buffer
225+
// in which case we'll fallback to copy. On non-CPU, array-api copy
226+
// semantics is handled in dlpack._place_array function.
227+
bool fallback_to_copy = !copy.has_value() && dlmt->dl_tensor.device.device_type == kDLCPU;
228+
229+
if (!copy.value_or(false)) {
230+
// require a view (copy == False) or try create a view (copy == None)
231+
auto result = device.client()->CreateViewOfDeviceBuffer(
232+
data, shape, *device.default_memory_space(), on_delete_callback, stream);
233+
if ((result.status().code() != absl::StatusCode::kInvalidArgument)
234+
|| (!fallback_to_copy))
235+
{
236+
// succesful view or return error when copy == False or copy ==
237+
// None with fallback_to_copy == False
238+
return result;
239+
}
240+
}
241+
// require copy (copy == True) or fallback to copy (copy == None)
240242

241-
// Create a copy.
242-
result = device.client()->BufferFromHostBuffer(
243-
data, element_type, dimensions, byte_strides,
244-
PjRtClient::HostBufferSemantics::kMutableZeroCopy, on_delete_callback,
245-
memory_space, /*device_layout=*/nullptr);
243+
// Convert tensor strides (expressed in number of elements) to byte strides.
244+
std::optional<std::vector<int64_t>> byte_strides;
245+
if (dlmt->dl_tensor.strides) {
246+
TF_ASSIGN_OR_RETURN(byte_strides, GetByteStrides(dlmt->dl_tensor));
246247
}
248+
249+
TF_ASSIGN_OR_RETURN(auto* memory_space, device.default_memory_space());
250+
251+
// Create a copy.
252+
auto result = device.client()->BufferFromHostBuffer(
253+
data, element_type, dimensions, byte_strides,
254+
PjRtClient::HostBufferSemantics::kMutableZeroCopy, on_delete_callback,
255+
memory_space, /*device_layout=*/nullptr);
256+
247257
return result;
248258
}
249259

@@ -424,7 +434,7 @@ absl::StatusOr<nb::object> DLPackManagedTensorToBuffer(
424434

425435
absl::StatusOr<nb::object> DLPackManagedTensorToBuffer(
426436
const nb::capsule& tensor, ifrt::Device* ifrt_device,
427-
nb_class_ptr<PyClient> client, std::optional<std::intptr_t> stream) {
437+
nb_class_ptr<PyClient> client, std::optional<std::intptr_t> stream, std::optional<bool> copy) {
428438
ifrt::PjRtDevice* device =
429439
llvm::dyn_cast_or_null<ifrt::PjRtDevice>(ifrt_device);
430440
if (device == nullptr) {
@@ -469,7 +479,7 @@ absl::StatusOr<nb::object> DLPackManagedTensorToBuffer(
469479

470480
TF_ASSIGN_OR_RETURN(auto pjrt_buffer,
471481
MakePjrtBuffer(*device->pjrt_device(), dlmt, shape,
472-
element_type, dimensions, stream));
482+
element_type, dimensions, copy, stream));
473483

474484
// We have taken ownership of the array inside the capsule; make sure the
475485
// capsule it cannot be used again.

jaxlib/dlpack.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ absl::StatusOr<nanobind::object> DLPackManagedTensorToBuffer(
4646

4747
absl::StatusOr<nanobind::object> DLPackManagedTensorToBuffer(
4848
const nanobind::capsule& tensor, ifrt::Device* device,
49-
nb_class_ptr<PyClient> client, std::optional<std::intptr_t> stream);
49+
nb_class_ptr<PyClient> client, std::optional<std::intptr_t> stream,
50+
std::optional<bool> copy);
5051

5152
// Converts a PrimitiveType to the nanobind specific implementation of
5253
// DLDataType.

jaxlib/xla.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -575,11 +575,11 @@ NB_MODULE(_jax, m) {
575575
m.def(
576576
"dlpack_managed_tensor_to_buffer",
577577
[](const nb::capsule& tensor, nb_class_ptr<PyDevice> device,
578-
std::optional<std::intptr_t> stream) {
578+
std::optional<std::intptr_t> stream, std::optional<bool> copy) {
579579
return xla::ValueOrThrow(DLPackManagedTensorToBuffer(
580-
tensor, device->device(), device->client(), stream));
580+
tensor, device->device(), device->client(), stream, copy));
581581
},
582-
nb::arg("dlpack"), nb::arg("device"), nb::arg("stream").none());
582+
nb::arg("dlpack"), nb::arg("device"), nb::arg("stream").none(), nb::arg("copy").none() = nb::none());
583583
// Legacy overload
584584
m.def(
585585
"dlpack_managed_tensor_to_buffer",

0 commit comments

Comments
 (0)