Skip to content

Commit b6ca5ce

Browse files
[SYCLomatic] Added unsupported migration rules & lit test case for all cuTensor APIs (#2785)
1 parent f119a76 commit b6ca5ce

12 files changed

+444
-62
lines changed

clang/lib/DPCT/ASTTraversal.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "RulesSHMEM/NVSHMEMAPIMigration.h"
2828
#include "RulesSecurity/Homoglyph.h"
2929
#include "RulesSecurity/MisleadingBidirectional.h"
30+
#include "RulesTensor/CUTensorAPIMigration.h"
3031
#include "TextModification.h"
3132
#include "Utility.h"
3233

@@ -197,5 +198,7 @@ REGISTER_RULE(CuDNNAPIRule, PassKind::PK_Migration, RuleGroupKind::RK_DNN)
197198

198199
REGISTER_RULE(NVSHMEMRule, PassKind::PK_Migration, RuleGroupKind::RK_NVSHMEM)
199200

201+
REGISTER_RULE(CUTensorRule, PassKind::PK_Migration, RuleGroupKind::RK_CUTensor)
202+
200203
} // namespace dpct
201204
} // namespace clang

clang/lib/DPCT/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ add_clang_library(DPCT
203203
RulesLang/CallExprRewriterCG.cpp
204204
RulesLang/CallExprRewriterWmma.cpp
205205
RulesSHMEM/CallExprRewriterNvshmem.cpp
206+
RulesTensor/CallExprRewriterCUTensor.cpp
206207
ErrorHandle/CrashRecovery.cpp
207208
Diagnostics/Diagnostics.cpp
208209
ErrorHandle/Error.cpp
@@ -242,6 +243,7 @@ add_clang_library(DPCT
242243
RulesCCL/NCCLAPIMigration.cpp
243244
RuleInfra/TypeLocRewriters.cpp
244245
RulesSHMEM/NVSHMEMAPIMigration.cpp
246+
RulesTensor/CUTensorAPIMigration.cpp
245247
Linux/AutoComplete.cpp
246248
RulesAsm/AsmMigration.cpp
247249
QueryAPIMapping/QueryAPIMapping.cpp

clang/lib/DPCT/RuleInfra/CallExprRewriter.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,15 @@ std::optional<std::string> FuncCallExprRewriter::buildRewriteString() {
134134

135135
std::unique_ptr<std::unordered_map<
136136
std::string, std::shared_ptr<CallExprRewriterFactoryBase>>>
137-
CallExprRewriterFactoryBase::RewriterMap = std::make_unique<std::unordered_map<
138-
std::string, std::shared_ptr<CallExprRewriterFactoryBase>>>();
137+
CallExprRewriterFactoryBase::RewriterMap =
138+
std::make_unique<std::unordered_map<
139+
std::string, std::shared_ptr<CallExprRewriterFactoryBase>>>();
139140

140141
std::unique_ptr<std::unordered_map<
141142
std::string, std::shared_ptr<CallExprRewriterFactoryBase>>>
142-
CallExprRewriterFactoryBase::MethodRewriterMap = std::make_unique<std::unordered_map<
143-
std::string, std::shared_ptr<CallExprRewriterFactoryBase>>>();
143+
CallExprRewriterFactoryBase::MethodRewriterMap =
144+
std::make_unique<std::unordered_map<
145+
std::string, std::shared_ptr<CallExprRewriterFactoryBase>>>();
144146

145147
void CallExprRewriterFactoryBase::initRewriterMap() {
146148
if (DpctGlobalInfo::useSYCLCompat()) {
@@ -162,6 +164,7 @@ void CallExprRewriterFactoryBase::initRewriterMap() {
162164
initRewriterMapMisc();
163165
initRewriterMapNccl();
164166
initRewriterMapNvshmem();
167+
initRewriterMapCUTensor();
165168
initRewriterMapStream();
166169
initRewriterMapTexture();
167170
initRewriterMapThrust();

clang/lib/DPCT/RuleInfra/CallExprRewriter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class CallExprRewriterFactoryBase {
7070
static void initRewriterMapMisc();
7171
static void initRewriterMapNccl();
7272
static void initRewriterMapNvshmem();
73+
static void initRewriterMapCUTensor();
7374
static void initRewriterMapStream();
7475
static void initRewriterMapTexture();
7576
static void initRewriterMapThrust();

clang/lib/DPCT/RulesInclude/InclusionHeaders.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ enum class RuleGroupKind : uint8_t {
3636
RK_CUB,
3737
RK_WMMA,
3838
RK_NVSHMEM,
39+
RK_CUTensor,
3940
NUM
4041
};
4142

clang/lib/DPCT/RulesInclude/InclusionHeaders.inc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,6 @@ REGIST_INCLUSION("nvshmem.h", FullMatch, NVSHMEM, Replace, false,
112112
HeaderType::HT_SHMEM)
113113
REGIST_INCLUSION("nvshmemx.h", FullMatch, NVSHMEM, Replace, false,
114114
HeaderType::HT_SHMEMX)
115+
116+
REGIST_INCLUSION("cutensor.h", FullMatch, CUTensor, Remove, true)
117+
REGIST_INCLUSION("cutensorMg.h", FullMatch, CUTensor, Remove, true)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
//===------------------------ APINamesCUTensor.inc ------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
// Helper Functions
10+
ENTRY_UNSUPPORTED("cutensorCreate", Diagnostics::API_NOT_MIGRATED)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
//===---------------------- CUTensorAPIMigration.cpp ----------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===-----------------------------------------------------------------------===//
8+
9+
#include "CUTensorAPIMigration.h"
10+
#include "RuleInfra/ExprAnalysis.h"
11+
12+
using namespace clang::dpct;
13+
using namespace clang::ast_matchers;
14+
15+
void clang::dpct::CUTensorRule::registerMatcher(ast_matchers::MatchFinder &MF) {
16+
auto CutensorAPIs = [&]() {
17+
return hasAnyName(
18+
// Helper Functions
19+
"cutensorCreate", "cutensorDestroy", "cutensorCreateTensorDescriptor",
20+
"cutensorDestroyTensorDescriptor", "cutensorGetErrorString",
21+
"cutensorGetVersion", "cutensorGetCudartVersion",
22+
// Element-wise Operations
23+
"cutensorCreateElementwiseTrinary", "cutensorElementwiseTrinaryExecute",
24+
"cutensorCreateElementwiseBinary", "cutensorElementwiseBinaryExecute",
25+
"cutensorCreatePermutation", "cutensorPermute",
26+
// Contraction Operations
27+
"cutensorCreateContraction", "cutensorContract",
28+
"cutensorCreateContractionTrinary", "cutensorContractTrinary",
29+
// Reduction Operations
30+
"cutensorCreateReduction", "cutensorReduce",
31+
// Generic Operation Functions
32+
"cutensorDestroyOperationDescriptor",
33+
"cutensorOperationDescriptorGetAttribute",
34+
"cutensorOperationDescriptorSetAttribute",
35+
"cutensorCreatePlanPreference", "cutensorDestroyPlanPreference",
36+
"cutensorPlanPreferenceSetAttribute", "cutensorEstimateWorkspaceSize",
37+
"cutensorCreatePlan", "cutensorDestroyPlan", "cutensorPlanGetAttribute",
38+
// Cache-related Operations
39+
"cutensorHandleResizePlanCache", "cutensorHandleReadPlanCacheFromFile",
40+
"cutensorHandleWritePlanCacheToFile", "cutensorReadKernelCacheFromFile",
41+
"cutensorWriteKernelCacheToFile",
42+
// Logger Functions
43+
"cutensorLoggerSetCallback", "cutensorLoggerSetFile",
44+
"cutensorLoggerOpenFile", "cutensorLoggerSetLevel",
45+
"cutensorLoggerSetMask", "cutensorLoggerForceDisable",
46+
// cuTENSORMg - General Operations
47+
"cutensorMgCreate", "cutensorMgDestroy",
48+
"cutensorMgCreateTensorDescriptor", "cutensorMgDestroyTensorDescriptor",
49+
"cutensorMgCreateCopyDescriptor", "cutensorMgDestroyCopyDescriptor",
50+
"cutensorMgCopyGetWorkspace", "cutensorMgCreateCopyPlan",
51+
"cutensorMgDestroyCopyPlan", "cutensorMgCopy",
52+
// cuTENSORMg - Contraction Operations
53+
"cutensorMgCreateContractionDescriptor",
54+
"cutensorMgDestroyContractionDescriptor",
55+
"cutensorMgCreateContractionFind", "cutensorMgDestroyContractionFind",
56+
"cutensorMgContractionGetWorkspace", "cutensorMgCreateContractionPlan",
57+
"cutensorMgDestroyContractionPlan", "cutensorMgContraction");
58+
};
59+
60+
MF.addMatcher(callExpr(callee(functionDecl(CutensorAPIs()))).bind("call"),
61+
this);
62+
}
63+
64+
void clang::dpct::CUTensorRule::runRule(
65+
const ast_matchers::MatchFinder::MatchResult &Result) {
66+
if (const CallExpr *CE = getNodeAsType<CallExpr>(Result, "call")) {
67+
std::string FuncName = "";
68+
const FunctionDecl *FD = CE->getDirectCallee();
69+
if (FD) {
70+
FuncName = FD->getNameInfo().getName().getAsString();
71+
}
72+
73+
report(CE->getBeginLoc(), Diagnostics::API_NOT_MIGRATED, false, FuncName);
74+
}
75+
76+
return;
77+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===----------------------- CUTensorAPIMigration.h -----------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef CUTENSOR_API_MIGRATION_H
10+
#define CUTENSOR_API_MIGRATION_H
11+
12+
#include "ASTTraversal.h"
13+
14+
using namespace clang::ast_matchers;
15+
16+
namespace clang {
17+
namespace dpct {
18+
19+
class CUTensorRule : public NamedMigrationRule<CUTensorRule> {
20+
public:
21+
void registerMatcher(ast_matchers::MatchFinder &MF) override;
22+
void runRule(const ast_matchers::MatchFinder::MatchResult &Result);
23+
};
24+
25+
} // namespace dpct
26+
} // namespace clang
27+
28+
#endif // CUTENSOR_API_MIGRATION_H
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
//===-------------------- CallExprRewriterCUTensor.cpp --------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "RuleInfra/CallExprRewriter.h"
10+
#include "RuleInfra/CallExprRewriterCommon.h"
11+
12+
namespace clang {
13+
namespace dpct {
14+
15+
#define REWRITER_FACTORY_ENTRY(FuncName, RewriterFactory, ...) \
16+
{FuncName, std::make_shared<RewriterFactory>(FuncName, __VA_ARGS__)},
17+
#define UNSUPPORTED_FACTORY_ENTRY(FuncName, MsgID) \
18+
REWRITER_FACTORY_ENTRY(FuncName, \
19+
UnsupportFunctionRewriterFactory<std::string>, MsgID, \
20+
FuncName)
21+
#define ENTRY_UNSUPPORTED(SOURCEAPINAME, MSGID) \
22+
UNSUPPORTED_FACTORY_ENTRY(SOURCEAPINAME, MSGID)
23+
24+
void CallExprRewriterFactoryBase::initRewriterMapCUTensor() {
25+
RewriterMap->merge(
26+
std::unordered_map<std::string,
27+
std::shared_ptr<CallExprRewriterFactoryBase>>({
28+
#include "APINamesCUTensor.inc"
29+
}));
30+
}
31+
32+
#undef ENTRY_UNSUPPORTED
33+
#undef UNSUPPORTED_FACTORY_ENTRY
34+
#undef REWRITER_FACTORY_ENTRY
35+
36+
} // namespace dpct
37+
} // namespace clang

0 commit comments

Comments
 (0)