Skip to content

Commit ea9e1bd

Browse files
authored
[SYCLomatic] Fix the issue that cub type not processed in using statement (#2704)
Signed-off-by: intwanghao <hao3.wang@intel.com>
1 parent b2a6445 commit ea9e1bd

File tree

3 files changed

+35
-20
lines changed

3 files changed

+35
-20
lines changed

clang/lib/DPCT/RulesLangLib/CUBAPIMigration.cpp

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,13 @@ void CubRule::registerMatcher(ast_matchers::MatchFinder &MF) {
760760
.bind("TypeDefDecl"),
761761
this);
762762

763+
MF.addMatcher(
764+
typeAliasDecl(
765+
hasType(hasCanonicalType(qualType(hasDeclaration(namedDecl(hasAnyName(
766+
"WarpScan", "WarpReduce", "BlockScan", "BlockReduce")))))))
767+
.bind("UsingDecl"),
768+
this);
769+
763770
auto isTempStorage = hasDeclaration(namedDecl(hasAnyName("TempStorage")));
764771
MF.addMatcher(declStmt(has(varDecl(anyOf(
765772
hasType(hasCanonicalType(qualType(isTempStorage))),
@@ -919,7 +926,7 @@ void CubRule::processCubDeclStmt(const DeclStmt *DS) {
919926
}
920927
}
921928
}
922-
void CubRule::processCubTypeDef(const TypedefDecl *TD) {
929+
void CubRule::processCubTypeDefOrUsing(const TypedefNameDecl *TD) {
923930
auto CanonicalType = TD->getUnderlyingType().getCanonicalType();
924931
std::string CanonicalTypeStr = CanonicalType.getAsString();
925932
if (isTypeInAnalysisScope(CanonicalType.getTypePtr()))
@@ -944,20 +951,9 @@ void CubRule::processCubTypeDef(const TypedefDecl *TD) {
944951
// Currently, typedef decl can be deleted in following cases
945952
for (auto &Element : TypeLocMatchResult) {
946953
if (auto TL = Element.getNodeAs<TypeLoc>("typeLoc")) {
947-
// 1. Used in TempStorage variable declaration
948-
if (auto AncestorVD = DpctGlobalInfo::findAncestor<VarDecl>(TL)) {
949-
auto VarType = AncestorVD->getType().getCanonicalType();
950-
std::string VarTypeStr =
951-
AncestorVD->getType().getCanonicalType().getAsString();
952-
if (isTypeInAnalysisScope(VarType.getTypePtr()) ||
953-
!(VarTypeStr.find("TempStorage") != std::string::npos &&
954-
VarTypeStr.find("struct cub::") == 0)) {
955-
DeleteFlag = false;
956-
break;
957-
}
958-
} // 2. Used in temporary class constructor
959-
else if (auto AncestorMTE =
960-
DpctGlobalInfo::findAncestor<MaterializeTemporaryExpr>(TL)) {
954+
// 1. Used in temporary class constructor
955+
if (auto AncestorMTE =
956+
DpctGlobalInfo::findAncestor<MaterializeTemporaryExpr>(TL)) {
961957
auto MC = DpctGlobalInfo::findAncestor<CXXMemberCallExpr>(AncestorMTE);
962958
if (MC) {
963959
auto ObjType = MC->getObjectType().getCanonicalType();
@@ -968,12 +964,27 @@ void CubRule::processCubTypeDef(const TypedefDecl *TD) {
968964
ObjTypeStr.find("class cub::BlockScan") == 0 ||
969965
ObjTypeStr.find("class cub::BlockReduce") == 0)) {
970966
DeleteFlag = false;
967+
std::cout << "2" << std::endl;
971968
break;
972969
}
973970
}
974-
} // 3. Used in self typedef decl
971+
} // 2. Used in TempStorage variable declaration
972+
else if (auto AncestorVD = DpctGlobalInfo::findAncestor<VarDecl>(TL)) {
973+
auto VarType = AncestorVD->getType().getCanonicalType();
974+
std::string VarTypeStr =
975+
AncestorVD->getType().getCanonicalType().getAsString();
976+
std::cout << VarTypeStr << std::endl;
977+
if (isTypeInAnalysisScope(VarType.getTypePtr()) ||
978+
!(VarTypeStr.find("TempStorage") != std::string::npos &&
979+
VarTypeStr.find("struct cub::") == 0)) {
980+
DeleteFlag = false;
981+
std::cout << "1" << std::endl;
982+
break;
983+
}
984+
}
985+
// 3. Used in self typedef decl
975986
else if (auto AncestorTD =
976-
DpctGlobalInfo::findAncestor<TypedefDecl>(TL)) {
987+
DpctGlobalInfo::findAncestor<TypedefNameDecl>(TL)) {
977988
if (AncestorTD != TD) {
978989
DeleteFlag = false;
979990
break;
@@ -1686,7 +1697,11 @@ void CubRule::runRule(const ast_matchers::MatchFinder::MatchResult &Result) {
16861697
processCubFuncCall(CE, true);
16871698
} else if (const TypedefDecl *TD =
16881699
getNodeAsType<TypedefDecl>(Result, "TypeDefDecl")) {
1689-
processCubTypeDef(TD);
1700+
processCubTypeDefOrUsing(TD);
1701+
} else if (const TypeAliasDecl *TAD =
1702+
getNodeAsType<TypeAliasDecl>(Result, "UsingDecl")) {
1703+
// std::cout << "found" << std::endl;
1704+
processCubTypeDefOrUsing(TAD);
16901705
} else if (auto TL = getNodeAsType<TypeLoc>(Result, "cudaTypeDef")) {
16911706
processTypeLoc(TL);
16921707
} else if (auto *UDD = getNodeAsType<UsingDirectiveDecl>(

clang/lib/DPCT/RulesLangLib/CUBAPIMigration.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class CubRule : public NamedMigrationRule<CubRule> {
110110
static int PlaceholderIndex;
111111
std::string getOpRepl(const Expr *Operator);
112112
void processCubDeclStmt(const DeclStmt *DS);
113-
void processCubTypeDef(const TypedefDecl *TD);
113+
void processCubTypeDefOrUsing(const TypedefNameDecl *TD);
114114
void processCubFuncCall(const CallExpr *CE, bool FuncCallUsed = false);
115115
void processCubMemberCall(const CXXMemberCallExpr *MC);
116116
void processTypeLoc(const TypeLoc *TL);

clang/test/dpct/cub/blocklevel/blockreduce.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ __global__ void SumKernel(int* data) {
5959
//CHECK-NEXT: data[threadid] = output;
6060
//CHECK-NEXT:}
6161
__global__ void ReduceKernel(int* data) {
62-
typedef cub::BlockReduce<int, 4> BlockReduce;
62+
using BlockReduce = cub::BlockReduce<int, 4>;
6363

6464
__shared__ typename BlockReduce::TempStorage temp1;
6565

0 commit comments

Comments
 (0)