Skip to content

Commit 7294170

Browse files
committed
[HIP] Implement urDeviceGetNativeHandle
This is mostly just a copy of the CUDA version of this implementation.
1 parent 3fca424 commit 7294170

File tree

2 files changed

+52
-4
lines changed

2 files changed

+52
-4
lines changed

source/adapters/hip/device.cpp

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
//===----------------------------------------------------------------------===//
1010

1111
#include "device.hpp"
12+
#include "adapter.hpp"
1213
#include "context.hpp"
1314
#include "event.hpp"
1415

@@ -950,8 +951,57 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetNativeHandle(
950951
}
951952

952953
UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
953-
ur_native_handle_t, ur_platform_handle_t,
954-
const ur_device_native_properties_t *, ur_device_handle_t *) {
954+
ur_native_handle_t hNativeDevice, ur_platform_handle_t hPlatform,
955+
[[maybe_unused]] const ur_device_native_properties_t *pProperties,
956+
ur_device_handle_t *phDevice) {
957+
// We can't cast between ur_native_handle_t and hipDevice_t, so memcpy the
958+
// bits instead
959+
hipDevice_t HIPDevice = 0;
960+
memcpy(&HIPDevice, &hNativeDevice, sizeof(hipDevice_t));
961+
962+
auto IsDevice = [=](std::unique_ptr<ur_device_handle_t_> &Dev) {
963+
return Dev->get() == HIPDevice;
964+
};
965+
966+
// If a platform is provided just check if the device is in it
967+
if (hPlatform) {
968+
auto SearchRes = std::find_if(begin(hPlatform->Devices),
969+
end(hPlatform->Devices), IsDevice);
970+
if (SearchRes != end(hPlatform->Devices)) {
971+
*phDevice = SearchRes->get();
972+
return UR_RESULT_SUCCESS;
973+
}
974+
}
975+
976+
// Get list of platforms
977+
uint32_t NumPlatforms = 0;
978+
ur_adapter_handle_t AdapterHandle = &adapter;
979+
ur_result_t Result =
980+
urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms);
981+
if (Result != UR_RESULT_SUCCESS)
982+
return Result;
983+
984+
// We can only have a maximum of one platform.
985+
if (NumPlatforms != 1)
986+
return UR_RESULT_ERROR_INVALID_OPERATION;
987+
988+
ur_platform_handle_t Platform = nullptr;
989+
990+
Result = urPlatformGet(&AdapterHandle, 1, NumPlatforms, &Platform, nullptr);
991+
if (Result != UR_RESULT_SUCCESS)
992+
return Result;
993+
994+
// Iterate through the platform's devices to find the device that matches
995+
// nativeHandle
996+
auto SearchRes = std::find_if(std::begin(Platform->Devices),
997+
std::end(Platform->Devices), IsDevice);
998+
if (SearchRes != end(Platform->Devices)) {
999+
*phDevice = static_cast<ur_device_handle_t>((*SearchRes).get());
1000+
return UR_RESULT_SUCCESS;
1001+
}
1002+
1003+
// If the provided nativeHandle cannot be matched to an
1004+
// existing device return error
9551005
return UR_RESULT_ERROR_INVALID_OPERATION;
9561006
}
9571007

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,2 @@
1-
urDeviceCreateWithNativeHandleTest.Success
2-
urDeviceCreateWithNativeHandleTest.SuccessWithOwnedNativeHandle
31
urDeviceCreateWithNativeHandleTest.SuccessWithUnOwnedNativeHandle
42
{{OPT}}urDeviceGetGlobalTimestampTest.SuccessSynchronizedTime

0 commit comments

Comments
 (0)