Skip to content

Commit 2b2e515

Browse files
authored
[SYCL] Duplicate address-taken functions in split device modules. (#6452)
* [SYCL] Duplicate address-taken functions in split device code modules. We conservatively assume that address-taken function can be used in any split module, so it needs to be duplicated. Signed-off-by: Konstantin S Bobrovsky <konstantin.s.bobrovsky@intel.com>
1 parent a5485f4 commit 2b2e515

File tree

2 files changed

+100
-1
lines changed

2 files changed

+100
-1
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
; This test checks that functions which are neither SYCL external functions nor
2+
; part of any call graph, but have their address taken, are retained in split
3+
; modules.
4+
5+
; -- Per-source split
6+
; RUN: sycl-post-link -split=source -emit-param-info -symbols -emit-exported-symbols -split-esimd -lower-esimd -O2 -spec-const=rt -S %s -o %tA.table
7+
; RUN: FileCheck %s -input-file=%tA_0.ll --check-prefixes CHECK-A0
8+
; RUN: FileCheck %s -input-file=%tA_1.ll --check-prefixes CHECK-A1
9+
; -- No split
10+
; RUN: sycl-post-link -emit-param-info -symbols -emit-exported-symbols -split-esimd -lower-esimd -O2 -spec-const=rt -S %s -o %tB.table
11+
; RUN: FileCheck %s -input-file=%tB_0.ll --check-prefixes CHECK-B0
12+
; -- Per-kernel split
13+
; RUN: sycl-post-link -split=kernel -emit-param-info -symbols -emit-exported-symbols -split-esimd -lower-esimd -O2 -spec-const=rt -S %s -o %tC.table
14+
; RUN: FileCheck %s -input-file=%tC_0.ll --check-prefixes CHECK-C0
15+
; RUN: FileCheck %s -input-file=%tC_1.ll --check-prefixes CHECK-C1
16+
17+
18+
; ModuleID = 'in.bc'
19+
source_filename = "llvm-link"
20+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
21+
target triple = "spir64-unknown-unknown"
22+
23+
$foo = comdat any
24+
$bar = comdat any
25+
26+
@"tableX" = weak global [1 x void ()*] [void ()* @foo], align 8
27+
@"tableY" = weak global [1 x void ()*] [void ()* @bar], align 8
28+
29+
30+
; Function Attrs: mustprogress norecurse nounwind
31+
define linkonce_odr dso_local spir_func void @foo() unnamed_addr #0 comdat align 2 {
32+
; CHECK-A0: define linkonce_odr dso_local spir_func void @foo
33+
; CHECK-A1: define linkonce_odr dso_local spir_func void @foo
34+
; CHECK-B0: define linkonce_odr dso_local spir_func void @foo
35+
; CHECK-B1: define linkonce_odr dso_local spir_func void @foo
36+
; CHECK-C0: define linkonce_odr dso_local spir_func void @foo
37+
; CHECK-C1: define linkonce_odr dso_local spir_func void @foo
38+
ret void
39+
}
40+
41+
; Function Attrs: mustprogress norecurse nounwind
42+
define linkonce_odr dso_local spir_func void @bar() unnamed_addr #1 comdat align 2 {
43+
; CHECK-A0: define linkonce_odr dso_local spir_func void @bar
44+
; CHECK-A1: define linkonce_odr dso_local spir_func void @bar
45+
; CHECK-B0: define linkonce_odr dso_local spir_func void @bar
46+
; CHECK-B1: define linkonce_odr dso_local spir_func void @bar
47+
; CHECK-C0: define linkonce_odr dso_local spir_func void @bar
48+
; CHECK-C1: define linkonce_odr dso_local spir_func void @bar
49+
ret void
50+
}
51+
52+
define weak_odr dso_local spir_kernel void @Kernel1() #2 {
53+
ret void
54+
}
55+
56+
define weak_odr dso_local spir_kernel void @Kernel2() #3 {
57+
ret void
58+
}
59+
60+
attributes #0 = { mustprogress norecurse nounwind "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "vector_function_ptrs"="tableX()" }
61+
attributes #1 = { mustprogress norecurse nounwind "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "vector_function_ptrs"="tableY()" }
62+
attributes #2 = { "sycl-module-id"="module1.cpp" "uniform-work-group-size"="true" }
63+
attributes #3 = { "sycl-module-id"="module2.cpp" "uniform-work-group-size"="true" }
64+
65+
66+
!0 = !{i32 1, i32 2}
67+
!1 = !{i32 4, i32 100000}
68+
!2 = !{}
69+
!3 = !{!"<ID>"}
70+
!4 = !{i32 1, !"wchar_size", i32 4}
71+
!5 = !{i32 7, !"frame-pointer", i32 2}
72+
!6 = !{!7, !8, i64 8}
73+
!7 = !{!"_ZTS4Base", !8, i64 8}
74+
!8 = !{!"int", !9, i64 0}
75+
!9 = !{!"omnipotent char", !10, i64 0}
76+
!10 = !{!"Simple C++ TBAA"}

