Skip to content

Commit f344c93

Browse files
[SYCLomatic][PTX] Added support for stmatrix migration (#2801)
1 parent 8a00958 commit f344c93

File tree

7 files changed

+355
-13
lines changed

7 files changed

+355
-13
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -600,9 +600,10 @@ bool SYCLGenBase::emitVariableDeclaration(const InlineAsmVarDecl *D) {
600600

601601
bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
602602
// Address expression only support ld/st/red & atom instructions.
603-
if (!CurrInst || !CurrInst->is(asmtok::op_st, asmtok::op_ld, asmtok::op_atom,
604-
asmtok::op_prefetch, asmtok::op_red,
605-
asmtok::op_cp, asmtok::op_ldmatrix)) {
603+
if (!CurrInst ||
604+
!CurrInst->is(asmtok::op_st, asmtok::op_ld, asmtok::op_atom,
605+
asmtok::op_prefetch, asmtok::op_red, asmtok::op_cp,
606+
asmtok::op_ldmatrix, asmtok::op_stmatrix)) {
606607
return SYCLGenError();
607608
}
608609
std::string Type;
@@ -635,7 +636,8 @@ bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
635636
if (CurrInst->is(asmtok::op_prefetch, asmtok::op_red) ||
636637
CanSuppressCast(Dst->getSymbol()))
637638
OS() << llvm::formatv("{0}", Reg);
638-
else if (CurrInst->is(asmtok::op_ldmatrix))
639+
else if (CurrInst->is(asmtok::op_ldmatrix) ||
640+
CurrInst->is(asmtok::op_stmatrix))
639641
OS() << llvm::formatv("(uintptr_t){0}", Reg);
640642
else
641643
OS() << llvm::formatv("(({0} *)(uintptr_t){1})", Type, Reg);
@@ -1376,6 +1378,64 @@ class SYCLGen : public SYCLGenBase {
13761378
return SYCLGenSuccess();
13771379
}
13781380

1381+
bool handle_stmatrix(const InlineAsmInstruction *Inst) override {
1382+
if (Inst->getNumInputOperands() != 1)
1383+
return SYCLGenError();
1384+
1385+
const auto *Type = dyn_cast<InlineAsmBuiltinType>(Inst->getType(0));
1386+
1387+
if (!Type || Type->getKind() != InlineAsmBuiltinType::b16)
1388+
return SYCLGenError();
1389+
1390+
const InlineAsmVectorExpr *VE;
1391+
if (VE = dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand(0))) {
1392+
auto numOutputOperands = VE->getNumElements();
1393+
if (Inst->hasAttr(InstAttr::x1)) {
1394+
if (numOutputOperands != 1)
1395+
return SYCLGenError();
1396+
} else if (Inst->hasAttr(InstAttr::x2)) {
1397+
if (numOutputOperands != 2)
1398+
return SYCLGenError();
1399+
} else if (Inst->hasAttr(InstAttr::x4)) {
1400+
if (numOutputOperands != 4)
1401+
return SYCLGenError();
1402+
}
1403+
} else {
1404+
return SYCLGenError();
1405+
}
1406+
1407+
llvm::SaveAndRestore<const InlineAsmInstruction *> Store(CurrInst);
1408+
CurrInst = Inst;
1409+
const auto *Dst =
1410+
dyn_cast_or_null<InlineAsmAddressExpr>(Inst->getOutputOperand());
1411+
if (!Dst)
1412+
return false;
1413+
1414+
OS() << MapNames::getDpctNamespace() << "experimental::matrix::stmatrix(";
1415+
if (emitStmt(Dst)) {
1416+
return SYCLGenError();
1417+
}
1418+
for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) {
1419+
if (isa<InlineAsmDiscardExpr>(VE->getElement(Inst)))
1420+
continue;
1421+
OS() << ", ";
1422+
if (emitStmt(VE->getElement(Inst)))
1423+
return SYCLGenError();
1424+
}
1425+
if (Inst->hasAttr(InstAttr::trans))
1426+
OS() << ", true";
1427+
OS() << ");";
1428+
const auto *KernelDecl = getImmediateOuterFuncDecl(GAS);
1429+
if (KernelDecl) {
1430+
auto FuncInfo = DeviceFunctionDecl::LinkRedecls(KernelDecl);
1431+
if (FuncInfo)
1432+
FuncInfo->addSubGroupSizeRequest(32, GAS->getBeginLoc(),
1433+
DpctGlobalInfo::getSubGroup(GAS));
1434+
}
1435+
1436+
return SYCLGenSuccess();
1437+
}
1438+
13791439
bool handle_mma(const InlineAsmInstruction *Inst) override {
13801440
if (Inst->getNumInputOperands() != 3)
13811441
return SYCLGenError();

clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,8 @@ class InlineAsmInstruction : public InlineAsmStmt {
340340
/// therest are input operands.
341341
SmallVector<InlineAsmExpr *, 4> InputOps;
342342

343+
SmallVector<InlineAsmExpr *, 4> OutputOps;
344+
343345
public:
344346
InlineAsmInstruction(InlineAsmIdentifierInfo *Op,
345347
SmallVector<AsmStateSpace, 4> AsmStateSpaces,

clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -736,10 +736,11 @@ InlineAsmExprResult InlineAsmParser::ActOnParenExpr(InlineAsmExpr *SubExpr) {
736736
InlineAsmExprResult
737737
InlineAsmParser::ActOnVectorExpr(ArrayRef<InlineAsmExpr *> Vec) {
738738

739-
// Vector size for ldmatrix are 1, 2, 4
739+
// Vector size for ldmatrix/stmatrix are 1, 2, 4
740740
// size(x) = 2 * sizeof(v).
741741
InlineAsmVectorType::VecKind Kind;
742-
if (Opcode->getTokenID() == asmtok::op_ldmatrix) {
742+
if (Opcode->getTokenID() == asmtok::op_ldmatrix ||
743+
Opcode->getTokenID() == asmtok::op_stmatrix) {
743744
switch (Vec.size()) {
744745
case 1:
745746
Kind = InlineAsmVectorType::x1;

clang/lib/DPCT/SrcAPI/APINames_ASM.inc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ ENTRY("griddepcontrol", "griddepcontrol", false, NO_FLAG, P1, "Comment")
7575
ENTRY("isspacep", "isspacep", false, NO_FLAG, P1, "Comment")
7676
ENTRY("istypep", "istypep", false, NO_FLAG, P1, "Comment")
7777
ENTRY("ld", "ld", true, NO_FLAG, P1, "Partial")
78-
ENTRY("ldmatrix", "ldmatrix", true, NO_FLAG, P1, "Successful")
78+
ENTRY("ldmatrix", "ldmatrix", true, NO_FLAG, P1, "Partial")
7979
ENTRY("ldu", "ldu", false, NO_FLAG, P1, "Comment")
8080
ENTRY("lg2", "lg2", true, NO_FLAG, P1, "Successful")
8181
ENTRY("lop3", "lop3", true, NO_FLAG, P1, "Successful")
@@ -123,7 +123,7 @@ ENTRY("sqrt", "sqrt", true, NO_FLAG, P1, "Successful")
123123
ENTRY("st", "st", true, NO_FLAG, P1, "Partial")
124124
ENTRY("stackrestore", "stackrestore", false, NO_FLAG, P1, "Comment")
125125
ENTRY("stacksave", "stacksave", false, NO_FLAG, P1, "Comment")
126-
ENTRY("stmatrix", "stmatrix", false, NO_FLAG, P1, "Comment")
126+
ENTRY("stmatrix", "stmatrix", true, NO_FLAG, P1, "Partial")
127127
ENTRY("sub", "sub", true, NO_FLAG, P1, "Partial")
128128
ENTRY("subc", "subc", false, NO_FLAG, P1, "Comment")
129129
ENTRY("suld", "suld", false, NO_FLAG, P1, "Comment")

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 163 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2058,7 +2058,7 @@ class joint_matrix {
20582058
const size_t num_elements;
20592059
};
20602060

2061-
/// Collectively loads 1 8x8 b16 (128 bytes) matrix from private memory to local
2061+
/// Collectively loads 1 8x8 b16 (128 bytes) matrix from local memory to private
20622062
/// memory per sub-group. Requires the sub-group size of kernel calling this
20632063
/// function to be 32.
20642064
/// 'mat' specifies the matrix index to be loaded. The first '(mat + 1) * 8'
@@ -2135,7 +2135,7 @@ void ldmatrix(uintptr_t addr, T *m, bool trans = false, unsigned mat = 0) {
21352135
}
21362136
}
21372137

2138-
/// Collectively loads 2 8x8 b16 (256 bytes) matrix from private memory to local
2138+
/// Collectively loads 2 8x8 b16 (256 bytes) matrix from local memory to private
21392139
/// memory per sub-group. Requires the sub-group size of kernel calling this
21402140
/// function to be 32.
21412141
/// The first 16 work items of sub-group contain the starting address of their
@@ -2172,7 +2172,7 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, bool trans = false) {
21722172
ldmatrix(addr, m2, trans, 1);
21732173
}
21742174

2175-
/// Collectively loads 4 8x8 b16 (512 bytes) matrix from private memory to local
2175+
/// Collectively loads 4 8x8 b16 (512 bytes) matrix from local memory to private
21762176
/// memory per sub-group. Requires the sub-group size of kernel calling this
21772177
/// function to be 32.
21782178
/// Each work item of sub-group contains the starting address of their
@@ -2218,6 +2218,166 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) {
22182218
ldmatrix(addr, m4, trans, 3);
22192219
}
22202220

2221+
/// Collectively stores 1 8x8 b16 (128 bytes) matrix from private memory to
2222+
/// local memory per sub-group.
2223+
/// Requires the sub-group size of kernel calling this function to be 32.
2224+
/// 'mat' specifies the matrix index to be stored. The first '(mat + 1) * 8'
2225+
/// work items of sub-group contain the starting address of their respective
2226+
/// matrix row in 'addr'.
2227+
/// After distributing addresses to other work items, each of the 32 work items
2228+
/// store 32-bits (2 packed 16-bit data) into 'm' for a total of 128 bytes.
2229+
/// 'trans' specifies to perform a transposed/non-transposed store by each work
2230+
/// item like below
2231+
/// Row Major: Each row of the matrix is stored by a group of 4 work items(wi)
2232+
/// row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2233+
/// row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2234+
/// ...
2235+
/// row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2236+
/// row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2237+
/// Col Major: Each col of the matrix is stored by a group of 4 work items(wi)
2238+
/// row-0: wi0 wi4 wi8 ... wi28
2239+
/// row-1: wi0 wi4 wi8 ... wi28
2240+
/// ...
2241+
/// row-6: wi3 wi7 wi11 ... wi31
2242+
/// row-7: wi3 wi7 wi11 ... wi31
2243+
/// \tparam [in] T Type of result variable (currently only supports 16-bit type)
2244+
/// \param [in] addr The starting address of corresponding matrix row for a work
2245+
/// item in local memory
2246+
/// \param [in] m The local memory to store the matrix. It points to 2 b16
2247+
/// type elements.
2248+
/// \param [in] trans Indicates whether the matrix to be stored transposed
2249+
/// \param [in] mat The matrix index to be stored
2250+
template <typename T>
2251+
void stmatrix(uintptr_t addr, T m, bool trans = false, unsigned mat = 0) {
2252+
auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
2253+
int lane = sg.get_local_linear_id();
2254+
2255+
int lane_group8_row = lane / 8;
2256+
int lane_group8_col = lane % 8;
2257+
2258+
if (!trans) {
2259+
// calculate the source lane
2260+
int src_lane = 2 * lane_group8_row;
2261+
if (lane_group8_col >= 4)
2262+
src_lane += 1;
2263+
2264+
// Broadcast the address from the source lane
2265+
auto recv_addr_uintp =
2266+
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
2267+
2268+
// Cast the received address from uintptr_t to the type of 'm'
2269+
auto recv_addr = reinterpret_cast<T *>(recv_addr_uintp);
2270+
2271+
// Non-transposed store
2272+
recv_addr[lane_group8_col % 4] = m;
2273+
} else {
2274+
// calculate the source lane
2275+
int src_lane = (lane % 4) * 2;
2276+
2277+
// Broadcast the address from the source lane
2278+
auto recv_addr_uintp_1 =
2279+
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
2280+
auto recv_addr_uintp_2 =
2281+
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane + 1);
2282+
2283+
// Cast the received address from uintptr_t to 'half *'
2284+
auto recv_addr_1 = reinterpret_cast<sycl::half *>(recv_addr_uintp_1);
2285+
auto recv_addr_2 = reinterpret_cast<sycl::half *>(recv_addr_uintp_2);
2286+
2287+
// Split the 32-bit value of 'm' into two 16-bits
2288+
sycl::half *val = reinterpret_cast<sycl::half *>(&m);
2289+
2290+
// Transposed store
2291+
int index = lane / 4;
2292+
recv_addr_1[index] = val[0];
2293+
recv_addr_2[index] = val[1];
2294+
}
2295+
}
2296+
2297+
/// Collectively stores 2 8x8 b16 (256 bytes) matrix from private memory to
2298+
/// local memory per sub-group.
2299+
/// Requires the sub-group size of kernel calling this function to be 32.
2300+
/// The first 16 work items of sub-group contain the starting address of their
2301+
/// respective matrix row in 'addr'.
2302+
/// After distributing addresses to other work items, each of the 32 work items
2303+
/// store 64-bits (32-bits per matrix) into 'm1' & 'm2' for a total of 256
2304+
/// bytes.
2305+
/// 'trans' specifies to perform a transposed/non-transposed store by each work
2306+
/// item like below
2307+
/// Row Major: Each row of the matrices is stored by a group of 4 work items(wi)
2308+
/// row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2309+
/// row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2310+
/// ...
2311+
/// row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2312+
/// row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2313+
/// Col Major: Each col of the matrices is stored by a group of 4 work items(wi)
2314+
/// row-0: wi0 wi4 wi8 ... wi28
2315+
/// row-1: wi0 wi4 wi8 ... wi28
2316+
/// ...
2317+
/// row-6: wi3 wi7 wi11 ... wi31
2318+
/// row-7: wi3 wi7 wi11 ... wi31
2319+
/// \tparam [in] T Type of result variable (currently only supports 16-bit type)
2320+
/// \param [in] addr The starting address of corresponding matrix row for a work
2321+
/// item in local memory
2322+
/// \param [in] m1 The local memory to store the data of 1st matrix. It points
2323+
/// to 2 b16 type elements.
2324+
/// \param [in] m2 The local memory to store the data of 2nd matrix. It points
2325+
/// to 2 b16 type elements.
2326+
/// \param [in] trans Indicates whether the matrix to be stored transposed
2327+
template <typename T>
2328+
void stmatrix(uintptr_t addr, T m1, T m2, bool trans = false) {
2329+
// Store 1st matrix
2330+
stmatrix(addr, m1, trans, 0);
2331+
// Store 2nd matrix
2332+
stmatrix(addr, m2, trans, 1);
2333+
}
2334+
2335+
/// Collectively stores 4 8x8 b16 (512 bytes) matrix from private memory to
2336+
/// local memory per sub-group.
2337+
/// Requires the sub-group size of kernel calling this function to be 32.
2338+
/// Each work item of sub-group contains the starting address of their
2339+
/// respective matrix row in 'addr'.
2340+
/// After distributing addresses to other work items, each of the 32 work items
2341+
/// store 128-bits (32-bits per matrix) into 'm1', 'm2', 'm3' & 'm4' for a total
2342+
/// of 512 bytes.
2343+
/// 'trans' specifies to perform a transposed/non-transposed store by each work
2344+
/// item like below
2345+
/// Row Major: Each row of the matrices is stored by a group of 4 work items(wi)
2346+
/// row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2347+
/// row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2348+
/// ...
2349+
/// row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2350+
/// row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2351+
/// Col Major: Each col of the matrices is stored by a group of 4 work items(wi)
2352+
/// row-0: wi0 wi4 wi8 ... wi28
2353+
/// row-1: wi0 wi4 wi8 ... wi28
2354+
/// ...
2355+
/// row-6: wi3 wi7 wi11 ... wi31
2356+
/// row-7: wi3 wi7 wi11 ... wi31
2357+
/// \tparam [in] T Type of result variable (currently only supports 16-bit type)
2358+
/// \param [in] addr The starting address of corresponding matrix row for a work
2359+
/// item in local memory
2360+
/// \param [in] m1 The local memory to store the data of 1st matrix. It points
2361+
/// to 2 b16 type elements.
2362+
/// \param [in] m2 The local memory to store the data of 2nd matrix. It points
2363+
/// to 2 b16 type elements.
2364+
/// \param [in] m3 The local memory to store the data of 3rd matrix. It points
2365+
/// to 2 b16 type elements.
2366+
/// \param [in] m4 The local memory to store the data of 4th matrix. It points
2367+
/// to 2 b16 type elements.
2368+
/// \param [in] trans Indicates whether the matrix to be stored transposed
2369+
template <typename T>
2370+
void stmatrix(uintptr_t addr, T m1, T m2, T m3, T m4, bool trans = false) {
2371+
// Store 1st matrix
2372+
stmatrix(addr, m1, trans, 0);
2373+
// Store 2nd matrix
2374+
stmatrix(addr, m2, trans, 1);
2375+
// Store 3rd matrix
2376+
stmatrix(addr, m3, trans, 2);
2377+
// Store 4th matrix
2378+
stmatrix(addr, m4, trans, 3);
2379+
}
2380+
22212381
/// A helper struct that defines the pack type for the input matrix fragments
22222382
/// of mma() function based on the type of input matrix fragments.
22232383
/// The MMAType struct is specialized for different types of input matrices.

0 commit comments

Comments
 (0)