Skip to content

Commit a4c6e91

Browse files
committed
Change urDeviceCreateWithNativeHandle to take an adapter handle.
It currently takes a platform handle, which is problematic for the sycl RT because its make_device api only takes a native handle, so to figure out the correct platform handle to pass at best we'd need to do some backend specific querying of the native object, but even then that isn't always possible as not all backends have a platform equivalent. The platform handle does currently enable an optional (slightly) faster path to return the correct device in some adapter implementations but this isn't essential for them to work, so really its primary purpose is to serve as the wrapped UR handle for the loader to work. This purpose is equally well served by an adapter handle, which will also be a lot easier for the sycl rt to correctly provide.
1 parent 83f7ad9 commit a4c6e91

File tree

17 files changed

+53
-69
lines changed

17 files changed

+53
-69
lines changed

include/ur_api.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2062,15 +2062,15 @@ typedef struct ur_device_native_properties_t {
20622062
/// - ::UR_RESULT_ERROR_DEVICE_LOST
20632063
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
20642064
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
2065-
/// + `NULL == hPlatform`
2065+
/// + `NULL == hAdapter`
20662066
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
20672067
/// + `NULL == phDevice`
20682068
/// - ::UR_RESULT_ERROR_UNSUPPORTED_FEATURE
20692069
/// + If the adapter has no underlying equivalent handle.
20702070
UR_APIEXPORT ur_result_t UR_APICALL
20712071
urDeviceCreateWithNativeHandle(
20722072
ur_native_handle_t hNativeDevice, ///< [in][nocheck] the native handle of the device.
2073-
ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance
2073+
ur_adapter_handle_t hAdapter, ///< [in] handle of the adapter to which `hNativeDevice` belongs
20742074
const ur_device_native_properties_t *pProperties, ///< [in][optional] pointer to native device properties struct.
20752075
ur_device_handle_t *phDevice ///< [out] pointer to the handle of the device object created.
20762076
);
@@ -11972,7 +11972,7 @@ typedef struct ur_device_get_native_handle_params_t {
1197211972
/// allowing the callback the ability to modify the parameter's value
1197311973
typedef struct ur_device_create_with_native_handle_params_t {
1197411974
ur_native_handle_t *phNativeDevice;
11975-
ur_platform_handle_t *phPlatform;
11975+
ur_adapter_handle_t *phAdapter;
1197611976
const ur_device_native_properties_t **ppProperties;
1197711977
ur_device_handle_t **pphDevice;
1197811978
} ur_device_create_with_native_handle_params_t;

include/ur_ddi.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2373,7 +2373,7 @@ typedef ur_result_t(UR_APICALL *ur_pfnDeviceGetNativeHandle_t)(
23732373
/// @brief Function-pointer for urDeviceCreateWithNativeHandle
23742374
typedef ur_result_t(UR_APICALL *ur_pfnDeviceCreateWithNativeHandle_t)(
23752375
ur_native_handle_t,
2376-
ur_platform_handle_t,
2376+
ur_adapter_handle_t,
23772377
const ur_device_native_properties_t *,
23782378
ur_device_handle_t *);
23792379

include/ur_print.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17357,10 +17357,10 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct
1735717357
*(params->phNativeDevice)));
1735817358

1735917359
os << ", ";
17360-
os << ".hPlatform = ";
17360+
os << ".hAdapter = ";
1736117361

1736217362
ur::details::printPtr(os,
17363-
*(params->phPlatform));
17363+
*(params->phAdapter));
1736417364

1736517365
os << ", ";
1736617366
os << ".pProperties = ";

scripts/core/device.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -820,9 +820,9 @@ params:
820820
- type: $x_native_handle_t
821821
name: hNativeDevice
822822
desc: "[in][nocheck] the native handle of the device."
823-
- type: $x_platform_handle_t
824-
name: hPlatform
825-
desc: "[in] handle of the platform instance"
823+
- type: $x_adapter_handle_t
824+
name: hAdapter
825+
desc: "[in] handle of the adapter to which `hNativeDevice` belongs"
826826
- type: const $x_device_native_properties_t*
827827
name: pProperties
828828
desc: "[in][optional] pointer to native device properties struct."

source/adapters/cuda/device.cpp

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,27 +1185,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetNativeHandle(
11851185
/// \return TBD
11861186

11871187
UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
1188-
ur_native_handle_t hNativeDevice, ur_platform_handle_t hPlatform,
1189-
const ur_device_native_properties_t *pProperties,
1188+
ur_native_handle_t hNativeDevice,
1189+
[[maybe_unused]] ur_adapter_handle_t hAdapter,
1190+
[[maybe_unused]] const ur_device_native_properties_t *pProperties,
11901191
ur_device_handle_t *phDevice) {
1191-
std::ignore = pProperties;
1192-
11931192
CUdevice CuDevice = static_cast<CUdevice>(hNativeDevice);
11941193

11951194
auto IsDevice = [=](std::unique_ptr<ur_device_handle_t_> &Dev) {
11961195
return Dev->get() == CuDevice;
11971196
};
11981197

1199-
// If a platform is provided just check if the device is in it
1200-
if (hPlatform) {
1201-
auto SearchRes = std::find_if(begin(hPlatform->Devices),
1202-
end(hPlatform->Devices), IsDevice);
1203-
if (SearchRes != end(hPlatform->Devices)) {
1204-
*phDevice = SearchRes->get();
1205-
return UR_RESULT_SUCCESS;
1206-
}
1207-
}
1208-
12091198
// Get list of platforms
12101199
uint32_t NumPlatforms = 0;
12111200
ur_adapter_handle_t AdapterHandle = &adapter;

source/adapters/hip/device.cpp

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -988,7 +988,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetNativeHandle(
988988
}
989989

990990
UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
991-
ur_native_handle_t hNativeDevice, ur_platform_handle_t hPlatform,
991+
ur_native_handle_t hNativeDevice,
992+
[[maybe_unused]] ur_adapter_handle_t hAdapter,
992993
[[maybe_unused]] const ur_device_native_properties_t *pProperties,
993994
ur_device_handle_t *phDevice) {
994995
// We can't cast between ur_native_handle_t and hipDevice_t, so memcpy the
@@ -1000,16 +1001,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
10001001
return Dev->get() == HIPDevice;
10011002
};
10021003

