Skip to content

Commit 68b7d31

Browse files
Added support for stmatrix migration
1 parent e251b18 commit 68b7d31

File tree

7 files changed

+287
-8
lines changed

7 files changed

+287
-8
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,8 @@ bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
602602
// Address expression only support ld/st/red & atom instructions.
603603
if (!CurrInst || !CurrInst->is(asmtok::op_st, asmtok::op_ld, asmtok::op_atom,
604604
asmtok::op_prefetch, asmtok::op_red,
605-
asmtok::op_cp, asmtok::op_ldmatrix)) {
605+
asmtok::op_cp, asmtok::op_ldmatrix,
606+
asmtok::op_stmatrix)) {
606607
return SYCLGenError();
607608
}
608609
std::string Type;
@@ -635,7 +636,8 @@ bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
635636
if (CurrInst->is(asmtok::op_prefetch, asmtok::op_red) ||
636637
CanSuppressCast(Dst->getSymbol()))
637638
OS() << llvm::formatv("{0}", Reg);
638-
else if (CurrInst->is(asmtok::op_ldmatrix))
639+
else if (CurrInst->is(asmtok::op_ldmatrix) ||
640+
CurrInst->is(asmtok::op_stmatrix))
639641
OS() << llvm::formatv("(uintptr_t){0}", Reg);
640642
else
641643
OS() << llvm::formatv("(({0} *)(uintptr_t){1})", Type, Reg);
@@ -1376,6 +1378,66 @@ class SYCLGen : public SYCLGenBase {
13761378
return SYCLGenSuccess();
13771379
}
13781380

