@@ -760,6 +760,13 @@ void CubRule::registerMatcher(ast_matchers::MatchFinder &MF) {
760
760
.bind (" TypeDefDecl" ),
761
761
this );
762
762
763
+ MF.addMatcher (
764
+ typeAliasDecl (
765
+ hasType (hasCanonicalType (qualType (hasDeclaration (namedDecl (hasAnyName (
766
+ " WarpScan" , " WarpReduce" , " BlockScan" , " BlockReduce" )))))))
767
+ .bind (" UsingDecl" ),
768
+ this );
769
+
763
770
auto isTempStorage = hasDeclaration (namedDecl (hasAnyName (" TempStorage" )));
764
771
MF.addMatcher (declStmt (has (varDecl (anyOf (
765
772
hasType (hasCanonicalType (qualType (isTempStorage))),
@@ -919,7 +926,7 @@ void CubRule::processCubDeclStmt(const DeclStmt *DS) {
919
926
}
920
927
}
921
928
}
922
- void CubRule::processCubTypeDef (const TypedefDecl *TD) {
929
+ void CubRule::processCubTypeDefOrUsing (const TypedefNameDecl *TD) {
923
930
auto CanonicalType = TD->getUnderlyingType ().getCanonicalType ();
924
931
std::string CanonicalTypeStr = CanonicalType.getAsString ();
925
932
if (isTypeInAnalysisScope (CanonicalType.getTypePtr ()))
@@ -944,20 +951,9 @@ void CubRule::processCubTypeDef(const TypedefDecl *TD) {
944
951
// Currently, typedef decl can be deleted in following cases
945
952
for (auto &Element : TypeLocMatchResult) {
946
953
if (auto TL = Element.getNodeAs <TypeLoc>(" typeLoc" )) {
947
- // 1. Used in TempStorage variable declaration
948
- if (auto AncestorVD = DpctGlobalInfo::findAncestor<VarDecl>(TL)) {
949
- auto VarType = AncestorVD->getType ().getCanonicalType ();
950
- std::string VarTypeStr =
951
- AncestorVD->getType ().getCanonicalType ().getAsString ();
952
- if (isTypeInAnalysisScope (VarType.getTypePtr ()) ||
953
- !(VarTypeStr.find (" TempStorage" ) != std::string::npos &&
954
- VarTypeStr.find (" struct cub::" ) == 0 )) {
955
- DeleteFlag = false ;
956
- break ;
957
- }
958
- } // 2. Used in temporary class constructor
959
- else if (auto AncestorMTE =
960
- DpctGlobalInfo::findAncestor<MaterializeTemporaryExpr>(TL)) {
954
+ // 1. Used in temporary class constructor
955
+ if (auto AncestorMTE =
956
+ DpctGlobalInfo::findAncestor<MaterializeTemporaryExpr>(TL)) {
961
957
auto MC = DpctGlobalInfo::findAncestor<CXXMemberCallExpr>(AncestorMTE);
962
958
if (MC) {
963
959
auto ObjType = MC->getObjectType ().getCanonicalType ();
@@ -968,12 +964,27 @@ void CubRule::processCubTypeDef(const TypedefDecl *TD) {
968
964
ObjTypeStr.find (" class cub::BlockScan" ) == 0 ||
969
965
ObjTypeStr.find (" class cub::BlockReduce" ) == 0 )) {
970
966
DeleteFlag = false ;
967
+ std::cout << " 2" << std::endl;
971
968
break ;
972
969
}
973
970
}
974
- } // 3. Used in self typedef decl
971
+ } // 2. Used in TempStorage variable declaration
972
+ else if (auto AncestorVD = DpctGlobalInfo::findAncestor<VarDecl>(TL)) {
973
+ auto VarType = AncestorVD->getType ().getCanonicalType ();
974
+ std::string VarTypeStr =
975
+ AncestorVD->getType ().getCanonicalType ().getAsString ();
976
+ std::cout << VarTypeStr << std::endl;
977
+ if (isTypeInAnalysisScope (VarType.getTypePtr ()) ||
978
+ !(VarTypeStr.find (" TempStorage" ) != std::string::npos &&
979
+ VarTypeStr.find (" struct cub::" ) == 0 )) {
980
+ DeleteFlag = false ;
981
+ std::cout << " 1" << std::endl;
982
+ break ;
983
+ }
984
+ }
985
+ // 3. Used in self typedef decl
975
986
else if (auto AncestorTD =
976
- DpctGlobalInfo::findAncestor<TypedefDecl >(TL)) {
987
+ DpctGlobalInfo::findAncestor<TypedefNameDecl >(TL)) {
977
988
if (AncestorTD != TD) {
978
989
DeleteFlag = false ;
979
990
break ;
@@ -1686,7 +1697,11 @@ void CubRule::runRule(const ast_matchers::MatchFinder::MatchResult &Result) {
1686
1697
processCubFuncCall (CE, true );
1687
1698
} else if (const TypedefDecl *TD =
1688
1699
getNodeAsType<TypedefDecl>(Result, " TypeDefDecl" )) {
1689
- processCubTypeDef (TD);
1700
+ processCubTypeDefOrUsing (TD);
1701
+ } else if (const TypeAliasDecl *TAD =
1702
+ getNodeAsType<TypeAliasDecl>(Result, " UsingDecl" )) {
1703
+ // std::cout << "found" << std::endl;
1704
+ processCubTypeDefOrUsing (TAD);
1690
1705
} else if (auto TL = getNodeAsType<TypeLoc>(Result, " cudaTypeDef" )) {
1691
1706
processTypeLoc (TL);
1692
1707
} else if (auto *UDD = getNodeAsType<UsingDirectiveDecl>(
0 commit comments