Skip to content

Commit c56b44e

Browse files
committed
[L0] move platform cache into the adapter structure
The platform cache is a global variable used exclusively by the L0 adapter, and it's protected by a loosely-associated spin lock. However, its destruction is associated with the lifetime of the adapter structure and is deleted the first time adapter refcount reaches 0. This was causing issues whenever the adapter was initialized and destroyed multiple time inside of a single process, which, for example, happens during tests. This patch fixes the above problem by moving the platform cache from the global state into the adapter structure. This allows for a simpler implementation that no longer requires an explicit lock and instead uses lazy loading (std::call_once). With this patch, all platform tests are now passing for L0. Closes #824
1 parent f086f36 commit c56b44e

File tree

7 files changed

+140
-139
lines changed

7 files changed

+140
-139
lines changed

.github/workflows/cmake.yml

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -213,15 +213,7 @@ jobs:
213213
working-directory: ${{github.workspace}}/build
214214
run: ctest -C ${{matrix.build_type}} --output-on-failure -L "adapter-specific" --timeout 180
215215

216-
# Temporarily disabling platform test for L0, because of hang
217-
# See issue: #824
218-
- name: Test L0 adapter
219-
if: matrix.adapter.name == 'L0'
220-
working-directory: ${{github.workspace}}/build
221-
run: ctest -C ${{matrix.build_type}} --output-on-failure -L "conformance" -E "platform-adapter_level_zero" --timeout 180
222-
223216
- name: Test adapters
224-
if: matrix.adapter.name != 'L0'
225217
working-directory: ${{github.workspace}}/build
226218
run: env UR_CTS_ADAPTER_PLATFORM="${{matrix.adapter.platform}}" ctest -C ${{matrix.build_type}} --output-on-failure -L "conformance" --timeout 180
227219

source/adapters/level_zero/adapter.cpp

Lines changed: 87 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,92 @@
1313

1414
ur_adapter_handle_t_ Adapter{};
1515

