Skip to content

Commit 267d8ed

Browse files
authored
Merge pull request #1632 from frasercrmck/fraser/hip-device-get-native-handle
[HIP] Implement urDeviceGetNativeHandle
2 parents 5255031 + 7294170 commit 267d8ed

File tree

2 files changed

+52
-4
lines changed

2 files changed

+52
-4
lines changed

source/adapters/hip/device.cpp

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
//===----------------------------------------------------------------------===//
1010

1111
#include "device.hpp"
12+
#include "adapter.hpp"
1213
#include "context.hpp"
1314
#include "event.hpp"
1415

@@ -954,8 +955,57 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetNativeHandle(
954955
}
955956

956957
UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
957-
ur_native_handle_t, ur_platform_handle_t,
958-
const ur_device_native_properties_t *, ur_device_handle_t *) {
958+
ur_native_handle_t hNativeDevice, ur_platform_handle_t hPlatform,
959+
[[maybe_unused]] const ur_device_native_properties_t *pProperties,
960+
ur_device_handle_t *phDevice) {
961+
// We can't cast between ur_native_handle_t and hipDevice_t, so memcpy the
962+
// bits instead
963+
hipDevice_t HIPDevice = 0;
964+
memcpy(&HIPDevice, &hNativeDevice, sizeof(hipDevice_t));
965+
966+
auto IsDevice = [=](std::unique_ptr<ur_device_handle_t_> &Dev) {
967+
return Dev->get() == HIPDevice;
968+
};
969+
970+
// If a platform is provided just check if the device is in it
971+
if (hPlatform) {
972+
auto SearchRes = std::find_if(begin(hPlatform->Devices),
973+
end(hPlatform->Devices), IsDevice);
974+
if (SearchRes != end(hPlatform->Devices)) {
975+
*phDevice = SearchRes->get();
976+
return UR_RESULT_SUCCESS;
977+
}
978+
}
979+
980+
// Get list of platforms
981+
uint32_t NumPlatforms = 0;
982+
ur_adapter_handle_t AdapterHandle = &adapter;
983+
ur_result_t Result =
984+
urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms);
985+
if (Result != UR_RESULT_SUCCESS)
986+
return Result;
987+
988+
// We can only have a maximum of one platform.
989+
if (NumPlatforms != 1)
990+
return UR_RESULT_ERROR_INVALID_OPERATION;
991+
992+
ur_platform_handle_t Platform = nullptr;
993+
994+
Result = urPlatformGet(&AdapterHandle, 1, NumPlatforms, &Platform, nullptr);
995+
if (Result != UR_RESULT_SUCCESS)
996+
return Result;
997+
998+
// Iterate through the platform's devices to find the device that matches
999+
// nativeHandle
1000+
auto SearchRes = std::find_if(std::begin(Platform->Devices),
1001+
std::end(Platform->Devices), IsDevice);
1002+
if (SearchRes != end(Platform->Devices)) {
1003+
*phDevice = static_cast<ur_device_handle_t>((*SearchRes).get());
1004+
return UR_RESULT_SUCCESS;
1005+
}
1006+
1007+
// If the provided nativeHandle cannot be matched to an
1008+
// existing device return error
9591009
return UR_RESULT_ERROR_INVALID_OPERATION;
9601010
}
9611011

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,2 @@
1-
urDeviceCreateWithNativeHandleTest.Success
2-
urDeviceCreateWithNativeHandleTest.SuccessWithOwnedNativeHandle
31
urDeviceCreateWithNativeHandleTest.SuccessWithUnOwnedNativeHandle
42
{{OPT}}urDeviceGetGlobalTimestampTest.SuccessSynchronizedTime

0 commit comments

Comments
 (0)