Skip to content

Commit 9690a19

Browse files
authored
[SYCLomatic][PTX] Support migration of PTX instruction ld.cg.global.v2.s16 and ld.cg.global.v4.s16 (#2769)
Signed-off-by: chenwei.sun <chenwei.sun@intel.com>
1 parent dead51d commit 9690a19

File tree

2 files changed

+109
-13
lines changed

2 files changed

+109
-13
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 75 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
609609
return false;
610610
};
611611

612-
if (CurrInst->is(asmtok::op_ld, asmtok::op_red))
612+
if (CurrInst->is(asmtok::op_red))
613613
OS() << "*";
614614
switch (Dst->getMemoryOpKind()) {
615615
case InlineAsmAddressExpr::Imm:
@@ -633,7 +633,7 @@ bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
633633
if (tryEmitStmt(Reg, Dst->getSymbol()))
634634
return SYCLGenSuccess();
635635

636-
if (CurrInst->is(asmtok::op_st))
636+
if (CurrInst->is(asmtok::op_st, asmtok::op_ld))
637637
OS() << llvm::formatv("(uintptr_t){0}", Reg);
638638
else
639639
OS() << llvm::formatv("(({0} *)((uintptr_t){1} + {2}))", Type, Reg,
@@ -2694,16 +2694,10 @@ class SYCLGen : public SYCLGenBase {
26942694
return SYCLGenSuccess();
26952695
}
26962696

2697-
bool HandleStVec(const InlineAsmInstruction *Inst, int VecNum) {
2698-
std::string Ops;
2699-
if (tryEmitStmt(Ops, Inst->getInputOperand(0)))
2700-
return SYCLGenError();
2701-
2702-
// To extract the values from the string like "{x, y, z, w}" and store them
2703-
// int Values vector
2697+
std::vector<std::string> extractValues(const std::string &Ops) {
27042698
std::vector<std::string> Values;
2705-
size_t start = 1; // Skip the '{' character
2706-
size_t end = Ops.find(',', start); // Find the first comma
2699+
size_t start = 1; // Skip the '{' character
2700+
size_t end = Ops.find(',', start);
27072701

27082702
while (end != std::string::npos) {
27092703
std::string Token = Ops.substr(start, end - start);
@@ -2720,11 +2714,24 @@ class SYCLGen : public SYCLGenBase {
27202714
std::string token = Ops.substr(start, Ops.size() - start - 1);
27212715
size_t first = token.find_first_not_of(' ');
27222716
size_t last = token.find_last_not_of(' ');
2723-
27242717
if (first != std::string::npos && last != std::string::npos) {
27252718
Values.push_back(token.substr(first, last - first + 1));
27262719
}
27272720

2721+
return Values;
2722+
}
2723+
2724+
bool HandleStVec(const InlineAsmInstruction *Inst, int VecNum) {
2725+
std::string Ops;
2726+
if (tryEmitStmt(Ops, Inst->getInputOperand(0)))
2727+
return SYCLGenError();
2728+
2729+
// To extract the values from the string like "{x, y, z, w}" and store them
2730+
// int Values vector
2731+
std::vector<std::string> Values = extractValues(Ops);
2732+
if (Values.size() != (size_t)VecNum)
2733+
return SYCLGenError();
2734+
27282735
std::string Output;
27292736
if (tryEmitStmt(Output, Inst->getOutputOperand()))
27302737
return SYCLGenError();
@@ -2790,6 +2797,43 @@ class SYCLGen : public SYCLGenBase {
27902797
return SYCLGenSuccess();
27912798
}
27922799

2800+
bool HandleLdVec(const InlineAsmInstruction *Inst, int VecNum) {
2801+
std::string Ops;
2802+
if (tryEmitStmt(Ops, Inst->getOutputOperand()))
2803+
return SYCLGenError();
2804+
2805+
// To extract the values from the string like "{x, y, z, w}" and store them
2806+
// int Values vector
2807+
std::vector<std::string> Values = extractValues(Ops);
2808+
if (Values.size() != (size_t)VecNum)
2809+
return SYCLGenError();
2810+
2811+
std::string Input;
2812+
if (tryEmitStmt(Input, Inst->getInputOperand(0)))
2813+
return SYCLGenError();
2814+
2815+
std::string Type;
2816+
if (tryEmitType(Type, Inst->getType(0)))
2817+
return SYCLGenError();
2818+
2819+
const auto *Src =
2820+
dyn_cast_or_null<InlineAsmAddressExpr>(Inst->getInputOperand(0));
2821+
if (!Src)
2822+
return SYCLGenError();
2823+
2824+
for (int Index = 0; Index < VecNum; Index++) {
2825+
OS() << llvm::formatv("{0} = *(({1} *){2} + {3}){4}", Values[Index],
2826+
Type, Input, Index,
2827+
Index == VecNum - 1 ? "" : ";\n");
2828+
if (Index < VecNum - 1) {
2829+
indent();
2830+
}
2831+
}
2832+
2833+
endstmt();
2834+
return SYCLGenSuccess();
2835+
}
2836+
27932837
bool handle_ld(const InlineAsmInstruction *Inst) override {
27942838
if (Inst->getNumInputOperands() != 1)
27952839
return SYCLGenError();
@@ -2801,14 +2845,32 @@ class SYCLGen : public SYCLGenBase {
28012845

28022846
if (!Src)
28032847
return SYCLGenError();
2848+
2849+
if (Inst->hasAttr(InstAttr::cg)) {
2850+
if (Inst->hasAttr(InstAttr::v4))
2851+
return HandleLdVec(Inst, 4);
2852+
if (Inst->hasAttr(InstAttr::v2))
2853+
return HandleLdVec(Inst, 2);
2854+
}
2855+
28042856
std::string Type;
28052857
if (tryEmitType(Type, Inst->getType(0)))
28062858
return SYCLGenError();
28072859
if (emitStmt(Dst))
28082860
return SYCLGenError();
28092861
OS() << " = ";
2810-
if (emitStmt(Src))
2862+
2863+
std::string InOp;
2864+
if (tryEmitStmt(InOp, Inst->getInputOperand(0)))
28112865
return SYCLGenError();
2866+
2867+
if (Src->getMemoryOpKind() == InlineAsmAddressExpr::RegImm) {
2868+
OS() << llvm::formatv("*(({0} *)({1} + {2}))", Type, InOp,
2869+
Src->getImmAddr()->getValue().getZExtValue());
2870+
} else {
2871+
OS() << "*" << InOp;
2872+
}
2873+
28122874
endstmt();
28132875
return SYCLGenSuccess();
28142876
}

clang/test/dpct/asm/ld.cu

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,38 @@ __device__ void shared_address_load32(uint32_t addr, uint32_t &val) {
3636
asm volatile("{ld.shared.b32 %0, [%1];}" : : "r"(val), "r"(addr) : "memory");
3737
}
3838

39+
// CHECK: inline void load_global_short2(sycl::short2 &a, const sycl::short2 *addr) {
40+
// CHECK-NEXT: short x, y, z, w;
41+
// CHECK-NEXT: x = *((int16_t *)(uintptr_t)addr + 0);
42+
// CHECK-NEXT: y = *((int16_t *)(uintptr_t)addr + 1);
43+
// CHECK-NEXT: a.x() = x;
44+
// CHECK-NEXT: a.y() = y;
45+
// CHECK-NEXT:}
46+
__device__ inline void load_global_short2(short2 &a, const short2 *addr) {
47+
short x, y, z, w;
48+
asm("ld.cg.global.v2.s16 {%0, %1}, [%2+0];" : "=h"(x), "=h"(y) : "l"(addr));
49+
a.x = x;
50+
a.y = y;
51+
}
52+
53+
// CHECK: inline void load_global_short4(sycl::short4 &a, const sycl::short4 *addr) {
54+
// CHECK-NEXT: short x, y, z, w;
55+
// CHECK-NEXT: x = *((int16_t *)(uintptr_t)addr + 0);
56+
// CHECK-NEXT: y = *((int16_t *)(uintptr_t)addr + 1);
57+
// CHECK-NEXT: z = *((int16_t *)(uintptr_t)addr + 2);
58+
// CHECK-NEXT: w = *((int16_t *)(uintptr_t)addr + 3);
59+
// CHECK-NEXT: a.x() = x;
60+
// CHECK-NEXT: a.y() = y;
61+
// CHECK-NEXT: a.z() = z;
62+
// CHECK-NEXT: a.w() = w;
63+
// CHECK-NEXT:}
64+
__device__ inline void load_global_short4(short4 &a, const short4 *addr) {
65+
short x, y, z, w;
66+
asm("ld.cg.global.v4.s16 {%0, %1, %2, %3}, [%4+0];" : "=h"(x), "=h"(y), "=h"(z), "=h"(w) : "l"(addr));
67+
a.x = x;
68+
a.y = y;
69+
a.z = z;
70+
a.w = w;
71+
}
72+
3973
// clang-format on

0 commit comments

Comments
 (0)