Skip to content

Commit 974a7d6

Browse files
authored
Merge pull request #715 from callumfare/callum/adapter_handle
Implement adapter instance handles
2 parents e8e96ce + b279985 commit 974a7d6

40 files changed

+2402
-514
lines changed

examples/hello_world/hello_world.cpp

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,36 @@ int main(int argc, char *argv[]) {
2626
}
2727
std::cout << "Platform initialized.\n";
2828

29+
uint32_t adapterCount = 0;
30+
std::vector<ur_adapter_handle_t> adapters;
2931
uint32_t platformCount = 0;
3032
std::vector<ur_platform_handle_t> platforms;
3133

32-
status = urPlatformGet(1, nullptr, &platformCount);
34+
status = urAdapterGet(0, nullptr, &adapterCount);
35+
if (status != UR_RESULT_SUCCESS) {
36+
std::cout << "urAdapterGet failed with return code: " << status
37+
<< std::endl;
38+
return 1;
39+
}
40+
adapters.resize(adapterCount);
41+
status = urAdapterGet(adapterCount, adapters.data(), nullptr);
42+
if (status != UR_RESULT_SUCCESS) {
43+
std::cout << "urAdapterGet failed with return code: " << status
44+
<< std::endl;
45+
return 1;
46+
}
47+
48+
status = urPlatformGet(adapters.data(), adapterCount, 1, nullptr,
49+
&platformCount);
3350
if (status != UR_RESULT_SUCCESS) {
3451
std::cout << "urPlatformGet failed with return code: " << status
3552
<< std::endl;
3653
goto out;
3754
}
3855

3956
platforms.resize(platformCount);
40-
status = urPlatformGet(platformCount, platforms.data(), nullptr);
57+
status = urPlatformGet(adapters.data(), adapterCount, platformCount,
58+
platforms.data(), nullptr);
4159
if (status != UR_RESULT_SUCCESS) {
4260
std::cout << "urPlatformGet failed with return code: " << status
4361
<< std::endl;
@@ -98,6 +116,9 @@ int main(int argc, char *argv[]) {
98116
}
99117

100118
out:
119+
for (auto adapter : adapters) {
120+
urAdapterRelease(adapter);
121+
}
101122
urTearDown(nullptr);
102123
return status == UR_RESULT_SUCCESS ? 0 : 1;
103124
}

include/ur.py

Lines changed: 88 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,6 @@ class ur_function_v(IntEnum):
165165
BINDLESS_IMAGES_DESTROY_EXTERNAL_SEMAPHORE_EXP = 147## Enumerator for ::urBindlessImagesDestroyExternalSemaphoreExp
166166
BINDLESS_IMAGES_WAIT_EXTERNAL_SEMAPHORE_EXP = 148 ## Enumerator for ::urBindlessImagesWaitExternalSemaphoreExp
167167
BINDLESS_IMAGES_SIGNAL_EXTERNAL_SEMAPHORE_EXP = 149 ## Enumerator for ::urBindlessImagesSignalExternalSemaphoreExp
168-
PLATFORM_GET_LAST_ERROR = 150 ## Enumerator for ::urPlatformGetLastError
169168
ENQUEUE_USM_FILL_2D = 151 ## Enumerator for ::urEnqueueUSMFill2D
170169
ENQUEUE_USM_MEMCPY_2D = 152 ## Enumerator for ::urEnqueueUSMMemcpy2D
171170
VIRTUAL_MEM_GRANULARITY_GET_INFO = 153 ## Enumerator for ::urVirtualMemGranularityGetInfo
@@ -192,6 +191,11 @@ class ur_function_v(IntEnum):
192191
LOADER_CONFIG_RETAIN = 174 ## Enumerator for ::urLoaderConfigRetain
193192
LOADER_CONFIG_GET_INFO = 175 ## Enumerator for ::urLoaderConfigGetInfo
194193
LOADER_CONFIG_ENABLE_LAYER = 176 ## Enumerator for ::urLoaderConfigEnableLayer
194+
ADAPTER_RELEASE = 177 ## Enumerator for ::urAdapterRelease
195+
ADAPTER_GET = 178 ## Enumerator for ::urAdapterGet
196+
ADAPTER_RETAIN = 179 ## Enumerator for ::urAdapterRetain
197+
ADAPTER_GET_LAST_ERROR = 180 ## Enumerator for ::urAdapterGetLastError
198+
ADAPTER_GET_INFO = 181 ## Enumerator for ::urAdapterGetInfo
195199

196200
class ur_function_t(c_int):
197201
def __str__(self):
@@ -288,6 +292,11 @@ class ur_bool_t(c_ubyte):
288292
class ur_loader_config_handle_t(c_void_p):
289293
pass
290294

