Skip to content

Commit 3943e7e

Browse files
authored
Merge pull request #1252 from pbalcer/l0-platform-hang
[L0] move platform cache into the adapter structure
2 parents 1cd402e + c56b44e commit 3943e7e

File tree

7 files changed

+140
-139
lines changed

7 files changed

+140
-139
lines changed

.github/workflows/cmake.yml

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -213,15 +213,7 @@ jobs:
213213
working-directory: ${{github.workspace}}/build
214214
run: ctest -C ${{matrix.build_type}} --output-on-failure -L "adapter-specific" --timeout 180
215215

216-
# Temporarily disabling platform test for L0, because of hang
217-
# See issue: #824
218-
- name: Test L0 adapter
219-
if: matrix.adapter.name == 'L0'
220-
working-directory: ${{github.workspace}}/build
221-
run: ctest -C ${{matrix.build_type}} --output-on-failure -L "conformance" -E "platform-adapter_level_zero" --timeout 180
222-
223216
- name: Test adapters
224-
if: matrix.adapter.name != 'L0'
225217
working-directory: ${{github.workspace}}/build
226218
run: env UR_CTS_ADAPTER_PLATFORM="${{matrix.adapter.platform}}" ctest -C ${{matrix.build_type}} --output-on-failure -L "conformance" --timeout 180
227219

source/adapters/level_zero/adapter.cpp

Lines changed: 87 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,92 @@
1313

1414
ur_adapter_handle_t_ Adapter{};
1515

