11
11
#include " adapter.hpp"
12
12
#include " ur_level_zero.hpp"
13
13
14
+ // Due to multiple DLLMain definitions with SYCL, Global Adapter is init at
15
+ // variable creation.
16
+ #if defined(_WIN32)
17
+ ur_adapter_handle_t_ *GlobalAdapter = new ur_adapter_handle_t_();
18
+ #else
19
+ ur_adapter_handle_t_ *GlobalAdapter;
20
+ #endif
21
+
14
22
ur_result_t initPlatforms (PlatformVec &platforms) noexcept try {
15
23
uint32_t ZeDriverCount = 0 ;
16
24
ZE2UR_CALL (zeDriverGet, (&ZeDriverCount, nullptr ));
@@ -37,8 +45,7 @@ ur_result_t initPlatforms(PlatformVec &platforms) noexcept try {
37
45
ur_result_t adapterStateInit () { return UR_RESULT_SUCCESS; }
38
46
39
47
ur_adapter_handle_t_::ur_adapter_handle_t_ () {
40
-
41
- Adapter.PlatformCache .Compute = [](Result<PlatformVec> &result) {
48
+ PlatformCache.Compute = [](Result<PlatformVec> &result) {
42
49
static std::once_flag ZeCallCountInitialized;
43
50
try {
44
51
std::call_once (ZeCallCountInitialized, []() {
@@ -52,7 +59,7 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() {
52
59
}
53
60
54
61
// initialize level zero only once.
55
- if (Adapter. ZeResult == std::nullopt) {
62
+ if (GlobalAdapter-> ZeResult == std::nullopt) {
56
63
// Setting these environment variables before running zeInit will enable
57
64
// the validation layer in the Level Zero loader.
58
65
if (UrL0Debug & UR_L0_DEBUG_VALIDATION) {
@@ -71,20 +78,21 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() {
71
78
// We must only initialize the driver once, even if urPlatformGet() is
72
79
// called multiple times. Declaring the return value as "static" ensures
73
80
// it's only called once.
74
- Adapter.ZeResult = ZE_CALL_NOCHECK (zeInit, (ZE_INIT_FLAG_GPU_ONLY));
81
+ GlobalAdapter->ZeResult =
82
+ ZE_CALL_NOCHECK (zeInit, (ZE_INIT_FLAG_GPU_ONLY));
75
83
}
76
- assert (Adapter. ZeResult !=
84
+ assert (GlobalAdapter-> ZeResult !=
77
85
std::nullopt); // verify that level-zero is initialized
78
86
PlatformVec platforms;
79
87
80
88
// Absorb the ZE_RESULT_ERROR_UNINITIALIZED and just return 0 Platforms.
81
- if (*Adapter. ZeResult == ZE_RESULT_ERROR_UNINITIALIZED) {
89
+ if (*GlobalAdapter-> ZeResult == ZE_RESULT_ERROR_UNINITIALIZED) {
82
90
result = std::move (platforms);
83
91
return ;
84
92
}
85
- if (*Adapter. ZeResult != ZE_RESULT_SUCCESS) {
93
+ if (*GlobalAdapter-> ZeResult != ZE_RESULT_SUCCESS) {
86
94
urPrint (" zeInit: Level Zero initialization failure\n " );
87
- result = ze2urResult (*Adapter. ZeResult );
95
+ result = ze2urResult (*GlobalAdapter-> ZeResult );
88
96
return ;
89
97
}
90
98
@@ -97,7 +105,11 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() {
97
105
};
98
106
}
99
107
100
- ur_adapter_handle_t_ Adapter{};
108
+ void globalAdapterOnDemandCleanup () {
109
+ if (GlobalAdapter) {
110
+ delete GlobalAdapter;
111
+ }
112
+ }
101
113
102
114
ur_result_t adapterStateTeardown () {
103
115
bool LeakFound = false ;
@@ -184,6 +196,11 @@ ur_result_t adapterStateTeardown() {
184
196
}
185
197
if (LeakFound)
186
198
return UR_RESULT_ERROR_INVALID_MEM_OBJECT;
199
+ // Due to multiple DLLMain definitions with SYCL, register to cleanup the
200
+ // Global Adapter after refcnt is 0
201
+ #if defined(_WIN32)
202
+ std::atexit (globalAdapterOnDemandCleanup);
203
+ #endif
187
204
188
205
return UR_RESULT_SUCCESS;
189
206
}
@@ -203,11 +220,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
203
220
// /< adapters available.
204
221
) {
205
222
if (NumEntries > 0 && Adapters) {
206
- std::lock_guard<std::mutex> Lock{Adapter.Mutex };
207
- if (Adapter.RefCount ++ == 0 ) {
208
- adapterStateInit ();
223
+ if (GlobalAdapter) {
224
+ std::lock_guard<std::mutex> Lock{GlobalAdapter->Mutex };
225
+ if (GlobalAdapter->RefCount ++ == 0 ) {
226
+ adapterStateInit ();
227
+ }
228
+ } else {
229
+ // If the GetAdapter is called after the Library began or was torndown,
230
+ // then temporarily create a new Adapter handle and register a new
231
+ // cleanup.
232
+ GlobalAdapter = new ur_adapter_handle_t_ ();
233
+ std::lock_guard<std::mutex> Lock{GlobalAdapter->Mutex };
234
+ if (GlobalAdapter->RefCount ++ == 0 ) {
235
+ adapterStateInit ();
236
+ }
237
+ std::atexit (globalAdapterOnDemandCleanup);
209
238
}
210
- *Adapters = &Adapter ;
239
+ *Adapters = GlobalAdapter ;
211
240
}
212
241
213
242
if (NumAdapters) {
@@ -218,17 +247,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
218
247
}
219
248
220
249
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease (ur_adapter_handle_t ) {
221
- std::lock_guard<std::mutex> Lock{Adapter.Mutex };
222
- if (--Adapter.RefCount == 0 ) {
223
- return adapterStateTeardown ();
250
+ // Check first if the Adapter pointer is valid
251
+ if (GlobalAdapter) {
252
+ std::lock_guard<std::mutex> Lock{GlobalAdapter->Mutex };
253
+ if (--GlobalAdapter->RefCount == 0 ) {
254
+ return adapterStateTeardown ();
255
+ }
224
256
}
225
257
226
258
return UR_RESULT_SUCCESS;
227
259
}
228
260
229
261
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain (ur_adapter_handle_t ) {
230
- std::lock_guard<std::mutex> Lock{Adapter.Mutex };
231
- Adapter.RefCount ++;
262
+ if (GlobalAdapter) {
263
+ std::lock_guard<std::mutex> Lock{GlobalAdapter->Mutex };
264
+ GlobalAdapter->RefCount ++;
265
+ }
232
266
233
267
return UR_RESULT_SUCCESS;
234
268
}
@@ -257,7 +291,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
257
291
case UR_ADAPTER_INFO_BACKEND:
258
292
return ReturnValue (UR_ADAPTER_BACKEND_LEVEL_ZERO);
259
293
case UR_ADAPTER_INFO_REFERENCE_COUNT:
260
- return ReturnValue (Adapter. RefCount .load ());
294
+ return ReturnValue (GlobalAdapter-> RefCount .load ());
261
295
default :
262
296
return UR_RESULT_ERROR_INVALID_ENUMERATION;
263
297
}
0 commit comments