1003-
// If a platform is provided just check if the device is in it
1004-
if (hPlatform) {
1005-
auto SearchRes = std::find_if(begin(hPlatform->Devices),
1006-
end(hPlatform->Devices), IsDevice);
1007-
if (SearchRes != end(hPlatform->Devices)) {
1008-
*phDevice = SearchRes->get();
1009-
return UR_RESULT_SUCCESS;
1010-
}
1011-
}
1012-
10131004
// Get list of platforms
10141005
uint32_t NumPlatforms = 0;
10151006
ur_adapter_handle_t AdapterHandle = &adapter;

source/adapters/level_zero/device.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1602,14 +1602,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetNativeHandle(
16021602

16031603
UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
16041604
ur_native_handle_t NativeDevice, ///< [in] the native handle of the device.
1605-
ur_platform_handle_t Platform, ///< [in] handle of the platform instance
1606-
const ur_device_native_properties_t
1605+
[[maybe_unused]] ur_adapter_handle_t
1606+
Adapter, ///< [in] handle of the platform instance
1607+
[[maybe_unused]] const ur_device_native_properties_t
16071608
*Properties, ///< [in][optional] pointer to native device properties
16081609
///< struct.
16091610
ur_device_handle_t
16101611
*Device ///< [out] pointer to the handle of the device object created.
16111612
) {
1612-
std::ignore = Properties;
16131613
auto ZeDevice = ur_cast<ze_device_handle_t>(NativeDevice);
16141614

16151615
// The SYCL spec requires that the set of devices must remain fixed for the
@@ -1622,12 +1622,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
16221622
if (const auto *platforms = GlobalAdapter->PlatformCache->get_value()) {
16231623
for (const auto &p : *platforms) {
16241624
Dev = p->getDeviceFromNativeHandle(ZeDevice);
1625-
if (Dev) {
1626-
// Check that the input Platform, if was given, matches the found one.
1627-
UR_ASSERT(!Platform || Platform == p.get(),
1628-
UR_RESULT_ERROR_INVALID_PLATFORM);
1629-
break;
1630-
}
16311625
}
16321626
} else {
16331627
return GlobalAdapter->PlatformCache->get_error();

source/adapters/mock/ur_mockddi.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -921,7 +921,8 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetNativeHandle(
921921
__urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
922922
ur_native_handle_t
923923
hNativeDevice, ///< [in][nocheck] the native handle of the device.
924-
ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance
924+
ur_adapter_handle_t
925+
hAdapter, ///< [in] handle of the adapter to which `hNativeDevice` belongs
925926
const ur_device_native_properties_t *
926927
pProperties, ///< [in][optional] pointer to native device properties struct.
927928
ur_device_handle_t
@@ -930,7 +931,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
930931
ur_result_t result = UR_RESULT_SUCCESS;
931932

932933
ur_device_create_with_native_handle_params_t params = {
933-
&hNativeDevice, &hPlatform, &pProperties, &phDevice};
934+
&hNativeDevice, &hAdapter, &pProperties, &phDevice};
934935

935936
auto beforeCallback = reinterpret_cast<ur_mock_callback_t>(
936937
mock::getCallbacks().get_before_callback(

source/adapters/native_cpu/device.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,11 +366,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetNativeHandle(
366366
}
367367

368368
UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
369-
ur_native_handle_t hNativeDevice, ur_platform_handle_t hPlatform,
369+
ur_native_handle_t hNativeDevice, ur_adapter_handle_t hAdapter,
370370
const ur_device_native_properties_t *pProperties,
371371
ur_device_handle_t *phDevice) {
372372
std::ignore = hNativeDevice;
373-
std::ignore = hPlatform;
373+
std::ignore = hAdapter;
374374
std::ignore = pProperties;
375375
std::ignore = phDevice;
376376

source/adapters/opencl/device.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1125,7 +1125,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetNativeHandle(
11251125
}
11261126

11271127
UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
1128-
ur_native_handle_t hNativeDevice, ur_platform_handle_t,
1128+
ur_native_handle_t hNativeDevice, ur_adapter_handle_t,
11291129
const ur_device_native_properties_t *, ur_device_handle_t *phDevice) {
11301130

11311131
*phDevice = reinterpret_cast<ur_device_handle_t>(hNativeDevice);

0 commit comments

Comments
 (0)