Skip to content

Add array-api copy semantics to DLPackManagedTensorToBuffer #29963

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

pearu
Copy link
Collaborator

@pearu pearu commented Jul 3, 2025

As in the title.

Fixes #29810

@pearu pearu changed the title Add array-api copy semantics to dlpack MakePjrtBuffer Add array-api copy semantics to dlpack DLPackManagedTensorToBuffer Jul 3, 2025
@pearu pearu force-pushed the pearu/from_dlpack-forced-copy branch from 604aea5 to 51573a4 Compare July 3, 2025 21:05
@pearu pearu changed the title Add array-api copy semantics to dlpack DLPackManagedTensorToBuffer Add array-api copy semantics to DLPackManagedTensorToBuffer Jul 3, 2025
@pearu pearu force-pushed the pearu/from_dlpack-forced-copy branch 3 times, most recently from e226d31 to d2c005f Compare July 4, 2025 09:22
@pearu pearu force-pushed the pearu/from_dlpack-forced-copy branch from d2c005f to 11ed1a3 Compare July 4, 2025 09:41
@pearu pearu requested review from superbobry and hawkinsp July 4, 2025 09:42
@pearu pearu marked this pull request as ready for review July 4, 2025 09:42
Copy link
Collaborator

@superbobry superbobry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

se = str(e)
i = se.index("is not aligned to")
if i > 0:
raise ValueError(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume we cannot raise this error directly in C++ because it talks about Python-level arguments?

_arr = jnp.asarray(_buf)
if copy and jaxlib_version >= (0, 6, 3):
# dlpack_managed_tensor_to_buffer implements array-api copy
# semantics, so resetting copy to avoid an unnecessary copy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be more clear to say that copy was already handled by dlpack_managed_tensor_to_buffer, I think.

Also, nit: add a "." :)

// require a view (copy == False) or try create a view (copy == None)
auto result = device.client()->CreateViewOfDeviceBuffer(
data, shape, *device.default_memory_space(), on_delete_callback, stream);
if ((result.status().code() != absl::StatusCode::kInvalidArgument)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: drop redundant parens?

TF_ASSIGN_OR_RETURN(auto* memory_space, device.default_memory_space());

// Create a copy.
auto result = device.client()->BufferFromHostBuffer(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: directly return?

bool fallback_to_copy = !copy.has_value() && dlmt->dl_tensor.device.device_type == kDLCPU;

if (!copy.value_or(false)) {
// require a view (copy == False) or try create a view (copy == None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'd nuke this and other similar comments personally, and instead try to change the code to make the conditions more obvious (e.g. on L233).

@@ -738,7 +738,7 @@ def buffer_to_dlpack_managed_tensor(
) -> Any: ...
@overload
def dlpack_managed_tensor_to_buffer(
tensor: Any, device: Device, stream: int | None
tensor: Any, device: Device, stream: int | None, copy: bool | None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add = None, since copy has a default.

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jul 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kokoro:force-run pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

jax.numpy.from_dlpack warns for unaligned data
2 participants