295+
###############################################################################
296+
## @brief Handle of an adapter instance
297+
class ur_adapter_handle_t(c_void_p):
298+
pass
299+
291300
###############################################################################
292301
## @brief Handle of a platform instance
293302
class ur_platform_handle_t(c_void_p):
@@ -501,6 +510,36 @@ def __str__(self):
501510
return str(ur_loader_config_info_v(self.value))
502511

503512

513+
###############################################################################
514+
## @brief Supported adapter info
515+
class ur_adapter_info_v(IntEnum):
516+
BACKEND = 0 ## [::ur_adapter_backend_t] Identifies the native backend supported by
517+
## the adapter.
518+
REFERENCE_COUNT = 1 ## [uint32_t] Reference count of the adapter.
519+
## The reference count returned should be considered immediately stale.
520+
## It is unsuitable for general use in applications. This feature is
521+
## provided for identifying memory leaks.
522+
523+
class ur_adapter_info_t(c_int):
524+
def __str__(self):
525+
return str(ur_adapter_info_v(self.value))
526+
527+
528+
###############################################################################
529+
## @brief Identifies backend of the adapter
530+
class ur_adapter_backend_v(IntEnum):
531+
UNKNOWN = 0 ## The backend is not a recognized one
532+
LEVEL_ZERO = 1 ## The backend is Level Zero
533+
OPENCL = 2 ## The backend is OpenCL
534+
CUDA = 3 ## The backend is CUDA
535+
HIP = 4 ## The backend is HIP
536+
NATIVE_CPU = 5 ## The backend is Native CPU
537+
538+
class ur_adapter_backend_t(c_int):
539+
def __str__(self):
540+
return str(ur_adapter_backend_v(self.value))
541+
542+
504543
###############################################################################
505544
## @brief Supported platform info
506545
class ur_platform_info_v(IntEnum):
@@ -2273,9 +2312,9 @@ class ur_loader_config_dditable_t(Structure):
22732312
###############################################################################
22742313
## @brief Function-pointer for urPlatformGet
22752314
if __use_win_types:
2276-
_urPlatformGet_t = WINFUNCTYPE( ur_result_t, c_ulong, POINTER(ur_platform_handle_t), POINTER(c_ulong) )
2315+
_urPlatformGet_t = WINFUNCTYPE( ur_result_t, POINTER(ur_adapter_handle_t), c_ulong, c_ulong, POINTER(ur_platform_handle_t), POINTER(c_ulong) )
22772316
else:
2278-
_urPlatformGet_t = CFUNCTYPE( ur_result_t, c_ulong, POINTER(ur_platform_handle_t), POINTER(c_ulong) )
2317+
_urPlatformGet_t = CFUNCTYPE( ur_result_t, POINTER(ur_adapter_handle_t), c_ulong, c_ulong, POINTER(ur_platform_handle_t), POINTER(c_ulong) )
22792318

22802319
###############################################################################
22812320
## @brief Function-pointer for urPlatformGetInfo
@@ -2298,13 +2337,6 @@ class ur_loader_config_dditable_t(Structure):
22982337
else:
22992338
_urPlatformCreateWithNativeHandle_t = CFUNCTYPE( ur_result_t, ur_native_handle_t, POINTER(ur_platform_native_properties_t), POINTER(ur_platform_handle_t) )
23002339

2301-
###############################################################################
2302-
## @brief Function-pointer for urPlatformGetLastError
2303-
if __use_win_types:
2304-
_urPlatformGetLastError_t = WINFUNCTYPE( ur_result_t, ur_platform_handle_t, POINTER(c_char_p), POINTER(c_long) )
2305-
else:
2306-
_urPlatformGetLastError_t = CFUNCTYPE( ur_result_t, ur_platform_handle_t, POINTER(c_char_p), POINTER(c_long) )
2307-
23082340
###############################################################################
23092341
## @brief Function-pointer for urPlatformGetApiVersion
23102342
if __use_win_types:
@@ -2328,7 +2360,6 @@ class ur_platform_dditable_t(Structure):
23282360
("pfnGetInfo", c_void_p), ## _urPlatformGetInfo_t
23292361
("pfnGetNativeHandle", c_void_p), ## _urPlatformGetNativeHandle_t
23302362
("pfnCreateWithNativeHandle", c_void_p), ## _urPlatformCreateWithNativeHandle_t
2331-
("pfnGetLastError", c_void_p), ## _urPlatformGetLastError_t
23322363
("pfnGetApiVersion", c_void_p), ## _urPlatformGetApiVersion_t
23332364
("pfnGetBackendOption", c_void_p) ## _urPlatformGetBackendOption_t
23342365
]
@@ -3565,13 +3596,53 @@ class ur_usm_p2p_exp_dditable_t(Structure):
35653596
else:
35663597
_urTearDown_t = CFUNCTYPE( ur_result_t, c_void_p )
35673598

