Skip to content

Commit a13f72a

Browse files
authored
[SYCLomatic] Fix the issue that kernel function reference not replaced by wrapper and wrapper_register when used as function pointer (#2765)
Signed-off-by: intwanghao <hao3.wang@intel.com>
1 parent 32d47ab commit a13f72a

File tree

3 files changed

+74
-25
lines changed

3 files changed

+74
-25
lines changed

clang/lib/DPCT/RulesLang/RulesLang.cpp

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4512,10 +4512,14 @@ void StreamAPICallRule::runRule(const MatchFinder::MatchResult &Result) {
45124512
}
45134513

45144514
void KernelCallRefRule::registerMatcher(ast_matchers::MatchFinder &MF) {
4515-
MF.addMatcher(declRefExpr(allOf(to(functionDecl(hasAttr(attr::CUDAGlobal))),
4516-
unless(hasAncestor(cudaKernelCallExpr()))))
4517-
.bind("kernelRef"),
4518-
this);
4515+
MF.addMatcher(
4516+
functionDecl(
4517+
forEachDescendant(
4518+
declRefExpr(allOf(to(functionDecl(hasAttr(attr::CUDAGlobal))),
4519+
unless(hasAncestor(cudaKernelCallExpr()))))
4520+
.bind("kernelRef")))
4521+
.bind("outerFunc"),
4522+
this);
45194523
MF.addMatcher(unresolvedLookupExpr(unless(hasAncestor(cudaKernelCallExpr())))
45204524
.bind("unresolvedRef"),
45214525
this);
@@ -4591,6 +4595,11 @@ void KernelCallRefRule::insertWrapperPostfix(const T *Node,
45914595
void KernelCallRefRule::runRule(
45924596
const ast_matchers::MatchFinder::MatchResult &Result) {
45934597
if (auto DRE = getAssistNodeAsType<DeclRefExpr>(Result, "kernelRef")) {
4598+
const FunctionDecl *OuterFD =
4599+
getAssistNodeAsType<FunctionDecl>(Result, "outerFunc");
4600+
if (!OuterFD) {
4601+
return;
4602+
}
45944603
if (auto ParentCE = DpctGlobalInfo::findAncestor<CallExpr>(DRE)) {
45954604
if (auto Callee = ParentCE->getDirectCallee()) {
45964605
if (dpct::DpctGlobalInfo::isInCudaPath(Callee->getBeginLoc())) {
@@ -4614,31 +4623,25 @@ void KernelCallRefRule::runRule(
46144623
DFI->collectInfoForWrapper(FD);
46154624
}
46164625
}
4617-
if (auto *OuterFD = DpctGlobalInfo::findAncestor<FunctionDecl>(DRE)) {
4618-
if ((OuterFD->getTemplatedKind() ==
4619-
FunctionDecl::TemplatedKind::TK_NonTemplate) ||
4620-
(OuterFD->getTemplatedKind() ==
4621-
FunctionDecl::TemplatedKind::TK_FunctionTemplate)) {
4622-
std::string TypeRepl;
4623-
if (DpctGlobalInfo::isCVersionCUDALaunchUsed()) {
4624-
if ((IsTemplateRelated &&
4625-
(!DRE->hasExplicitTemplateArgs() ||
4626-
(DRE->getNumTemplateArgs() <= TemplateParamNum))) ||
4627-
DRE->hadMultipleCandidates()) {
4628-
TypeRepl = getTypeRepl(DRE);
4629-
}
4626+
if ((OuterFD->getTemplatedKind() ==
4627+
FunctionDecl::TemplatedKind::TK_NonTemplate) ||
4628+
(OuterFD->getTemplatedKind() ==
4629+
FunctionDecl::TemplatedKind::TK_FunctionTemplate)) {
4630+
std::string TypeRepl;
4631+
if (DpctGlobalInfo::isCVersionCUDALaunchUsed()) {
4632+
if ((IsTemplateRelated &&
4633+
(!DRE->hasExplicitTemplateArgs() ||
4634+
(DRE->getNumTemplateArgs() <= TemplateParamNum))) ||
4635+
DRE->hadMultipleCandidates()) {
4636+
TypeRepl = getTypeRepl(DRE);
46304637
}
4631-
insertWrapperPostfix<DeclRefExpr>(
4632-
DRE, std::move(TypeRepl),
4633-
DpctGlobalInfo::isCVersionCUDALaunchUsed());
46344638
}
4639+
insertWrapperPostfix<DeclRefExpr>(
4640+
DRE, std::move(TypeRepl), DpctGlobalInfo::isCVersionCUDALaunchUsed());
46354641
}
46364642
}
46374643
if (auto ULE =
46384644
getAssistNodeAsType<UnresolvedLookupExpr>(Result, "unresolvedRef")) {
4639-
if (!DpctGlobalInfo::isCVersionCUDALaunchUsed()) {
4640-
return;
4641-
}
46424645
bool KernelRefFound = false;
46434646
for (auto *D : ULE->decls()) {
46444647
const FunctionDecl *FD = dyn_cast<FunctionDecl>(D);
@@ -4670,7 +4673,8 @@ void KernelCallRefRule::runRule(
46704673
}
46714674
}
46724675
}
4673-
insertWrapperPostfix<UnresolvedLookupExpr>(ULE, getTypeRepl(ULE), true);
4676+
insertWrapperPostfix<UnresolvedLookupExpr>(
4677+
ULE, getTypeRepl(ULE), DpctGlobalInfo::isCVersionCUDALaunchUsed());
46744678
}
46754679
}
46764680

clang/test/dpct/function_pointer2.cu

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// RUN: dpct --format-range=none -out-root %T/function_pointer2 %s --cuda-include-path="%cuda-path/include" -- -x cuda --cuda-host-only
2+
// RUN: FileCheck --input-file %T/function_pointer2/function_pointer2.dp.cpp --match-full-lines %s
3+
// RUN: %if build_lit %{icpx -c -fsycl %T/function_pointer2/function_pointer2.dp.cpp -o %T/function_pointer2/function_pointer2.dp.o %}
4+
5+
#include <cuda_runtime.h>
6+
#include <iostream>
7+
8+
template<typename T>
9+
__global__ static inline void vectorTemplateAdd(const T *A, T *B, T *C, int N) {
10+
int i = blockIdx.x * blockDim.x + threadIdx.x;
11+
if (i < N) {
12+
C[i] = A[i] + B[i];
13+
}
14+
}
15+
16+
template <typename T>
17+
using fpt = void(*)(const T *, T*, T*, int);
18+
19+
template<typename T>
20+
void foo() {
21+
int *d_A, *d_B, *d_C;
22+
// CHECK: fpt<T> fp = &vectorTemplateAdd_wrapper<T>;
23+
fpt<T> fp = &vectorTemplateAdd<T>;
24+
// CHECK: dpct::kernel_launcher::launch(fp, 1, 10, 0, 0, d_A, d_B, d_C, 10);
25+
fp<<<1, 10>>>(d_A, d_B, d_C, 10);
26+
}
27+
28+
static __global__ void setup_kernel(int p){}
29+
30+
template<typename T>
31+
void goo();
32+
33+
template<typename T>
34+
void goo() {
35+
// CHECK: auto a = (void *)setup_kernel_wrapper;
36+
auto a = (void *)setup_kernel;
37+
}
38+
39+
template void goo<int>();
40+
41+
int main() {
42+
foo<int>();
43+
std::cout << "test success" << std::endl;
44+
return 0;
45+
}

clang/test/dpct/kernel_without_name.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ template <typename T> struct kernel_type_t {
293293
// CHECK-NEXT: int b,
294294
// CHECK-NEXT: Tk *mem) {
295295
// CHECK-NEXT: using Tk = typename kernel_type_t<T>::Type;
296-
template <typename T> __global__
296+
template <typename T> __device__
297297
void foo_device7(int a,
298298
int b) {
299299
using Tk = typename kernel_type_t<T>::Type;

0 commit comments

Comments
 (0)