Skip to content

Commit c03712d

Browse files
authored
[SYCLomatic] Fix migration issue for free-function-queries (#2804)
Signed-off-by: Jiang, Zhiwei <zhiwei.jiang@intel.com>
1 parent 0c9fd43 commit c03712d

File tree

8 files changed

+113
-78
lines changed

8 files changed

+113
-78
lines changed

clang/lib/DPCT/AnalysisInfo.cpp

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -326,16 +326,22 @@ class FreeQueriesInfo {
326326

327327
static const FreeQueriesNames &getNames(FreeQueriesKind);
328328
static std::shared_ptr<FreeQueriesInfo> getInfo(const FunctionDecl *);
329-
static void printFreeQueriesFunctionName(llvm::raw_ostream &OS,
330-
FreeQueriesKind K,
331-
unsigned Dimension) {
329+
template <typename T>
330+
static typename std::enable_if<std::is_same_v<T, unsigned> ||
331+
std::is_same_v<T, std::string>>::type
332+
printFreeQueriesFunctionName(llvm::raw_ostream &OS, FreeQueriesKind K,
333+
T Dimension) {
332334
OS << getNames(K).FreeQueriesFuncName;
333335
if (K != FreeQueriesKind::SubGroup) {
334336
OS << '<';
335-
if (Dimension) {
336-
OS << Dimension;
337+
if constexpr (std::is_same_v<T, unsigned>) {
338+
if (Dimension) {
339+
OS << Dimension;
340+
} else {
341+
OS << "dpct_placeholder /* Fix the dimension manually */";
342+
}
337343
} else {
338-
OS << "dpct_placeholder /* Fix the dimension manually */";
344+
OS << Dimension;
339345
}
340346
OS << '>';
341347
}
@@ -7515,11 +7521,13 @@ void FreeQueriesInfo::printImmediateText(llvm::raw_ostream &OS, const Node *S,
75157521
return Info->printImmediateText(OS, S->getBeginLoc(), K);
75167522
}
75177523

7518-
#ifdef DPCT_DEBUG_BUILD
7519-
llvm::errs() << "Can not get FreeQueriesInfo for this FunctionDecl\n";
7520-
assert(0);
7521-
#endif // DPCT_DEBUG_BUILD
7522-
7524+
auto DFI = DeviceFunctionDecl::LinkRedecls(FD);
7525+
if (!DFI)
7526+
return;
7527+
auto Index = DpctGlobalInfo::getCudaKernelDimDFIIndexThenInc();
7528+
DpctGlobalInfo::insertCudaKernelDimDFIMap(Index, DFI);
7529+
printFreeQueriesFunctionName<std::string>(
7530+
OS, K, "{{NEEDREPLACEG" + std::to_string(Index) + "}}");
75237531
} else {
75247532
if (auto DFI = DeviceFunctionDecl::LinkRedecls(FD))
75257533
DFI->setItem();
@@ -7549,6 +7557,7 @@ void FreeQueriesInfo::printImmediateText(llvm::raw_ostream &OS,
75497557
(*Iter)->Infos.push_back(Idx);
75507558
Index = Iter - MacroInfos.begin();
75517559
} else {
7560+
IsMacro = false;
75527561
auto SLocInfo = DpctGlobalInfo::getLocInfo(SL);
75537562
if (SLocInfo.first != FilePath)
75547563
return;
@@ -7579,7 +7588,7 @@ void FreeQueriesInfo::emplaceExtraDecl() {
75797588
auto &KindNames =
75807589
getNames(static_cast<FreeQueriesKind>(FreeQueriesKind::NdItem));
75817590
OS << "auto " << KindNames.ExtraVariableName << " = ";
7582-
printFreeQueriesFunctionName(
7591+
printFreeQueriesFunctionName<unsigned>(
75837592
OS, static_cast<FreeQueriesKind>(FreeQueriesKind::NdItem), Dimension);
75847593
OS << ';' << NL << Indent;
75857594
}
@@ -7593,28 +7602,29 @@ std::string FreeQueriesInfo::getReplaceString(unsigned Num) {
75937602
bool IsMacro = isMacro(Num);
75947603
if (IsMacro) {
75957604
if (Index < MacroInfos.size()) {
7596-
return buildStringFromPrinter(printFreeQueriesFunctionName, Kind,
7597-
MacroInfos[Index]->Dimension);
7605+
return buildStringFromPrinter(printFreeQueriesFunctionName<unsigned>,
7606+
Kind, MacroInfos[Index]->Dimension);
75987607
}
75997608
#ifdef DPCT_DEBUG_BUILD
76007609
llvm::errs() << "FreeQueriesInfo index[" << Index
7601-
<< "]is larger than list size[" << InfoList.size() << "]\n";
7610+
<< "] is larger than list size[" << MacroInfos.size() << "]\n";
76027611
assert(0);
76037612
#endif // DPCT_DEBUG_BUILD
76047613
}
76057614
if (Index < InfoList.size())
76067615
return InfoList[Index]->getReplaceString(getKind(Num));
76077616
#ifdef DPCT_DEBUG_BUILD
76087617
llvm::errs() << "FreeQueriesInfo index[" << Index
7609-
<< "]is larger than list size[" << InfoList.size() << "]\n";
7618+
<< "] is larger than list size[" << InfoList.size() << "]\n";
76107619
assert(0);
76117620
#endif // DPCT_DEBUG_BUILD
76127621
return "";
76137622
}
76147623

76157624
std::string FreeQueriesInfo::getReplaceString(FreeQueriesKind K) {
76167625
if (K != FreeQueriesKind::NdItem || Counter[K] < 2)
7617-
return buildStringFromPrinter(printFreeQueriesFunctionName, K, Dimension);
7626+
return buildStringFromPrinter(printFreeQueriesFunctionName<unsigned>, K,
7627+
Dimension);
76187628
else
76197629
return getNames(K).ExtraVariableName;
76207630
}

clang/test/dpct/builtin_warpSize.cu

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: dpct --no-dpcpp-extensions=free-function-queries --format-range=none -out-root %T/builtin_warpSize %s --cuda-include-path="%cuda-path/include" -- -x cuda --cuda-host-only
1+
// RUN: dpct --format-range=none -out-root %T/builtin_warpSize %s --cuda-include-path="%cuda-path/include" -- -x cuda --cuda-host-only
22
// RUN: FileCheck --input-file %T/builtin_warpSize/builtin_warpSize.dp.cpp --match-full-lines %s
33
// RUN: %if build_lit %{icpx -c -fsycl %T/builtin_warpSize/builtin_warpSize.dp.cpp -o %T/builtin_warpSize/builtin_warpSize.dp.o %}
44

@@ -8,7 +8,7 @@
88

99

1010
__global__ void foo(){
11-
// CHECK: int a = item_ct1.get_sub_group().get_local_range().get(0);
11+
// CHECK: int a = sycl::ext::oneapi::this_work_item::get_sub_group().get_local_range().get(0);
1212
// CHECK-NEXT: int warpSize = 1;
1313
// CHECK-NEXT: warpSize = 2;
1414
// CHECK-NEXT: int c= warpSize;
@@ -18,8 +18,8 @@ __global__ void foo(){
1818
int c= warpSize;
1919
}
2020

21-
// CHECK: void bar(const sycl::nd_item<3> &item_ct1){
22-
// CHECK-NEXT: int a = sycl::max((int)item_ct1.get_sub_group().get_local_range().get(0), 0);
21+
// CHECK: void bar(){
22+
// CHECK-NEXT: int a = sycl::max((int)sycl::ext::oneapi::this_work_item::get_sub_group().get_local_range().get(0), 0);
2323
// CHECK-NEXT: int warpSize = 1;
2424
// CHECK-NEXT: int b = sycl::max(warpSize, 0);
2525
// CHECK-NEXT: }
@@ -29,8 +29,8 @@ __global__ void bar(){
2929
int b = max(warpSize, 0);
3030
}
3131

32-
// CHECK: int tensorPos(const int ct, const sycl::nd_item<3> &item_ct1, int numLane = 0) {
33-
// CHECK-NEXT: if (!numLane) numLane = item_ct1.get_sub_group().get_local_range().get(0);
32+
// CHECK: int tensorPos(const int ct, int numLane = 0) {
33+
// CHECK-NEXT: if (!numLane) numLane = sycl::ext::oneapi::this_work_item::get_sub_group().get_local_range().get(0);
3434
// CHECK-NEXT: int r = ct * numLane;
3535
// CHECK-NEXT: return r;
3636
// CHECK-NEXT: }
@@ -39,18 +39,18 @@ __device__ int tensorPos(const int ct, const int numLane = warpSize) {
3939
return r;
4040
}
4141

42-
// CHECK: int tensorPos(const int ct, const sycl::nd_item<3> &item_ct1, int numLane);
42+
// CHECK: int tensorPos(const int ct, int numLane);
4343
__device__ int tensorPos(const int ct, const int numLane);
4444

4545

4646

4747

4848

49-
// CHECK: int tensorPos2(const int ct, const sycl::nd_item<3> &item_ct1, int numLane);
49+
// CHECK: int tensorPos2(const int ct, int numLane);
5050
__device__ int tensorPos2(const int ct, const int numLane);
5151

52-
// CHECK: int tensorPos2(const int ct, const sycl::nd_item<3> &item_ct1, int numLane) {
53-
// CHECK-NEXT: if (!numLane) numLane = item_ct1.get_sub_group().get_local_range().get(0);
52+
// CHECK: int tensorPos2(const int ct, int numLane) {
53+
// CHECK-NEXT: if (!numLane) numLane = sycl::ext::oneapi::this_work_item::get_sub_group().get_local_range().get(0);
5454
// CHECK-NEXT: int r = ct * numLane;
5555
// CHECK-NEXT: return r;
5656
// CHECK-NEXT: }
@@ -59,9 +59,9 @@ __device__ int tensorPos2(const int ct, const int numLane) {
5959
return r;
6060
}
6161

62-
// CHECK: int tensorPos2(const int ct, const sycl::nd_item<3> &item_ct1, int numLane = 0);
62+
// CHECK: int tensorPos2(const int ct, int numLane = 0);
6363
__device__ int tensorPos2(const int ct, const int numLane = warpSize);
6464

6565

66-
// CHECK: int tensorPos3(const int ct, const sycl::nd_item<3> &item_ct1, int numLane = 0) {}
66+
// CHECK: int tensorPos3(const int ct, int numLane = 0) {}
6767
__device__ int tensorPos3(const int ct, const int numLane = warpSize) {}

clang/test/dpct/macro_test.cu

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
// RUN: cd %T
66
// RUN: rm -rf %T/macro_test_output
77
// RUN: mkdir %T/macro_test_output
8-
// RUN: dpct --no-dpcpp-extensions=free-function-queries -out-root %T/macro_test_output macro_test.cu --cuda-include-path="%cuda-path/include" -- -x cuda --cuda-host-only
8+
// RUN: dpct -out-root %T/macro_test_output macro_test.cu --cuda-include-path="%cuda-path/include" -- -x cuda --cuda-host-only
99
// RUN: FileCheck --input-file %T/macro_test_output/macro_test.dp.cpp --match-full-lines macro_test.cu
1010
// RUN: %if build_lit %{icpx -c -fsycl -DNO_BUILD_TEST %T/macro_test_output/macro_test.dp.cpp -o %T/macro_test_output/macro_test.dp.o %}
1111
// RUN: FileCheck --input-file %T/macro_test_output/macro_test.h --match-full-lines macro_test.h
@@ -67,8 +67,8 @@ __global__ void foo_kernel() {}
6767
//CHECK-NEXT: #ifdef MACRO_CC
6868
//CHECK-NEXT: , int c
6969
//CHECK-NEXT: #endif
70-
//CHECK-NEXT: , const sycl::nd_item<3> &item_ct1) {
71-
//CHECK-NEXT: int x = item_ct1.get_group(2);
70+
//CHECK-NEXT: ) {
71+
//CHECK-NEXT: int x = sycl::ext::oneapi::this_work_item::get_nd_item<3>().get_group(2);
7272
//CHECK-NEXT: }
7373
__global__ void foo_kernel2(int a, int b
7474
#ifdef MACRO_CC
@@ -288,7 +288,7 @@ int b;
288288
//CHECK-NEXT: sycl::nd_range<3>(sycl::range<3>(1, 1, 2) * sycl::range<3>(1, 1, 2),
289289
//CHECK-NEXT: sycl::range<3>(1, 1, 2)),
290290
//CHECK-NEXT: [=](sycl::nd_item<3> item_ct1) {
291-
//CHECK-NEXT: foo_kernel2(3, 3, item_ct1);
291+
//CHECK-NEXT: foo_kernel2(3, 3);
292292
//CHECK-NEXT: });
293293
foo_kernel2<<<2, 2, 0>>>(3,3
294294
#ifdef MACRO_CC
@@ -435,13 +435,14 @@ FFF
435435

436436
}
437437

438-
// CHECK: #define FFFFF(aaa,bbb) void foo4(const int * __restrict__ aaa, const float * __restrict__ bbb, int *c, BBB, const sycl::nd_item<3> &item_ct1, float *sp_lj, float *sp_coul, int *ljd, double la[8][1])
438+
// CHECK: #define FFFFF(aaa,bbb) void foo4(const int * __restrict__ aaa, const float * __restrict__ bbb, int *c, BBB, float *sp_lj, float *sp_coul, int *ljd, double la[8][1])
439439
#define FFFFF(aaa,bbb) __device__ void foo4(const int * __restrict__ aaa, const float * __restrict__ bbb, int *c, BBB)
440440

441441
// CHECK: FFFFF(pos, q)
442442
// CHECK-NEXT: {
443443
// CHECK-EMPTY:
444-
// CHECK-NEXT: const int tid = item_ct1.get_local_id(2);
444+
// CHECK-NEXT: const int tid =
445+
// CHECK-NEXT: sycl::ext::oneapi::this_work_item::get_nd_item<3>().get_local_id(2);
445446
// CHECK-NEXT: }
446447
FFFFF(pos, q)
447448
{
@@ -452,13 +453,14 @@ FFFFF(pos, q)
452453
const int tid = threadIdx.x;
453454
}
454455

455-
// CHECK: #define FFFFFF(aaa,bbb) void foo5(const int * __restrict__ aaa, const float * __restrict__ bbb, const sycl::nd_item<3> &item_ct1, float *sp_lj, float *sp_coul, int *ljd, double la[8][1])
456+
// CHECK: #define FFFFFF(aaa,bbb) void foo5(const int * __restrict__ aaa, const float * __restrict__ bbb, float *sp_lj, float *sp_coul, int *ljd, double la[8][1])
456457
#define FFFFFF(aaa,bbb) __device__ void foo5(const int * __restrict__ aaa, const float * __restrict__ bbb)
457458

458459
// CHECK: FFFFFF(pos, q)
459460
// CHECK-NEXT: {
460461
// CHECK-EMPTY:
461-
// CHECK-NEXT: const int tid = item_ct1.get_local_id(2);
462+
// CHECK-NEXT: const int tid =
463+
// CHECK-NEXT: sycl::ext::oneapi::this_work_item::get_nd_item<3>().get_local_id(2);
462464
// CHECK-NEXT: }
463465
FFFFFF(pos, q)
464466
{
@@ -483,9 +485,13 @@ __device__ void foo6(AAA, BBB)
483485

484486
//CHECK: #define MM __umul24
485487
//CHECK-NEXT: #define MUL(a, b) sycl::mul24((unsigned int)a, (unsigned int)b)
486-
//CHECK-NEXT: void foo7(const sycl::nd_item<3> &item_ct1) {
487-
//CHECK-NEXT: unsigned int tid = MUL(item_ct1.get_local_range(2), item_ct1.get_group(2)) +
488-
//CHECK-NEXT: item_ct1.get_local_range(2);
488+
//CHECK-NEXT: void foo7() {
489+
//CHECK-NEXT: auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
490+
//CHECK-NEXT: unsigned int tid =
491+
//CHECK-NEXT: MUL(sycl::ext::oneapi::this_work_item::get_nd_item<3>().get_local_range(
492+
//CHECK-NEXT: 2),
493+
//CHECK-NEXT: sycl::ext::oneapi::this_work_item::get_nd_item<3>().get_group(2)) +
494+
//CHECK-NEXT: item_ct1.get_local_range(2);
489495
//CHECK-NEXT: unsigned int tid2 = sycl::mul24((unsigned int)item_ct1.get_local_range(2),
490496
//CHECK-NEXT: (unsigned int)item_ct1.get_group_range(2));
491497
//CHECK-NEXT: }
@@ -573,7 +579,7 @@ void templatefoo2(){
573579
CALL_KERNEL2(8, AAA)
574580
}
575581

576-
//CHECK: void foo11(const sycl::nd_item<3> &item_ct1){
582+
//CHECK: void foo11(){
577583
//CHECK-NEXT: sycl::exp((double)(THREAD_IDX_X));
578584
//CHECK-NEXT: }
579585
__global__ void foo11(){
@@ -915,13 +921,14 @@ void foo20() {
915921
}
916922

917923
//CHECK: #define CALLSHFLSYNC(x) \
918-
//CHECK-NEXT: dpct::select_from_sub_group(item_ct1.get_sub_group(), x, 3 ^ 1);
924+
//CHECK-NEXT: dpct::select_from_sub_group( \
925+
//CHECK-NEXT: sycl::ext::oneapi::this_work_item::get_sub_group(), x, 3 ^ 1);
919926
#define CALLSHFLSYNC(x) __shfl_sync(0xffffffff, x, 3 ^ 1);
920927
//CHECK: #define CALLANYSYNC(x) \
921928
//CHECK-NEXT: sycl::any_of_group( \
922-
//CHECK-NEXT: item_ct1.get_sub_group(), \
923-
//CHECK-NEXT: (0xffffffff & \
924-
//CHECK-NEXT: (0x1 << item_ct1.get_sub_group().get_local_linear_id())) && \
929+
//CHECK-NEXT: sycl::ext::oneapi::this_work_item::get_sub_group(), \
930+
//CHECK-NEXT: (0xffffffff & (0x1 << sycl::ext::oneapi::this_work_item::get_sub_group() \
931+
//CHECK-NEXT: .get_local_linear_id())) && \
925932
//CHECK-NEXT: x != 0.0f);
926933
#define CALLANYSYNC(x) __any_sync(0xffffffff, x != 0.0f);
927934

@@ -964,7 +971,8 @@ foo23(void)
964971
}
965972

966973
//CHECK: #define SHFL(x, y, z) \
967-
//CHECK-NEXT: dpct::select_from_sub_group(item_ct1.get_sub_group(), (x), (y), (z))
974+
//CHECK-NEXT: dpct::select_from_sub_group( \
975+
//CHECK-NEXT: sycl::ext::oneapi::this_work_item::get_sub_group(), (x), (y), (z))
968976
#define SHFL(x, y, z) __shfl((x), (y), (z))
969977
__global__ void foo24(){
970978
int i;

clang/test/dpct/macro_test.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
//CHECK: #define THREAD_IDX_X item_ct1.get_local_id(2)
1+
//CHECK: #define THREAD_IDX_X \
2+
//CHECK-NEXT: sycl::ext::oneapi::this_work_item::get_nd_item<3>().get_local_id(2)
23
#define THREAD_IDX_X threadIdx.x
34

45
#define STRINGIFY_(...) #__VA_ARGS__

0 commit comments

Comments
 (0)