|
9 | 9 | //===----------------------------------------------------------------------===//
|
10 | 10 |
|
11 | 11 | #include "common.hpp"
|
| 12 | +#include "context.hpp" |
12 | 13 |
|
13 |
| -UR_APIEXPORT ur_result_t UR_APICALL |
14 |
| -urUsmP2PEnablePeerAccessExp(ur_device_handle_t, ur_device_handle_t) { |
15 |
| - detail::ur::die( |
16 |
| - "urUsmP2PEnablePeerAccessExp is not implemented for HIP adapter."); |
17 |
| - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; |
| 14 | +UR_APIEXPORT ur_result_t UR_APICALL urUsmP2PEnablePeerAccessExp( |
| 15 | + ur_device_handle_t commandDevice, ur_device_handle_t peerDevice) { |
| 16 | + try { |
| 17 | + ScopedContext active(commandDevice); |
| 18 | + UR_CHECK_ERROR(hipDeviceEnablePeerAccess(peerDevice->get(), 0)); |
| 19 | + } catch (ur_result_t err) { |
| 20 | + return err; |
| 21 | + } |
| 22 | + return UR_RESULT_SUCCESS; |
18 | 23 | }
|
19 | 24 |
|
20 |
| -UR_APIEXPORT ur_result_t UR_APICALL |
21 |
| -urUsmP2PDisablePeerAccessExp(ur_device_handle_t, ur_device_handle_t) { |
22 |
| - detail::ur::die( |
23 |
| - "urUsmP2PDisablePeerAccessExp is not implemented for HIP adapter."); |
24 |
| - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; |
| 25 | +UR_APIEXPORT ur_result_t UR_APICALL urUsmP2PDisablePeerAccessExp( |
| 26 | + ur_device_handle_t commandDevice, ur_device_handle_t peerDevice) { |
| 27 | + try { |
| 28 | + ScopedContext active(commandDevice); |
| 29 | + UR_CHECK_ERROR(hipDeviceDisablePeerAccess(peerDevice->get())); |
| 30 | + } catch (ur_result_t err) { |
| 31 | + return err; |
| 32 | + } |
| 33 | + return UR_RESULT_SUCCESS; |
25 | 34 | }
|
26 | 35 |
|
27 | 36 | UR_APIEXPORT ur_result_t UR_APICALL urUsmP2PPeerAccessGetInfoExp(
|
28 |
| - ur_device_handle_t, ur_device_handle_t, ur_exp_peer_info_t, size_t propSize, |
29 |
| - void *pPropValue, size_t *pPropSizeRet) { |
| 37 | + ur_device_handle_t commandDevice, ur_device_handle_t peerDevice, |
| 38 | + ur_exp_peer_info_t propName, size_t propSize, void *pPropValue, |
| 39 | + size_t *pPropSizeRet) { |
30 | 40 | UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet);
|
31 |
| - // Zero return value indicates that all of the queries currently return false. |
32 |
| - return ReturnValue(uint32_t{0}); |
| 41 | + |
| 42 | + int value; |
| 43 | + hipDeviceP2PAttr hipAttr; |
| 44 | + try { |
| 45 | + ScopedContext active(commandDevice); |
| 46 | + switch (propName) { |
| 47 | + case UR_EXP_PEER_INFO_UR_PEER_ACCESS_SUPPORTED: { |
| 48 | + hipAttr = hipDevP2PAttrAccessSupported; |
| 49 | + break; |
| 50 | + } |
| 51 | + case UR_EXP_PEER_INFO_UR_PEER_ATOMICS_SUPPORTED: { |
| 52 | + hipAttr = hipDevP2PAttrNativeAtomicSupported; |
| 53 | + break; |
| 54 | + } |
| 55 | + default: { |
| 56 | + return UR_RESULT_ERROR_INVALID_ENUMERATION; |
| 57 | + } |
| 58 | + } |
| 59 | + UR_CHECK_ERROR(hipDeviceGetP2PAttribute( |
| 60 | + &value, hipAttr, commandDevice->get(), peerDevice->get())); |
| 61 | + } catch (ur_result_t err) { |
| 62 | + return err; |
| 63 | + } |
| 64 | + return ReturnValue(value); |
33 | 65 | }
|
0 commit comments