Skip to content

Commit 089ff92

Browse files
test4
Signed-off-by: Daiyaan Ahmed <daiyaan.ahmed@intel.com>
1 parent 8199999 commit 089ff92

File tree

6 files changed

+102
-175
lines changed

6 files changed

+102
-175
lines changed

clang/lib/DPCT/RulesLang/APINamesGraph.inc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,9 @@ ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
2626

2727
ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
2828
UseExtGraph,
29-
CALL_FACTORY_ENTRY("cudaGraphLaunch",
30-
CALL(MapNames::getDpctNamespace() +
31-
"experimental::launch",
32-
ARG(0), ARG(1))),
29+
CALL_FACTORY_ENTRY("cudaGraphLaunch", CALL(MapNames::getDpctNamespace() +
30+
"experimental::launch",
31+
ARG(0), ARG(1))),
3332
UNSUPPORT_FACTORY_ENTRY("cudaGraphLaunch",
3433
Diagnostics::TRY_EXPERIMENTAL_FEATURE,
3534
ARG("cudaGraphLaunch"),
@@ -109,7 +108,7 @@ ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
109108
CALL_FACTORY_ENTRY("cudaGraphAddKernelNode",
110109
CALL(MapNames::getDpctNamespace() +
111110
"experimental::add_kernel_node",
112-
ARG(0), ARG(1), ARG(2))),
111+
ARG(0), ARG(1), ARG(2), ARG(3), ARG(4))),
113112
UNSUPPORT_FACTORY_ENTRY("cudaGraphAddKernelNode",
114113
Diagnostics::TRY_EXPERIMENTAL_FEATURE,
115114
ARG("cudaGraphAddKernelNode"),

clang/lib/DPCT/RulesLang/MapNamesLang.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,4 +371,4 @@ MapNamesLang::MapTy GraphRule::KernelNodeParamNames{
371371
{"func", "func"}};
372372

373373
} // namespace dpct
374-
} // namespace clang
374+
} // namespace clang