llvm/tools/sycl-post-link/ModuleSplitter.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ bool isSpirvSyclBuiltin(StringRef FName) {
107107
return FName.startswith("__spirv_") || FName.startswith("__sycl_");
108108
}
109109

110+
bool isKernel(const Function &F) {
111+
return F.getCallingConv() == CallingConv::SPIR_KERNEL;
112+
}
113+
110114
bool isEntryPoint(const Function &F, bool EmitOnlyKernelsAsEntryPoints) {
111115
// Skip declarations, if any: they should not be included into a vector of
112116
// entry points groups or otherwise we will end up with incorrectly generated
@@ -115,7 +119,7 @@ bool isEntryPoint(const Function &F, bool EmitOnlyKernelsAsEntryPoints) {
115119
return false;
116120

117121
// Kernels are always considered to be entry points
118-
if (CallingConv::SPIR_KERNEL == F.getCallingConv())
122+
if (isKernel(F))
119123
return true;
120124

121125
if (!EmitOnlyKernelsAsEntryPoints) {
@@ -285,6 +289,7 @@ class CallGraph {
285289
private:
286290
std::unordered_map<const Function *, FunctionSet> Graph;
287291
SmallPtrSet<const Function *, 1> EmptySet;
292+
FunctionSet AddrTakenFunctions;
288293

289294
public:
290295
CallGraph(const Module &M) {
@@ -297,6 +302,9 @@ class CallGraph {
297302
}
298303
}
299304
}
305+
if (F.hasAddressTaken()) {
306+
AddrTakenFunctions.insert(&F);
307+
}
300308
}
301309
}
302310

@@ -307,13 +315,28 @@ class CallGraph {
307315
? make_range(EmptySet.begin(), EmptySet.end())
308316
: make_range(It->second.begin(), It->second.end());
309317
}
318+
319+
iterator_range<FunctionSet::const_iterator> addrTakenFunctions() const {
320+
return make_range(AddrTakenFunctions.begin(), AddrTakenFunctions.end());
321+
}
310322
};
311323

312324
void collectFunctionsToExtract(SetVector<const GlobalValue *> &GVs,
313325
const EntryPointGroup &ModuleEntryPoints,
314326
const CallGraph &Deps) {
315327
for (const auto *F : ModuleEntryPoints.Functions)
316328
GVs.insert(F);
329+
// It is conservatively assumed that any address-taken function can be invoked
330+
// or otherwise used by any function in any module split from the initial one.
331+
// So such functions along with the call graphs they start are always
332+
// extracted (and duplicated in each split module).
333+
// TODO: try to determine which split modules really use address-taken
334+
// functions and only duplicate the functions in such modules. Note that usage
335+
// may include e.g. function address comparison w/o actual invocation.
336+
for (const auto *F : Deps.addrTakenFunctions()) {
337+
if (!isKernel(*F) && (isESIMDFunction(*F) == ModuleEntryPoints.isEsimd()))
338+
GVs.insert(F);
339+
}
317340

318341
// GVs has SetVector type. This type inserts a value only if it is not yet
319342
// present there. So, recursion is not expected here.

0 commit comments

Comments
 (0)