@@ -406,13 +406,13 @@ def __array__(self, dtype=None, context=None, copy=None):
406
406
407
407
def __dlpack__ (self , * , stream : int | Any | None = None ):
408
408
if len (self ._arrays ) != 1 :
409
- raise ValueError ("__dlpack__ only supported for unsharded arrays." )
409
+ raise BufferError ("__dlpack__ only supported for unsharded arrays." )
410
410
from jax ._src .dlpack import to_dlpack # pylint: disable=g-import-not-at-top
411
411
return to_dlpack (self , stream = stream )
412
412
413
413
def __dlpack_device__ (self ) -> tuple [enum .Enum , int ]:
414
414
if len (self ._arrays ) != 1 :
415
- raise ValueError ("__dlpack__ only supported for unsharded arrays." )
415
+ raise BufferError ("__dlpack__ only supported for unsharded arrays." )
416
416
417
417
from jax ._src .dlpack import DLDeviceType # pylint: disable=g-import-not-at-top
418
418
@@ -426,17 +426,17 @@ def __dlpack_device__(self) -> tuple[enum.Enum, int]:
426
426
elif "rocm" in platform_version :
427
427
dl_device_type = DLDeviceType .kDLROCM
428
428
else :
429
- raise ValueError ("Unknown GPU platform for __dlpack__: "
429
+ raise BufferError ("Unknown GPU platform for __dlpack__: "
430
430
f"{ platform_version } " )
431
431
432
432
local_hardware_id = _get_device (self ).local_hardware_id
433
433
if local_hardware_id is None :
434
- raise ValueError ("Couldn't get local_hardware_id for __dlpack__" )
434
+ raise BufferError ("Couldn't get local_hardware_id for __dlpack__" )
435
435
436
436
return dl_device_type , local_hardware_id
437
437
438
438
else :
439
- raise ValueError (
439
+ raise BufferError (
440
440
"__dlpack__ device only supported for CPU and GPU, got platform: "
441
441
f"{ self .platform ()} "
442
442
)
0 commit comments