Skip to content

Commit 164819e

Browse files
test4
Signed-off-by: Daiyaan Ahmed <daiyaan.ahmed@intel.com>
1 parent 6468e49 commit 164819e

File tree

6 files changed

+102
-175
lines changed

6 files changed

+102
-175
lines changed

clang/lib/DPCT/RulesLang/APINamesGraph.inc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,9 @@ ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
2626

2727
ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
2828
UseExtGraph,
29-
CALL_FACTORY_ENTRY("cudaGraphLaunch",
30-
CALL(MapNames::getDpctNamespace() +
31-
"experimental::launch",
32-
ARG(0), ARG(1))),
29+
CALL_FACTORY_ENTRY("cudaGraphLaunch", CALL(MapNames::getDpctNamespace() +
30+
"experimental::launch",
31+
ARG(0), ARG(1))),
3332
UNSUPPORT_FACTORY_ENTRY("cudaGraphLaunch",
3433
Diagnostics::TRY_EXPERIMENTAL_FEATURE,
3534
ARG("cudaGraphLaunch"),
@@ -109,7 +108,7 @@ ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
109108
CALL_FACTORY_ENTRY("cudaGraphAddKernelNode",
110109
CALL(MapNames::getDpctNamespace() +
111110
"experimental::add_kernel_node",
112-
ARG(0), ARG(1), ARG(2))),
111+
ARG(0), ARG(1), ARG(2), ARG(3), ARG(4))),
113112
UNSUPPORT_FACTORY_ENTRY("cudaGraphAddKernelNode",
114113
Diagnostics::TRY_EXPERIMENTAL_FEATURE,
115114
ARG("cudaGraphAddKernelNode"),

clang/lib/DPCT/RulesLang/MapNamesLang.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,4 +371,4 @@ MapNamesLang::MapTy GraphRule::KernelNodeParamNames{
371371
{"func", "func"}};
372372

373373
} // namespace dpct
374-
} // namespace clang
374+
} // namespace clang

