Skip to content

Commit 9390279

Browse files
authored
Merge pull request #1419 from nrspruit/main_l0_adapter_release_lib
[L0] Create/Destroy Adapter Handle during lib init
2 parents cc268e5 + 5518b48 commit 9390279

File tree

7 files changed

+93
-27
lines changed

7 files changed

+93
-27
lines changed

source/adapters/level_zero/CMakeLists.txt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (C) 2022 Intel Corporation
1+
# Copyright (C) 2022-2024 Intel Corporation
22
# Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions.
33
# See LICENSE.TXT
44
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
@@ -122,6 +122,13 @@ add_ur_adapter(${TARGET_NAME}
122122
${CMAKE_CURRENT_SOURCE_DIR}/../../ur/ur.cpp
123123
)
124124

125+
if(NOT WIN32)
126+
target_sources(ur_adapter_level_zero
127+
PRIVATE
128+
${CMAKE_CURRENT_SOURCE_DIR}/adapter_lib_init_linux.cpp
129+
)
130+
endif()
131+
125132
# TODO: fix level_zero adapter conversion warnings
126133
target_compile_options(${TARGET_NAME} PRIVATE
127134
$<$<CXX_COMPILER_ID:MSVC>:/wd4805 /wd4244>

source/adapters/level_zero/adapter.cpp

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@
1111
#include "adapter.hpp"
1212
#include "ur_level_zero.hpp"
1313

14+
// Due to multiple DLLMain definitions with SYCL, Global Adapter is init at
15+
// variable creation.
16+
#if defined(_WIN32)
17+
ur_adapter_handle_t_ *GlobalAdapter = new ur_adapter_handle_t_();
18+
#else
19+
ur_adapter_handle_t_ *GlobalAdapter;
20+
#endif
21+
1422
ur_result_t initPlatforms(PlatformVec &platforms) noexcept try {
1523
uint32_t ZeDriverCount = 0;
1624
ZE2UR_CALL(zeDriverGet, (&ZeDriverCount, nullptr));
@@ -37,8 +45,7 @@ ur_result_t initPlatforms(PlatformVec &platforms) noexcept try {
3745
ur_result_t adapterStateInit() { return UR_RESULT_SUCCESS; }
3846

3947
ur_adapter_handle_t_::ur_adapter_handle_t_() {
40-
41-
Adapter.PlatformCache.Compute = [](Result<PlatformVec> &result) {
48+
PlatformCache.Compute = [](Result<PlatformVec> &result) {
4249
static std::once_flag ZeCallCountInitialized;
4350
try {
4451
std::call_once(ZeCallCountInitialized, []() {
@@ -52,7 +59,7 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() {
5259
}
5360

5461
// initialize level zero only once.
55-
if (Adapter.ZeResult == std::nullopt) {
62+
if (GlobalAdapter->ZeResult == std::nullopt) {
5663
// Setting these environment variables before running zeInit will enable
5764
// the validation layer in the Level Zero loader.
5865
if (UrL0Debug & UR_L0_DEBUG_VALIDATION) {
@@ -71,20 +78,21 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() {
7178
// We must only initialize the driver once, even if urPlatformGet() is
7279
// called multiple times. Declaring the return value as "static" ensures
7380
// it's only called once.
74-
Adapter.ZeResult = ZE_CALL_NOCHECK(zeInit, (ZE_INIT_FLAG_GPU_ONLY));
81+
GlobalAdapter->ZeResult =
82+
ZE_CALL_NOCHECK(zeInit, (ZE_INIT_FLAG_GPU_ONLY));
7583
}
76-
assert(Adapter.ZeResult !=
84+
assert(GlobalAdapter->ZeResult !=
7785
std::nullopt); // verify that level-zero is initialized
7886
PlatformVec platforms;
7987

8088
// Absorb the ZE_RESULT_ERROR_UNINITIALIZED and just return 0 Platforms.
81-
if (*Adapter.ZeResult == ZE_RESULT_ERROR_UNINITIALIZED) {
89+
if (*GlobalAdapter->ZeResult == ZE_RESULT_ERROR_UNINITIALIZED) {
8290
result = std::move(platforms);
8391
return;
8492
}
85-
if (*Adapter.ZeResult != ZE_RESULT_SUCCESS) {
93+
if (*GlobalAdapter->ZeResult != ZE_RESULT_SUCCESS) {
8694
urPrint("zeInit: Level Zero initialization failure\n");
87-
result = ze2urResult(*Adapter.ZeResult);
95+
result = ze2urResult(*GlobalAdapter->ZeResult);
8896
return;
8997
}
9098

@@ -97,7 +105,11 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() {
97105
};
98106
}
99107

100-
ur_adapter_handle_t_ Adapter{};
108+
void globalAdapterOnDemandCleanup() {
109+
if (GlobalAdapter) {
110+
delete GlobalAdapter;
111+
}
112+
}
101113

102114
ur_result_t adapterStateTeardown() {
103115
bool LeakFound = false;
@@ -184,6 +196,11 @@ ur_result_t adapterStateTeardown() {
184196
}
185197
if (LeakFound)
186198
return UR_RESULT_ERROR_INVALID_MEM_OBJECT;
199+
// Due to multiple DLLMain definitions with SYCL, register to cleanup the
200+
// Global Adapter after refcnt is 0
201+
#if defined(_WIN32)
202+
std::atexit(globalAdapterOnDemandCleanup);
203+
#endif
187204

188205
return UR_RESULT_SUCCESS;
189206
}
@@ -203,11 +220,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
203220
///< adapters available.
204221
) {
205222
if (NumEntries > 0 && Adapters) {
206-
std::lock_guard<std::mutex> Lock{Adapter.Mutex};
207-
if (Adapter.RefCount++ == 0) {
208-
adapterStateInit();
223+
if (GlobalAdapter) {
224+
std::lock_guard<std::mutex> Lock{GlobalAdapter->Mutex};
225+
if (GlobalAdapter->RefCount++ == 0) {
226+
adapterStateInit();
227+
}
228+
} else {
229+
// If the GetAdapter is called after the Library began or was torndown,
230+
// then temporarily create a new Adapter handle and register a new
231+
// cleanup.
232+
GlobalAdapter = new ur_adapter_handle_t_();
233+
std::lock_guard<std::mutex> Lock{GlobalAdapter->Mutex};
234+
if (GlobalAdapter->RefCount++ == 0) {
235+
adapterStateInit();
236+
}
237+
std::atexit(globalAdapterOnDemandCleanup);
209238
}
210-
*Adapters = &Adapter;
239+
*Adapters = GlobalAdapter;
211240
}
212241

213242
if (NumAdapters) {
@@ -218,17 +247,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
218247
}
219248

220249
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
221-
std::lock_guard<std::mutex> Lock{Adapter.Mutex};
222-
if (--Adapter.RefCount == 0) {
223-
return adapterStateTeardown();
250+
// Check first if the Adapter pointer is valid
251+
if (GlobalAdapter) {
252+
std::lock_guard<std::mutex> Lock{GlobalAdapter->Mutex};
253+
if (--GlobalAdapter->RefCount == 0) {
254+
return adapterStateTeardown();
255+
}
224256
}
225257

226258
return UR_RESULT_SUCCESS;
227259
}
228260

229261
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
230-
std::lock_guard<std::mutex> Lock{Adapter.Mutex};
231-
Adapter.RefCount++;
262+
if (GlobalAdapter) {
263+
std::lock_guard<std::mutex> Lock{GlobalAdapter->Mutex};
264+
GlobalAdapter->RefCount++;
265+
}
232266

233267
return UR_RESULT_SUCCESS;
234268
}
@@ -257,7 +291,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
257291
case UR_ADAPTER_INFO_BACKEND:
258292
return ReturnValue(UR_ADAPTER_BACKEND_LEVEL_ZERO);
259293
case UR_ADAPTER_INFO_REFERENCE_COUNT:
260-
return ReturnValue(Adapter.RefCount.load());
294+
return ReturnValue(GlobalAdapter->RefCount.load());
261295
default:
262296
return UR_RESULT_ERROR_INVALID_ENUMERATION;
263297
}

source/adapters/level_zero/adapter.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ struct ur_adapter_handle_t_ {
2525
ZeCache<Result<PlatformVec>> PlatformCache;
2626
};
2727

28-
extern ur_adapter_handle_t_ Adapter;
28+
extern ur_adapter_handle_t_ *GlobalAdapter;
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//===--------- adapter_lib_init_linux.cpp - Level Zero Adapter ------------===//
2+
//
3+
// Copyright (C) 2023 Intel Corporation
4+
//
5+
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM
6+
// Exceptions. See LICENSE.TXT
7+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
#include "adapter.hpp"
12+
#include "ur_level_zero.hpp"
13+
14+
void __attribute__((constructor)) createAdapterHandle() {
15+
if (!GlobalAdapter) {
16+
GlobalAdapter = new ur_adapter_handle_t_();
17+
}
18+
}
19+
20+
void __attribute__((destructor)) deleteAdapterHandle() {
21+
if (GlobalAdapter) {
22+
delete GlobalAdapter;
23+
GlobalAdapter = nullptr;
24+
}
25+
}

source/adapters/level_zero/device.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,7 +1442,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
14421442
// a valid Level Zero device.
14431443

14441444
ur_device_handle_t Dev = nullptr;
1445-
if (const auto *platforms = Adapter.PlatformCache->get_value()) {
1445+
if (const auto *platforms = GlobalAdapter->PlatformCache->get_value()) {
14461446
for (const auto &p : *platforms) {
14471447
Dev = p->getDeviceFromNativeHandle(ZeDevice);
14481448
if (Dev) {
@@ -1453,7 +1453,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
14531453
}
14541454
}
14551455
} else {
1456-
return Adapter.PlatformCache->get_error();
1456+
return GlobalAdapter->PlatformCache->get_error();
14571457
}
14581458

14591459
if (Dev == nullptr)

source/adapters/level_zero/platform.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGet(
2929
) {
3030
// Platform handles are cached for reuse. This is to ensure consistent
3131
// handle pointers across invocations and to improve retrieval performance.
32-
if (const auto *cached_platforms = Adapter.PlatformCache->get_value();
32+
if (const auto *cached_platforms = GlobalAdapter->PlatformCache->get_value();
3333
cached_platforms) {
3434
uint32_t nplatforms = (uint32_t)cached_platforms->size();
3535
if (NumPlatforms) {
@@ -41,7 +41,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGet(
4141
}
4242
}
4343
} else {
44-
return Adapter.PlatformCache->get_error();
44+
return GlobalAdapter->PlatformCache->get_error();
4545
}
4646

4747
return UR_RESULT_SUCCESS;
@@ -133,7 +133,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformCreateWithNativeHandle(
133133
auto ZeDriver = ur_cast<ze_driver_handle_t>(NativePlatform);
134134

135135
uint32_t NumPlatforms = 0;
136-
ur_adapter_handle_t AdapterHandle = &Adapter;
136+
ur_adapter_handle_t AdapterHandle = GlobalAdapter;
137137
UR_CALL(urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms));
138138

139139
if (NumPlatforms) {

source/adapters/level_zero/queue.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueCreateWithNativeHandle(
569569
// Maybe this is not completely correct.
570570
uint32_t NumEntries = 1;
571571
ur_platform_handle_t Platform{};
572-
ur_adapter_handle_t AdapterHandle = &Adapter;
572+
ur_adapter_handle_t AdapterHandle = GlobalAdapter;
573573
UR_CALL(urPlatformGet(&AdapterHandle, 1, NumEntries, &Platform, nullptr));
574574

575575
ur_device_handle_t UrDevice = Device;

0 commit comments

Comments
 (0)