Skip to content

Commit 4820d71

Browse files
authored
[SYCLomatic] Fix the issue that wmma type is not processed well if namespace alias used (#2820)
Signed-off-by: intwanghao <hao3.wang@intel.com>
1 parent 1206a52 commit 4820d71

File tree

3 files changed

+41
-1
lines changed

3 files changed

+41
-1
lines changed

clang/lib/DPCT/RuleInfra/ExprAnalysis.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,12 @@ void ExprAnalysis::analyzeExpr(const DeclRefExpr *DRE) {
525525
dyn_cast<NamespaceDecl>(Qualifier->getAsNamespace())) {
526526
CTSName = getNameSpace(NSD) + "::" + DRE->getNameInfo().getAsString();
527527
}
528+
} else if (auto NA = Qualifier->getAsNamespaceAlias()) {
529+
auto ND = NA->getNamespace();
530+
if (ND && (ND->getName() == "wmma") &&
531+
dpct::DpctGlobalInfo::isInCudaPath(ND->getBeginLoc())) {
532+
CTSName = getNameSpace(ND) + "::" + DRE->getNameInfo().getAsString();
533+
}
528534
} else if (!IsNamespaceOrAlias || !IsSpecicalAPI) {
529535
if (DRE->getDecl()->isCXXClassMember()) {
530536
std::string Result;

clang/lib/DPCT/RulesLang/RulesLangNoneAPIAndType.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1123,8 +1123,10 @@ void NamespaceRule::runRule(const MatchFinder::MatchResult &Result) {
11231123
} else if (auto NAD = getAssistNodeAsType<NamespaceAliasDecl>(
11241124
Result, "namespaceAlias")) {
11251125
std::string Namespace = NAD->getNamespace()->getNameAsString();
1126-
if (Namespace == "cooperative_groups" || Namespace == "placeholders")
1126+
if (Namespace == "cooperative_groups" || Namespace == "placeholders" ||
1127+
Namespace == "wmma") {
11271128
emplaceTransformation(new ReplaceDecl(NAD, ""));
1129+
}
11281130
} else if (auto UD = getAssistNodeAsType<UsingDecl>(Result, "using")) {
11291131
auto &SM = DpctGlobalInfo::getSourceManager();
11301132
SourceLocation Beg, End;

clang/test/dpct/wmma2.cu

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// clang-format off
2+
// UNSUPPORTED: cuda-8.0, cuda-9.0, cuda-9.1, cuda-9.2, cuda-10.0
3+
// UNSUPPORTED: v8.0, v9.0, v9.1, v9.2, v10.0
4+
// RUN: dpct --format-range=none --use-experimental-features=matrix -out-root %T/wmma2 %s --cuda-include-path="%cuda-path/include" -- -std=c++14 -x cuda --cuda-host-only
5+
// RUN: FileCheck --input-file %T/wmma2/wmma2.dp.cpp --match-full-lines %s
6+
// RUN: %if build_lit %{icpx -c -fsycl -DNO_BUILD_TEST %T/wmma2/wmma2.dp.cpp -o %T/wmma2/wmma2.dp.o %}
7+
8+
#include <assert.h>
9+
#include <cuda.h>
10+
#include <iostream>
11+
#include <mma.h>
12+
// CHECK: #include <sycl/sycl.hpp>
13+
// CHECK: #include <dpct/dpct.hpp>
14+
namespace wmmaa = nvcuda::wmma;
15+
16+
template<typename T>
17+
__global__ void simple_wmma_gemm(T *d) {
18+
wmmaa::fragment<wmmaa::accumulator, 16, 16, 16, T> c_frag;
19+
// CHECK: sycl::ext::oneapi::experimental::matrix::joint_matrix_store(sycl::ext::oneapi::this_work_item::get_sub_group(), c_frag.get(), sycl::address_space_cast<sycl::access::address_space::generic_space, sycl::access::decorated::no, T>(d), 1, sycl::ext::oneapi::experimental::matrix::layout::row_major);
20+
wmmaa::store_matrix_sync(d, c_frag, 1, wmmaa::mem_row_major);
21+
}
22+
23+
int main() {
24+
25+
simple_wmma_gemm<half><<<1, 1>>>(nullptr);
26+
27+
simple_wmma_gemm<float><<<1, 1>>>(nullptr);
28+
29+
return 0;
30+
}
31+
32+

0 commit comments

Comments
 (0)