Skip to content

Commit 7ef0e44

Browse files
committed
Fix malformed tests in test_usm_ndarray_dlpack
These tests would fail on machines with more than 2 devices for a given platform due to an incorrect asusmption that the DLPack device ID would match that of the cached root devices, of which only 2 are kept per platform
1 parent bfd4d57 commit 7ef0e44

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

dpctl/tests/test_usm_ndarray_dlpack.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_dlpack_device(usm_type, all_root_devices):
8484
assert type(dev) is tuple
8585
assert len(dev) == 2
8686
assert dev[0] == device_oneAPI
87-
assert sycl_dev == all_root_devices[dev[1]]
87+
assert dev[1] == sycl_dev.get_device_id()
8888

8989

9090
def test_dlpack_exporter(typestr, usm_type, all_root_devices):
@@ -834,15 +834,15 @@ def test_sycl_device_to_dldevice(all_root_devices):
834834
assert type(dev) is tuple
835835
assert len(dev) == 2
836836
assert dev[0] == device_oneAPI
837-
assert dev[1] == all_root_devices.index(sycl_dev)
837+
assert dev[1] == sycl_dev.get_device_id()
838838

839839

840840
def test_dldevice_to_sycl_device(all_root_devices):
841841
for sycl_dev in all_root_devices:
842842
dldev = dpt.empty(0, device=sycl_dev).__dlpack_device__()
843843
dev = dpt.dldevice_to_sycl_device(dldev)
844844
assert type(dev) is dpctl.SyclDevice
845-
assert dev == all_root_devices[dldev[1]]
845+
assert dev.get_device_id() == sycl_dev.get_device_id()
846846

847847

848848
def test_dldevice_conversion_arg_validation():

0 commit comments

Comments
 (0)