16-
ur_result_t adapterStateTeardown() {
17-
// reclaim ur_platform_handle_t objects here since we don't have
18-
// urPlatformRelease.
19-
for (ur_platform_handle_t Platform : *URPlatformsCache) {
20-
delete Platform;
16+
ur_result_t initPlatforms(PlatformVec &platforms) noexcept try {
17+
uint32_t ZeDriverCount = 0;
18+
ZE2UR_CALL(zeDriverGet, (&ZeDriverCount, nullptr));
19+
if (ZeDriverCount == 0) {
20+
return UR_RESULT_SUCCESS;
21+
}
22+
23+
std::vector<ze_driver_handle_t> ZeDrivers;
24+
ZeDrivers.resize(ZeDriverCount);
25+
26+
ZE2UR_CALL(zeDriverGet, (&ZeDriverCount, ZeDrivers.data()));
27+
for (uint32_t I = 0; I < ZeDriverCount; ++I) {
28+
auto platform = std::make_unique<ur_platform_handle_t_>(ZeDrivers[I]);
29+
UR_CALL(platform->initialize());
30+
31+
// Save a copy in the cache for future uses.
32+
platforms.push_back(std::move(platform));
33+
}
34+
return UR_RESULT_SUCCESS;
35+
} catch (...) {
36+
return exceptionToResult(std::current_exception());
37+
}
38+
39+
ur_result_t adapterStateInit() {
40+
static std::once_flag ZeCallCountInitialized;
41+
try {
42+
std::call_once(ZeCallCountInitialized, []() {
43+
if (UrL0LeaksDebug) {
44+
ZeCallCount = new std::map<std::string, int>;
45+
}
46+
});
47+
} catch (const std::bad_alloc &) {
48+
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
49+
} catch (...) {
50+
return UR_RESULT_ERROR_UNKNOWN;
2151
}
22-
delete URPlatformsCache;
23-
delete URPlatformsCacheMutex;
2452

53+
// initialize level zero only once.
54+
if (Adapter.ZeResult == std::nullopt) {
55+
// Setting these environment variables before running zeInit will enable the
56+
// validation layer in the Level Zero loader.
57+
if (UrL0Debug & UR_L0_DEBUG_VALIDATION) {
58+
setEnvVar("ZE_ENABLE_VALIDATION_LAYER", "1");
59+
setEnvVar("ZE_ENABLE_PARAMETER_VALIDATION", "1");
60+
}
61+
62+
if (getenv("SYCL_ENABLE_PCI") != nullptr) {
63+
urPrint("WARNING: SYCL_ENABLE_PCI is deprecated and no longer needed.\n");
64+
}
65+
66+
// TODO: We can still safely recover if something goes wrong during the
67+
// init. Implement handling segfault using sigaction.
68+
69+
// We must only initialize the driver once, even if urPlatformGet() is
70+
// called multiple times. Declaring the return value as "static" ensures
71+
// it's only called once.
72+
Adapter.ZeResult = ZE_CALL_NOCHECK(zeInit, (ZE_INIT_FLAG_GPU_ONLY));
73+
}
74+
75+
Adapter.PlatformCache.Compute = [](Result<PlatformVec> &result) {
76+
assert(Adapter.ZeResult !=
77+
std::nullopt); // verify that level-zero is initialized
78+
PlatformVec platforms;
79+
80+
// Absorb the ZE_RESULT_ERROR_UNINITIALIZED and just return 0 Platforms.
81+
if (*Adapter.ZeResult == ZE_RESULT_ERROR_UNINITIALIZED) {
82+
result = std::move(platforms);
83+
return;
84+
}
85+
if (*Adapter.ZeResult != ZE_RESULT_SUCCESS) {
86+
urPrint("zeInit: Level Zero initialization failure\n");
87+
result = ze2urResult(*Adapter.ZeResult);
88+
return;
89+
}
90+
91+
ur_result_t err = initPlatforms(platforms);
92+
if (err == UR_RESULT_SUCCESS) {
93+
result = std::move(platforms);
94+
} else {
95+
result = err;
96+
}
97+
};
98+
return UR_RESULT_SUCCESS;
99+
}
100+
101+
ur_result_t adapterStateTeardown() {
25102
bool LeakFound = false;
26103

27104
// Print the balance of various create/destroy native calls.
@@ -126,9 +203,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
126203
) {
127204
if (NumEntries > 0 && Adapters) {
128205
std::lock_guard<std::mutex> Lock{Adapter.Mutex};
129-
// TODO: Some initialization that happens in urPlatformsGet could be moved
130-
// here for when RefCount reaches 1
131-
Adapter.RefCount++;
206+
if (Adapter.RefCount++ == 0) {
207+
adapterStateInit();
208+
}
132209
*Adapters = &Adapter;
133210
}
134211

source/adapters/level_zero/adapter.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,18 @@
1010

1111
#include <atomic>
1212
#include <mutex>
13+
#include <optional>
14+
#include <ur/ur.hpp>
15+
#include <ze_api.h>
16+
17+
using PlatformVec = std::vector<std::unique_ptr<ur_platform_handle_t_>>;
1318

1419
struct ur_adapter_handle_t_ {
1520
std::atomic<uint32_t> RefCount = 0;
1621
std::mutex Mutex;
22+
23+
std::optional<ze_result_t> ZeResult;
24+
ZeCache<Result<PlatformVec>> PlatformCache;
1725
};
1826

1927
extern ur_adapter_handle_t_ Adapter;

source/adapters/level_zero/device.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
//===----------------------------------------------------------------------===//
1010

1111
#include "device.hpp"
12+
#include "adapter.hpp"
1213
#include "ur_level_zero.hpp"
1314
#include "ur_util.hpp"
1415
#include <algorithm>
@@ -1321,21 +1322,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
13211322
// Level Zero devices when we initialized the platforms/devices cache, so the
13221323
// "NativeHandle" must already be in the cache. If it is not, this must not be
13231324
// a valid Level Zero device.
1324-
//
1325-
// TODO: maybe we should populate cache of platforms if it wasn't already.
1326-
// For now assert that is was populated.
1327-
UR_ASSERT(URPlatformCachePopulated, UR_RESULT_ERROR_INVALID_VALUE);
1328-
const std::lock_guard<SpinLock> Lock{*URPlatformsCacheMutex};
13291325

13301326
ur_device_handle_t Dev = nullptr;
1331-
for (ur_platform_handle_t ThePlatform : *URPlatformsCache) {
1332-
Dev = ThePlatform->getDeviceFromNativeHandle(ZeDevice);
1333-
if (Dev) {
1334-
// Check that the input Platform, if was given, matches the found one.
1335-
UR_ASSERT(!Platform || Platform == ThePlatform,
1336-
UR_RESULT_ERROR_INVALID_PLATFORM);
1337-
break;
1327+
if (const auto *platforms = Adapter.PlatformCache->get_value()) {
1328+
for (const auto &p : *platforms) {
1329+
Dev = p->getDeviceFromNativeHandle(ZeDevice);
1330+
if (Dev) {
1331+
// Check that the input Platform, if was given, matches the found one.
1332+
UR_ASSERT(!Platform || Platform == p.get(),
1333+
UR_RESULT_ERROR_INVALID_PLATFORM);
1334+
break;
1335+
}
13381336
}
1337+
} else {
1338+
return Adapter.PlatformCache->get_error();
13391339
}
13401340

13411341
if (Dev == nullptr)

source/adapters/level_zero/platform.cpp

Lines changed: 12 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -27,101 +27,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGet(
2727
uint32_t *NumPlatforms ///< [out][optional] returns the total number of
2828
///< platforms available.
2929
) {
30-
static std::once_flag ZeCallCountInitialized;
31-
try {
32-
std::call_once(ZeCallCountInitialized, []() {
33-
if (UrL0LeaksDebug) {
34-
ZeCallCount = new std::map<std::string, int>;
35-
}
36-
});
37-
} catch (const std::bad_alloc &) {
38-
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
39-
} catch (...) {
40-
return UR_RESULT_ERROR_UNKNOWN;
41-
}
42-
43-
// Setting these environment variables before running zeInit will enable the
44-
// validation layer in the Level Zero loader.
45-
if (UrL0Debug & UR_L0_DEBUG_VALIDATION) {
46-
setEnvVar("ZE_ENABLE_VALIDATION_LAYER", "1");
47-
setEnvVar("ZE_ENABLE_PARAMETER_VALIDATION", "1");
48-
}
49-
50-
if (getenv("SYCL_ENABLE_PCI") != nullptr) {
51-
urPrint("WARNING: SYCL_ENABLE_PCI is deprecated and no longer needed.\n");
52-
}
53-
54-
// TODO: We can still safely recover if something goes wrong during the init.
55-
// Implement handling segfault using sigaction.
56-
57-
// We must only initialize the driver once, even if urPlatformGet() is called
58-
// multiple times. Declaring the return value as "static" ensures it's only
59-
// called once.
60-
static ze_result_t ZeResult =
61-
ZE_CALL_NOCHECK(zeInit, (ZE_INIT_FLAG_GPU_ONLY));
62-
63-
// Absorb the ZE_RESULT_ERROR_UNINITIALIZED and just return 0 Platforms.
64-
if (ZeResult == ZE_RESULT_ERROR_UNINITIALIZED) {
65-
UR_ASSERT(NumEntries == 0, UR_RESULT_ERROR_INVALID_VALUE);
66-
if (NumPlatforms)
67-
*NumPlatforms = 0;
68-
return UR_RESULT_SUCCESS;
69-
}
70-
71-
if (ZeResult != ZE_RESULT_SUCCESS) {
72-
urPrint("zeInit: Level Zero initialization failure\n");
73-
return ze2urResult(ZeResult);
74-
}
75-
76-
// Cache ur_platform_handle_t for reuse in the future
77-
// It solves two problems;
78-
// 1. sycl::platform equality issue; we always return the same
79-
// ur_platform_handle_t
80-
// 2. performance; we can save time by immediately return from cache.
81-
//
82-
83-
const std::lock_guard<SpinLock> Lock{*URPlatformsCacheMutex};
84-
if (!URPlatformCachePopulated) {
85-
try {
86-
// Level Zero does not have concept of Platforms, but Level Zero driver is
87-
// the closest match.
88-
uint32_t ZeDriverCount = 0;
89-
ZE2UR_CALL(zeDriverGet, (&ZeDriverCount, nullptr));
90-
if (ZeDriverCount == 0) {
91-
URPlatformCachePopulated = true;
92-
} else {
93-
std::vector<ze_driver_handle_t> ZeDrivers;
94-
ZeDrivers.resize(ZeDriverCount);
95-
96-
ZE2UR_CALL(zeDriverGet, (&ZeDriverCount, ZeDrivers.data()));
97-
for (uint32_t I = 0; I < ZeDriverCount; ++I) {
98-
auto Platform = new ur_platform_handle_t_(ZeDrivers[I]);
99-
// Save a copy in the cache for future uses.
100-
URPlatformsCache->push_back(Platform);
101-
102-
UR_CALL(Platform->initialize());
103-
}
104-
URPlatformCachePopulated = true;
30+
// Platform handles are cached for reuse. This is to ensure consistent
31+
// handle pointers across invocations and to improve retrieval performance.
32+
if (const auto *cached_platforms = Adapter.PlatformCache->get_value()) {
33+
uint32_t nplatforms = (uint32_t)cached_platforms->size();
34+
if (NumPlatforms) {
35+
*NumPlatforms = nplatforms;
36+
}
37+
if (Platforms) {
38+
for (uint32_t i = 0; i < std::min(nplatforms, NumEntries); ++i) {
39+
Platforms[i] = cached_platforms->at(i).get();
10540
}
106-
} catch (const std::bad_alloc &) {
107-
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
108-
} catch (...) {
109-
return UR_RESULT_ERROR_UNKNOWN;
11041
}
111-
}
112-
113-
// Populate returned platforms from the cache.
114-
if (Platforms) {
115-
UR_ASSERT(NumEntries <= URPlatformsCache->size(),
116-
UR_RESULT_ERROR_INVALID_PLATFORM);
117-
std::copy_n(URPlatformsCache->begin(), NumEntries, Platforms);
118-
}
119-
120-
if (NumPlatforms) {
121-
if (*NumPlatforms == 0)
122-
*NumPlatforms = URPlatformsCache->size();
123-
else
124-
*NumPlatforms = (std::min)(URPlatformsCache->size(), (size_t)NumEntries);
42+
} else {
43+
return Adapter.PlatformCache->get_error();
12544
}
12645

12746
return UR_RESULT_SUCCESS;

source/ur/ur.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,3 @@ bool PrintTrace = [] {
2222
}
2323
return false;
2424
}();
25-
26-
// Apparatus for maintaining immutable cache of platforms.
27-
std::vector<ur_platform_handle_t> *URPlatformsCache =
28-
new std::vector<ur_platform_handle_t>;
29-
SpinLock *URPlatformsCacheMutex = new SpinLock;
30-
bool URPlatformCachePopulated = false;

source/ur/ur.hpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <shared_mutex>
2020
#include <string>
2121
#include <thread>
22+
#include <variant>
2223
#include <vector>
2324

2425
#include <ur_api.h>
@@ -191,16 +192,6 @@ struct _ur_platform {};
191192
// Controls tracing UR calls from within the UR itself.
192193
extern bool PrintTrace;
193194

194-
// Apparatus for maintaining immutable cache of platforms.
195-
//
196-
// Note we only create a simple pointer variables such that C++ RT won't
197-
// deallocate them automatically at the end of the main program.
198-
// The heap memory allocated for these global variables reclaimed only at
199-
// explicit tear-down.
200-
extern std::vector<ur_platform_handle_t> *URPlatformsCache;
201-
extern SpinLock *URPlatformsCacheMutex;
202-
extern bool URPlatformCachePopulated;
203-
204195
// The getInfo*/ReturnHelper facilities provide shortcut way of
205196
// writing return bytes for the various getInfo APIs.
206197
namespace ur {
@@ -310,3 +301,23 @@ class UrReturnHelper {
310301
void *param_value;
311302
size_t *param_value_size_ret;
312303
};
304+
305+
template <typename T> class Result {
306+
public:
307+
Result(ur_result_t err) : value_or_err(err) {}
308+
Result(T value) : value_or_err(std::move(value)) {}
309+
Result() : value_or_err(UR_RESULT_ERROR_UNINITIALIZED) {}
310+
311+
bool is_err() { return std::holds_alternative<ur_result_t>(value_or_err); }
312+
explicit operator bool() const { return !is_err(); }
313+
314+
const T *get_value() { return std::get_if<T>(&value_or_err); }
315+
316+
ur_result_t get_error() {
317+
auto *err = std::get_if<ur_result_t>(&value_or_err);
318+
return err ? *err : UR_RESULT_SUCCESS;
319+
}
320+
321+
private:
322+
std::variant<ur_result_t, T> value_or_err;
323+
};

0 commit comments

Comments
 (0)