@@ -239,6 +239,43 @@ cdef inline bint _check_peer_access(SyclDevice dev, SyclDevice peer) except *:
239
239
return False
240
240
241
241
242
+ cdef inline void _raise_invalid_peer_access(
243
+ SyclDevice dev,
244
+ SyclDevice peer,
245
+ ) except * :
246
+ """
247
+ Check peer access ahead of time and raise errors for invalid cases.
248
+ """
249
+ cdef list _peer_access_backends = [
250
+ _backend_type._CUDA,
251
+ _backend_type._HIP,
252
+ _backend_type._LEVEL_ZERO
253
+ ]
254
+ cdef _backend_type BTy1 = DPCTLDevice_GetBackend(dev._device_ref)
255
+ cdef _backend_type BTy2 = DPCTLDevice_GetBackend(peer.get_device_ref())
256
+ if (BTy1 != BTy2):
257
+ raise ValueError (
258
+ f" Device with backend {_backend_type_to_filter_string_part(BTy1)} "
259
+ " cannot peer access device with backend "
260
+ f" {_backend_type_to_filter_string_part(BTy2)}"
261
+ )
262
+ if (BTy1 not in _peer_access_backends):
263
+ raise ValueError (
264
+ " Peer access not supported for backend "
265
+ f" {_backend_type_to_filter_string_part(BTy1)}"
266
+ )
267
+ if (BTy2 not in _peer_access_backends):
268
+ raise ValueError (
269
+ " Peer access not supported for backend "
270
+ f" {_backend_type_to_filter_string_part(BTy2)}"
271
+ )
272
+ if (dev == peer):
273
+ raise ValueError (
274
+ " Peer access cannot be enabled between a device and itself"
275
+ )
276
+ return
277
+
278
+
242
279
@ functools.lru_cache (maxsize = None )
243
280
def _cached_filter_string (d : SyclDevice ):
244
281
"""
@@ -1850,7 +1887,6 @@ cdef class SyclDevice(_SyclDevice):
1850
1887
f" {type(peer)}"
1851
1888
)
1852
1889
p_dev = < SyclDevice> peer
1853
-
1854
1890
if _check_peer_access(self , p_dev):
1855
1891
return DPCTLDevice_CanAccessPeer(
1856
1892
self ._device_ref,
@@ -1893,7 +1929,6 @@ cdef class SyclDevice(_SyclDevice):
1893
1929
f" {type(peer)}"
1894
1930
)
1895
1931
p_dev = < SyclDevice> peer
1896
-
1897
1932
if _check_peer_access(self , p_dev):
1898
1933
return DPCTLDevice_CanAccessPeer(
1899
1934
self ._device_ref,
@@ -1931,14 +1966,11 @@ cdef class SyclDevice(_SyclDevice):
1931
1966
f" {type(peer)}"
1932
1967
)
1933
1968
p_dev = < SyclDevice> peer
1934
-
1935
- if _check_peer_access(self , p_dev):
1936
- DPCTLDevice_EnablePeerAccess(
1937
- self ._device_ref,
1938
- p_dev.get_device_ref()
1939
- )
1940
- else :
1941
- raise ValueError (" Peer access cannot be enabled for these devices" )
1969
+ _raise_invalid_peer_access(self , p_dev)
1970
+ DPCTLDevice_EnablePeerAccess(
1971
+ self ._device_ref,
1972
+ p_dev.get_device_ref()
1973
+ )
1942
1974
return
1943
1975
1944
1976
def disable_peer_access (self , peer ):
@@ -1969,14 +2001,12 @@ cdef class SyclDevice(_SyclDevice):
1969
2001
f" {type(peer)}"
1970
2002
)
1971
2003
p_dev = < SyclDevice> peer
1972
-
2004
+ _raise_invalid_peer_access( self , p_dev)
1973
2005
if _check_peer_access(self , p_dev):
1974
2006
DPCTLDevice_DisablePeerAccess(
1975
2007
self ._device_ref,
1976
2008
p_dev.get_device_ref()
1977
2009
)
1978
- else :
1979
- raise ValueError (" Peer access cannot be enabled for these devices" )
1980
2010
return
1981
2011
1982
2012
@property
0 commit comments