diff --git a/clang/lib/DPCT/ASTTraversal.cpp b/clang/lib/DPCT/ASTTraversal.cpp index 427f0c4fcc8c..97127494bbb4 100644 --- a/clang/lib/DPCT/ASTTraversal.cpp +++ b/clang/lib/DPCT/ASTTraversal.cpp @@ -93,6 +93,7 @@ void MigrationRule::emplaceTransformation(TextModification *TM) { // RuleLang REGISTER_RULE(IterationSpaceBuiltinRule, PassKind::PK_Analysis) REGISTER_RULE(ErrorHandlingIfStmtRule, PassKind::PK_Migration) +REGISTER_RULE(CastScopedEnumTypeRule, PassKind::PK_Migration) REGISTER_RULE(ErrorHandlingHostAPIRule, PassKind::PK_Migration) REGISTER_RULE(AtomicFunctionRule, PassKind::PK_Migration) REGISTER_RULE(ZeroLengthArrayRule, PassKind::PK_Migration) diff --git a/clang/lib/DPCT/RulesLang/RulesLang.cpp b/clang/lib/DPCT/RulesLang/RulesLang.cpp index d40666018a31..6a56daedebab 100644 --- a/clang/lib/DPCT/RulesLang/RulesLang.cpp +++ b/clang/lib/DPCT/RulesLang/RulesLang.cpp @@ -4534,6 +4534,50 @@ void StreamAPICallRule::runRule(const MatchFinder::MatchResult &Result) { } } +void CastScopedEnumTypeRule::registerMatcher(ast_matchers::MatchFinder &MF) { + MF.addMatcher(binaryOperator(isComparisonOperator()).bind("binOp"), this); +} +void CastScopedEnumTypeRule::runRule( + const ast_matchers::MatchFinder::MatchResult &Result) { + auto BO = getNodeAsType(Result, "binOp"); + if (!BO) + return; + + // List the types don't need to explicit cast type after migration. + const std::unordered_set TypeNoCast = { + "int", MapNames::getDpctNamespace() + "err0", + MapNames::getDpctNamespace() + "err1", + MapNames::getDpctNamespace() + "pointer_attributes"}; + + auto InsertEnumCast = [&](const Expr *E) { + const clang::EnumDecl *EnumDecl = + E->getType().getCanonicalType()->getAs()->getDecl(); + + std::string EnumName = EnumDecl->getNameAsString(); + std::string ReplacedName = + MapNames::findReplacedName(MapNames::TypeNamesMap, EnumName); + + if (TypeNoCast.count(ReplacedName) || ReplacedName == EnumName || + EnumName.empty() || + ReplacedName.empty()) // EnumName Empty means the enum is Anonymous + return; + if (dpct::DpctGlobalInfo::isInCudaPath(EnumDecl->getLocation())) { + insertAroundStmt(E, "static_cast(", ")"); + } + }; + auto LHSExpr = BO->getLHS()->IgnoreImpCasts(); + auto RHSExpr = BO->getRHS()->IgnoreImpCasts(); + + if (LHSExpr->getType()->isEnumeralType() && !dyn_cast(LHSExpr) && + !RHSExpr->getType()->isEnumeralType()) { + InsertEnumCast(LHSExpr); + } else if (!LHSExpr->getType()->isEnumeralType() && + RHSExpr->getType()->isEnumeralType() && + !dyn_cast(RHSExpr)) { + InsertEnumCast(RHSExpr); + } +} + void KernelCallRefRule::registerMatcher(ast_matchers::MatchFinder &MF) { MF.addMatcher( functionDecl( diff --git a/clang/lib/DPCT/RulesLang/RulesLang.h b/clang/lib/DPCT/RulesLang/RulesLang.h index a9e83884103d..cba7ca4859a4 100644 --- a/clang/lib/DPCT/RulesLang/RulesLang.h +++ b/clang/lib/DPCT/RulesLang/RulesLang.h @@ -434,6 +434,14 @@ class StreamAPICallRule : public NamedMigrationRule { void runRule(const ast_matchers::MatchFinder::MatchResult &Result); }; +/// Migration rule for binary operator calls +class CastScopedEnumTypeRule + : public NamedMigrationRule { +public: + void registerMatcher(ast_matchers::MatchFinder &MF) override; + void runRule(const ast_matchers::MatchFinder::MatchResult &Result); +}; + /// Migration rule for kernel API calls class KernelCallRule : public NamedMigrationRule { std::unordered_set Insertions; diff --git a/clang/test/dpct/enum_type.cu b/clang/test/dpct/enum_type.cu new file mode 100644 index 000000000000..3b0654a715cf --- /dev/null +++ b/clang/test/dpct/enum_type.cu @@ -0,0 +1,24 @@ +// RUN: dpct --format-range=none --use-experimental-features=matrix -out-root %T/enum_type %s --cuda-include-path="%cuda-path/include" -- -std=c++14 -x cuda --cuda-host-only +// RUN: FileCheck --input-file %T/enum_type/enum_type.dp.cpp --match-full-lines %s +// RUN: %if build_lit %{icpx -c -fsycl -DNO_BUILD_TEST %T/enum_type/enum_type.dp.cpp -o %T/enum_type/enum_type.dp.o %} +#include +int main() { + // CHECK: sycl::usm::alloc mem_type; + CUmemorytype mem_type; + // CHECK: if (static_cast(mem_type) == 0) + if (mem_type == 0) + ; + // CHECK: if (0 == static_cast(mem_type)) + if (0 == mem_type) + ; + // CHECK: if (sycl::usm::alloc::host == mem_type) + if (CU_MEMORYTYPE_HOST == mem_type) + ; + // CHECK: if (0 <= static_cast(mem_type)) + if (0 <= mem_type) + ; + // CHECK: if (static_cast(mem_type) > 0) + if (mem_type > 0) + ; + return 0; +}