Skip to content

Commit 3bc4252

Browse files
authored
[SYCLomatic] Migrate the LibCU APIs in the template class. (#2880)
Signed-off-by: Chen, Sheng S <sheng.s.chen@intel.com>
1 parent 3980bfe commit 3bc4252

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

clang/lib/DPCT/RulesLangLib/LIBCUAPIMigration.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,13 @@ void LIBCURule::registerMatcher(ast_matchers::MatchFinder &MF) {
3737
"compare_exchange_strong", "fetch_add", "fetch_sub",
3838
"at");
3939
};
40+
auto LIBCUMemberHasNames = [&]() {
41+
return anyOf(
42+
hasMemberName("load"), hasMemberName("store"),
43+
hasMemberName("exchange"), hasMemberName("compare_exchange_weak"),
44+
hasMemberName("compare_exchange_strong"), hasMemberName("fetch_add"),
45+
hasMemberName("fetch_sub"), hasMemberName("at"));
46+
};
4047
auto LIBCUTypesHasNames = [&]() {
4148
return hasAnyName("cuda::atomic", "cuda::std::atomic",
4249
"cuda::std::array");
@@ -47,6 +54,9 @@ void LIBCURule::registerMatcher(ast_matchers::MatchFinder &MF) {
4754
callee(cxxMethodDecl(LIBCUMemberFuncHasNames()))))
4855
.bind("MemberCall"),
4956
this);
57+
MF.addMatcher(cxxDependentScopeMemberExpr(LIBCUMemberHasNames())
58+
.bind("DependentMemCall"),
59+
this);
5060
}
5161
{
5262
MF.addMatcher(dependentScopeDeclRefExpr().bind("DependentScope"),
@@ -88,6 +98,16 @@ void LIBCURule::runRule(const ast_matchers::MatchFinder::MatchResult &Result) {
8898
if (const CXXMemberCallExpr *MC =
8999
getNodeAsType<CXXMemberCallExpr>(Result, "MemberCall")) {
90100
EA.analyze(MC);
101+
} else if (const CXXDependentScopeMemberExpr *CDSE =
102+
getNodeAsType<CXXDependentScopeMemberExpr>(
103+
Result, "DependentMemCall")) {
104+
auto Parent = dpct::DpctGlobalInfo::getContext().getParents(*CDSE);
105+
auto *CE = Parent[0].get<CallExpr>();
106+
if (CE) {
107+
for (size_t i = 0; i < CE->getNumArgs(); i++) {
108+
EA.analyze(CE->getArg(i));
109+
}
110+
}
91111
} else if (const CallExpr *CE = getNodeAsType<CallExpr>(Result, "FuncCall")) {
92112
EA.analyze(CE);
93113
} else if (auto TL = getNodeAsType<TypeLoc>(Result, "TypeLoc")) {

clang/test/dpct/LibCU/libcu_atomic.cu

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,24 @@
1010
// CHECK: #include <dpct/atomic.hpp>
1111
#include <cuda/atomic>
1212

13+
template <class T> bool is_complete(const T &result) {
14+
return !(result == static_cast<T>(0.0) && std::signbit(result));
15+
}
16+
17+
template <typename T> struct ReduceArg {
18+
19+
private:
20+
// CHECK: dpct::atomic<T, sycl::memory_scope::system, sycl::memory_order::relaxed> *result_h;
21+
cuda::atomic<T, cuda::thread_scope_system> *result_h;
22+
23+
public:
24+
void complete() {
25+
//CHECK: while (!is_complete(result_h[0].load(sycl::memory_order::relaxed))) {
26+
while (!is_complete(result_h[0].load(cuda::std::memory_order_relaxed))) {
27+
}
28+
}
29+
};
30+
1331
int main(){
1432
// CHECK: sycl::atomic_fence(sycl::memory_order::release, sycl::memory_scope::system);
1533
cuda::atomic_thread_fence(cuda::std::memory_order_release, cuda::thread_scope_system);

0 commit comments

Comments
 (0)