Skip to content

[SYCLomatic] Add explicit cast if CUDA enum type is migrated scoped type and the value is compared to a different type value. #2829

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: SYCLomatic
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clang/lib/DPCT/ASTTraversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ void MigrationRule::emplaceTransformation(TextModification *TM) {
// RuleLang
REGISTER_RULE(IterationSpaceBuiltinRule, PassKind::PK_Analysis)
REGISTER_RULE(ErrorHandlingIfStmtRule, PassKind::PK_Migration)
REGISTER_RULE(BinaryOperatorCallRule, PassKind::PK_Migration)
REGISTER_RULE(ErrorHandlingHostAPIRule, PassKind::PK_Migration)
REGISTER_RULE(AtomicFunctionRule, PassKind::PK_Migration)
REGISTER_RULE(ZeroLengthArrayRule, PassKind::PK_Migration)
Expand Down
54 changes: 54 additions & 0 deletions clang/lib/DPCT/RulesLang/RulesLang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4534,6 +4534,60 @@ void StreamAPICallRule::runRule(const MatchFinder::MatchResult &Result) {
}
}

void BinaryOperatorCallRule::registerMatcher(ast_matchers::MatchFinder &MF) {
MF.addMatcher(binaryOperator(isComparisonOperator()).bind("binOp"), this);
}
void BinaryOperatorCallRule::runRule(
const ast_matchers::MatchFinder::MatchResult &Result) {
auto BO = getNodeAsType<BinaryOperator>(Result, "binOp");
if (!BO)
return;

// The migration rule covered these test type, no need to cast.
Copy link
Preview

Copilot AI May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] There appears to be a typographical error in the comment. Consider revising the comment to clarify the intended meaning (e.g., "The migration rule covers these enum types so no cast is required.").

Suggested change
// The migration rule covered these test type, no need to cast.
// The migration rule covers these enum types, so no cast is required.

Copilot uses AI. Check for mistakes.

std::vector<std::string> MigratedEnumType = {"cudaMemoryType", "cufftResult_t",
"cudaError", "CUresult",
"cudaError_enum", "cudaComputeMode"};

auto InsertEnumCast = [&](const Expr *E) {
QualType EType = E->getType();
if (EType->isEnumeralType()) {

if (const auto *EnumType =
EType.getCanonicalType()->getAs<clang::EnumType>()) {

const clang::EnumDecl *EnumDecl = EnumType->getDecl();
std::string EnumName = EnumDecl->getNameAsString();
clang::SourceLocation EnumLoc = EnumDecl->getLocation();
auto it = std::find(MigratedEnumType.begin(), MigratedEnumType.end(),
EnumName);
if (it != MigratedEnumType.end() ||
EnumName.empty()) // Empty means the enum is Anonymous
return;
if (dpct::DpctGlobalInfo::isInCudaPath(EnumLoc) &&
!EnumDecl->isScoped()) {
SourceLocation EndLoc = Lexer::getLocForEndOfToken(
E->getEndLoc(), 0, DpctGlobalInfo::getSourceManager(),
LangOptions());
DpctGlobalInfo::getInstance().addReplacement(
std::make_shared<ExtReplacement>(
DpctGlobalInfo::getSourceManager(), E->getBeginLoc(), 0,
"static_cast<int>(", nullptr));
DpctGlobalInfo::getInstance().addReplacement(
std::make_shared<ExtReplacement>(
DpctGlobalInfo::getSourceManager(), EndLoc, 0, ")", nullptr));
}
}
}
};
auto LHSType = BO->getLHS()->IgnoreImpCasts()->getType();
auto RHSType = BO->getRHS()->IgnoreImpCasts()->getType();
if (LHSType->isEnumeralType() && RHSType->isEnumeralType()) {
return;
}
InsertEnumCast(BO->getLHS()->IgnoreImpCasts());
InsertEnumCast(BO->getRHS()->IgnoreImpCasts());
}

void KernelCallRefRule::registerMatcher(ast_matchers::MatchFinder &MF) {
MF.addMatcher(
functionDecl(
Expand Down
7 changes: 7 additions & 0 deletions clang/lib/DPCT/RulesLang/RulesLang.h
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,13 @@ class StreamAPICallRule : public NamedMigrationRule<StreamAPICallRule> {
void runRule(const ast_matchers::MatchFinder::MatchResult &Result);
};

/// Migration rule for binary operator calls
class BinaryOperatorCallRule : public NamedMigrationRule<BinaryOperatorCallRule> {
public:
void registerMatcher(ast_matchers::MatchFinder &MF) override;
void runRule(const ast_matchers::MatchFinder::MatchResult &Result);
};

/// Migration rule for kernel API calls
class KernelCallRule : public NamedMigrationRule<KernelCallRule> {
std::unordered_set<unsigned> Insertions;
Expand Down
21 changes: 21 additions & 0 deletions clang/test/dpct/enum_type.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: dpct --format-range=none --use-experimental-features=matrix -out-root %T/enum_type %s --cuda-include-path="%cuda-path/include" -- -std=c++14 -x cuda --cuda-host-only
// RUN: FileCheck --input-file %T/enum_type/enum_type.dp.cpp --match-full-lines %s
// RUN: %if build_lit %{icpx -c -fsycl -DNO_BUILD_TEST %T/enum_type/enum_type.dp.cpp -o %T/enum_type/enum_type.dp.o %}
#include <cuda.h>
int main() {
// CHECK: sycl::usm::alloc mem_type;
CUmemorytype mem_type;
// CHECK: if (static_cast<int>(mem_type) == 0)
if (mem_type == 0)
;
// CHECK: if (0 == static_cast<int>(mem_type))
if (0 == mem_type)
;
// CHECK: if (0 <= static_cast<int>(mem_type))
if (0 <= mem_type)
;
// CHECK: if (static_cast<int>(mem_type) > 0)
if (mem_type > 0)
;
return 0;
}
Loading