Skip to content

Commit 251a7e5

Browse files
authored
[SYCLomatic][PTX] Support migration of PTX instruction mul.f16x2 (#2771)
Signed-off-by: chenwei.sun <chenwei.sun@intel.com>
1 parent 9690a19 commit 251a7e5

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,6 +1363,7 @@ class SYCLGen : public SYCLGenBase {
13631363
Type->getKind() == InlineAsmBuiltinType::s32 ||
13641364
Type->getKind() == InlineAsmBuiltinType::u32 ||
13651365
Type->getKind() == InlineAsmBuiltinType::s64 ||
1366+
Type->getKind() == InlineAsmBuiltinType::f16x2 ||
13661367
Type->getKind() == InlineAsmBuiltinType::u64;
13671368
}
13681369

@@ -1419,6 +1420,13 @@ class SYCLGen : public SYCLGenBase {
14191420
OS() << Cast(GetWiderTypeAsString(Type), Op[0]) << " * "
14201421
<< Cast(GetWiderTypeAsString(Type), Op[1]);
14211422
// mul.lo
1423+
} else if (Type->getKind() == InlineAsmBuiltinType::f16x2) {
1424+
std::string FormatTemp =
1425+
"((sycl::vec<int, 1>({0})).as<sycl::vec<sycl::half, 2>>() * "
1426+
"(sycl::vec<int, 1>({1})).as<sycl::vec<sycl::half, "
1427+
"2>>()).as<sycl::vec<int, 1>>().x()";
1428+
1429+
OS() << llvm::formatv(FormatTemp.c_str(), Op[0], Op[1]);
14221430
} else {
14231431
// Need to add a new help function.
14241432
// OS() << Op[0] << " * " << Op[1];

clang/test/dpct/asm/mul.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,14 @@ __global__ void mul() {
5454
asm("mul.wide.u64 %0, %1, %2;" : "=r"(u64) : "r"(x), "r"(y));
5555
}
5656

57+
// CHECK: inline uint32_t mul_f16x2(uint32_t a, uint32_t b) {
58+
// CHECK-NEXT: uint32_t c;
59+
// CHECK-NEXT: c = ((sycl::vec<int, 1>(a)).as<sycl::vec<sycl::half, 2>>() * (sycl::vec<int, 1>(b)).as<sycl::vec<sycl::half, 2>>()).as<sycl::vec<int, 1>>().x();
60+
// CHECK-NEXT: return c;
61+
// CHECK-NEXT: }
62+
inline __device__ uint32_t mul_f16x2(uint32_t a, uint32_t b) {
63+
uint32_t c;
64+
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
65+
return c;
66+
}
5767
// clang-format on

0 commit comments

Comments
 (0)