Skip to content

Commit 186bfb9

Browse files
authored
Merge pull request #1121 from kswiecicki/val-use-after-free
[UR] Add lifetime validation to validation layer
2 parents 32e2533 + c0f0a70 commit 186bfb9

File tree

12 files changed

+1312
-122
lines changed

12 files changed

+1312
-122
lines changed

scripts/core/INTRO.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,8 @@ Layers currently included with the runtime are as follows:
268268
- Enables non-adapter-specific parameter validation (e.g. checking for null values).
269269
* - UR_LAYER_LEAK_CHECKING
270270
- Performs some leak checking for API calls involving object creation/destruction.
271+
* - UR_LAYER_LIFETIME_VALIDATION
272+
- Performs lifetime validation on objects (check if it was used within the scope of its creation and destruction) used in API calls. Automatically enables UR_LAYER_LEAK_CHECKING.
271273
* - UR_LAYER_FULL_VALIDATION
272274
- Enables UR_LAYER_PARAMETER_VALIDATION and UR_LAYER_LEAK_CHECKING.
273275
* - UR_LAYER_TRACING

scripts/templates/helper.py

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Copyright (C) 2022-2023 Intel Corporation
2+
Copyright (C) 2022-2024 Intel Corporation
33
44
Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions.
55
See LICENSE.TXT
@@ -1486,45 +1486,79 @@ def get_loader_epilogue(specs, namespace, tags, obj, meta):
14861486

14871487
return epilogue
14881488

1489+
1490+
def get_event_wait_list_functions(specs, namespace, tags):
1491+
funcs = []
1492+
for s in specs:
1493+
for obj in s['objects']:
1494+
if re.match(r"function", obj['type']):
1495+
if any(x['name'] == 'phEventWaitList' for x in obj['params']) and any(
1496+
x['name'] == 'numEventsInWaitList' for x in obj['params']):
1497+
funcs.append(make_func_name(namespace, tags, obj))
1498+
return funcs
1499+
1500+
14891501
"""
1490-
Public:
1491-
returns a dictionary with lists of create, retain and release functions
1502+
Private:
1503+
returns a dictionary with lists of create, get, retain and release functions
14921504
"""
1493-
def get_create_retain_release_functions(specs, namespace, tags):
1505+
def _get_create_get_retain_release_functions(specs, namespace, tags):
14941506
funcs = []
14951507
for s in specs:
14961508
for obj in s['objects']:
14971509
if re.match(r"function", obj['type']):
14981510
funcs.append(make_func_name(namespace, tags, obj))
14991511

1500-
create_suffixes = r"(Create[A-Za-z]*){1}"
1501-
retain_suffixes = r"(Retain){1}"
1502-
release_suffixes = r"(Release){1}"
1512+
create_suffixes = r"(Create[A-Za-z]*){1}$"
1513+
get_suffixes = r"(Get){1}$"
1514+
retain_suffixes = r"(Retain){1}$"
1515+
release_suffixes = r"(Release){1}$"
1516+
common_prefix = r"^" + namespace
15031517

1504-
create_exp = namespace + r"([A-Za-z]+)" + create_suffixes
1505-
retain_exp = namespace + r"([A-Za-z]+)" + retain_suffixes
1506-
release_exp = namespace + r"([A-Za-z]+)" + release_suffixes
1518+
create_exp = common_prefix + r"[A-Za-z]+" + create_suffixes
1519+
get_exp = common_prefix + r"[A-Za-z]+" + get_suffixes
1520+
retain_exp = common_prefix + r"[A-Za-z]+" + retain_suffixes
1521+
release_exp = common_prefix + r"[A-Za-z]+" + release_suffixes
15071522