clang/lib/DPCT/RulesLang/RulesLang.cpp

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4518,22 +4518,14 @@ void StreamAPICallRule::runRule(const MatchFinder::MatchResult &Result) {
45184518

45194519
void KernelCallRefRule::registerMatcher(ast_matchers::MatchFinder &MF) {
45204520

4521-
auto cudaKernelNodeParamsMatcher = memberExpr(hasObjectExpression(hasType(
4522-
type(hasUnqualifiedDesugaredType(recordType(hasDeclaration(recordDecl(hasAnyName("cudaKernelNodeParams")))))))));
45234521
MF.addMatcher(
4524-
functionDecl(
4525-
forEachDescendant(
4526-
declRefExpr(
4527-
allOf(
4528-
to(functionDecl(hasAttr(attr::CUDAGlobal))),
4529-
unless(hasAncestor(cudaKernelCallExpr()))
4530-
)
4531-
).bind("kernelRef")
4532-
),
4533-
unless(hasDescendant(cudaKernelNodeParamsMatcher))
4534-
).bind("outerFunc"),
4535-
this);
4536-
4522+
functionDecl(
4523+
forEachDescendant(
4524+
declRefExpr(allOf(to(functionDecl(hasAttr(attr::CUDAGlobal))),
4525+
unless(hasAncestor(cudaKernelCallExpr()))))
4526+
.bind("kernelRef")))
4527+
.bind("outerFunc"),
4528+
this);
45374529

45384530
MF.addMatcher(unresolvedLookupExpr(unless(hasAncestor(cudaKernelCallExpr())))
45394531
.bind("unresolvedRef"),
@@ -4582,13 +4574,11 @@ void KernelCallRefRule::insertWrapperPostfix(const T *Node,
45824574
bool isInsertWrapperRegister) {
45834575
auto NLoc = DpctGlobalInfo::getSourceManager().getSpellingLoc(
45844576
Node->getNameInfo().getBeginLoc());
4585-
std::cout << "Inserting _wrapper at location: " << NLoc.printToString(DpctGlobalInfo::getSourceManager()) << "\n";
45864577
emplaceTransformation(new InsertText(
45874578
NLoc.getLocWithOffset(Node->getNameInfo().getAsString().length()),
45884579
"_wrapper"));
45894580

45904581
if (!isInsertWrapperRegister) {
4591-
std::cout << "Not inserting wrapper_register\n";
45924582
return;
45934583
}
45944584
const Expr *E = Node;
@@ -4604,7 +4594,6 @@ void KernelCallRefRule::insertWrapperPostfix(const T *Node,
46044594
E = COC;
46054595
}
46064596
}
4607-
std::cout << "Inserting wrapper_register with TypeRepl: " << TypeRepl << "\n";
46084597
emplaceTransformation(new InsertBeforeStmt(
46094598
E, MapNames::getDpctNamespace() + "wrapper_register" + TypeRepl + "("));
46104599
emplaceTransformation(new InsertAfterStmt(E, ").get()"));
@@ -4613,7 +4602,6 @@ void KernelCallRefRule::insertWrapperPostfix(const T *Node,
46134602
void KernelCallRefRule::runRule(
46144603
const ast_matchers::MatchFinder::MatchResult &Result) {
46154604
if (auto DRE = getAssistNodeAsType<DeclRefExpr>(Result, "kernelRef")) {
4616-
std::cout << "KernelRef matched\n";
46174605
const FunctionDecl *OuterFD =
46184606
getAssistNodeAsType<FunctionDecl>(Result, "outerFunc");
46194607
if (!OuterFD) {
@@ -7189,23 +7177,13 @@ TextModification *
71897177
ReplaceMemberAssignAsSetMethod(const Expr *E, const MemberExpr *ME,
71907178
StringRef MethodName, StringRef ReplacedArg,
71917179
StringRef ExtraArg, StringRef ExtraFeild) {
7192-
std::cout << "Entering ReplaceMemberAssignAsSetMethod (overloaded)\n";
7193-
std::cout << "Expr: " << E->getStmtClassName() << "\n";
7194-
std::cout << "MemberExpr: " << ME->getMemberNameInfo().getAsString() << "\n";
7195-
std::cout << "MethodName: " << MethodName.str() << "\n";
7196-
std::cout << "ReplacedArg: " << ReplacedArg.str() << "\n";
7197-
std::cout << "ExtraArg: " << ExtraArg.str() << "\n";
7198-
std::cout << "ExtraFeild: " << ExtraFeild.str() << "\n";
71997180
if (ReplacedArg.empty()) {
72007181
if (auto RHS = getRhs(E)) {
7201-
std::cout << "RHS found: " << ExprAnalysis::ref(RHS) << "\n";
7202-
StringRef c = ExprAnalysis::ref(RHS);
72037182
return ReplaceMemberAssignAsSetMethod(
72047183
getStmtExpansionSourceRange(E).getEnd(), ME, MethodName,
72057184
ExprAnalysis::ref(RHS), ExtraArg, ExtraFeild);
72067185
}
72077186
}
7208-
std::cout << "ReplacedArg is not empty or RHS not found\n";
72097187
return ReplaceMemberAssignAsSetMethod(getStmtExpansionSourceRange(E).getEnd(),
72107188
ME, MethodName, ReplacedArg, ExtraArg);
72117189
}

clang/lib/DPCT/RulesLang/RulesLang.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -998,17 +998,17 @@ class CompatWithClangRule : public NamedMigrationRule<CompatWithClangRule> {
998998
void runRule(const ast_matchers::MatchFinder::MatchResult &Result);
999999
};
10001000

1001-
class GraphRule : public NamedMigrationRule<GraphRule> {
1002-
static MapNames::MapTy KernelNodeParamNames;
1003-
const Expr *getAssignedBO(const Expr *E, ASTContext &Context);
1004-
const Expr *getParentAsAssignedBO(const Expr *E, ASTContext &Context);
1005-
1001+
class GraphAnalysisRule : public NamedMigrationRule<GraphAnalysisRule> {
10061002
public:
10071003
void registerMatcher(ast_matchers::MatchFinder &MF) override;
10081004
void runRule(const ast_matchers::MatchFinder::MatchResult &Result);
10091005
};
10101006

1011-
class GraphAnalysisRule : public NamedMigrationRule<GraphAnalysisRule> {
1007+
class GraphRule : public NamedMigrationRule<GraphRule> {
1008+
static MapNames::MapTy KernelNodeParamNames;
1009+
const Expr *getAssignedBO(const Expr *E, ASTContext &Context);
1010+
const Expr *getParentAsAssignedBO(const Expr *E, ASTContext &Context);
1011+
10121012
public:
10131013
void registerMatcher(ast_matchers::MatchFinder &MF) override;
10141014
void runRule(const ast_matchers::MatchFinder::MatchResult &Result);

clang/lib/DPCT/RulesLang/RulesLangGraph.cpp

Lines changed: 9 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -91,46 +91,20 @@ void GraphRule::runRule(const MatchFinder::MatchResult &Result) {
9191
getParentAsAssignedBO(ME, *Result.Context))) {
9292
auto *LHS = BO->getLHS()->IgnoreCasts();
9393
if (auto *ME = dyn_cast<MemberExpr>(LHS)) {
94-
std::cout << "Member Expr\n";
95-
// Get the base expression of the MemberExpr
9694
auto *Base = ME->getBase()->IgnoreImpCasts();
97-
98-
// Check if the base is a DeclRefExpr
9995
if (auto *DRE = dyn_cast<DeclRefExpr>(Base)) {
100-
std::cout << "DeclRef Expr\n";
101-
// Get the variable declaration
10296
if (auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
103-
std::cout << "Base VarDecl Expr\n";
104-
// Get the variable name
105-
std::string varName = VD->getNameAsString();
106-
107-
// Get the RHS of the assignment
108-
clang::Expr *RHS = BO->getRHS()->IgnoreCasts();
109-
110-
// Check if RHS is a DeclRefExpr referring to a function
97+
std::string VarName = VD->getNameAsString();
98+
auto *RHS = BO->getRHS()->IgnoreCasts();
11199
if (auto *RHS_DRE = dyn_cast<DeclRefExpr>(RHS)) {
112-
std::cout << "RHS DRE Expr\n";
113100
if (auto *FD = dyn_cast<FunctionDecl>(RHS_DRE->getDecl())) {
114-
std::cout << "RHS FunctionDecl Expr\n";
115-
// Get the function name
116-
std::string funcName = FD->getNameAsString();
117-
std::string wrapperName = funcName + "_wrapper";
118-
119-
// Construct the replacement expression
120-
std::string ReplacementExpr =
121-
varName + ".set_func((void*) dpct::wrapper_register(&" +
122-
wrapperName + ").get());";
123-
std::cout << "Replacement String: " << ReplacementExpr
124-
<< "\n";
125-
std::string rp = "(void*) dpct::wrapper_register(&" +
126-
wrapperName + ").get()";
127-
StringRef ReplacedArg = rp;
128-
emplaceTransformation(ReplaceMemberAssignAsSetMethod(
129-
BO, ME, FieldName, ReplacedArg));
130-
// Replace the original assignment with the new expression
131-
// emplaceTransformation(
132-
// new ReplaceToken(ME->getBeginLoc(), ME->getEndLoc(),
133-
// std ::move(ReplacementExpr)));
101+
std::string FuncName = FD->getNameAsString();
102+
std::string WrapperName = FuncName;
103+
std::string AccessOperator = VD->getType()->isPointerType() ? "->" : ".";
104+
std::string ReplacementStr = VarName + AccessOperator + "set_func("
105+
"(void*) dpct::wrapper_register(&" + WrapperName ;
106+
emplaceTransformation(new ReplaceToken(BO->getBeginLoc(), BO->getEndLoc(), std::move(ReplacementStr)));
107+
emplaceTransformation(new InsertAfterStmt(BO, ")"));
134108
return;
135109
}
136110
}
@@ -139,7 +113,6 @@ void GraphRule::runRule(const MatchFinder::MatchResult &Result) {
139113
}
140114
}
141115
}
142-
std::cout << "Coming here\n";
143116
if (auto BO = getParentAsAssignedBO(ME, *Result.Context)) {
144117
StringRef ReplacedArg = "";
145118
emplaceTransformation(

0 commit comments

Comments
 (0)