@@ -211,39 +211,49 @@ absl::StatusOr<std::vector<int64_t>> GetByteStrides(const DLTensor& dl_tensor) {
211
211
absl::StatusOr<std::unique_ptr<PjRtBuffer>> MakePjrtBuffer (
212
212
PjRtDevice& device, ::DLManagedTensor* dlmt, const Shape& shape,
213
213
PrimitiveType element_type, absl::Span<int64_t const > dimensions,
214
+ std::optional<bool > copy = std::nullopt,
214
215
std::optional<std::intptr_t > stream = std::nullopt) {
215
216
std::function<void ()> on_delete_callback;
216
217
if (dlmt->deleter ) {
217
218
on_delete_callback = [dlmt]() { dlmt->deleter (dlmt); };
218
219
}
219
220
220
- // First try to create a view.
221
221
void * data =
222
222
static_cast <char *>(dlmt->dl_tensor .data ) + dlmt->dl_tensor .byte_offset ;
223
- auto result = device.client ()->CreateViewOfDeviceBuffer (
224
- data, shape, *device.default_memory_space (), on_delete_callback, stream);
225
-
226
- // If that fails with invalid argument, it's possibly because of the incorrect
227
- // alignment. If we're on CPU, we can create a copy of buffer.
228
- if (result.status ().code () == absl::StatusCode::kInvalidArgument &&
229
- dlmt->dl_tensor .device .device_type == kDLCPU ) {
230
- LOG (WARNING) << " DLPack buffer is not aligned (data at: " << data
231
- << " ). Creating a copy." ;
232
-
233
- // Convert tensor strides (expressed in number of elements) to byte strides.
234
- std::optional<std::vector<int64_t >> byte_strides;
235
- if (dlmt->dl_tensor .strides ) {
236
- TF_ASSIGN_OR_RETURN (byte_strides, GetByteStrides (dlmt->dl_tensor ));
237
- }
238
223
239
- TF_ASSIGN_OR_RETURN (auto * memory_space, device.default_memory_space ());
224
+ // On CPU, creating a view may fail because of unaligned data buffer
225
+ // in which case we'll fallback to copy. On non-CPU, array-api copy
226
+ // semantics is handled in dlpack._place_array function.
227
+ bool fallback_to_copy = !copy.has_value () && dlmt->dl_tensor .device .device_type == kDLCPU ;
228
+
229
+ if (!copy.value_or (false )) {
230
+ // require a view (copy == False) or try create a view (copy == None)
231
+ auto result = device.client ()->CreateViewOfDeviceBuffer (
232
+ data, shape, *device.default_memory_space (), on_delete_callback, stream);
233
+ if ((result.status ().code () != absl::StatusCode::kInvalidArgument )
234
+ || (!fallback_to_copy))
235
+ {
236
+ // succesful view or return error when copy == False or copy ==
237
+ // None with fallback_to_copy == False
238
+ return result;
239
+ }
240
+ }
241
+ // require copy (copy == True) or fallback to copy (copy == None)
240
242
241
- // Create a copy.
242
- result = device.client ()->BufferFromHostBuffer (
243
- data, element_type, dimensions, byte_strides,
244
- PjRtClient::HostBufferSemantics::kMutableZeroCopy , on_delete_callback,
245
- memory_space, /* device_layout=*/ nullptr );
243
+ // Convert tensor strides (expressed in number of elements) to byte strides.
244
+ std::optional<std::vector<int64_t >> byte_strides;
245
+ if (dlmt->dl_tensor .strides ) {
246
+ TF_ASSIGN_OR_RETURN (byte_strides, GetByteStrides (dlmt->dl_tensor ));
246
247
}
248
+
249
+ TF_ASSIGN_OR_RETURN (auto * memory_space, device.default_memory_space ());
250
+
251
+ // Create a copy.
252
+ auto result = device.client ()->BufferFromHostBuffer (
253
+ data, element_type, dimensions, byte_strides,
254
+ PjRtClient::HostBufferSemantics::kMutableZeroCopy , on_delete_callback,
255
+ memory_space, /* device_layout=*/ nullptr );
256
+
247
257
return result;
248
258
}
249
259
@@ -424,7 +434,7 @@ absl::StatusOr<nb::object> DLPackManagedTensorToBuffer(
424
434
425
435
absl::StatusOr<nb::object> DLPackManagedTensorToBuffer (
426
436
const nb::capsule& tensor, ifrt::Device* ifrt_device,
427
- nb_class_ptr<PyClient> client, std::optional<std::intptr_t > stream) {
437
+ nb_class_ptr<PyClient> client, std::optional<std::intptr_t > stream, std::optional< bool > copy ) {
428
438
ifrt::PjRtDevice* device =
429
439
llvm::dyn_cast_or_null<ifrt::PjRtDevice>(ifrt_device);
430
440
if (device == nullptr ) {
@@ -469,7 +479,7 @@ absl::StatusOr<nb::object> DLPackManagedTensorToBuffer(
469
479
470
480
TF_ASSIGN_OR_RETURN (auto pjrt_buffer,
471
481
MakePjrtBuffer (*device->pjrt_device (), dlmt, shape,
472
- element_type, dimensions, stream));
482
+ element_type, dimensions, copy, stream));
473
483
474
484
// We have taken ownership of the array inside the capsule; make sure the
475
485
// capsule it cannot be used again.
0 commit comments