1508-
create_funcs, retain_funcs, release_funcs = (
1523+
create_funcs, get_funcs, retain_funcs, release_funcs = (
15091524
list(filter(lambda f: re.match(create_exp, f), funcs)),
1525+
list(filter(lambda f: re.match(get_exp, f), funcs)),
15101526
list(filter(lambda f: re.match(retain_exp, f), funcs)),
15111527
list(filter(lambda f: re.match(release_exp, f), funcs)),
15121528
)
15131529

1514-
create_funcs, retain_funcs = (
1515-
list(filter(lambda f: re.sub(create_suffixes, "Release", f) in release_funcs, create_funcs)),
1516-
list(filter(lambda f: re.sub(retain_suffixes, "Release", f) in release_funcs, retain_funcs)),
1517-
)
1530+
return {"create": create_funcs, "get": get_funcs, "retain": retain_funcs, "release": release_funcs}
15181531

1519-
return {"create": create_funcs, "retain": retain_funcs, "release": release_funcs}
15201532

1533+
"""
1534+
Public:
1535+
returns a list of dictionaries containing handle types and the corresponding create, get, retain and release functions
1536+
"""
1537+
def get_handle_create_get_retain_release_functions(specs, namespace, tags):
1538+
# Handles without release function
1539+
excluded_handles = ["$x_platform_handle_t", "$x_native_handle_t"]
1540+
# Handles from experimental features
1541+
exp_prefix = "$x_exp"
1542+
1543+
funcs = _get_create_get_retain_release_functions(specs, namespace, tags)
1544+
records = []
1545+
for h in get_adapter_handles(specs):
1546+
if h['name'] in excluded_handles or h['name'].startswith(exp_prefix):
1547+
continue
15211548

1522-
def get_event_wait_list_functions(specs, namespace, tags):
1523-
funcs = []
1524-
for s in specs:
1525-
for obj in s['objects']:
1526-
if re.match(r"function", obj['type']):
1527-
if any(x['name'] == 'phEventWaitList' for x in obj['params']) and any(
1528-
x['name'] == 'numEventsInWaitList' for x in obj['params']):
1529-
funcs.append(make_func_name(namespace, tags, obj))
1530-
return funcs
1549+
class_type = subt(namespace, tags, h['class'])
1550+
create_funcs = list(filter(lambda f: class_type in f, funcs['create']))
1551+
get_funcs = list(filter(lambda f: class_type in f, funcs['get']))
1552+
retain_funcs = list(filter(lambda f: class_type in f, funcs['retain']))
1553+
release_funcs = list(filter(lambda f: class_type in f, funcs['release']))
1554+
1555+
record = {}
1556+
record['handle'] = subt(namespace, tags, h['name'])
1557+
record['create'] = create_funcs
1558+
record['get'] = get_funcs
1559+
record['retain'] = retain_funcs
1560+
record['release'] = release_funcs
1561+
1562+
records.append(record)
1563+
1564+
return records

scripts/templates/valddi.cpp.mako

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@ from templates import helper as th
77
88
x=tags['$x']
99
X=x.upper()
10-
create_retain_release_funcs=th.get_create_retain_release_functions(specs, n, tags)
10+
11+
handle_create_get_retain_release_funcs=th.get_handle_create_get_retain_release_functions(specs, n, tags)
1112
%>/*
1213
*
13-
* Copyright (C) 2023 Intel Corporation
14+
* Copyright (C) 2023-2024 Intel Corporation
1415
*
1516
* Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions.
1617
* See LICENSE.TXT
@@ -27,11 +28,12 @@ namespace ur_validation_layer
2728
%for obj in th.get_adapter_functions(specs):
2829
<%
2930
func_name=th.make_func_name(n, tags, obj)
30-
object_param=th.make_param_lines(n, tags, obj, format=["name"])[-1]
31-
object_param_type=th.make_param_lines(n, tags, obj, format=["type"])[-1]
31+
3232
param_checks=th.make_param_checks(n, tags, obj, meta=meta).items()
3333
first_errors = [X + "_RESULT_ERROR_INVALID_NULL_POINTER", X + "_RESULT_ERROR_INVALID_NULL_HANDLE"]
3434
sorted_param_checks = sorted(param_checks, key=lambda pair: False if pair[0] in first_errors else True)
35+
36+
tracked_params = list(filter(lambda p: any(th.subt(n, tags, p['type']) in [hf['handle'], hf['handle'] + "*"] for hf in handle_create_get_retain_release_funcs), obj['params']))
3537
%>
3638
///////////////////////////////////////////////////////////////////////////////
3739
/// @brief Intercept function for ${th.make_func_name(n, tags, obj)}
@@ -72,39 +74,49 @@ namespace ur_validation_layer
7274

7375
}
7476

77+
%for tp in tracked_params:
78+
<%
79+
tp_input_handle_funcs = next((hf for hf in handle_create_get_retain_release_funcs if th.subt(n, tags, tp['type']) == hf['handle'] and "[in]" in tp['desc']), {})
80+
is_related_create_get_retain_release_func = any(func_name in funcs for funcs in tp_input_handle_funcs.values())
81+
%>
82+
%if tp_input_handle_funcs and not is_related_create_get_retain_release_func:
83+
if (context.enableLifetimeValidation && !refCountContext.isReferenceValid(${tp['name']})) {
84+
refCountContext.logInvalidReference(${tp['name']});
85+
}
86+
%endif
87+
%endfor
88+
7589
${x}_result_t result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} );
7690

77-
%if func_name == n + "AdapterRelease":
91+
%for tp in tracked_params:
92+
<%
93+
tp_handle_funcs = next((hf for hf in handle_create_get_retain_release_funcs if th.subt(n, tags, tp['type']) in [hf['handle'], hf['handle'] + "*"]), None)
94+
is_handle_to_adapter = ("_adapter_handle_t" in tp['type'])
95+
%>
96+
%if func_name in tp_handle_funcs['create']:
7897
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
7998
{
80-
refCountContext.decrementRefCount(${object_param}, true);
99+
refCountContext.createRefCount(*${tp['name']});
81100
}
82-
%elif func_name == n + "AdapterRetain":
83-
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
101+
%elif func_name in tp_handle_funcs['get']:
102+
if( context.enableLeakChecking && ${tp['name']} && result == UR_RESULT_SUCCESS )
84103
{
85-
refCountContext.incrementRefCount(${object_param}, true);
86-
}
87-
%elif func_name == n + "AdapterGet":
88-
if( context.enableLeakChecking && phAdapters && result == UR_RESULT_SUCCESS )
89-
{
90-
refCountContext.createOrIncrementRefCount(*phAdapters, true);
91-
}
92-
%elif func_name in create_retain_release_funcs["create"]:
93-
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
94-
{
95-
refCountContext.createRefCount(*${object_param});
104+
for (uint32_t i = ${th.param_traits.range_start(tp)}; i < ${th.param_traits.range_end(tp)}; i++) {
105+
refCountContext.createOrIncrementRefCount(${tp['name']}[i], ${str(is_handle_to_adapter).lower()});
106+
}
96107
}
97-
%elif func_name in create_retain_release_funcs["retain"]:
108+
%elif func_name in tp_handle_funcs['retain']:
98109
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
99110
{
100-
refCountContext.incrementRefCount(${object_param});
111+
refCountContext.incrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
101112
}
102-
%elif func_name in create_retain_release_funcs["release"]:
113+
%elif func_name in tp_handle_funcs['release']:
103114
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
104115
{
105-
refCountContext.decrementRefCount(${object_param});
116+
refCountContext.decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
106117
}
107118
%endif
119+
%endfor
108120

109121
return result;
110122
}
@@ -167,16 +179,22 @@ namespace ur_validation_layer
167179
if (enabledLayerNames.count(nameFullValidation)) {
168180
enableParameterValidation = true;
169181
enableLeakChecking = true;
182+
enableLifetimeValidation = true;
170183
} else {
171184
if (enabledLayerNames.count(nameParameterValidation)) {
172185
enableParameterValidation = true;
173186
}
174187
if (enabledLayerNames.count(nameLeakChecking)) {
175188
enableLeakChecking = true;
176189
}
190+
if (enabledLayerNames.count(nameLifetimeValidation)) {
191+
// Handle lifetime validation requires leak checking feature.
192+
enableLifetimeValidation = true;
193+
enableLeakChecking = true;
194+
}
177195
}
178196

179-
if(!enableParameterValidation && !enableLeakChecking) {
197+
if (!enableParameterValidation && !enableLeakChecking && !enableLifetimeValidation) {
180198
return result;
181199
}
182200

source/loader/layers/validation/ur_leak_check.hpp

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (C) 2023 Intel Corporation
1+
// Copyright (C) 2023-2024 Intel Corporation
22
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions.
33
// See LICENSE.TXT
44
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
@@ -9,6 +9,7 @@
99
#include "ur_validation_layer.hpp"
1010

1111
#include <mutex>
12+
#include <typeindex>
1213
#include <unordered_map>
1314
#include <utility>
1415

@@ -20,7 +21,12 @@ struct RefCountContext {
2021
private:
2122
struct RefRuntimeInfo {
2223
int64_t refCount;
24+
std::type_index type;
2325
std::vector<BacktraceLine> backtrace;
26+
27+
RefRuntimeInfo(int64_t refCount, std::type_index type,
28+
std::vector<BacktraceLine> backtrace)
29+
: refCount(refCount), type(type), backtrace(backtrace) {}
2430
};
2531

2632
enum RefCountUpdateType {
@@ -34,26 +40,32 @@ struct RefCountContext {
3440
std::unordered_map<void *, struct RefRuntimeInfo> counts;
3541
int64_t adapterCount = 0;
3642

37-
void updateRefCount(void *ptr, enum RefCountUpdateType type,
43+
template <typename T>
44+
void updateRefCount(T handle, enum RefCountUpdateType type,
3845
bool isAdapterHandle = false) {
3946
std::unique_lock<std::mutex> ulock(mutex);
4047

48+
void *ptr = static_cast<void *>(handle);
4149
auto it = counts.find(ptr);
4250

4351
switch (type) {
4452
case REFCOUNT_CREATE_OR_INCREASE:
4553
if (it == counts.end()) {
46-
counts[ptr] = {1, getCurrentBacktrace()};
54+
std::tie(it, std::ignore) = counts.emplace(
55+
ptr, RefRuntimeInfo{1, std::type_index(typeid(handle)),
56+
getCurrentBacktrace()});
4757
if (isAdapterHandle) {
4858
adapterCount++;
4959
}
5060
} else {
51-
counts[ptr].refCount++;
61+
it->second.refCount++;
5262
}
5363
break;
5464
case REFCOUNT_CREATE:
5565
if (it == counts.end()) {
56-
counts[ptr] = {1, getCurrentBacktrace()};
66+
std::tie(it, std::ignore) = counts.emplace(
67+
ptr, RefRuntimeInfo{1, std::type_index(typeid(handle)),
68+
getCurrentBacktrace()});
5769
} else {
5870
context.logger.error("Handle {} already exists", ptr);
5971
return;
@@ -65,29 +77,31 @@ struct RefCountContext {
6577
"Attempting to retain nonexistent handle {}", ptr);
6678
return;
6779
} else {
68-
counts[ptr].refCount++;
80+
it->second.refCount++;
6981
}
7082
break;
7183
case REFCOUNT_DECREASE:
7284
if (it == counts.end()) {
73-
counts[ptr] = {-1, getCurrentBacktrace()};
85+
std::tie(it, std::ignore) = counts.emplace(
86+
ptr, RefRuntimeInfo{-1, std::type_index(typeid(handle)),
87+
getCurrentBacktrace()});
7488
} else {
75-
counts[ptr].refCount--;
89+
it->second.refCount--;
7690
}
7791

78-
if (counts[ptr].refCount < 0) {
92+
if (it->second.refCount < 0) {
7993
context.logger.error(
8094
"Attempting to release nonexistent handle {}", ptr);
81-
} else if (counts[ptr].refCount == 0 && isAdapterHandle) {
95+
} else if (it->second.refCount == 0 && isAdapterHandle) {
8296
adapterCount--;
8397
}
8498
break;
8599
}
86100

87101
context.logger.debug("Reference count for handle {} changed to {}", ptr,
88-
counts[ptr].refCount);
102+
it->second.refCount);
89103

90-
if (counts[ptr].refCount == 0) {
104+
if (it->second.refCount == 0) {
91105
counts.erase(ptr);
92106
}
93107

@@ -99,22 +113,36 @@ struct RefCountContext {
99113
}
100114

101115
public:
102-
void createRefCount(void *ptr) { updateRefCount(ptr, REFCOUNT_CREATE); }
116+
template <typename T> void createRefCount(T handle) {
117+
updateRefCount<T>(handle, REFCOUNT_CREATE);
118+
}
103119

104-
void incrementRefCount(void *ptr, bool isAdapterHandle = false) {
105-
updateRefCount(ptr, REFCOUNT_INCREASE, isAdapterHandle);
120+
template <typename T>
121+
void incrementRefCount(T handle, bool isAdapterHandle = false) {
122+
updateRefCount(handle, REFCOUNT_INCREASE, isAdapterHandle);
106123
}
107124

108-
void decrementRefCount(void *ptr, bool isAdapterHandle = false) {
109-
updateRefCount(ptr, REFCOUNT_DECREASE, isAdapterHandle);
125+
template <typename T>
126+
void decrementRefCount(T handle, bool isAdapterHandle = false) {
127+
updateRefCount(handle, REFCOUNT_DECREASE, isAdapterHandle);
110128
}
111129

112-
void createOrIncrementRefCount(void *ptr, bool isAdapterHandle = false) {
113-
updateRefCount(ptr, REFCOUNT_CREATE_OR_INCREASE, isAdapterHandle);
130+
template <typename T>
131+
void createOrIncrementRefCount(T handle, bool isAdapterHandle = false) {
132+
updateRefCount(handle, REFCOUNT_CREATE_OR_INCREASE, isAdapterHandle);
114133
}
115134

116135
void clear() { counts.clear(); }
117136

137+
template <typename T> bool isReferenceValid(T handle) {
138+
auto it = counts.find(static_cast<void *>(handle));
139+
if (it == counts.end() || it->second.refCount < 1) {
140+
return false;
141+
}
142+
143+
return (it->second.type == std::type_index(typeid(handle)));
144+
}
145+
118146
void logInvalidReferences() {
119147
for (auto &[ptr, refRuntimeInfo] : counts) {
120148
context.logger.error("Retained {} reference(s) to handle {}",
@@ -128,6 +156,10 @@ struct RefCountContext {
128156
}
129157
}
130158

159+
void logInvalidReference(void *ptr) {
160+
context.logger.error("There are no valid references to handle {}", ptr);
161+
}
162+
131163
} refCountContext;
132164

133165
} // namespace ur_validation_layer

0 commit comments

Comments
 (0)