Skip to content

Commit 16d76bc

Browse files
authored
[SYCLomatic][PTX] Support migration of PTX instruction st.cs.global.v2.s16 and st.cs.global.v4.s16 (#2748)
Signed-off-by: chenwei.sun <chenwei.sun@intel.com>
1 parent f39c71b commit 16d76bc

File tree

3 files changed

+114
-11
lines changed

3 files changed

+114
-11
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 89 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88

99
#include "AsmMigration.h"
1010
#include "AnalysisInfo.h"
11+
#include "Diagnostics/Diagnostics.h"
12+
#include "ErrorHandle/CrashRecovery.h"
13+
#include "RuleInfra/MapNames.h"
1114
#include "RulesAsm/Parser/AsmNodes.h"
1215
#include "RulesAsm/Parser/AsmParser.h"
1316
#include "RulesAsm/Parser/AsmTokenKinds.h"
14-
#include "ErrorHandle/CrashRecovery.h"
15-
#include "Diagnostics/Diagnostics.h"
16-
#include "RuleInfra/MapNames.h"
1717
#include "TextModification.h"
1818
#include "Utility.h"
1919
#include "clang/AST/Expr.h"
@@ -609,7 +609,7 @@ bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
609609
return false;
610610
};
611611

612-
if (CurrInst->is(asmtok::op_st, asmtok::op_ld, asmtok::op_red))
612+
if (CurrInst->is(asmtok::op_ld, asmtok::op_red))
613613
OS() << "*";
614614
switch (Dst->getMemoryOpKind()) {
615615
case InlineAsmAddressExpr::Imm:
@@ -632,8 +632,12 @@ bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
632632
std::string Reg;
633633
if (tryEmitStmt(Reg, Dst->getSymbol()))
634634
return SYCLGenSuccess();
635-
OS() << llvm::formatv("(({0} *)((uintptr_t){1} + {2}))", Type, Reg,
636-
Dst->getImmAddr()->getValue().getZExtValue());
635+
636+
if (CurrInst->is(asmtok::op_st))
637+
OS() << llvm::formatv("(uintptr_t){0}", Reg);
638+
else
639+
OS() << llvm::formatv("(({0} *)((uintptr_t){1} + {2}))", Type, Reg,
640+
Dst->getImmAddr()->getValue().getZExtValue());
637641
break;
638642
}
639643
case InlineAsmAddressExpr::Var: {
@@ -2690,24 +2694,98 @@ class SYCLGen : public SYCLGenBase {
26902694
return SYCLGenSuccess();
26912695
}
26922696

2697+
bool HandleStVec(const InlineAsmInstruction *Inst, int VecNum) {
2698+
std::string Ops;
2699+
if (tryEmitStmt(Ops, Inst->getInputOperand(0)))
2700+
return SYCLGenError();
2701+
2702+
// To extract the values from the string like "{x, y, z, w}" and store them
2703+
// int Values vector
2704+
std::vector<std::string> Values;
2705+
size_t start = 1; // Skip the '{' character
2706+
size_t end = Ops.find(',', start); // Find the first comma
2707+
2708+
while (end != std::string::npos) {
2709+
std::string Token = Ops.substr(start, end - start);
2710+
size_t First = Token.find_first_not_of(' ');
2711+
size_t Last = Token.find_last_not_of(' ');
2712+
if (First != std::string::npos && Last != std::string::npos) {
2713+
Values.push_back(Token.substr(First, Last - First + 1));
2714+
}
2715+
start = end + 1;
2716+
end = Ops.find(',', start);
2717+
}
2718+
2719+
// Extract the last value after the last comma
2720+
std::string token = Ops.substr(start, Ops.size() - start - 1);
2721+
size_t first = token.find_first_not_of(' ');
2722+
size_t last = token.find_last_not_of(' ');
2723+
2724+
if (first != std::string::npos && last != std::string::npos) {
2725+
Values.push_back(token.substr(first, last - first + 1));
2726+
}
2727+
2728+
std::string Output;
2729+
if (tryEmitStmt(Output, Inst->getOutputOperand()))
2730+
return SYCLGenError();
2731+
2732+
std::string Type;
2733+
if (tryEmitType(Type, Inst->getType(0)))
2734+
return SYCLGenError();
2735+
2736+
const auto *Dst =
2737+
dyn_cast_or_null<InlineAsmAddressExpr>(Inst->getOutputOperand());
2738+
if (!Dst)
2739+
return SYCLGenError();
2740+
2741+
for (int Index = 0; Index < VecNum; Index++) {
2742+
OS() << llvm::formatv("*(({0} *)({1}) + {2}) = {3}{4}", Type, Output,
2743+
Index, Values[Index],
2744+
Index == VecNum - 1 ? "" : ";\n");
2745+
}
2746+
2747+
endstmt();
2748+
return SYCLGenSuccess();
2749+
}
2750+
26932751
bool handle_st(const InlineAsmInstruction *Inst) override {
26942752
if (Inst->getNumInputOperands() != 1)
26952753
return SYCLGenError();
2696-
llvm::SaveAndRestore<const InlineAsmInstruction *> Store(CurrInst);
2697-
CurrInst = Inst;
2754+
2755+
llvm::SaveAndRestore<const InlineAsmInstruction *> Store(CurrInst, Inst);
2756+
2757+
if (Inst->hasAttr(InstAttr::cs)) {
2758+
if (Inst->hasAttr(InstAttr::v4))
2759+
return HandleStVec(Inst, 4);
2760+
if (Inst->hasAttr(InstAttr::v2))
2761+
return HandleStVec(Inst, 2);
2762+
}
2763+
26982764
const auto *Src = Inst->getInputOperand(0);
26992765
const auto *Dst =
27002766
dyn_cast_or_null<InlineAsmAddressExpr>(Inst->getOutputOperand());
27012767
if (!Dst)
2702-
return false;
2768+
return SYCLGenError();
2769+
27032770
std::string Type;
27042771
if (tryEmitType(Type, Inst->getType(0)))
27052772
return SYCLGenError();
2706-
if (emitStmt(Dst))
2773+
2774+
std::string OutOp;
2775+
if (tryEmitStmt(OutOp, Inst->getOutputOperand()))
27072776
return SYCLGenError();
2777+
2778+
if (Dst->getMemoryOpKind() == InlineAsmAddressExpr::RegImm) {
2779+
OS() << llvm::formatv("*(({0} *)({1} + {2}))", Type, OutOp,
2780+
Dst->getImmAddr()->getValue().getZExtValue());
2781+
} else {
2782+
OS() << "*" << OutOp;
2783+
}
2784+
27082785
OS() << " = ";
27092786
if (emitStmt(Src))
27102787
return SYCLGenError();
2788+
27112789
endstmt();
27122790
return SYCLGenSuccess();
27132791
}
@@ -2722,7 +2800,7 @@ class SYCLGen : public SYCLGenBase {
27222800
const auto *Dst = Inst->getOutputOperand();
27232801

27242802
if (!Src)
2725-
return false;
2803+
return SYCLGenError();
27262804
std::string Type;
27272805
if (tryEmitType(Type, Inst->getType(0)))
27282806
return SYCLGenError();

clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,7 @@ MODIFIER(rc8, ".rc8")
418418
MODIFIER(ecl, ".ecl")
419419
MODIFIER(ecr, ".ecr")
420420
MODIFIER(rc16, ".rc16")
421+
MODIFIER(cs, ".cs")
421422

422423
#undef LINKAGE
423424
#undef TARGET

clang/test/dpct/asm/st.cu

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,28 @@ __device__ void shared_address_store32(uint32_t addr, uint32_t val) {
3737
asm volatile("{st.shared.b32 [%0], %1;}" : : "r"(__addr), "r"(__val) : "memory");
3838
}
3939

40+
#if (defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__)
41+
#define __PTR "l"
42+
#else
43+
#define __PTR "r"
44+
#endif
45+
46+
// CHECK: inline void store_streaming_short4(sycl::short4 *addr, short x, short y, short z, short w) {
47+
// CHECK-NEXT: *((int16_t *)((uintptr_t)addr) + 0) = x;
48+
// CHECK-NEXT: *((int16_t *)((uintptr_t)addr) + 1) = y;
49+
// CHECK-NEXT: *((int16_t *)((uintptr_t)addr) + 2) = z;
50+
// CHECK-NEXT: *((int16_t *)((uintptr_t)addr) + 3) = w;
51+
// CHECK-NEXT: }
52+
__device__ inline void store_streaming_short4(short4 *addr, short x, short y, short z, short w) {
53+
asm("st.cs.global.v4.s16 [%0+0], {%1, %2, %3, %4};" ::__PTR(addr), "h"(x), "h"(y), "h"(z), "h"(w));
54+
}
55+
56+
// CHECK: inline void store_streaming_short2(sycl::short2 *addr, short x, short y) {
57+
// CHECK-NEXT: *((int16_t *)((uintptr_t)addr) + 0) = x;
58+
// CHECK-NEXT: *((int16_t *)((uintptr_t)addr) + 1) = y;
59+
// CHECK-NEXT: }
60+
__device__ inline void store_streaming_short2(short2 *addr, short x, short y) {
61+
asm("st.cs.global.v2.s16 [%0+0], {%1, %2};" ::__PTR(addr), "h"(x), "h"(y));
62+
}
63+
4064
// clang-format on

0 commit comments

Comments
 (0)