16-
ur_result_t adapterStateTeardown() {
17-
// reclaim ur_platform_handle_t objects here since we don't have
18-
// urPlatformRelease.
19-
for (ur_platform_handle_t Platform : *URPlatformsCache) {
20-
delete Platform;
16+
ur_result_t initPlatforms(PlatformVec &platforms) noexcept try {
17+
uint32_t ZeDriverCount = 0;
18+
ZE2UR_CALL(zeDriverGet, (&ZeDriverCount, nullptr));
19+
if (ZeDriverCount == 0) {
20+
return UR_RESULT_SUCCESS;
21+
}
22+
23+
std::vector<ze_driver_handle_t> ZeDrivers;
24+
ZeDrivers.resize(ZeDriverCount);
25+
26+
ZE2UR_CALL(zeDriverGet, (&ZeDriverCount, ZeDrivers.data()));
27+
for (uint32_t I = 0; I < ZeDriverCount; ++I) {
28+
auto platform = std::make_unique<ur_platform_handle_t_>(ZeDrivers[I]);
29+
UR_CALL(platform->initialize());
30+
31+
// Save a copy in the cache for future uses.
32+
platforms.push_back(std::move(platform));
33+
}
34+
return UR_RESULT_SUCCESS;
35+
} catch (...) {
36+
return exceptionToResult(std::current_exception());
37+
}
38+
39+
ur_result_t adapterStateInit() {
40+
static std::once_flag ZeCallCountInitialized;
41+
try {
42+
std::call_once(ZeCallCountInitialized, []() {
43+
if (UrL0LeaksDebug) {
44+
ZeCallCount = new std::map<std::string, int>;
45+
}
46+
});
47+
} catch (const std::bad_alloc &) {
48+
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
49+
} catch (...) {
50+
return UR_RESULT_ERROR_UNKNOWN;
2151
}
22-
delete URPlatformsCache;
23-
delete URPlatformsCacheMutex;
2452

53+
// initialize level zero only once.
54+
if (Adapter.ZeResult == std::nullopt) {
55+
// Setting these environment variables before running zeInit will enable the
56+
// validation layer in the Level Zero loader.
57+
if (UrL0Debug & UR_L0_DEBUG_VALIDATION) {
58+
setEnvVar("ZE_ENABLE_VALIDATION_LAYER", "1");
59+
setEnvVar("ZE_ENABLE_PARAMETER_VALIDATION", "1");
60+
}
61+
62+
if (getenv("SYCL_ENABLE_PCI") != nullptr) {
63+
urPrint("WARNING: SYCL_ENABLE_PCI is deprecated and no longer needed.\n");
64+
}
65+
66+
// TODO: We can still safely recover if something goes wrong during the
67+
// init. Implement handling segfault using sigaction.
68+
69+
// We must only initialize the driver once, even if urPlatformGet() is
70+
// called multiple times. Declaring the return value as "static" ensures
71+
// it's only called once.
72+
Adapter.ZeResult = ZE_CALL_NOCHECK(zeInit, (ZE_INIT_FLAG_GPU_ONLY));
73+
}
74+
75+
Adapter.PlatformCache.Compute = [](Result<PlatformVec> &result) {
76+
assert(Adapter.ZeResult !=
77+
std::nullopt); // verify that level-zero is initialized
78+
PlatformVec platforms;
79+
80+
// Absorb the ZE_RESULT_ERROR_UNINITIALIZED and just return 0 Platforms.
81+
if (*Adapter.ZeResult == ZE_RESULT_ERROR_UNINITIALIZED) {
82+
result = std::move(platforms);
83+
return;
84+
}
85+
if (*Adapter.ZeResult != ZE_RESULT_SUCCESS) {
86+
urPrint("zeInit: Level Zero initialization failure\n");
87+
result = ze2urResult(*Adapter.ZeResult);
88+
return;
89+
}
90+
91+
ur_result_t err = initPlatforms(platforms);
92+
if (err == UR_RESULT_SUCCESS) {
93+
result = std::move(platforms);
94+
} else {
95+
result = err;
96+
}
97+
};
98+
return UR_RESULT_SUCCESS;
99+
}
100+
101+
ur_result_t adapterStateTeardown() {
25102
bool LeakFound = false;
26103

27104
// Print the balance of various create/destroy native calls.
@@ -126,9 +203,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
126203
) {
127204
if (NumEntries > 0 && Adapters) {
128205
std::lock_guard<std::mutex> Lock{Adapter.Mutex};
129-
// TODO: Some initialization that happens in urPlatformsGet could be moved
130-
// here for when RefCount reaches 1
131-
Adapter.RefCount++;
206+
if (Adapter.RefCount++ == 0) {
207+
adapterStateInit();
208+
}
132209
*Adapters = &Adapter;
133210
}
134211

source/adapters/level_zero/adapter.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,18 @@
1010

1111
#include <atomic>
1212
#include <mutex>
13+
#include <optional>
14+
#include <ur/ur.hpp>
15+
#include <ze_api.h>
16+
17+
using PlatformVec = std::vector<std::unique_ptr<ur_platform_handle_t_>>;
1318

1419
struct ur_adapter_handle_t_ {
1520
std::atomic<uint32_t> RefCount = 0;
1621
std::mutex Mutex;
22+
23+
std::optional<ze_result_t> ZeResult;
24+
ZeCache<Result<PlatformVec>> PlatformCache;
1725
};
1826

1927
extern ur_adapter_handle_t_ Adapter;

source/adapters/level_zero/device.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
//===----------------------------------------------------------------------===//
1010

1111
#include "device.hpp"
12+
#include "adapter.hpp"
1213
#include "ur_level_zero.hpp"
1314
#include "ur_util.hpp"
1415
#include <algorithm>
@@ -1321,21 +1322,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
13211322
// Level Zero devices when we initialized the platforms/devices cache, so the
13221323
// "NativeHandle" must already be in the cache. If it is not, this must not be
13231324
// a valid Level Zero device.
1324-
//
1325-
// TODO: maybe we should populate cache of platforms if it wasn't already.
1326-
// For now assert that is was populated.
1327-
UR_ASSERT(URPlatformCachePopulated, UR_RESULT_ERROR_INVALID_VALUE);
1328-
const std::lock_guard<SpinLock> Lock{*URPlatformsCacheMutex};
13291325

13301326
ur_device_handle_t Dev = nullptr;
1331-
for (ur_platform_handle_t ThePlatform : *URPlatformsCache) {
1332-
Dev = ThePlatform->getDeviceFromNativeHandle(ZeDevice);
1333-
if (Dev) {
1334-
// Check that the input Platform, if was given, matches the found one.
1335-
UR_ASSERT(!Platform || Platform == ThePlatform,
1336-
UR_RESULT_ERROR_INVALID_PLATFORM);
1337-
break;
1327+
if (const auto *platforms = Adapter.PlatformCache->get_value()) {
1328+
for (const auto &p : *platforms) {
1329+
Dev = p->getDeviceFromNativeHandle(ZeDevice);
1330+
if (Dev) {
1331+
// Check that the input Platform, if was given, matches the found one.
1332+
UR_ASSERT(!Platform || Platform == p.get(),
1333+
UR_RESULT_ERROR_INVALID_PLATFORM);
1334+
break;
1335+
}
13381336
}
1337+
} else {
1338+
return Adapter.PlatformCache->get_error();
13391339
}
13401340

13411341
if (Dev == nullptr)

source/adapters/level_zero/platform.cpp

Lines changed: 12 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -27,101 +27,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGet(
2727
uint32_t *NumPlatforms ///< [out][optional] returns the total number of
2828
///< platforms available.
2929
) {
30-
static std::once_flag ZeCallCountInitialized;
31-
try {
32-
std::call_once(ZeCallCountInitialized, []() {
33-
if (UrL0LeaksDebug) {
34-
ZeCallCount = new std::map<std::string, int>;
35-
}
36-
});
37-
} catch (const std::bad_alloc &) {
38-
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
39-
} catch (...) {
40-
return UR_RESULT_ERROR_UNKNOWN;
41-
}
42-
43-
// Setting these environment variables before running zeInit will enable the
44-
// validation layer in the Level Zero loader.
45-
if (UrL0Debug & UR_L0_DEBUG_VALIDATION) {
46-
setEnvVar("ZE_ENABLE_VALIDATION_LAYER", "1");
47-
setEnvVar("ZE_ENABLE_PARAMETER_VALIDATION", "1");
48-
}
49-
50-
if (getenv("SYCL_ENABLE_PCI") != nullptr) {
51-
urPrint("WARNING: SYCL_ENABLE_PCI is deprecated and no longer needed.\n");
52-
}
53-
54-
// TODO: We can still safely recover if something goes wrong during the init.
55-
// Implement handling segfault using sigaction.
56-
57-
// We must only initialize the driver once, even if urPlatformGet() is called
58-
// multiple times. Declaring the return value as "static" ensures it's only
59-
// called once.
60-
static ze_result_t ZeResult =
61-
ZE_CALL_NOCHECK(zeInit, (ZE_INIT_FLAG_GPU_ONLY));
62-
63-
// Absorb the ZE_RESULT_ERROR_UNINITIALIZED and just return 0 Platforms.
64-
if (ZeResult == ZE_RESULT_ERROR_UNINITIALIZED) {
65-
UR_ASSERT(NumEntries == 0, UR_RESULT_ERROR_INVALID_VALUE);
66-
if (NumPlatforms)
67-
*NumPlatforms = 0;
68-
return UR_RESULT_SUCCESS;
69-
}
70-
71-
if (ZeResult != ZE_RESULT_SUCCESS) {
72-
urPrint("zeInit: Level Zero initialization failure\n");
73-
return ze2urResult(ZeResult);
74-
}
75-
76-
// Cache ur_platform_handle_t for reuse in the future
77-
// It solves two problems;
78-
// 1. sycl::platform equality issue; we always return the same
79-
// ur_platform_handle_t
80-
// 2. performance; we can save time by immediately return from cache.
81-
//
82-
83-
const std::lock_guard<SpinLock> Lock{*URPlatformsCacheMutex};
84-
if (!URPlatformCachePopulated) {
85-
try {
86-
// Level Zero does not have concept of Platforms, but Level Zero driver is
87-
// the closest match.
88-
uint32_t ZeDriverCount = 0;
89-
ZE2UR_CALL(zeDriverGet, (&ZeDriverCount, nullptr));
90-
if (ZeDriverCount == 0) {
91-
URPlatformCachePopulated = true;
92-
} else {
93-
std::vector<ze_driver_handle_t> ZeDrivers;
94-
ZeDrivers.resize(ZeDriverCount);
95-
96-
ZE2UR_CALL(zeDriverGet, (&ZeDriverCount, ZeDrivers.data()));
97-
for (uint32_t I = 0; I < ZeDriverCount; ++I) {
98-
auto Platform = new ur_platform_handle_t_(ZeDrivers[I]);
99-
// Save a copy in the cache for future uses.
100-
URPlatformsCache->push_back(Platform);
101-
102-
UR_CALL(Platform->initialize());
103-
}
104-
URPlatformCachePopulated = true;
30+
// Platform handles are cached for reuse. This is to ensure consistent
31+
// handle pointers across invocations and to improve retrieval performance.
32+
if (const auto *cached_platforms = Adapter.PlatformCache->get_value()) {
33+
uint32_t nplatforms = (uint32_t)cached_platforms->size();
34+
if (NumPlatforms) {
35+
*NumPlatforms = nplatforms;
36+
}
37+
if (Platforms) {
38+
for (uint32_t i = 0; i < std::min(nplatforms, NumEntries); ++i) {
39+
Platforms[i] = cached_platforms->at(i).get();
10540
}
106-
} catch (const std::bad_alloc &) {
107-
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
108-
} catch (...) {
109-
return UR_RESULT_ERROR_UNKNOWN;
11041
}
111-
}
112-
113-
// Populate returned platforms from the cache.
114-
if (Platforms) {
115-
UR_ASSERT(NumEntries <= URPlatformsCache->size(),
116-
UR_RESULT_ERROR_INVALID_PLATFORM);
117-
std::copy_n(URPlatformsCache->begin(), NumEntries, Platforms);
118-
}
119-
120-
if (NumPlatforms) {
121-
if (*NumPlatforms == 0)
122-
*NumPlatforms = URPlatformsCache->size();
123-
else
124-
*NumPlatforms = (std::min)(URPlatformsCache->size(), (size_t)NumEntries);
42+
} else {
43+
return Adapter.PlatformCache->get_error();
12544
}
12645

12746
return UR_RESULT_SUCCESS;

source/ur/ur.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,3 @@ bool PrintTrace = [] {
2222
}
2323
return false;
2424
}();
25-
26-
// Apparatus for maintaining immutable cache of platforms.
27-
std::vector<ur_platform_handle_t> *URPlatformsCache =
28-
new std::vector<ur_platform_handle_t>;
29-
SpinLock *URPlatformsCacheMutex = new SpinLock;
30-
bool URPlatformCachePopulated = false;

source/ur/ur.hpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <shared_mutex>
2020
#include <string>
2121
#include <thread>
22+
#include <variant>
2223
#include <vector>
2324

2425
#include <ur_api.h>
@@ -191,16 +192,6 @@ struct _ur_platform {};
191192
// Controls tracing UR calls from within the UR itself.
192193
extern bool PrintTrace;
193194

194-
// Apparatus for maintaining immutable cache of platforms.
195-
//
196-
// Note we only create a simple pointer variables such that C++ RT won't
197-
// deallocate them automatically at the end of the main program.
198-
// The heap memory allocated for these global variables reclaimed only at
199-
// explicit tear-down.
200-
extern std::vector<ur_platform_handle_t> *URPlatformsCache;
201-
extern SpinLock *URPlatformsCacheMutex;
202-
extern bool URPlatformCachePopulated;
203-
204195
// The getInfo*/ReturnHelper facilities provide shortcut way of
205196
// writing return bytes for the various getInfo APIs.
206197
namespace ur {
@@ -310,3 +301,23 @@ class UrReturnHelper {
310301
void *param_value;
311302
size_t *param_value_size_ret;
312303
};
304+
305+
template <typename T> class Result {
306+
public:
307+
Result(ur_result_t err) : value_or_err(err) {}
308+
Result(T value) : value_or_err(std::move(value)) {}
309+
Result() : value_or_err(UR_RESULT_ERROR_UNINITIALIZED) {}
310+
311+
bool is_err() { return std::holds_alternative<ur_result_t>(value_or_err); }
312+
explicit operator bool() const { return !is_err(); }
313+
314+
const T *get_value() { return std::get_if<T>(&value_or_err); }
315+
316+
ur_result_t get_error() {
317+
auto *err = std::get_if<ur_result_t>(&value_or_err);
318+
return err ? *err : UR_RESULT_SUCCESS;
319+
}
320+
321+
private:
322+
std::variant<ur_result_t, T> value_or_err;
323+
};

0 commit comments

Comments
 (0)