clang/lib/DPCT/RulesLang/RulesLang.cpp

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4509,22 +4509,14 @@ void StreamAPICallRule::runRule(const MatchFinder::MatchResult &Result) {
45094509

45104510
void KernelCallRefRule::registerMatcher(ast_matchers::MatchFinder &MF) {
45114511

4512-
auto cudaKernelNodeParamsMatcher = memberExpr(hasObjectExpression(hasType(
4513-
type(hasUnqualifiedDesugaredType(recordType(hasDeclaration(recordDecl(hasAnyName("cudaKernelNodeParams")))))))));
45144512
MF.addMatcher(
4515-
functionDecl(
4516-
forEachDescendant(
4517-
declRefExpr(
4518-
allOf(
4519-
to(functionDecl(hasAttr(attr::CUDAGlobal))),
4520-
unless(hasAncestor(cudaKernelCallExpr()))
4521-
)
4522-
).bind("kernelRef")
4523-
),
4524-
unless(hasDescendant(cudaKernelNodeParamsMatcher))
4525-
).bind("outerFunc"),
4526-
this);
4527-
4513+
functionDecl(
4514+
forEachDescendant(
4515+
declRefExpr(allOf(to(functionDecl(hasAttr(attr::CUDAGlobal))),
4516+
unless(hasAncestor(cudaKernelCallExpr()))))
4517+
.bind("kernelRef")))
4518+
.bind("outerFunc"),
4519+
this);
45284520

45294521
MF.addMatcher(unresolvedLookupExpr(unless(hasAncestor(cudaKernelCallExpr())))
45304522
.bind("unresolvedRef"),
@@ -4573,13 +4565,11 @@ void KernelCallRefRule::insertWrapperPostfix(const T *Node,
45734565
bool isInsertWrapperRegister) {
45744566
auto NLoc = DpctGlobalInfo::getSourceManager().getSpellingLoc(
45754567
Node->getNameInfo().getBeginLoc());
4576-
std::cout << "Inserting _wrapper at location: " << NLoc.printToString(DpctGlobalInfo::getSourceManager()) << "\n";
45774568
emplaceTransformation(new InsertText(
45784569
NLoc.getLocWithOffset(Node->getNameInfo().getAsString().length()),
45794570
"_wrapper"));
45804571

45814572
if (!isInsertWrapperRegister) {
4582-
std::cout << "Not inserting wrapper_register\n";
45834573
return;
45844574
}
45854575
const Expr *E = Node;
@@ -4595,7 +4585,6 @@ void KernelCallRefRule::insertWrapperPostfix(const T *Node,
45954585
E = COC;
45964586
}
45974587
}
4598-
std::cout << "Inserting wrapper_register with TypeRepl: " << TypeRepl << "\n";
45994588
emplaceTransformation(new InsertBeforeStmt(
46004589
E, MapNames::getDpctNamespace() + "wrapper_register" + TypeRepl + "("));
46014590
emplaceTransformation(new InsertAfterStmt(E, ").get()"));
@@ -4604,7 +4593,6 @@ void KernelCallRefRule::insertWrapperPostfix(const T *Node,
46044593
void KernelCallRefRule::runRule(
46054594
const ast_matchers::MatchFinder::MatchResult &Result) {
46064595
if (auto DRE = getAssistNodeAsType<DeclRefExpr>(Result, "kernelRef")) {
4607-
std::cout << "KernelRef matched\n";
46084596
const FunctionDecl *OuterFD =
46094597
getAssistNodeAsType<FunctionDecl>(Result, "outerFunc");
46104598
if (!OuterFD) {
@@ -7179,23 +7167,13 @@ TextModification *
71797167
ReplaceMemberAssignAsSetMethod(const Expr *E, const MemberExpr *ME,
71807168
StringRef MethodName, StringRef ReplacedArg,
71817169
StringRef ExtraArg, StringRef ExtraFeild) {
7182-
std::cout << "Entering ReplaceMemberAssignAsSetMethod (overloaded)\n";
7183-
std::cout << "Expr: " << E->getStmtClassName() << "\n";
7184-
std::cout << "MemberExpr: " << ME->getMemberNameInfo().getAsString() << "\n";
7185-
std::cout << "MethodName: " << MethodName.str() << "\n";
7186-
std::cout << "ReplacedArg: " << ReplacedArg.str() << "\n";
7187-
std::cout << "ExtraArg: " << ExtraArg.str() << "\n";
7188-
std::cout << "ExtraFeild: " << ExtraFeild.str() << "\n";
71897170
if (ReplacedArg.empty()) {
71907171
if (auto RHS = getRhs(E)) {
7191-
std::cout << "RHS found: " << ExprAnalysis::ref(RHS) << "\n";
7192-
StringRef c = ExprAnalysis::ref(RHS);
71937172
return ReplaceMemberAssignAsSetMethod(
71947173
getStmtExpansionSourceRange(E).getEnd(), ME, MethodName,
71957174
ExprAnalysis::ref(RHS), ExtraArg, ExtraFeild);
71967175
}
71977176
}
7198-
std::cout << "ReplacedArg is not empty or RHS not found\n";
71997177
return ReplaceMemberAssignAsSetMethod(getStmtExpansionSourceRange(E).getEnd(),
72007178
ME, MethodName, ReplacedArg, ExtraArg);
72017179
}

clang/lib/DPCT/RulesLang/RulesLang.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -998,17 +998,17 @@ class CompatWithClangRule : public NamedMigrationRule<CompatWithClangRule> {
998998
void runRule(const ast_matchers::MatchFinder::MatchResult &Result);
999999
};
10001000

1001-
class GraphRule : public NamedMigrationRule<GraphRule> {
1002-
static MapNames::MapTy KernelNodeParamNames;
1003-
const Expr *getAssignedBO(const Expr *E, ASTContext &Context);
1004-
const Expr *getParentAsAssignedBO(const Expr *E, ASTContext &Context);
1005-
1001+
class GraphAnalysisRule : public NamedMigrationRule<GraphAnalysisRule> {
10061002
public:
10071003
void registerMatcher(ast_matchers::MatchFinder &MF) override;
10081004
void runRule(const ast_matchers::MatchFinder::MatchResult &Result);
10091005
};
10101006

1011-
class GraphAnalysisRule : public NamedMigrationRule<GraphAnalysisRule> {
1007+
class GraphRule : public NamedMigrationRule<GraphRule> {
1008+
static MapNames::MapTy KernelNodeParamNames;
1009+
const Expr *getAssignedBO(const Expr *E, ASTContext &Context);
1010+
const Expr *getParentAsAssignedBO(const Expr *E, ASTContext &Context);
1011+
10121012
public:
10131013
void registerMatcher(ast_matchers::MatchFinder &MF) override;
10141014
void runRule(const ast_matchers::MatchFinder::MatchResult &Result);

clang/lib/DPCT/RulesLang/RulesLangGraph.cpp

Lines changed: 9 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -91,46 +91,20 @@ void GraphRule::runRule(const MatchFinder::MatchResult &Result) {
9191
getParentAsAssignedBO(ME, *Result.Context))) {
9292
auto *LHS = BO->getLHS()->IgnoreCasts();
9393
if (auto *ME = dyn_cast<MemberExpr>(LHS)) {
94-
std::cout << "Member Expr\n";
95-
// Get the base expression of the MemberExpr
9694
auto *Base = ME->getBase()->IgnoreImpCasts();
97-
98-
// Check if the base is a DeclRefExpr
9995
if (auto *DRE = dyn_cast<DeclRefExpr>(Base)) {
100-
std::cout << "DeclRef Expr\n";
101-
// Get the variable declaration
10296
if (auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
103-
std::cout << "Base VarDecl Expr\n";
104-
// Get the variable name
105-
std::string varName = VD->getNameAsString();
106-
107-
// Get the RHS of the assignment
108-
clang::Expr *RHS = BO->getRHS()->IgnoreCasts();
109-
110-
// Check if RHS is a DeclRefExpr referring to a function
97+
std::string VarName = VD->getNameAsString();
98+
auto *RHS = BO->getRHS()->IgnoreCasts();
11199
if (auto *RHS_DRE = dyn_cast<DeclRefExpr>(RHS)) {
112-
std::cout << "RHS DRE Expr\n";
113100
if (auto *FD = dyn_cast<FunctionDecl>(RHS_DRE->getDecl())) {
114-
std::cout << "RHS FunctionDecl Expr\n";
115-
// Get the function name
116-
std::string funcName = FD->getNameAsString();
117-
std::string wrapperName = funcName + "_wrapper";
118-
119-
// Construct the replacement expression
120-
std::string ReplacementExpr =
121-
varName + ".set_func((void*) dpct::wrapper_register(&" +
122-
wrapperName + ").get());";
123-
std::cout << "Replacement String: " << ReplacementExpr
124-
<< "\n";
125-
std::string rp = "(void*) dpct::wrapper_register(&" +
126-
wrapperName + ").get()";
127-
StringRef ReplacedArg = rp;
128-
emplaceTransformation(ReplaceMemberAssignAsSetMethod(
129-
BO, ME, FieldName, ReplacedArg));
130-
// Replace the original assignment with the new expression
131-
// emplaceTransformation(
132-
// new ReplaceToken(ME->getBeginLoc(), ME->getEndLoc(),
133-
// std ::move(ReplacementExpr)));
101+
std::string FuncName = FD->getNameAsString();
102+
std::string WrapperName = FuncName;
103+
std::string AccessOperator = VD->getType()->isPointerType() ? "->" : ".";
104+
std::string ReplacementStr = VarName + AccessOperator + "set_func("
105+
"(void*) dpct::wrapper_register(&" + WrapperName ;
106+
emplaceTransformation(new ReplaceToken(BO->getBeginLoc(), BO->getEndLoc(), std::move(ReplacementStr)));
107+
emplaceTransformation(new InsertAfterStmt(BO, ")"));
134108
return;
135109
}
136110
}
@@ -139,7 +113,6 @@ void GraphRule::runRule(const MatchFinder::MatchResult &Result) {
139113
}
140114
}
141115
}
142-
std::cout << "Coming here\n";
143116
if (auto BO = getParentAsAssignedBO(ME, *Result.Context)) {
144117
StringRef ReplacedArg = "";
145118
emplaceTransformation(

0 commit comments

Comments
 (0)