11
11
#include " adapter.hpp"
12
12
#include " ur_level_zero.hpp"
13
13
14
- ur_adapter_handle_t_ *Adapter;
14
+ // Due to multiple DLLMain definitions with SYCL, Global Adapter is init at
15
+ // variable creation.
16
+ #if defined(_WIN32)
17
+ ur_adapter_handle_t_ *GlobalAdapter = new ur_adapter_handle_t_();
18
+ #else
19
+ ur_adapter_handle_t_ *GlobalAdapter;
20
+ #endif
15
21
16
22
ur_result_t initPlatforms (PlatformVec &platforms) noexcept try {
17
23
uint32_t ZeDriverCount = 0 ;
@@ -53,7 +59,7 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() {
53
59
}
54
60
55
61
// initialize level zero only once.
56
- if (Adapter ->ZeResult == std::nullopt) {
62
+ if (GlobalAdapter ->ZeResult == std::nullopt) {
57
63
// Setting these environment variables before running zeInit will enable
58
64
// the validation layer in the Level Zero loader.
59
65
if (UrL0Debug & UR_L0_DEBUG_VALIDATION) {
@@ -72,20 +78,21 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() {
72
78
// We must only initialize the driver once, even if urPlatformGet() is
73
79
// called multiple times. Declaring the return value as "static" ensures
74
80
// it's only called once.
75
- Adapter->ZeResult = ZE_CALL_NOCHECK (zeInit, (ZE_INIT_FLAG_GPU_ONLY));
81
+ GlobalAdapter->ZeResult =
82
+ ZE_CALL_NOCHECK (zeInit, (ZE_INIT_FLAG_GPU_ONLY));
76
83
}
77
- assert (Adapter ->ZeResult !=
84
+ assert (GlobalAdapter ->ZeResult !=
78
85
std::nullopt); // verify that level-zero is initialized
79
86
PlatformVec platforms;
80
87
81
88
// Absorb the ZE_RESULT_ERROR_UNINITIALIZED and just return 0 Platforms.
82
- if (*Adapter ->ZeResult == ZE_RESULT_ERROR_UNINITIALIZED) {
89
+ if (*GlobalAdapter ->ZeResult == ZE_RESULT_ERROR_UNINITIALIZED) {
83
90
result = std::move (platforms);
84
91
return ;
85
92
}
86
- if (*Adapter ->ZeResult != ZE_RESULT_SUCCESS) {
93
+ if (*GlobalAdapter ->ZeResult != ZE_RESULT_SUCCESS) {
87
94
urPrint (" zeInit: Level Zero initialization failure\n " );
88
- result = ze2urResult (*Adapter ->ZeResult );
95
+ result = ze2urResult (*GlobalAdapter ->ZeResult );
89
96
return ;
90
97
}
91
98
@@ -98,6 +105,14 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() {
98
105
};
99
106
}
100
107
108
+ #if defined(_WIN32)
109
+ void globalAdapterWindowsCleanup () {
110
+ if (GlobalAdapter) {
111
+ delete GlobalAdapter;
112
+ }
113
+ }
114
+ #endif
115
+
101
116
ur_result_t adapterStateTeardown () {
102
117
bool LeakFound = false ;
103
118
@@ -183,6 +198,11 @@ ur_result_t adapterStateTeardown() {
183
198
}
184
199
if (LeakFound)
185
200
return UR_RESULT_ERROR_INVALID_MEM_OBJECT;
201
+ // Due to multiple DLLMain definitions with SYCL, register to cleanup the
202
+ // Global Adapter after refcnt is 0
203
+ #if defined(_WIN32)
204
+ std::atexit (globalAdapterWindowsCleanup);
205
+ #endif
186
206
187
207
return UR_RESULT_SUCCESS;
188
208
}
@@ -202,12 +222,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
202
222
// /< adapters available.
203
223
) {
204
224
if (NumEntries > 0 && Adapters) {
205
- if (Adapter ) {
206
- std::lock_guard<std::mutex> Lock{Adapter ->Mutex };
207
- if (Adapter ->RefCount ++ == 0 ) {
225
+ if (GlobalAdapter ) {
226
+ std::lock_guard<std::mutex> Lock{GlobalAdapter ->Mutex };
227
+ if (GlobalAdapter ->RefCount ++ == 0 ) {
208
228
adapterStateInit ();
209
229
}
210
- *Adapters = Adapter ;
230
+ *Adapters = GlobalAdapter ;
211
231
} else {
212
232
return UR_RESULT_ERROR_UNINITIALIZED;
213
233
}
@@ -222,9 +242,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
222
242
223
243
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease (ur_adapter_handle_t ) {
224
244
// Check first if the Adapter pointer is valid
225
- if (Adapter ) {
226
- std::lock_guard<std::mutex> Lock{Adapter ->Mutex };
227
- if (--Adapter ->RefCount == 0 ) {
245
+ if (GlobalAdapter ) {
246
+ std::lock_guard<std::mutex> Lock{GlobalAdapter ->Mutex };
247
+ if (--GlobalAdapter ->RefCount == 0 ) {
228
248
return adapterStateTeardown ();
229
249
}
230
250
}
@@ -233,9 +253,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
233
253
}
234
254
235
255
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain (ur_adapter_handle_t ) {
236
- if (Adapter ) {
237
- std::lock_guard<std::mutex> Lock{Adapter ->Mutex };
238
- Adapter ->RefCount ++;
256
+ if (GlobalAdapter ) {
257
+ std::lock_guard<std::mutex> Lock{GlobalAdapter ->Mutex };
258
+ GlobalAdapter ->RefCount ++;
239
259
} else {
240
260
return UR_RESULT_ERROR_UNINITIALIZED;
241
261
}
@@ -267,7 +287,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
267
287
case UR_ADAPTER_INFO_BACKEND:
268
288
return ReturnValue (UR_ADAPTER_BACKEND_LEVEL_ZERO);
269
289
case UR_ADAPTER_INFO_REFERENCE_COUNT:
270
- return ReturnValue (Adapter ->RefCount .load ());
290
+ return ReturnValue (GlobalAdapter ->RefCount .load ());
271
291
default :
272
292
return UR_RESULT_ERROR_INVALID_ENUMERATION;
273
293
}
0 commit comments