3599+
###############################################################################
3600+
## @brief Function-pointer for urAdapterGet
3601+
if __use_win_types:
3602+
_urAdapterGet_t = WINFUNCTYPE( ur_result_t, c_ulong, POINTER(ur_adapter_handle_t), POINTER(c_ulong) )
3603+
else:
3604+
_urAdapterGet_t = CFUNCTYPE( ur_result_t, c_ulong, POINTER(ur_adapter_handle_t), POINTER(c_ulong) )
3605+
3606+
###############################################################################
3607+
## @brief Function-pointer for urAdapterRelease
3608+
if __use_win_types:
3609+
_urAdapterRelease_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t )
3610+
else:
3611+
_urAdapterRelease_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t )
3612+
3613+
###############################################################################
3614+
## @brief Function-pointer for urAdapterRetain
3615+
if __use_win_types:
3616+
_urAdapterRetain_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t )
3617+
else:
3618+
_urAdapterRetain_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t )
3619+
3620+
###############################################################################
3621+
## @brief Function-pointer for urAdapterGetLastError
3622+
if __use_win_types:
3623+
_urAdapterGetLastError_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t, POINTER(c_char_p), POINTER(c_long) )
3624+
else:
3625+
_urAdapterGetLastError_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t, POINTER(c_char_p), POINTER(c_long) )
3626+
3627+
###############################################################################
3628+
## @brief Function-pointer for urAdapterGetInfo
3629+
if __use_win_types:
3630+
_urAdapterGetInfo_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t, ur_adapter_info_t, c_size_t, c_void_p, POINTER(c_size_t) )
3631+
else:
3632+
_urAdapterGetInfo_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t, ur_adapter_info_t, c_size_t, c_void_p, POINTER(c_size_t) )
3633+
35683634

35693635
###############################################################################
35703636
## @brief Table of Global functions pointers
35713637
class ur_global_dditable_t(Structure):
35723638
_fields_ = [
35733639
("pfnInit", c_void_p), ## _urInit_t
3574-
("pfnTearDown", c_void_p) ## _urTearDown_t
3640+
("pfnTearDown", c_void_p), ## _urTearDown_t
3641+
("pfnAdapterGet", c_void_p), ## _urAdapterGet_t
3642+
("pfnAdapterRelease", c_void_p), ## _urAdapterRelease_t
3643+
("pfnAdapterRetain", c_void_p), ## _urAdapterRetain_t
3644+
("pfnAdapterGetLastError", c_void_p), ## _urAdapterGetLastError_t
3645+
("pfnAdapterGetInfo", c_void_p) ## _urAdapterGetInfo_t
35753646
]
35763647

35773648
###############################################################################
@@ -3768,7 +3839,6 @@ def __init__(self, version : ur_api_version_t):
37683839
self.urPlatformGetInfo = _urPlatformGetInfo_t(self.__dditable.Platform.pfnGetInfo)
37693840
self.urPlatformGetNativeHandle = _urPlatformGetNativeHandle_t(self.__dditable.Platform.pfnGetNativeHandle)
37703841
self.urPlatformCreateWithNativeHandle = _urPlatformCreateWithNativeHandle_t(self.__dditable.Platform.pfnCreateWithNativeHandle)
3771-
self.urPlatformGetLastError = _urPlatformGetLastError_t(self.__dditable.Platform.pfnGetLastError)
37723842
self.urPlatformGetApiVersion = _urPlatformGetApiVersion_t(self.__dditable.Platform.pfnGetApiVersion)
37733843
self.urPlatformGetBackendOption = _urPlatformGetBackendOption_t(self.__dditable.Platform.pfnGetBackendOption)
37743844

@@ -4048,6 +4118,11 @@ def __init__(self, version : ur_api_version_t):
40484118
# attach function interface to function address
40494119
self.urInit = _urInit_t(self.__dditable.Global.pfnInit)
40504120
self.urTearDown = _urTearDown_t(self.__dditable.Global.pfnTearDown)
4121+
self.urAdapterGet = _urAdapterGet_t(self.__dditable.Global.pfnAdapterGet)
4122+
self.urAdapterRelease = _urAdapterRelease_t(self.__dditable.Global.pfnAdapterRelease)
4123+
self.urAdapterRetain = _urAdapterRetain_t(self.__dditable.Global.pfnAdapterRetain)
4124+
self.urAdapterGetLastError = _urAdapterGetLastError_t(self.__dditable.Global.pfnAdapterGetLastError)
4125+
self.urAdapterGetInfo = _urAdapterGetInfo_t(self.__dditable.Global.pfnAdapterGetInfo)
40514126

40524127
# call driver to get function pointers
40534128
VirtualMem = ur_virtual_mem_dditable_t()

0 commit comments

Comments
 (0)