Skip to content

Commit 0ae1059

Browse files
authored
[SYCLomatic] Support migration for type __half2_raw (#2713)
Signed-off-by: Jiang, Zhiwei <zhiwei.jiang@intel.com>
1 parent 1122e0c commit 0ae1059

File tree

5 files changed

+72
-13
lines changed

5 files changed

+72
-13
lines changed

clang/lib/DPCT/RuleInfra/MapNames.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,8 @@ void MapNames::setExplicitNamespaceMap(
813813
{"cusparseSpGEMMAlg_t", std::make_shared<TypeNameRule>("int")},
814814
{"cusparseSpSVAlg_t", std::make_shared<TypeNameRule>("int")},
815815
{"__half_raw", std::make_shared<TypeNameRule>("uint16_t")},
816+
{"__half2_raw",
817+
std::make_shared<TypeNameRule>(MapNames::getClNamespace() + "ushort2")},
816818
{"cudaFuncAttributes",
817819
std::make_shared<TypeNameRule>(MapNames::getDpctNamespace() +
818820
"kernel_function_info")},

clang/lib/DPCT/RulesLang/MapNamesLang.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ const std::map<std::string, int> MapNamesLang::VectorTypeMigratedTypeSizeMap{
9191
{"ulonglong4", 32}, {"float1", 4}, {"float2", 8},
9292
{"float3", 16}, {"float4", 16}, {"double1", 8},
9393
{"double2", 16}, {"double3", 32}, {"double4", 32},
94-
{"__half", 2}, {"__half2", 4}, {"__half_raw", 2}};
94+
{"__half", 2}, {"__half2", 4}, {"__half_raw", 2},
95+
{"__half2_raw", 4}};
9596

9697
const std::map<clang::dpct::KernelArgType, int>
9798
MapNamesLang::KernelArgTypeSizeMap{

clang/lib/DPCT/RulesLang/MapNamesLang.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ const std::string StringLiteralUnsupported{"UNSUPPORTED"};
2929
"ulonglong1", "longlong2", "ulonglong2", "longlong3", "ulonglong3", \
3030
"longlong4", "ulonglong4", "double1", "double2", "double3", "double4", \
3131
"__half", "__half2", "half", "half2", "__nv_bfloat16", "nv_bfloat16", \
32-
"__nv_bfloat162", "nv_bfloat162", "__half_raw"
32+
"__nv_bfloat162", "nv_bfloat162", "__half_raw", "__half2_raw"
3333
#define VECTORTYPE2MARRAYNAMES "__nv_bfloat162", "nv_bfloat162"
3434

3535
/// Record mapping between names

clang/lib/DPCT/RulesLang/RulesLang.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,12 +1218,12 @@ void VectorTypeNamespaceRule::registerMatcher(MatchFinder &MF) {
12181218
"longlong1", "ulonglong1", "double1", "__half_raw")))
12191219
.bind("inherit"),
12201220
this);
1221-
// Matcher for __half_raw implicitly convert to half.
1221+
// Matcher for __half_raw/__half2_raw implicitly convert to half/half2.
12221222
MF.addMatcher(
12231223
declRefExpr(allOf(unless(hasParent(memberExpr())),
12241224
unless(hasParent(unaryOperator(hasOperatorName("&")))),
1225-
to(varDecl(hasType(qualType(hasDeclaration(
1226-
namedDecl(hasAnyName("__half_raw"))))))),
1225+
to(varDecl(hasType(qualType(hasDeclaration(namedDecl(
1226+
hasAnyName("__half_raw", "__half2_raw"))))))),
12271227
hasParent(implicitCastExpr())))
12281228
.bind("halfRawExpr"),
12291229
this);
@@ -1405,7 +1405,7 @@ void VectorTypeNamespaceRule::runRule(const MatchFinder::MatchResult &Result) {
14051405
UETT, Diagnostics::SIZEOF_WARNING, true, argTypeName,
14061406
"Check that the allocated memory size in the migrated code is correct");
14071407
}
1408-
// Runrule for __half_raw implicitly convert to half.
1408+
// Run rule for __half_raw/__half2_raw implicitly convert to half/half2.
14091409
if (auto DRE = getNodeAsType<DeclRefExpr>(Result, "halfRawExpr")) {
14101410
if (const auto *RT =
14111411
DRE->getType().getCanonicalType()->getAs<RecordType>()) {
@@ -1414,13 +1414,17 @@ void VectorTypeNamespaceRule::runRule(const MatchFinder::MatchResult &Result) {
14141414
}
14151415
ExprAnalysis EA;
14161416
std::string Replacement;
1417-
llvm::raw_string_ostream OS(Replacement);
1418-
OS << MapNames::getClNamespace() + "bit_cast<" +
1419-
MapNames::getClNamespace() + "half>(";
14201417
EA.analyze(DRE);
1421-
OS << EA.getReplacedString();
1422-
OS << ")";
1423-
OS.flush();
1418+
if (DRE->getType().getCanonicalType().getAsString() == "__half2_raw") {
1419+
llvm::raw_string_ostream OS(Replacement);
1420+
OS << EA.getReplacedString() << ".as<" << MapNames::getClNamespace()
1421+
<< "half2>()";
1422+
} else {
1423+
llvm::raw_string_ostream OS(Replacement);
1424+
OS << MapNames::getClNamespace() << "bit_cast<"
1425+
<< MapNames::getClNamespace() << "half>(" << EA.getReplacedString()
1426+
<< ")";
1427+
}
14241428
emplaceTransformation(new ReplaceStmt(DRE, Replacement));
14251429
return;
14261430
}

clang/test/dpct/half_raw.cu

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
#include <cuda_fp16.h>
88
#include <cuda_runtime.h>
9-
int main() {
9+
10+
void foo1() {
1011
// CHECK: uint16_t one_h{0x3C00};
1112
__half_raw one_h{0x3C00};
1213
// CHECK: uint16_t zero_h{0};
@@ -40,3 +41,54 @@ int main() {
4041
// CHECK: uint16_t *ptr1 = &one_h;
4142
__half_raw *ptr1 = &one_h;
4243
}
44+
45+
void foo2() {
46+
// CHECK: sycl::ushort2 one_h{0x3C00};
47+
__half2_raw one_h{0x3C00};
48+
// CHECK: sycl::ushort2 zero_h{0};
49+
__half2_raw zero_h{0};
50+
// CHECK: sycl::ushort2 *ptr = new sycl::ushort2{0};
51+
__half2_raw *ptr = new __half2_raw{0};
52+
// clang-format off
53+
// CHECK: ptr->x() = 0x3C00;
54+
ptr->x = 0x3C00;
55+
// CHECK: ptr ->x() = 0x3C00;
56+
ptr ->x = 0x3C00;
57+
// CHECK: ptr-> x() = 0x3C00;
58+
ptr-> x = 0x3C00;
59+
// CHECK: ptr -> x() = 0x3C00;
60+
ptr -> x = 0x3C00;
61+
// CHECK: zero_h.x() = 0x3C00;
62+
zero_h.x = 0x3C00;
63+
// CHECK: zero_h .x() = 0x3C00;
64+
zero_h .x = 0x3C00;
65+
// CHECK: zero_h. x() = 0x3C00;
66+
zero_h. x = 0x3C00;
67+
// CHECK: zero_h . x() = 0x3C00;
68+
zero_h . x = 0x3C00;
69+
// CHECK: ptr->y() = 0x3C00;
70+
ptr->y = 0x3C00;
71+
// CHECK: ptr ->y() = 0x3C00;
72+
ptr ->y = 0x3C00;
73+
// CHECK: ptr-> y() = 0x3C00;
74+
ptr-> y = 0x3C00;
75+
// CHECK: ptr -> y() = 0x3C00;
76+
ptr -> y = 0x3C00;
77+
// CHECK: zero_h.y() = 0x3C00;
78+
zero_h.y = 0x3C00;
79+
// CHECK: zero_h .y() = 0x3C00;
80+
zero_h .y = 0x3C00;
81+
// CHECK: zero_h. y() = 0x3C00;
82+
zero_h. y = 0x3C00;
83+
// CHECK: zero_h . y() = 0x3C00;
84+
zero_h . y = 0x3C00;
85+
// clang-format on
86+
// CHECK: sycl::half2 alpha = one_h.as<sycl::half2>();
87+
half2 alpha = one_h;
88+
// CHECK: alpha = one_h.as<sycl::half2>();
89+
alpha = one_h;
90+
// CHECK: uint16_t as = zero_h.x();
91+
uint16_t as = zero_h.x;
92+
// CHECK: sycl::ushort2 *ptr1 = &one_h;
93+
__half2_raw *ptr1 = &one_h;
94+
}

0 commit comments

Comments
 (0)