Skip to content

Commit 6468e49

Browse files
test3
Signed-off-by: Daiyaan Ahmed <daiyaan.ahmed@intel.com>
1 parent 98fcde3 commit 6468e49

File tree

9 files changed

+263
-86
lines changed

9 files changed

+263
-86
lines changed

clang/lib/DPCT/ASTTraversal.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ REGISTER_RULE(TypeRemoveRule, PassKind::PK_Analysis)
157157
REGISTER_RULE(CompatWithClangRule, PassKind::PK_Migration)
158158
REGISTER_RULE(AssertRule, PassKind::PK_Migration)
159159
REGISTER_RULE(GraphRule, PassKind::PK_Migration)
160+
REGISTER_RULE(GraphAnalysisRule, PassKind::PK_Analysis)
160161
REGISTER_RULE(GraphicsInteropRule, PassKind::PK_Migration)
161162
REGISTER_RULE(RulesLangAddrSpaceConvRule, PassKind::PK_Migration)
162163

clang/lib/DPCT/AnalysisInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2504,7 +2504,7 @@ unsigned DpctGlobalInfo::ExperimentalFlag = 0;
25042504
unsigned DpctGlobalInfo::HelperFuncPreferenceFlag = 0;
25052505
bool DpctGlobalInfo::AnalysisModeFlag = false;
25062506
bool DpctGlobalInfo::UseSYCLCompatFlag = false;
2507-
bool DpctGlobalInfo::CVersionCUDALaunchUsedFlag = false;
2507+
bool DpctGlobalInfo::UseWrapperRegisterFnPtrFlag = false;
25082508
unsigned int DpctGlobalInfo::ColorOption = 1;
25092509
std::unordered_map<int, std::shared_ptr<DeviceFunctionInfo>>
25102510
DpctGlobalInfo::CubPlaceholderIndexMap;

clang/lib/DPCT/AnalysisInfo.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,8 +1353,8 @@ class DpctGlobalInfo {
13531353
static bool useNoQueueDevice() {
13541354
return getHelperFuncPreference(HelperFuncPreference::NoQueueDevice);
13551355
}
1356-
static void setCVersionCUDALaunchUsed() { CVersionCUDALaunchUsedFlag = true; }
1357-
static bool isCVersionCUDALaunchUsed() { return CVersionCUDALaunchUsedFlag; }
1356+
static void setUseWrapperRegisterFnPtr() { UseWrapperRegisterFnPtrFlag = true; }
1357+
static bool useWrapperRegisterFnPtr() { return UseWrapperRegisterFnPtrFlag; }
13581358
static void setUseSYCLCompat(bool Flag = true) { UseSYCLCompatFlag = Flag; }
13591359
static bool useSYCLCompat() { return UseSYCLCompatFlag; }
13601360
static bool useEnqueueBarrier() {
@@ -1684,7 +1684,7 @@ class DpctGlobalInfo {
16841684
static unsigned HelperFuncPreferenceFlag;
16851685
static bool AnalysisModeFlag;
16861686
static bool UseSYCLCompatFlag;
1687-
static bool CVersionCUDALaunchUsedFlag;
1687+
static bool UseWrapperRegisterFnPtrFlag;
16881688
static unsigned int ColorOption;
16891689
static std::unordered_map<int, std::shared_ptr<DeviceFunctionInfo>>
16901690
CubPlaceholderIndexMap;

clang/lib/DPCT/RulesLang/APINamesGraph.inc

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,16 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
CONDITIONAL_FACTORY_ENTRY(
9+
ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
1010
UseExtGraph,
11-
ASSIGNABLE_FACTORY(ASSIGN_FACTORY_ENTRY(
12-
"cudaGraphInstantiate", DEREF(0),
13-
NEW(MapNames::getClNamespace() +
14-
"ext::oneapi::experimental::command_graph<" +
15-
MapNames::getClNamespace() +
16-
"ext::oneapi::experimental::graph_state::executable>",
17-
MEMBER_CALL(ARG(1), true, "finalize")))),
11+
CALL_FACTORY_ENTRY("cudaGraphInstantiate",
12+
CALL(MapNames::getDpctNamespace() +
13+
"experimental::instantiate",
14+
ARG(0), ARG(1))),
1815
UNSUPPORT_FACTORY_ENTRY("cudaGraphInstantiate",
1916
Diagnostics::TRY_EXPERIMENTAL_FEATURE,
2017
ARG("cudaGraphInstantiate"),
21-
ARG("--use-experimental-features=graph")))
18+
ARG("--use-experimental-features=graph"))))
2219

