Skip to content

Commit 72831ac

Browse files
Micky774copybara-github
authored andcommitted
PR #10433: Fixes type annotation overload of dlpack_managed_tensor_to_buffer in python/xla_extension
Imported from GitHub PR #10433 Encountered bug in jax-ml/jax#20175 (see this [comment](jax-ml/jax#20175 (comment))). This adjusts the stub file to properly overload `dlpack_managed_tensor_to_buffer` so that both signatures can be checked against. Copybara import of the project: -- 75cabb5 by Meekail Zain <zainmeekail@gmail.com>: Update Merging this change closes #10433 COPYBARA_INTEGRATE_REVIEW=#10433 from Micky774:type_update 75cabb5 PiperOrigin-RevId: 615973838
1 parent 5e83299 commit 72831ac

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

xla/python/xla_extension/__init__.pyi

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -700,9 +700,16 @@ class DeviceTopology:
700700
def buffer_to_dlpack_managed_tensor(
701701
buffer: ArrayImpl, stream: int | None = None
702702
) -> Any: ...
703+
@overload
703704
def dlpack_managed_tensor_to_buffer(
704705
tensor: Any, device: Device, stream: int | None
705706
) -> ArrayImpl: ...
707+
@overload
708+
def dlpack_managed_tensor_to_buffer( # Legacy overload
709+
tensor: Any,
710+
cpu_backend: Optional[Client] = ...,
711+
gpu_backend: Optional[Client] = ...,
712+
) -> ArrayImpl: ...
706713

707714
def cuda_array_interface_to_buffer(
708715
cai: Dict[str, Union[
@@ -714,12 +721,6 @@ def cuda_array_interface_to_buffer(
714721
gpu_backend: Optional[Client] = ...,
715722
) -> ArrayImpl: ...
716723

717-
# Legacy overload
718-
def dlpack_managed_tensor_to_buffer(
719-
tensor: Any,
720-
cpu_backend: Optional[Client] = ...,
721-
gpu_backend: Optional[Client] = ...,
722-
) -> ArrayImpl: ...
723724

724725
# === BEGIN py_traceback.cc
725726

0 commit comments

Comments
 (0)