Skip to content

[SYCLomatic] Add explicit cast if CUDA enum type is migrated scoped type and the value is compared to a different type value. #2829

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: SYCLomatic
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clang/lib/DPCT/ASTTraversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ void MigrationRule::emplaceTransformation(TextModification *TM) {
// RuleLang
REGISTER_RULE(IterationSpaceBuiltinRule, PassKind::PK_Analysis)
REGISTER_RULE(ErrorHandlingIfStmtRule, PassKind::PK_Migration)
REGISTER_RULE(CastScopedEnumTypeRule, PassKind::PK_Migration)
REGISTER_RULE(ErrorHandlingHostAPIRule, PassKind::PK_Migration)
REGISTER_RULE(AtomicFunctionRule, PassKind::PK_Migration)
REGISTER_RULE(ZeroLengthArrayRule, PassKind::PK_Migration)
Expand Down
44 changes: 44 additions & 0 deletions clang/lib/DPCT/RulesLang/RulesLang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4534,6 +4534,50 @@ void StreamAPICallRule::runRule(const MatchFinder::MatchResult &Result) {
}
}

void CastScopedEnumTypeRule::registerMatcher(ast_matchers::MatchFinder &MF) {
MF.addMatcher(binaryOperator(isComparisonOperator()).bind("binOp"), this);
}
void CastScopedEnumTypeRule::runRule(
const ast_matchers::MatchFinder::MatchResult &Result) {
auto BO = getNodeAsType<BinaryOperator>(Result, "binOp");
if (!BO)
return;

// List the types don't need to explicit cast type after migration.
const std::unordered_set<std::string> TypeNoCast = {
"int", MapNames::getDpctNamespace() + "err0",
MapNames::getDpctNamespace() + "err1",
MapNames::getDpctNamespace() + "pointer_attributes"};

auto InsertEnumCast = [&](const Expr *E) {
const clang::EnumDecl *EnumDecl =
E->getType().getCanonicalType()->getAs<clang::EnumType>()->getDecl();

std::string EnumName = EnumDecl->getNameAsString();
std::string ReplacedName =
MapNames::findReplacedName(MapNames::TypeNamesMap, EnumName);

if (TypeNoCast.count(ReplacedName) || ReplacedName == EnumName ||
EnumName.empty() ||
ReplacedName.empty()) // EnumName Empty means the enum is Anonymous
return;
if (dpct::DpctGlobalInfo::isInCudaPath(EnumDecl->getLocation())) {
insertAroundStmt(E, "static_cast<int>(", ")");
}
};
auto LHSExpr = BO->getLHS()->IgnoreImpCasts();
auto RHSExpr = BO->getRHS()->IgnoreImpCasts();

if (LHSExpr->getType()->isEnumeralType() && !dyn_cast<CallExpr>(LHSExpr) &&
!RHSExpr->getType()->isEnumeralType()) {
InsertEnumCast(LHSExpr);
} else if (!LHSExpr->getType()->isEnumeralType() &&
RHSExpr->getType()->isEnumeralType() &&
!dyn_cast<CallExpr>(RHSExpr)) {
InsertEnumCast(RHSExpr);
}
}

void KernelCallRefRule::registerMatcher(ast_matchers::MatchFinder &MF) {
MF.addMatcher(
functionDecl(
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/DPCT/RulesLang/RulesLang.h
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,14 @@ class StreamAPICallRule : public NamedMigrationRule<StreamAPICallRule> {
void runRule(const ast_matchers::MatchFinder::MatchResult &Result);
};

/// Migration rule for binary operator calls
class CastScopedEnumTypeRule
: public NamedMigrationRule<CastScopedEnumTypeRule> {
public:
void registerMatcher(ast_matchers::MatchFinder &MF) override;
void runRule(const ast_matchers::MatchFinder::MatchResult &Result);
};

/// Migration rule for kernel API calls
class KernelCallRule : public NamedMigrationRule<KernelCallRule> {
std::unordered_set<unsigned> Insertions;
Expand Down
24 changes: 24 additions & 0 deletions clang/test/dpct/enum_type.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: dpct --format-range=none --use-experimental-features=matrix -out-root %T/enum_type %s --cuda-include-path="%cuda-path/include" -- -std=c++14 -x cuda --cuda-host-only
// RUN: FileCheck --input-file %T/enum_type/enum_type.dp.cpp --match-full-lines %s
// RUN: %if build_lit %{icpx -c -fsycl -DNO_BUILD_TEST %T/enum_type/enum_type.dp.cpp -o %T/enum_type/enum_type.dp.o %}
#include <cuda.h>
int main() {
// CHECK: sycl::usm::alloc mem_type;
CUmemorytype mem_type;
// CHECK: if (static_cast<int>(mem_type) == 0)
if (mem_type == 0)
;
// CHECK: if (0 == static_cast<int>(mem_type))
if (0 == mem_type)
;
// CHECK: if (sycl::usm::alloc::host == mem_type)
if (CU_MEMORYTYPE_HOST == mem_type)
;
// CHECK: if (0 <= static_cast<int>(mem_type))
if (0 <= mem_type)
;
// CHECK: if (static_cast<int>(mem_type) > 0)
if (mem_type > 0)
;
return 0;
}