|
9 | 9 | //===----------------------------------------------------------------------===//
|
10 | 10 |
|
11 | 11 | #include "device.hpp"
|
| 12 | +#include "adapter.hpp" |
12 | 13 | #include "context.hpp"
|
13 | 14 | #include "event.hpp"
|
14 | 15 |
|
@@ -954,8 +955,57 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetNativeHandle(
|
954 | 955 | }
|
955 | 956 |
|
956 | 957 | UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
|
957 |
| - ur_native_handle_t, ur_platform_handle_t, |
958 |
| - const ur_device_native_properties_t *, ur_device_handle_t *) { |
| 958 | + ur_native_handle_t hNativeDevice, ur_platform_handle_t hPlatform, |
| 959 | + [[maybe_unused]] const ur_device_native_properties_t *pProperties, |
| 960 | + ur_device_handle_t *phDevice) { |
| 961 | + // We can't cast between ur_native_handle_t and hipDevice_t, so memcpy the |
| 962 | + // bits instead |
| 963 | + hipDevice_t HIPDevice = 0; |
| 964 | + memcpy(&HIPDevice, &hNativeDevice, sizeof(hipDevice_t)); |
| 965 | + |
| 966 | + auto IsDevice = [=](std::unique_ptr<ur_device_handle_t_> &Dev) { |
| 967 | + return Dev->get() == HIPDevice; |
| 968 | + }; |
| 969 | + |
| 970 | + // If a platform is provided just check if the device is in it |
| 971 | + if (hPlatform) { |
| 972 | + auto SearchRes = std::find_if(begin(hPlatform->Devices), |
| 973 | + end(hPlatform->Devices), IsDevice); |
| 974 | + if (SearchRes != end(hPlatform->Devices)) { |
| 975 | + *phDevice = SearchRes->get(); |
| 976 | + return UR_RESULT_SUCCESS; |
| 977 | + } |
| 978 | + } |
| 979 | + |
| 980 | + // Get list of platforms |
| 981 | + uint32_t NumPlatforms = 0; |
| 982 | + ur_adapter_handle_t AdapterHandle = &adapter; |
| 983 | + ur_result_t Result = |
| 984 | + urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms); |
| 985 | + if (Result != UR_RESULT_SUCCESS) |
| 986 | + return Result; |
| 987 | + |
| 988 | + // We can only have a maximum of one platform. |
| 989 | + if (NumPlatforms != 1) |
| 990 | + return UR_RESULT_ERROR_INVALID_OPERATION; |
| 991 | + |
| 992 | + ur_platform_handle_t Platform = nullptr; |
| 993 | + |
| 994 | + Result = urPlatformGet(&AdapterHandle, 1, NumPlatforms, &Platform, nullptr); |
| 995 | + if (Result != UR_RESULT_SUCCESS) |
| 996 | + return Result; |
| 997 | + |
| 998 | + // Iterate through the platform's devices to find the device that matches |
| 999 | + // nativeHandle |
| 1000 | + auto SearchRes = std::find_if(std::begin(Platform->Devices), |
| 1001 | + std::end(Platform->Devices), IsDevice); |
| 1002 | + if (SearchRes != end(Platform->Devices)) { |
| 1003 | + *phDevice = static_cast<ur_device_handle_t>((*SearchRes).get()); |
| 1004 | + return UR_RESULT_SUCCESS; |
| 1005 | + } |
| 1006 | + |
| 1007 | + // If the provided nativeHandle cannot be matched to an |
| 1008 | + // existing device return error |
959 | 1009 | return UR_RESULT_ERROR_INVALID_OPERATION;
|
960 | 1010 | }
|
961 | 1011 |
|
|
0 commit comments