Skip to content

Commit f8d6182

Browse files
committed
Address feedback: rethink layer::isEnabled
1 parent 7162e22 commit f8d6182

File tree

9 files changed

+51
-32
lines changed

9 files changed

+51
-32
lines changed

scripts/templates/trcddi.cpp.mako

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,15 @@ namespace ur_tracing_layer
102102
}
103103
%endfor
104104

105-
${x}_result_t context_t::init(ur_dditable_t *dditable) {
105+
${x}_result_t
106+
context_t::init(ur_dditable_t *dditable,
107+
const std::set<std::string> &enabledLayerNames) {
106108
${x}_result_t result = ${X}_RESULT_SUCCESS;
107109

110+
if(!enabledLayerNames.count(name)) {
111+
return result;
112+
}
113+
108114
%for tbl in th.get_pfntables(specs, meta, n, tags):
109115
if( ${X}_RESULT_SUCCESS == result )
110116
{

scripts/templates/valddi.cpp.mako

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,23 @@ namespace ur_validation_layer
138138
}
139139

140140
%endfor
141-
${x}_result_t context_t::init(${x}_dditable_t *dditable) {
141+
${x}_result_t
142+
context_t::init(ur_dditable_t *dditable,
143+
const std::set<std::string> &enabledLayerNames) {
142144
${x}_result_t result = ${X}_RESULT_SUCCESS;
143145

146+
if (enabledLayerNames.count(nameFullValidation)) {
147+
enableParameterValidation = true;
148+
enableLeakChecking = true;
149+
} else {
150+
if (enabledLayerNames.count(nameParameterValidation)) {
151+
enableParameterValidation = true;
152+
}
153+
if (enabledLayerNames.count(nameLeakChecking)) {
154+
enableLeakChecking = true;
155+
}
156+
}
157+
144158
if(!enableParameterValidation && !enableLeakChecking) {
145159
return result;
146160
}

source/loader/layers/tracing/ur_tracing_layer.hpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,9 @@ class __urdlllocal context_t : public proxy_layer_context_t {
3030

3131
bool isAvailable() const override;
3232

33-
bool isEnabled(const std::set<std::string> &enabledLayerNames) override {
34-
return enabledLayerNames.find(name) != enabledLayerNames.end();
35-
}
36-
3733
std::vector<std::string> getNames() const override { return {name}; }
38-
ur_result_t init(ur_dditable_t *dditable) override;
34+
ur_result_t init(ur_dditable_t *dditable,
35+
const std::set<std::string> &enabledLayerNames) override;
3936
uint64_t notify_begin(uint32_t id, const char *name, void *args);
4037
void notify_end(uint32_t id, const char *name, void *args,
4138
ur_result_t *resultp, uint64_t instance);

source/loader/layers/tracing/ur_trcddi.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6782,9 +6782,14 @@ __urdlllocal ur_result_t UR_APICALL urGetDeviceProcAddrTable(
67826782
return result;
67836783
}
67846784

6785-
ur_result_t context_t::init(ur_dditable_t *dditable) {
6785+
ur_result_t context_t::init(ur_dditable_t *dditable,
6786+
const std::set<std::string> &enabledLayerNames) {
67866787
ur_result_t result = UR_RESULT_SUCCESS;
67876788

6789+
if (!enabledLayerNames.count(name)) {
6790+
return result;
6791+
}
6792+
67886793
if (UR_RESULT_SUCCESS == result) {
67896794
result = ur_tracing_layer::urGetGlobalProcAddrTable(
67906795
UR_API_VERSION_CURRENT, &dditable->Global);

source/loader/layers/ur_proxy_layer.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ class __urdlllocal proxy_layer_context_t {
2424

2525
virtual std::vector<std::string> getNames() const = 0;
2626
virtual bool isAvailable() const = 0;
27-
virtual bool isEnabled(const std::set<std::string> &enabledLayerNames) = 0;
28-
virtual ur_result_t init(ur_dditable_t *dditable) = 0;
27+
virtual ur_result_t
28+
init(ur_dditable_t *dditable,
29+
const std::set<std::string> &enabledLayerNames) = 0;
2930
};
3031

3132
#endif /* UR_PROXY_LAYER_H */

source/loader/layers/validation/ur_valddi.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8145,9 +8145,22 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetDeviceProcAddrTable(
81458145
return result;
81468146
}
81478147

8148-
ur_result_t context_t::init(ur_dditable_t *dditable) {
8148+
ur_result_t context_t::init(ur_dditable_t *dditable,
8149+
const std::set<std::string> &enabledLayerNames) {
81498150
ur_result_t result = UR_RESULT_SUCCESS;
81508151

8152+
if (enabledLayerNames.count(nameFullValidation)) {
8153+
enableParameterValidation = true;
8154+
enableLeakChecking = true;
8155+
} else {
8156+
if (enabledLayerNames.count(nameParameterValidation)) {
8157+
enableParameterValidation = true;
8158+
}
8159+
if (enabledLayerNames.count(nameLeakChecking)) {
8160+
enableLeakChecking = true;
8161+
}
8162+
}
8163+
81518164
if (!enableParameterValidation && !enableLeakChecking) {
81528165
return result;
81538166
}

source/loader/layers/validation/ur_validation_layer.cpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,4 @@ context_t::context_t() : logger(logger::create_logger("validation")) {}
2020
///////////////////////////////////////////////////////////////////////////////
2121
context_t::~context_t() {}
2222

23-
bool context_t::isEnabled(const std::set<std::string> &enabledLayerNames) {
24-
if (enabledLayerNames.find(nameFullValidation) != enabledLayerNames.end()) {
25-
enableParameterValidation = true;
26-
enableLeakChecking = true;
27-
} else {
28-
if (enabledLayerNames.find(nameParameterValidation) !=
29-
enabledLayerNames.end()) {
30-
enableParameterValidation = true;
31-
}
32-
if (enabledLayerNames.find(nameLeakChecking) !=
33-
enabledLayerNames.end()) {
34-
enableLeakChecking = true;
35-
}
36-
}
37-
return enableParameterValidation || enableLeakChecking;
38-
}
39-
4023
} // namespace ur_validation_layer

source/loader/layers/validation/ur_validation_layer.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ class __urdlllocal context_t : public proxy_layer_context_t {
3030
~context_t();
3131

3232
bool isAvailable() const override { return true; }
33-
bool isEnabled(const std::set<std::string> &enabledLayerNames) override;
3433
std::vector<std::string> getNames() const override {
3534
return {nameFullValidation, nameParameterValidation, nameLeakChecking};
3635
}
37-
ur_result_t init(ur_dditable_t *dditable) override;
36+
ur_result_t init(ur_dditable_t *dditable,
37+
const std::set<std::string> &enabledLayerNames) override;
3838

3939
private:
4040
const std::string nameFullValidation = "UR_LAYER_FULL_VALIDATION";

source/loader/ur_lib.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ void context_t::parseEnvEnabledLayers() {
5454

5555
void context_t::initLayers() const {
5656
for (auto &l : layers) {
57-
if (l->isAvailable() && l->isEnabled(enabledLayerNames)) {
58-
l->init(&context->urDdiTable);
57+
if (l->isAvailable()) {
58+
l->init(&context->urDdiTable, enabledLayerNames);
5959
}
6060
}
6161
}

0 commit comments

Comments
 (0)