Skip to content

Commit 7d3f6e9

Browse files
authored
[SYCLomatic] Add explicit cast if CUDA enum type is migrated to scoped type and the value is compared to a different type value. (#2829)
Signed-off-by: Chen, Sheng S <sheng.s.chen@intel.com>
1 parent 9f1b8c6 commit 7d3f6e9

File tree

4 files changed

+77
-0
lines changed

4 files changed

+77
-0
lines changed

clang/lib/DPCT/ASTTraversal.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ void MigrationRule::emplaceTransformation(TextModification *TM) {
9393
// RuleLang
9494
REGISTER_RULE(IterationSpaceBuiltinRule, PassKind::PK_Analysis)
9595
REGISTER_RULE(ErrorHandlingIfStmtRule, PassKind::PK_Migration)
96+
REGISTER_RULE(CastScopedEnumTypeRule, PassKind::PK_Migration)
9697
REGISTER_RULE(ErrorHandlingHostAPIRule, PassKind::PK_Migration)
9798
REGISTER_RULE(AtomicFunctionRule, PassKind::PK_Migration)
9899
REGISTER_RULE(ZeroLengthArrayRule, PassKind::PK_Migration)

clang/lib/DPCT/RulesLang/RulesLang.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4515,6 +4515,50 @@ void StreamAPICallRule::runRule(const MatchFinder::MatchResult &Result) {
45154515
}
45164516
}
45174517

4518+
void CastScopedEnumTypeRule::registerMatcher(ast_matchers::MatchFinder &MF) {
4519+
MF.addMatcher(binaryOperator(isComparisonOperator()).bind("binOp"), this);
4520+
}
4521+
void CastScopedEnumTypeRule::runRule(
4522+
const ast_matchers::MatchFinder::MatchResult &Result) {
4523+
auto BO = getNodeAsType<BinaryOperator>(Result, "binOp");
4524+
if (!BO)
4525+
return;
4526+
4527+
// List the types don't need to explicit cast type after migration.
4528+
const std::unordered_set<std::string> TypeNoCast = {
4529+
"int", MapNames::getDpctNamespace() + "err0",
4530+
MapNames::getDpctNamespace() + "err1",
4531+
MapNames::getDpctNamespace() + "pointer_attributes"};
4532+
4533+
auto InsertEnumCast = [&](const Expr *E) {
4534+
const clang::EnumDecl *EnumDecl =
4535+
E->getType().getCanonicalType()->getAs<clang::EnumType>()->getDecl();
4536+
4537+
std::string EnumName = EnumDecl->getNameAsString();
4538+
std::string ReplacedName =
4539+
MapNames::findReplacedName(MapNames::TypeNamesMap, EnumName);
4540+
4541+
if (TypeNoCast.count(ReplacedName) || ReplacedName == EnumName ||
4542+
EnumName.empty() ||
4543+
ReplacedName.empty()) // EnumName Empty means the enum is Anonymous
4544+
return;
4545+
if (dpct::DpctGlobalInfo::isInCudaPath(EnumDecl->getLocation())) {
4546+
insertAroundStmt(E, "static_cast<int>(", ")");
4547+
}
4548+
};
4549+
auto LHSExpr = BO->getLHS()->IgnoreImpCasts();
4550+
auto RHSExpr = BO->getRHS()->IgnoreImpCasts();
4551+
4552+
if (LHSExpr->getType()->isEnumeralType() && !dyn_cast<CallExpr>(LHSExpr) &&
4553+
!RHSExpr->getType()->isEnumeralType()) {
4554+
InsertEnumCast(LHSExpr);
4555+
} else if (!LHSExpr->getType()->isEnumeralType() &&
4556+
RHSExpr->getType()->isEnumeralType() &&
4557+
!dyn_cast<CallExpr>(RHSExpr)) {
4558+
InsertEnumCast(RHSExpr);
4559+
}
4560+
}
4561+
45184562
void KernelCallRefRule::registerMatcher(ast_matchers::MatchFinder &MF) {
45194563
MF.addMatcher(
45204564
functionDecl(

clang/lib/DPCT/RulesLang/RulesLang.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,14 @@ class StreamAPICallRule : public NamedMigrationRule<StreamAPICallRule> {
435435
void runRule(const ast_matchers::MatchFinder::MatchResult &Result);
436436
};
437437

438+
/// Migration rule for binary operator calls
439+
class CastScopedEnumTypeRule
440+
: public NamedMigrationRule<CastScopedEnumTypeRule> {
441+
public:
442+
void registerMatcher(ast_matchers::MatchFinder &MF) override;
443+
void runRule(const ast_matchers::MatchFinder::MatchResult &Result);
444+
};
445+
438446
/// Migration rule for kernel API calls
439447
class KernelCallRule : public NamedMigrationRule<KernelCallRule> {
440448
std::unordered_set<unsigned> Insertions;

clang/test/dpct/enum_type.cu

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: dpct --format-range=none -out-root %T/enum_type %s --cuda-include-path="%cuda-path/include" -- -std=c++14 -x cuda --cuda-host-only
2+
// RUN: FileCheck --input-file %T/enum_type/enum_type.dp.cpp --match-full-lines %s
3+
// RUN: %if build_lit %{icpx -c -fsycl -DNO_BUILD_TEST %T/enum_type/enum_type.dp.cpp -o %T/enum_type/enum_type.dp.o %}
4+
#include <cuda.h>
5+
int main() {
6+
// CHECK: sycl::usm::alloc mem_type;
7+
CUmemorytype mem_type;
8+
// CHECK: if (static_cast<int>(mem_type) == 0)
9+
if (mem_type == 0)
10+
;
11+
// CHECK: if (0 == static_cast<int>(mem_type))
12+
if (0 == mem_type)
13+
;
14+
// CHECK: if (sycl::usm::alloc::host == mem_type)
15+
if (CU_MEMORYTYPE_HOST == mem_type)
16+
;
17+
// CHECK: if (0 <= static_cast<int>(mem_type))
18+
if (0 <= mem_type)
19+
;
20+
// CHECK: if (static_cast<int>(mem_type) > 0)
21+
if (mem_type > 0)
22+
;
23+
return 0;
24+
}

0 commit comments

Comments
 (0)