1381+
bool handle_stmatrix(const InlineAsmInstruction *Inst) override {
1382+
if (Inst->getNumInputOperands() != 1)
1383+
return SYCLGenError();
1384+
1385+
const auto *Type = dyn_cast<InlineAsmBuiltinType>(Inst->getType(0));
1386+
1387+
if (!Type || Type->getKind() != InlineAsmBuiltinType::b16)
1388+
return SYCLGenError();
1389+
1390+
const InlineAsmVectorExpr *VE;
1391+
if (VE = dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand(0))) {
1392+
auto numOutputOperands = VE->getNumElements();
1393+
if (Inst->hasAttr(InstAttr::x1)) {
1394+
if (numOutputOperands != 1)
1395+
return SYCLGenError();
1396+
} else if (Inst->hasAttr(InstAttr::x2)) {
1397+
if (numOutputOperands != 2)
1398+
return SYCLGenError();
1399+
} else if (Inst->hasAttr(InstAttr::x4)) {
1400+
if (numOutputOperands != 4)
1401+
return SYCLGenError();
1402+
}
1403+
} else {
1404+
return SYCLGenError();
1405+
}
1406+
1407+
llvm::SaveAndRestore<const InlineAsmInstruction *> Store(CurrInst);
1408+
CurrInst = Inst;
1409+
const auto *Dst =
1410+
dyn_cast_or_null<InlineAsmAddressExpr>(Inst->getOutputOperand());
1411+
if (!Dst)
1412+
return false;
1413+
1414+
OS() << MapNames::getDpctNamespace() << "experimental::matrix::stmatrix(";
1415+
if (emitStmt(Dst)) {
1416+
return SYCLGenError();
1417+
}
1418+
OS() << ", ";
1419+
for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) {
1420+
if (isa<InlineAsmDiscardExpr>(VE->getElement(Inst)))
1421+
continue;
1422+
if (emitStmt(VE->getElement(Inst)))
1423+
return SYCLGenError();
1424+
OS() << ", ";
1425+
}
1426+
OS() << DpctGlobalInfo::getItem(GAS);
1427+
if (Inst->hasAttr(InstAttr::trans))
1428+
OS() << ", true";
1429+
OS() << ");";
1430+
const auto *KernelDecl = getImmediateOuterFuncDecl(GAS);
1431+
if (KernelDecl) {
1432+
auto FuncInfo = DeviceFunctionDecl::LinkRedecls(KernelDecl);
1433+
if (FuncInfo)
1434+
FuncInfo->addSubGroupSizeRequest(32, GAS->getBeginLoc(),
1435+
DpctGlobalInfo::getSubGroup(GAS));
1436+
}
1437+
1438+
return SYCLGenSuccess();
1439+
}
1440+
13791441
bool handle_mma(const InlineAsmInstruction *Inst) override {
13801442
if (Inst->getNumInputOperands() != 3)
13811443
return SYCLGenError();

clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,8 @@ class InlineAsmInstruction : public InlineAsmStmt {
340340
/// therest are input operands.
341341
SmallVector<InlineAsmExpr *, 4> InputOps;
342342

343+
SmallVector<InlineAsmExpr *, 4> OutputOps;
344+
343345
public:
344346
InlineAsmInstruction(InlineAsmIdentifierInfo *Op,
345347
SmallVector<AsmStateSpace, 4> AsmStateSpaces,

clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -736,10 +736,11 @@ InlineAsmExprResult InlineAsmParser::ActOnParenExpr(InlineAsmExpr *SubExpr) {
736736
InlineAsmExprResult
737737
InlineAsmParser::ActOnVectorExpr(ArrayRef<InlineAsmExpr *> Vec) {
738738

739-
// Vector size for ldmatrix are 1, 2, 4
739+
// Vector size for ldmatrix/stmatrix are 1, 2, 4
740740
// size(x) = 2 * sizeof(v).
741741
InlineAsmVectorType::VecKind Kind;
742-
if (Opcode->getTokenID() == asmtok::op_ldmatrix) {
742+
if (Opcode->getTokenID() == asmtok::op_ldmatrix ||
743+
Opcode->getTokenID() == asmtok::op_stmatrix) {
743744
switch (Vec.size()) {
744745
case 1:
745746
Kind = InlineAsmVectorType::x1;

clang/lib/DPCT/SrcAPI/APINames_ASM.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ ENTRY("sqrt", "sqrt", true, NO_FLAG, P1, "Successful")
123123
ENTRY("st", "st", true, NO_FLAG, P1, "Partial")
124124
ENTRY("stackrestore", "stackrestore", false, NO_FLAG, P1, "Comment")
125125
ENTRY("stacksave", "stacksave", false, NO_FLAG, P1, "Comment")
126-
ENTRY("stmatrix", "stmatrix", false, NO_FLAG, P1, "Comment")
126+
ENTRY("stmatrix", "stmatrix", true, NO_FLAG, P1, "Successful")
127127
ENTRY("sub", "sub", true, NO_FLAG, P1, "Partial")
128128
ENTRY("subc", "subc", false, NO_FLAG, P1, "Comment")
129129
ENTRY("suld", "suld", false, NO_FLAG, P1, "Comment")

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ max(T1 a, T2 b) {
425425
return sycl::fmax(static_cast<common_t>(a), static_cast<common_t>(b));
426426
}
427427

428-
// pow functions overload.
428+
// pow functions overstore.
429429
inline float pow(const float a, const int b) { return sycl::pown(a, b); }
430430
inline double pow(const double a, const int b) { return sycl::pown(a, b); }
431431
inline float pow(const float a, const float b) { return sycl::pow(a, b); }
@@ -2218,6 +2218,101 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) {
22182218
ldmatrix(addr, m4, trans, 3);
22192219
}
22202220

2221+
/// Stores 1 8x8 b16 matrix from local memory to shared memory (32-bits per wi)
2222+
/// Requires the sub-group size of kernel calling this function to be 32
2223+
/// \tparam [in] T The type of matrix elements
2224+
/// \param [in] addr The address of the matrix in shared memory
2225+
/// \param [in] m The local memory containing data of matrix
2226+
/// \param [in] item The sycl::nd_item index space class
2227+
/// \param [in] trans Indicates whether the matrix to be stored transposed
2228+
/// \param [in] mat The matrix index to be stored
2229+
template <typename T, typename ItemT>
2230+
void stmatrix(uintptr_t addr, T m, const ItemT &item, bool trans = false,
2231+
unsigned mat = 0) {
2232+
int lane = item.get_sub_group().get_local_linear_id();
2233+
2234+
int lane_group8_row = lane / 8;
2235+
int lane_group8_col = lane % 8;
2236+
2237+
if (!trans) {
2238+
// calculate the source lane
2239+
int src_lane = 2 * lane_group8_row;
2240+
if (lane_group8_col >= 4)
2241+
src_lane += 1;
2242+
2243+
// Broadcast the address from the source lane
2244+
auto recv_addr_uintp = dpct::select_from_sub_group(
2245+
item.get_sub_group(), addr, mat * 8 + src_lane);
2246+
2247+
// Cast the received address from uintptr_t to the type of 'm'
2248+
auto recv_addr = reinterpret_cast<T *>(recv_addr_uintp);
2249+
2250+
// Non-transposed store
2251+
recv_addr[lane_group8_col % 4] = m;
2252+
} else {
2253+
// calculate the source lane
2254+
int src_lane = (lane % 4) * 2;
2255+
2256+
// Broadcast the address from the source lane
2257+
auto recv_addr_uintp_1 = dpct::select_from_sub_group(
2258+
item.get_sub_group(), addr, mat * 8 + src_lane);
2259+
auto recv_addr_uintp_2 = dpct::select_from_sub_group(
2260+
item.get_sub_group(), addr, mat * 8 + src_lane + 1);
2261+
2262+
// Cast the received address from uintptr_t to 'half *'
2263+
auto recv_addr_1 = reinterpret_cast<sycl::half *>(recv_addr_uintp_1);
2264+
auto recv_addr_2 = reinterpret_cast<sycl::half *>(recv_addr_uintp_2);
2265+
2266+
// Split the 32-bit value of 'm' into two 16-bits
2267+
sycl::half *val = reinterpret_cast<sycl::half *>(&m);
2268+
2269+
// Transposed store
2270+
int index = lane / 4;
2271+
recv_addr_1[index] = val[0];
2272+
recv_addr_2[index] = val[1];
2273+
}
2274+
}
2275+
2276+
/// Stores 2 8x8 b16 matrix from local memory to shared memory (32-bits per wi)
2277+
/// Requires the sub-group size of kernel calling this function to be 32
2278+
/// \tparam [in] T The type of matrix elements
2279+
/// \param [in] addr The address of the matrix in shared memory
2280+
/// \param [in] m1 The local memory containing data of 1st matrix
2281+
/// \param [in] m2 The local memory containing data of 2nd matrix
2282+
/// \param [in] item The sycl::nd_item index space class
2283+
/// \param [in] trans Indicates whether the matrix to be stored transposed
2284+
template <typename T, typename ItemT>
2285+
void stmatrix(uintptr_t addr, T m1, T m2, const ItemT &item,
2286+
bool trans = false) {
2287+
// Store 1st matrix
2288+
stmatrix(addr, m1, item, trans, 0);
2289+
// Store 2nd matrix
2290+
stmatrix(addr, m2, item, trans, 1);
2291+
}
2292+
2293+
/// Stores 4 8x8 b16 matrix from local memory to shared memory (32-bits per wi)
2294+
/// Requires the sub-group size of kernel calling this function to be 32
2295+
/// \tparam [in] T The type of matrix elements
2296+
/// \param [in] addr The address of the matrix in shared memory
2297+
/// \param [in] m1 The local memory containing data of 1st matrix
2298+
/// \param [in] m2 The local memory containing data of 2nd matrix
2299+
/// \param [in] m3 The local memory containing data of 3rd matrix
2300+
/// \param [in] m4 The local memory containing data of 4th matrix
2301+
/// \param [in] item The sycl::nd_item index space class
2302+
/// \param [in] trans Indicates whether the matrix to be stored transposed
2303+
template <typename T, typename ItemT>
2304+
void stmatrix(uintptr_t addr, T m1, T m2, T m3, T m4, const ItemT &item,
2305+
bool trans = false) {
2306+
// Store 1st matrix
2307+
stmatrix(addr, m1, item, trans, 0);
2308+
// Store 2nd matrix
2309+
stmatrix(addr, m2, item, trans, 1);
2310+
// Store 3rd matrix
2311+
stmatrix(addr, m3, item, trans, 2);
2312+
// Store 4th matrix
2313+
stmatrix(addr, m4, item, trans, 3);
2314+
}
2315+
22212316
/// A helper struct that defines the pack type for the input matrix fragments
22222317
/// of mma() function based on the type of input matrix fragments.
22232318
/// The MMAType struct is specialized for different types of input matrices.

clang/test/dpct/asm/stmatrix.cu

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
// UNSUPPORTED: cuda-8.0, cuda-9.0, cuda-9.1, cuda-9.2, cuda-10.0, cuda-10.1, cuda-10.2
2+
// UNSUPPORTED: v8.0, v9.0, v9.1, v9.2, v10.0, v10.1, v10.2
3+
// RUN: dpct --format-range=none -out-root %T/stmatrix %s --cuda-include-path="%cuda-path/include" -- -std=c++14 -x cuda --cuda-host-only
4+
// RUN: FileCheck %s --match-full-lines --input-file %T/stmatrix/stmatrix.dp.cpp
5+
// RUN: %if build_lit %{icpx -c -DNO_BUILD_TEST -fsycl %T/stmatrix/stmatrix.dp.cpp -o %T/stmatrix/stmatrix.dp.o %}
6+
7+
// clang-format off
8+
#include <cuda_runtime.h>
9+
#include <cuda_fp16.h>
10+
11+
/*
12+
stmatrix.sync.aligned.shape.num{.trans}{.ss}.type [p], r;
13+
14+
Below are the currenly supported configurations:
15+
.shape = {.m8n8};
16+
.num = {.x1, .x2, .x4};
17+
.ss = {.shared{::cta}};
18+
.type = {.b16};
19+
*/
20+
21+
__device__ void store_matrix_x1(void *sh_r_addr, int *r) {
22+
// CHECK: auto addr = sh_r_addr;
23+
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
24+
25+
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], item_ct1);
26+
asm volatile("stmatrix.sync.aligned.m8n8.x1.shared.b16 [%0], {%1};\n"
27+
:
28+
: "r"(addr), "r"(r[0]));
29+
}
30+
31+
__device__ void store_matrix_x2(void *sh_r_addr, int *r) {
32+
// CHECK: auto addr = sh_r_addr;
33+
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
34+
35+
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], r[1], item_ct1);
36+
asm volatile("stmatrix.sync.aligned.m8n8.x2.shared.b16 [%0], {%1, %2};\n"
37+
:
38+
: "r"(addr), "r"(r[0]), "r"(r[1]));
39+
}
40+
41+
__device__ void store_matrix_x4(void *sh_r_addr, int *r) {
42+
// CHECK: auto addr = sh_r_addr;
43+
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
44+
45+
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], r[1], r[2], r[3], item_ct1);
46+
asm volatile("stmatrix.sync.aligned.m8n8.x4.shared.b16 [%0], {%1, %2, %3, %4};\n"
47+
:
48+
: "r"(addr), "r"(r[0]), "r"(r[1]), "r"(r[2]), "r"(r[3]));
49+
}
50+
51+
__device__ void store_matrix_x1_trans(void *sh_r_addr, int *r) {
52+
// CHECK: auto addr = sh_r_addr;
53+
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
54+
55+
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], item_ct1, true);
56+
asm volatile("stmatrix.sync.aligned.m8n8.x1.trans.shared.b16 [%0], {%1};\n"
57+
:
58+
: "r"(addr), "r"(r[0]));
59+
}
60+
61+
__device__ void store_matrix_x2_trans(void *sh_r_addr, int *r) {
62+
// CHECK: auto addr = sh_r_addr;
63+
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
64+
65+
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], r[1], item_ct1, true);
66+
asm volatile("stmatrix.sync.aligned.m8n8.x2.trans.shared.b16 [%0], {%1, %2};\n"
67+
:
68+
: "r"(addr), "r"(r[0]), "r"(r[1]));
69+
}
70+
71+
__device__ void store_matrix_x4_trans(void *sh_r_addr, int *r) {
72+
// CHECK: auto addr = sh_r_addr;
73+
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
74+
75+
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], r[1], r[2], r[3], item_ct1, true);
76+
asm volatile("stmatrix.sync.aligned.m8n8.x4.trans.shared.b16 [%0], {%1, %2, %3, %4};\n"
77+
:
78+
: "r"(addr), "r"(r[0]), "r"(r[1]), "r"(r[2]), "r"(r[3]));
79+
}
80+
81+
__global__ void store_kernel() {
82+
__shared__ half s_data[1024];
83+
int r[4];
84+
85+
store_matrix_x1(s_data, r);
86+
store_matrix_x2(s_data, r);
87+
store_matrix_x4(s_data, r);
88+
store_matrix_x1_trans(s_data, r);
89+
store_matrix_x2_trans(s_data, r);
90+
store_matrix_x4_trans(s_data, r);
91+
}
92+
93+
int main () {
94+
// CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} {
95+
store_kernel<<<1, 32>>>();
96+
97+
return 0;
98+
}
99+
100+
#ifndef NO_BUILD_TEST
101+
__device__ void test_xn(uint32_t addr, int *r) {
102+
// CHECK: DPCT1053:{{.*}}: Migration of device assembly code is not supported.
103+
asm volatile("stmatrix.sync.aligned.m8n8.x1.shared.b16 [%0], {%1, %2};\n"
104+
:
105+
: "r"(addr), "r"(r[0]), "r"(r[1]));
106+
107+
// CHECK: DPCT1053:{{.*}}: Migration of device assembly code is not supported.
108+
asm volatile("stmatrix.sync.aligned.m8n8.x2.shared.b16 [%0], {%0};\n"
109+
:
110+
: "r"(addr));
111+
112+
// CHECK: DPCT1053:{{.*}}: Migration of device assembly code is not supported.
113+
asm volatile("stmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2}, [%3];\n"
114+
:
115+
: "r"(addr), "r"(r[0]), "r"(r[1]), "r"(r[2]));
116+
}
117+
#endif // NO_BUILD_TEST
118+
119+
// clang-format on

docs/dev_guide/api-mapping-status/ASM_API_migration_status.csv

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ griddepcontrol,NO,
4141
isspacep,NO,
4242
istypep,NO,
4343
ld,YES, Partial
44-
ldmatrix,YES,
44+
ldmatrix,YES,Partial
4545
ldu,NO,
4646
lg2,YES,
4747
lop3,YES,
@@ -89,7 +89,7 @@ sqrt,YES,
8989
st,YES, Partial
9090
stackrestore,NO,
9191
stacksave,NO,
92-
stmatrix,NO,
92+
stmatrix,YES,Partial
9393
sub,YES, Partial
9494
subc,NO,
9595
suld,NO,

0 commit comments

Comments
 (0)