Skip to content

Commit eff82ba

Browse files
committed
Add helper for raising more specific errors in peer access methods
1 parent 09a1ef6 commit eff82ba

File tree

1 file changed

+43
-13
lines changed

1 file changed

+43
-13
lines changed

dpctl/_sycl_device.pyx

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,43 @@ cdef inline bint _check_peer_access(SyclDevice dev, SyclDevice peer) except *:
239239
return False
240240

241241

242+
cdef inline void _raise_invalid_peer_access(
243+
SyclDevice dev,
244+
SyclDevice peer,
245+
) except *:
246+
"""
247+
Check peer access ahead of time and raise errors for invalid cases.
248+
"""
249+
cdef list _peer_access_backends = [
250+
_backend_type._CUDA,
251+
_backend_type._HIP,
252+
_backend_type._LEVEL_ZERO
253+
]
254+
cdef _backend_type BTy1 = DPCTLDevice_GetBackend(dev._device_ref)
255+
cdef _backend_type BTy2 = DPCTLDevice_GetBackend(peer.get_device_ref())
256+
if (BTy1 != BTy2):
257+
raise ValueError(
258+
f"Device with backend {_backend_type_to_filter_string_part(BTy1)} "
259+
"cannot peer access device with backend "
260+
f"{_backend_type_to_filter_string_part(BTy2)}"
261+
)
262+
if (BTy1 not in _peer_access_backends):
263+
raise ValueError(
264+
"Peer access not supported for backend "
265+
f"{_backend_type_to_filter_string_part(BTy1)}"
266+
)
267+
if (BTy2 not in _peer_access_backends):
268+
raise ValueError(
269+
"Peer access not supported for backend "
270+
f"{_backend_type_to_filter_string_part(BTy2)}"
271+
)
272+
if (dev == peer):
273+
raise ValueError(
274+
"Peer access cannot be enabled between a device and itself"
275+
)
276+
return
277+
278+
242279
@functools.lru_cache(maxsize=None)
243280
def _cached_filter_string(d : SyclDevice):
244281
"""
@@ -1850,7 +1887,6 @@ cdef class SyclDevice(_SyclDevice):
18501887
f"{type(peer)}"
18511888
)
18521889
p_dev = <SyclDevice>peer
1853-
18541890
if _check_peer_access(self, p_dev):
18551891
return DPCTLDevice_CanAccessPeer(
18561892
self._device_ref,
@@ -1893,7 +1929,6 @@ cdef class SyclDevice(_SyclDevice):
18931929
f"{type(peer)}"
18941930
)
18951931
p_dev = <SyclDevice>peer
1896-
18971932
if _check_peer_access(self, p_dev):
18981933
return DPCTLDevice_CanAccessPeer(
18991934
self._device_ref,
@@ -1931,14 +1966,11 @@ cdef class SyclDevice(_SyclDevice):
19311966
f"{type(peer)}"
19321967
)
19331968
p_dev = <SyclDevice>peer
1934-
1935-
if _check_peer_access(self, p_dev):
1936-
DPCTLDevice_EnablePeerAccess(
1937-
self._device_ref,
1938-
p_dev.get_device_ref()
1939-
)
1940-
else:
1941-
raise ValueError("Peer access cannot be enabled for these devices")
1969+
_raise_invalid_peer_access(self, p_dev)
1970+
DPCTLDevice_EnablePeerAccess(
1971+
self._device_ref,
1972+
p_dev.get_device_ref()
1973+
)
19421974
return
19431975

19441976
def disable_peer_access(self, peer):
@@ -1969,14 +2001,12 @@ cdef class SyclDevice(_SyclDevice):
19692001
f"{type(peer)}"
19702002
)
19712003
p_dev = <SyclDevice>peer
1972-
2004+
_raise_invalid_peer_access(self, p_dev)
19732005
if _check_peer_access(self, p_dev):
19742006
DPCTLDevice_DisablePeerAccess(
19752007
self._device_ref,
19762008
p_dev.get_device_ref()
19772009
)
1978-
else:
1979-
raise ValueError("Peer access cannot be enabled for these devices")
19802010
return
19812011

19822012
@property

0 commit comments

Comments
 (0)