2320
ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
2421
UseExtGraph, DELETE_FACTORY_ENTRY("cudaGraphExecDestroy", ARG(0)),
@@ -29,8 +26,10 @@ ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
2926

3027
ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
3128
UseExtGraph,
32-
MEMBER_CALL_FACTORY_ENTRY("cudaGraphLaunch", ARG(1), true,
33-
"ext_oneapi_graph", DEREF(0)),
29+
CALL_FACTORY_ENTRY("cudaGraphLaunch",
30+
CALL(MapNames::getDpctNamespace() +
31+
"experimental::launch",
32+
ARG(0), ARG(1))),
3433
UNSUPPORT_FACTORY_ENTRY("cudaGraphLaunch",
3534
Diagnostics::TRY_EXPERIMENTAL_FEATURE,
3635
ARG("cudaGraphLaunch"),
@@ -104,3 +103,14 @@ ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
104103
Diagnostics::TRY_EXPERIMENTAL_FEATURE,
105104
ARG("cudaGraphDestroy"),
106105
ARG("--use-experimental-features=graph"))))
106+
107+
ASSIGNABLE_FACTORY(CONDITIONAL_FACTORY_ENTRY(
108+
UseExtGraph,
109+
CALL_FACTORY_ENTRY("cudaGraphAddKernelNode",
110+
CALL(MapNames::getDpctNamespace() +
111+
"experimental::add_kernel_node",
112+
ARG(0), ARG(1), ARG(2))),
113+
UNSUPPORT_FACTORY_ENTRY("cudaGraphAddKernelNode",
114+
Diagnostics::TRY_EXPERIMENTAL_FEATURE,
115+
ARG("cudaGraphAddKernelNode"),
116+
ARG("--use-experimental-features=graph"))))

clang/lib/DPCT/RulesLang/RulesLang.cpp

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4508,14 +4508,24 @@ void StreamAPICallRule::runRule(const MatchFinder::MatchResult &Result) {
45084508
}
45094509

