Skip to content

Commit 87fe8e6

Browse files
committed
[L0] Create/Destory Adapter Handle during lib init
Signed-off-by: Neil R. Spruit <neil.r.spruit@intel.com>
1 parent cc268e5 commit 87fe8e6

File tree

8 files changed

+104
-27
lines changed

8 files changed

+104
-27
lines changed

source/adapters/level_zero/CMakeLists.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,18 @@ add_ur_adapter(${TARGET_NAME}
122122
${CMAKE_CURRENT_SOURCE_DIR}/../../ur/ur.cpp
123123
)
124124

125+
if(WIN32)
126+
target_sources(ur_adapter_level_zero
127+
PRIVATE
128+
${CMAKE_CURRENT_SOURCE_DIR}/adapter_lib_init_windows.cpp
129+
)
130+
else()
131+
target_sources(ur_adapter_level_zero
132+
PRIVATE
133+
${CMAKE_CURRENT_SOURCE_DIR}/adapter_lib_init_linux.cpp
134+
)
135+
endif()
136+
125137
# TODO: fix level_zero adapter conversion warnings
126138
target_compile_options(${TARGET_NAME} PRIVATE
127139
$<$<CXX_COMPILER_ID:MSVC>:/wd4805 /wd4244>

source/adapters/level_zero/adapter.cpp

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#include "adapter.hpp"
1212
#include "ur_level_zero.hpp"
1313

14+
ur_adapter_handle_t_ *Adapter;
15+
1416
ur_result_t initPlatforms(PlatformVec &platforms) noexcept try {
1517
uint32_t ZeDriverCount = 0;
1618
ZE2UR_CALL(zeDriverGet, (&ZeDriverCount, nullptr));
@@ -37,8 +39,7 @@ ur_result_t initPlatforms(PlatformVec &platforms) noexcept try {
3739
ur_result_t adapterStateInit() { return UR_RESULT_SUCCESS; }
3840

3941
ur_adapter_handle_t_::ur_adapter_handle_t_() {
40-
41-
Adapter.PlatformCache.Compute = [](Result<PlatformVec> &result) {
42+
PlatformCache.Compute = [](Result<PlatformVec> &result) {
4243
static std::once_flag ZeCallCountInitialized;
4344
try {
4445
std::call_once(ZeCallCountInitialized, []() {
@@ -52,7 +53,7 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() {
5253
}
5354

5455
// initialize level zero only once.
55-
if (Adapter.ZeResult == std::nullopt) {
56+
if (Adapter->ZeResult == std::nullopt) {
5657
// Setting these environment variables before running zeInit will enable
5758
// the validation layer in the Level Zero loader.
5859
if (UrL0Debug & UR_L0_DEBUG_VALIDATION) {
@@ -71,20 +72,20 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() {
7172
// We must only initialize the driver once, even if urPlatformGet() is
7273
// called multiple times. Declaring the return value as "static" ensures
7374
// it's only called once.
74-
Adapter.ZeResult = ZE_CALL_NOCHECK(zeInit, (ZE_INIT_FLAG_GPU_ONLY));
75+
Adapter->ZeResult = ZE_CALL_NOCHECK(zeInit, (ZE_INIT_FLAG_GPU_ONLY));
7576
}
76-
assert(Adapter.ZeResult !=
77+
assert(Adapter->ZeResult !=
7778
std::nullopt); // verify that level-zero is initialized
7879
PlatformVec platforms;
7980

8081
// Absorb the ZE_RESULT_ERROR_UNINITIALIZED and just return 0 Platforms.
81-
if (*Adapter.ZeResult == ZE_RESULT_ERROR_UNINITIALIZED) {
82+
if (*Adapter->ZeResult == ZE_RESULT_ERROR_UNINITIALIZED) {
8283
result = std::move(platforms);
8384
return;
8485
}
85-
if (*Adapter.ZeResult != ZE_RESULT_SUCCESS) {
86+
if (*Adapter->ZeResult != ZE_RESULT_SUCCESS) {
8687
urPrint("zeInit: Level Zero initialization failure\n");
87-
result = ze2urResult(*Adapter.ZeResult);
88+
result = ze2urResult(*Adapter->ZeResult);
8889
return;
8990
}
9091

@@ -97,8 +98,6 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() {
9798
};
9899
}
99100

100-
ur_adapter_handle_t_ Adapter{};
101-
102101
ur_result_t adapterStateTeardown() {
103102
bool LeakFound = false;
104103

@@ -203,11 +202,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
203202
///< adapters available.
204203
) {
205204
if (NumEntries > 0 && Adapters) {
206-
std::lock_guard<std::mutex> Lock{Adapter.Mutex};
207-
if (Adapter.RefCount++ == 0) {
208-
adapterStateInit();
205+
if (Adapter) {
206+
std::lock_guard<std::mutex> Lock{Adapter->Mutex};
207+
if (Adapter->RefCount++ == 0) {
208+
adapterStateInit();
209+
}
210+
*Adapters = Adapter;
211+
} else {
212+
return UR_RESULT_ERROR_UNINITIALIZED;
209213
}
210-
*Adapters = &Adapter;
211214
}
212215

213216
if (NumAdapters) {
@@ -218,17 +221,24 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
218221
}
219222

220223
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
221-
std::lock_guard<std::mutex> Lock{Adapter.Mutex};
222-
if (--Adapter.RefCount == 0) {
223-
return adapterStateTeardown();
224+
// Check first if the Adapter pointer is valid
225+
if (Adapter) {
226+
std::lock_guard<std::mutex> Lock{Adapter->Mutex};
227+
if (--Adapter->RefCount == 0) {
228+
return adapterStateTeardown();
229+
}
224230
}
225231

226232
return UR_RESULT_SUCCESS;
227233
}
228234

229235
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
230-
std::lock_guard<std::mutex> Lock{Adapter.Mutex};
231-
Adapter.RefCount++;
236+
if (Adapter) {
237+
std::lock_guard<std::mutex> Lock{Adapter->Mutex};
238+
Adapter->RefCount++;
239+
} else {
240+
return UR_RESULT_ERROR_UNINITIALIZED;
241+
}
232242

233243
return UR_RESULT_SUCCESS;
234244
}
@@ -257,7 +267,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
257267
case UR_ADAPTER_INFO_BACKEND:
258268
return ReturnValue(UR_ADAPTER_BACKEND_LEVEL_ZERO);
259269
case UR_ADAPTER_INFO_REFERENCE_COUNT:
260-
return ReturnValue(Adapter.RefCount.load());
270+
return ReturnValue(Adapter->RefCount.load());
261271
default:
262272
return UR_RESULT_ERROR_INVALID_ENUMERATION;
263273
}

source/adapters/level_zero/adapter.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ struct ur_adapter_handle_t_ {
2525
ZeCache<Result<PlatformVec>> PlatformCache;
2626
};
2727

28-
extern ur_adapter_handle_t_ Adapter;
28+
extern ur_adapter_handle_t_ *Adapter;
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
*
3+
* Copyright (C) 2024 Intel Corporation
4+
*
5+
* Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions.
6+
* See LICENSE.TXT
7+
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
*
9+
* @file adapter_lib_init_linux.cpp
10+
*
11+
*/
12+
13+
#include "adapter.hpp"
14+
#include "ur_level_zero.hpp"
15+
16+
void __attribute__((constructor)) createAdapterHandle() {
17+
if (!Adapter) {
18+
Adapter = new ur_adapter_handle_t_();
19+
}
20+
}
21+
22+
void __attribute__((destructor)) deleteAdapterHandle() {
23+
if (Adapter) {
24+
delete Adapter;
25+
}
26+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
*
3+
* Copyright (C) 2024 Intel Corporation
4+
*
5+
* Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions.
6+
* See LICENSE.TXT
7+
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
*
9+
* @file adapter_lib_init_windows.cpp
10+
*
11+
*/
12+
13+
#include "adapter.hpp"
14+
#include "ur_level_zero.hpp"
15+
#include <windows.h>
16+
17+
extern "C" BOOL APIENTRY DllMain(HINSTANCE hinstDLL, DWORD fdwReason,
18+
LPVOID lpvReserved) {
19+
if (fdwReason == DLL_PROCESS_DETACH) {
20+
if (Adapter) {
21+
delete Adapter;
22+
}
23+
} else if (fdwReason == DLL_PROCESS_ATTACH) {
24+
if (!Adapter) {
25+
Adapter = new ur_adapter_handle_t_();
26+
}
27+
}
28+
return TRUE;
29+
}

source/adapters/level_zero/device.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,7 +1442,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
14421442
// a valid Level Zero device.
14431443

14441444
ur_device_handle_t Dev = nullptr;
1445-
if (const auto *platforms = Adapter.PlatformCache->get_value()) {
1445+
if (const auto *platforms = Adapter->PlatformCache->get_value()) {
14461446
for (const auto &p : *platforms) {
14471447
Dev = p->getDeviceFromNativeHandle(ZeDevice);
14481448
if (Dev) {
@@ -1453,7 +1453,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
14531453
}
14541454
}
14551455
} else {
1456-
return Adapter.PlatformCache->get_error();
1456+
return Adapter->PlatformCache->get_error();
14571457
}
14581458

14591459
if (Dev == nullptr)

source/adapters/level_zero/platform.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGet(
2929
) {
3030
// Platform handles are cached for reuse. This is to ensure consistent
3131
// handle pointers across invocations and to improve retrieval performance.
32-
if (const auto *cached_platforms = Adapter.PlatformCache->get_value();
32+
if (const auto *cached_platforms = Adapter->PlatformCache->get_value();
3333
cached_platforms) {
3434
uint32_t nplatforms = (uint32_t)cached_platforms->size();
3535
if (NumPlatforms) {
@@ -41,7 +41,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGet(
4141
}
4242
}
4343
} else {
44-
return Adapter.PlatformCache->get_error();
44+
return Adapter->PlatformCache->get_error();
4545
}
4646

4747
return UR_RESULT_SUCCESS;
@@ -133,7 +133,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformCreateWithNativeHandle(
133133
auto ZeDriver = ur_cast<ze_driver_handle_t>(NativePlatform);
134134

135135
uint32_t NumPlatforms = 0;
136-
ur_adapter_handle_t AdapterHandle = &Adapter;
136+
ur_adapter_handle_t AdapterHandle = Adapter;
137137
UR_CALL(urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms));
138138

139139
if (NumPlatforms) {

source/adapters/level_zero/queue.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueCreateWithNativeHandle(
569569
// Maybe this is not completely correct.
570570
uint32_t NumEntries = 1;
571571
ur_platform_handle_t Platform{};
572-
ur_adapter_handle_t AdapterHandle = &Adapter;
572+
ur_adapter_handle_t AdapterHandle = Adapter;
573573
UR_CALL(urPlatformGet(&AdapterHandle, 1, NumEntries, &Platform, nullptr));
574574

575575
ur_device_handle_t UrDevice = Device;

0 commit comments

Comments
 (0)