45104510
void KernelCallRefRule::registerMatcher(ast_matchers::MatchFinder &MF) {
4511+
4512+
auto cudaKernelNodeParamsMatcher = memberExpr(hasObjectExpression(hasType(
4513+
type(hasUnqualifiedDesugaredType(recordType(hasDeclaration(recordDecl(hasAnyName("cudaKernelNodeParams")))))))));
45114514
MF.addMatcher(
4512-
functionDecl(
4513-
forEachDescendant(
4514-
declRefExpr(allOf(to(functionDecl(hasAttr(attr::CUDAGlobal))),
4515-
unless(hasAncestor(cudaKernelCallExpr()))))
4516-
.bind("kernelRef")))
4517-
.bind("outerFunc"),
4518-
this);
4515+
functionDecl(
4516+
forEachDescendant(
4517+
declRefExpr(
4518+
allOf(
4519+
to(functionDecl(hasAttr(attr::CUDAGlobal))),
4520+
unless(hasAncestor(cudaKernelCallExpr()))
4521+
)
4522+
).bind("kernelRef")
4523+
),
4524+
unless(hasDescendant(cudaKernelNodeParamsMatcher))
4525+
).bind("outerFunc"),
4526+
this);
4527+
4528+
45194529
MF.addMatcher(unresolvedLookupExpr(unless(hasAncestor(cudaKernelCallExpr())))
45204530
.bind("unresolvedRef"),
45214531
this);
@@ -4563,14 +4573,13 @@ void KernelCallRefRule::insertWrapperPostfix(const T *Node,
45634573
bool isInsertWrapperRegister) {
45644574
auto NLoc = DpctGlobalInfo::getSourceManager().getSpellingLoc(
45654575
Node->getNameInfo().getBeginLoc());
4566-
4567-
std::cout <<"WRAPPER APPENDED: " << "\n";
4568-
4576+
std::cout << "Inserting _wrapper at location: " << NLoc.printToString(DpctGlobalInfo::getSourceManager()) << "\n";
45694577
emplaceTransformation(new InsertText(
45704578
NLoc.getLocWithOffset(Node->getNameInfo().getAsString().length()),
45714579
"_wrapper"));
45724580

45734581
if (!isInsertWrapperRegister) {
4582+
std::cout << "Not inserting wrapper_register\n";
45744583
return;
45754584
}
45764585
const Expr *E = Node;
@@ -4586,6 +4595,7 @@ void KernelCallRefRule::insertWrapperPostfix(const T *Node,
45864595
E = COC;
45874596
}
45884597
}
4598+
std::cout << "Inserting wrapper_register with TypeRepl: " << TypeRepl << "\n";
45894599
emplaceTransformation(new InsertBeforeStmt(
45904600
E, MapNames::getDpctNamespace() + "wrapper_register" + TypeRepl + "("));
45914601
emplaceTransformation(new InsertAfterStmt(E, ").get()"));
@@ -4594,6 +4604,7 @@ void KernelCallRefRule::insertWrapperPostfix(const T *Node,
45944604
void KernelCallRefRule::runRule(
45954605
const ast_matchers::MatchFinder::MatchResult &Result) {
45964606
if (auto DRE = getAssistNodeAsType<DeclRefExpr>(Result, "kernelRef")) {
4607+
std::cout << "KernelRef matched\n";
45974608
const FunctionDecl *OuterFD =
45984609
getAssistNodeAsType<FunctionDecl>(Result, "outerFunc");
45994610
if (!OuterFD) {
@@ -4627,7 +4638,7 @@ void KernelCallRefRule::runRule(
46274638
(OuterFD->getTemplatedKind() ==
46284639
FunctionDecl::TemplatedKind::TK_FunctionTemplate)) {
46294640
std::string TypeRepl;
4630-
if (DpctGlobalInfo::isCVersionCUDALaunchUsed()) {
4641+
if (DpctGlobalInfo::useWrapperRegisterFnPtr()) {
46314642
if ((IsTemplateRelated &&
46324643
(!DRE->hasExplicitTemplateArgs() ||
46334644
(DRE->getNumTemplateArgs() <= TemplateParamNum))) ||
@@ -4636,7 +4647,7 @@ void KernelCallRefRule::runRule(
46364647
}
46374648
}
46384649
insertWrapperPostfix<DeclRefExpr>(
4639-
DRE, std::move(TypeRepl), DpctGlobalInfo::isCVersionCUDALaunchUsed());
4650+
DRE, std::move(TypeRepl), DpctGlobalInfo::useWrapperRegisterFnPtr());
46404651
}
46414652
}
46424653
if (auto ULE =
@@ -4673,7 +4684,7 @@ void KernelCallRefRule::runRule(
46734684
}
46744685
}
46754686
insertWrapperPostfix<UnresolvedLookupExpr>(
4676-
ULE, getTypeRepl(ULE), DpctGlobalInfo::isCVersionCUDALaunchUsed());
4687+
ULE, getTypeRepl(ULE), DpctGlobalInfo::useWrapperRegisterFnPtr());
46774688
}
46784689
}
46794690

@@ -4946,7 +4957,7 @@ void KernelCallRule::runRule(
49464957

49474958
if (!getAddressedRef(CalleeDRE)) {
49484959
if (IsFuncTypeErased) {
4949-
DpctGlobalInfo::setCVersionCUDALaunchUsed();
4960+
DpctGlobalInfo::setUseWrapperRegisterFnPtr();
49504961
}
49514962
std::string ReplStr;
49524963
llvm::raw_string_ostream OS(ReplStr);
@@ -7168,16 +7179,23 @@ TextModification *
71687179
ReplaceMemberAssignAsSetMethod(const Expr *E, const MemberExpr *ME,
71697180
StringRef MethodName, StringRef ReplacedArg,
71707181
StringRef ExtraArg, StringRef ExtraFeild) {
7182+
std::cout << "Entering ReplaceMemberAssignAsSetMethod (overloaded)\n";
7183+
std::cout << "Expr: " << E->getStmtClassName() << "\n";
7184+
std::cout << "MemberExpr: " << ME->getMemberNameInfo().getAsString() << "\n";
7185+
std::cout << "MethodName: " << MethodName.str() << "\n";
7186+
std::cout << "ReplacedArg: " << ReplacedArg.str() << "\n";
7187+
std::cout << "ExtraArg: " << ExtraArg.str() << "\n";
7188+
std::cout << "ExtraFeild: " << ExtraFeild.str() << "\n";
71717189
if (ReplacedArg.empty()) {
71727190
if (auto RHS = getRhs(E)) {
7191+
std::cout << "RHS found: " << ExprAnalysis::ref(RHS) << "\n";
71737192
StringRef c = ExprAnalysis::ref(RHS);
7174-
std::cout <<"Replaced String: "<< c.str() <<"\n";
71757193
return ReplaceMemberAssignAsSetMethod(
71767194
getStmtExpansionSourceRange(E).getEnd(), ME, MethodName,
71777195
ExprAnalysis::ref(RHS), ExtraArg, ExtraFeild);
71787196
}
71797197
}
7180-
std::cout << "Coming her!!!!!!!!!e\n";
7198+
std::cout << "ReplacedArg is not empty or RHS not found\n";
71817199
return ReplaceMemberAssignAsSetMethod(getStmtExpansionSourceRange(E).getEnd(),
71827200
ME, MethodName, ReplacedArg, ExtraArg);
71837201
}

clang/lib/DPCT/RulesLang/RulesLang.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,12 @@ class GraphRule : public NamedMigrationRule<GraphRule> {
10081008
void runRule(const ast_matchers::MatchFinder::MatchResult &Result);
10091009
};
10101010

1011+
class GraphAnalysisRule : public NamedMigrationRule<GraphAnalysisRule> {
1012+
public:
1013+
void registerMatcher(ast_matchers::MatchFinder &MF) override;
1014+
void runRule(const ast_matchers::MatchFinder::MatchResult &Result);
1015+
};
1016+
10111017
class AssertRule : public NamedMigrationRule<AssertRule> {
10121018
public:
10131019
void registerMatcher(ast_matchers::MatchFinder &MF) override;

clang/lib/DPCT/RulesLang/RulesLangGraph.cpp

Lines changed: 85 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,37 @@ extern DpctOption<opt, bool> AsyncHandler;
2828
namespace clang {
2929
namespace dpct {
3030

31+
void GraphAnalysisRule::registerMatcher(MatchFinder &MF) {
32+
auto kernelNodeTypeName = [&]() {
33+
return hasAnyName("cudaKernelNodeParams");
34+
};
35+
MF.addMatcher(
36+
memberExpr(
37+
hasObjectExpression(hasType(type(hasUnqualifiedDesugaredType(
38+
recordType(hasDeclaration(recordDecl(kernelNodeTypeName()))))))))
39+
.bind("KernelNodeType"),
40+
this);
41+
}
42+
43+
void GraphAnalysisRule::runRule(const MatchFinder::MatchResult &Result) {
44+
if (auto ME = getNodeAsType<MemberExpr>(Result, "KernelNodeType")) {
45+
auto BaseTy = DpctGlobalInfo::getUnqualifiedTypeName(
46+
ME->getBase()->getType().getDesugaredType(*Result.Context),
47+
*Result.Context);
48+
auto MemberName = ME->getMemberNameInfo().getAsString();
49+
if (BaseTy == "cudaKernelNodeParams") {
50+
DpctGlobalInfo::setUseWrapperRegisterFnPtr();
51+
}
52+
}
53+
}
54+
3155
void GraphRule::registerMatcher(MatchFinder &MF) {
3256
auto functionName = [&]() {
33-
return hasAnyName("cudaGraphInstantiate", "cudaGraphLaunch",
34-
"cudaGraphExecDestroy", "cudaGraphAddEmptyNode",
35-
"cudaGraphAddDependencies", "cudaGraphExecUpdate",
36-
"cudaGraphNodeGetType", "cudaGraphGetNodes",
37-
"cudaGraphGetRootNodes", "cudaGraphDestroy");
57+
return hasAnyName(
58+
"cudaGraphInstantiate", "cudaGraphLaunch", "cudaGraphExecDestroy",
59+
"cudaGraphAddEmptyNode", "cudaGraphAddDependencies",
60+
"cudaGraphExecUpdate", "cudaGraphNodeGetType", "cudaGraphGetNodes",
61+
"cudaGraphGetRootNodes", "cudaGraphDestroy", "cudaGraphAddKernelNode");
3862
};
3963
MF.addMatcher(
4064
callExpr(callee(functionDecl(functionName()))).bind("FunctionCall"),
@@ -55,29 +79,67 @@ void GraphRule::runRule(const MatchFinder::MatchResult &Result) {
5579
*Result.Context);
5680
auto MemberName = ME->getMemberNameInfo().getAsString();
5781
if (BaseTy == "cudaKernelNodeParams") {
58-
std::cout <<"NODE PARAMS FOUND\n";
59-
DpctGlobalInfo::setCVersionCUDALaunchUsed();
6082
auto FieldName = KernelNodeParamNames[MemberName];
6183
if (FieldName.empty()) {
6284
report(ME->getBeginLoc(), Diagnostics::API_NOT_MIGRATED, false,
6385
DpctGlobalInfo::getOriginalTypeName(ME->getBase()->getType()) +
6486
"::" + ME->getMemberDecl()->getName().str());
6587
return;
66-
6788
}
68-
// if(FieldName == "func"){
69-
// Check for the binary operator and fetch the RHS
70-
// Strip the explicit typecast if it exists
71-
// Check for VarDecl on the StrippedRHS
72-
// If not a VarDecl, then insert user warning
73-
// Check for VarDecl Type to be a FunctionDecl
74-
// If FunctionDecl, then
75-
// VarDecl, get var name, Get kernel_node_params variable name
76-
// Create the expression, hardcoded strting
77-
// Create new replace object and emplace transformation (nodeParams.set_func((void*)dpct::wrapper_register(&incrementKernel_wrapper).get());)
78-
// If VarDecl and not a FunctionDecl and if type of VarDecl is function pointer
79-
// Create a hardcoded string (nodeParams.set_func(a.get()));
80-
// }
89+
if (FieldName == "func") {
90+
if (auto BO = dyn_cast<BinaryOperator>(
91+
getParentAsAssignedBO(ME, *Result.Context))) {
92+
auto *LHS = BO->getLHS()->IgnoreCasts();
93+
if (auto *ME = dyn_cast<MemberExpr>(LHS)) {
94+
std::cout << "Member Expr\n";
95+
// Get the base expression of the MemberExpr
96+
auto *Base = ME->getBase()->IgnoreImpCasts();
97+
98+
// Check if the base is a DeclRefExpr
99+
if (auto *DRE = dyn_cast<DeclRefExpr>(Base)) {
100+
std::cout << "DeclRef Expr\n";
101+
// Get the variable declaration
102+
if (auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
103+
std::cout << "Base VarDecl Expr\n";
104+
// Get the variable name
105+
std::string varName = VD->getNameAsString();
106+
107+
// Get the RHS of the assignment
108+
clang::Expr *RHS = BO->getRHS()->IgnoreCasts();
109+
110+
// Check if RHS is a DeclRefExpr referring to a function
111+
if (auto *RHS_DRE = dyn_cast<DeclRefExpr>(RHS)) {
112+
std::cout << "RHS DRE Expr\n";
113+
if (auto *FD = dyn_cast<FunctionDecl>(RHS_DRE->getDecl())) {
114+
std::cout << "RHS FunctionDecl Expr\n";
115+
// Get the function name
116+
std::string funcName = FD->getNameAsString();
117+
std::string wrapperName = funcName + "_wrapper";
118+
119+
// Construct the replacement expression
120+
std::string ReplacementExpr =
121+
varName + ".set_func((void*) dpct::wrapper_register(&" +
122+
wrapperName + ").get());";
123+
std::cout << "Replacement String: " << ReplacementExpr
124+
<< "\n";
125+
std::string rp = "(void*) dpct::wrapper_register(&" +
126+
wrapperName + ").get()";
127+
StringRef ReplacedArg = rp;
128+
emplaceTransformation(ReplaceMemberAssignAsSetMethod(
129+
BO, ME, FieldName, ReplacedArg));
130+
// Replace the original assignment with the new expression
131+
// emplaceTransformation(
132+
// new ReplaceToken(ME->getBeginLoc(), ME->getEndLoc(),
133+
// std ::move(ReplacementExpr)));
134+
return;
135+
}
136+
}
137+
}
138+
}
139+
}
140+
}
141+
}
142+
std::cout << "Coming here\n";
81143
if (auto BO = getParentAsAssignedBO(ME, *Result.Context)) {
82144
StringRef ReplacedArg = "";
83145
emplaceTransformation(
@@ -106,8 +168,8 @@ const Expr *GraphRule::getParentAsAssignedBO(const Expr *E,
106168
return nullptr;
107169
}
108170

109-
// Return the binary operator if E is the lhs of an assign expression, otherwise
110-
// nullptr.
171+
// Return the binary operator if E is the lhs of an assign expression,
172+
// otherwise nullptr.
111173
const Expr *GraphRule::getAssignedBO(const Expr *E, ASTContext &Context) {
112174
if (dyn_cast<MemberExpr>(E)) {
113175
// Continue finding parents when E is MemberExpr.

clang/lib/DPCT/SrcAPI/APINames.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ ENTRY(cudaGraphAddEventWaitNode, cudaGraphAddEventWaitNode, false, NO_FLAG, P4,
404404
ENTRY(cudaGraphAddExternalSemaphoresSignalNode, cudaGraphAddExternalSemaphoresSignalNode, false, NO_FLAG, P4, "comment")
405405
ENTRY(cudaGraphAddExternalSemaphoresWaitNode, cudaGraphAddExternalSemaphoresWaitNode, false, NO_FLAG, P4, "comment")
406406
ENTRY(cudaGraphAddHostNode, cudaGraphAddHostNode, false, NO_FLAG, P4, "comment")
407-
ENTRY(cudaGraphAddKernelNode, cudaGraphAddKernelNode, false, NO_FLAG, P4, "comment")
407+
ENTRY(cudaGraphAddKernelNode, cudaGraphAddKernelNode, true, NO_FLAG, P4, "Successful/DPCT1119")
408408
ENTRY(cudaGraphAddMemAllocNode, cudaGraphAddMemAllocNode, false, NO_FLAG, P4, "comment")
409409
ENTRY(cudaGraphAddMemFreeNode, cudaGraphAddMemFreeNode, false, NO_FLAG, P4, "comment")
410410
ENTRY(cudaGraphAddMemcpyNode, cudaGraphAddMemcpyNode, false, NO_FLAG, P4, "comment")

0 commit comments

Comments
 (0)