Skip to content

Commit 84b40b6

Browse files
Added support for stmatrix migration
1 parent f5fff21 commit 84b40b6

File tree

9 files changed

+332
-22
lines changed

9 files changed

+332
-22
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -556,10 +556,15 @@ bool SYCLGenBase::emitVectorType(const InlineAsmVectorType *T) {
556556
return SYCLGenError();
557557
OS() << ", ";
558558
switch (T->getKind()) {
559+
case InlineAsmVectorType::x1:
560+
OS() << 1;
561+
break;
559562
case InlineAsmVectorType::v2:
563+
case InlineAsmVectorType::x2:
560564
OS() << 2;
561565
break;
562566
case InlineAsmVectorType::v4:
567+
case InlineAsmVectorType::x4:
563568
OS() << 4;
564569
break;
565570
case InlineAsmVectorType::v8:
@@ -589,9 +594,9 @@ bool SYCLGenBase::emitVariableDeclaration(const InlineAsmVarDecl *D) {
589594

590595
bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
591596
// Address expression only support ld/st/red & atom instructions.
592-
if (!CurrInst ||
593-
!CurrInst->is(asmtok::op_st, asmtok::op_ld, asmtok::op_atom,
594-
asmtok::op_prefetch, asmtok::op_red, asmtok::op_cp)) {
597+
if (!CurrInst || !CurrInst->is(asmtok::op_st, asmtok::op_ld, asmtok::op_atom,
598+
asmtok::op_prefetch, asmtok::op_red,
599+
asmtok::op_cp, asmtok::op_stmatrix)) {
595600
return SYCLGenError();
596601
}
597602
std::string Type;
@@ -624,6 +629,8 @@ bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
624629
if (CurrInst->is(asmtok::op_prefetch, asmtok::op_red) ||
625630
CanSuppressCast(Dst->getSymbol()))
626631
OS() << llvm::formatv("{0}", Reg);
632+
else if (CurrInst->is(asmtok::op_stmatrix))
633+
OS() << llvm::formatv("(uintptr_t){0}", Reg);
627634
else
628635
OS() << llvm::formatv("(({0} *)(uintptr_t){1})", Type, Reg);
629636
break;
@@ -1305,6 +1312,61 @@ class SYCLGen : public SYCLGenBase {
13051312
return SYCLGenSuccess();
13061313
}
13071314

1315+
bool handle_stmatrix(const InlineAsmInstruction *Inst) override {
1316+
if (Inst->getNumInputOperands() != 1)
1317+
return SYCLGenError();
1318+
1319+
llvm::SaveAndRestore<const InlineAsmInstruction *> Store(CurrInst);
1320+
CurrInst = Inst;
1321+
const auto *Dst =
1322+
dyn_cast_or_null<InlineAsmAddressExpr>(Inst->getOutputOperand());
1323+
if (!Dst)
1324+
return false;
1325+
1326+
const InlineAsmVectorExpr *VE;
1327+
if (VE = dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand(0))) {
1328+
auto numOutputOperands = VE->getNumElements();
1329+
if (Inst->hasAttr(InstAttr::x1)) {
1330+
if (numOutputOperands != 1)
1331+
return SYCLGenError();
1332+
} else if (Inst->hasAttr(InstAttr::x2)) {
1333+
if (numOutputOperands != 2)
1334+
return SYCLGenError();
1335+
} else if (Inst->hasAttr(InstAttr::x4)) {
1336+
if (numOutputOperands != 4)
1337+
return SYCLGenError();
1338+
}
1339+
} else {
1340+
return SYCLGenError();
1341+
}
1342+
1343+
OS() << MapNames::getDpctNamespace() << "experimental::matrix::stmatrix(";
1344+
if (emitStmt(Dst)) {
1345+
return SYCLGenError();
1346+
}
1347+
OS() << ", ";
1348+
for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) {
1349+
if (isa<InlineAsmDiscardExpr>(VE->getElement(Inst)))
1350+
continue;
1351+
if (emitStmt(VE->getElement(Inst)))
1352+
return SYCLGenError();
1353+
OS() << ", ";
1354+
}
1355+
OS() << DpctGlobalInfo::getItem(GAS);
1356+
if (Inst->hasAttr(InstAttr::trans))
1357+
OS() << ", true";
1358+
OS() << ");";
1359+
const auto *KernelDecl = getImmediateOuterFuncDecl(GAS);
1360+
if (KernelDecl) {
1361+
auto FuncInfo = DeviceFunctionDecl::LinkRedecls(KernelDecl);
1362+
if (FuncInfo)
1363+
FuncInfo->addSubGroupSizeRequest(32, GAS->getBeginLoc(),
1364+
DpctGlobalInfo::getSubGroup(GAS));
1365+
}
1366+
1367+
return SYCLGenSuccess();
1368+
}
1369+
13081370
bool handle_prefetch(const InlineAsmInstruction *Inst) override {
13091371
if (!DpctGlobalInfo::useExtPrefetch() || Inst->getNumInputOperands() != 1)
13101372
return SYCLGenError();
@@ -2881,6 +2943,7 @@ class SYCLGen : public SYCLGenBase {
28812943
bool handle_ld(const InlineAsmInstruction *Inst) override {
28822944
if (Inst->getNumInputOperands() != 1)
28832945
return SYCLGenError();
2946+
28842947
llvm::SaveAndRestore<const InlineAsmInstruction *> Store(CurrInst);
28852948
CurrInst = Inst;
28862949
const auto *Src =

clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class InlineAsmBuiltinType : public InlineAsmType {
116116
// This class is used for device asm vector types.
117117
class InlineAsmVectorType : public InlineAsmType {
118118
public:
119-
enum VecKind { v2, v4, v8 };
119+
enum VecKind { v2, v4, v8, x1, x2, x4 };
120120

121121
private:
122122
VecKind Kind;
@@ -340,6 +340,8 @@ class InlineAsmInstruction : public InlineAsmStmt {
340340
/// therest are input operands.
341341
SmallVector<InlineAsmExpr *, 4> InputOps;
342342

343+
SmallVector<InlineAsmExpr *, 4> OutputOps;
344+
343345
public:
344346
InlineAsmInstruction(InlineAsmIdentifierInfo *Op,
345347
SmallVector<AsmStateSpace, 4> AsmStateSpaces,

clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ InlineAsmStmtResult InlineAsmParser::ParseInstruction() {
327327
if (!Tok.getIdentifier() || !Tok.getIdentifier()->isInstruction())
328328
return AsmStmtError();
329329

330-
InlineAsmIdentifierInfo *Opcode = Tok.getIdentifier();
330+
Opcode = Tok.getIdentifier();
331331
ConsumeToken();
332332

333333
SmallVector<InstAttr, 4> Attrs;
@@ -736,20 +736,38 @@ InlineAsmExprResult InlineAsmParser::ActOnParenExpr(InlineAsmExpr *SubExpr) {
736736
InlineAsmExprResult
737737
InlineAsmParser::ActOnVectorExpr(ArrayRef<InlineAsmExpr *> Vec) {
738738

739-
// Vector size must be 2, 4, or 8.
739+
// Vector size for stmatrix are 1, 2, 4
740+
// size(x) = 2 * sizeof(v).
740741
InlineAsmVectorType::VecKind Kind;
741-
switch (Vec.size()) {
742-
case 2:
743-
Kind = InlineAsmVectorType::v2;
744-
break;
745-
case 4:
746-
Kind = InlineAsmVectorType::v4;
747-
break;
748-
case 8:
749-
Kind = InlineAsmVectorType::v8;
750-
break;
751-
default:
752-
return AsmExprError();
742+
if (Opcode->getTokenID() == asmtok::op_stmatrix) {
743+
switch (Vec.size()) {
744+
case 1:
745+
Kind = InlineAsmVectorType::x1;
746+
break;
747+
case 2:
748+
Kind = InlineAsmVectorType::x2;
749+
break;
750+
case 4:
751+
Kind = InlineAsmVectorType::x4;
752+
break;
753+
default:
754+
return AsmExprError();
755+
}
756+
} else {
757+
// Vector size must be 2, 4, or 8.
758+
switch (Vec.size()) {
759+
case 2:
760+
Kind = InlineAsmVectorType::v2;
761+
break;
762+
case 4:
763+
Kind = InlineAsmVectorType::v4;
764+
break;
765+
case 8:
766+
Kind = InlineAsmVectorType::v8;
767+
break;
768+
default:
769+
return AsmExprError();
770+
}
753771
}
754772

755773
InlineAsmBuiltinType *ElementType = nullptr;

clang/lib/DPCT/RulesAsm/Parser/AsmParser.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,8 @@ class InlineAsmParser {
247247
};
248248

249249
public:
250+
InlineAsmIdentifierInfo *Opcode;
251+
250252
InlineAsmParser(InlineAsmContext &Ctx, SourceMgr &Mgr)
251253
: Lexer(*Mgr.getMemoryBuffer(Mgr.getMainFileID())), Context(Ctx),
252254
SrcMgr(Mgr), CurScope(nullptr) {

clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,14 @@ MODIFIER(v2, ".v2")
274274
MODIFIER(v4, ".v4")
275275
MODIFIER(v8, ".v8")
276276

277+
// Matrix modifiers
278+
MODIFIER(x1, ".x1")
279+
MODIFIER(x2, ".x2")
280+
MODIFIER(x4, ".x4")
281+
282+
// Matrix shape
283+
MODIFIER(m8n8, ".m8n8")
284+
277285
STATE_SPACE(reg, ".reg")
278286
STATE_SPACE(sreg, ".sreg")
279287
STATE_SPACE(const, ".const")
@@ -420,6 +428,8 @@ MODIFIER(ecr, ".ecr")
420428
MODIFIER(rc16, ".rc16")
421429
MODIFIER(cs, ".cs")
422430
MODIFIER(to, ".to")
431+
MODIFIER(aligned, ".aligned")
432+
MODIFIER(trans, ".trans")
423433

424434
#undef LINKAGE
425435
#undef TARGET

clang/lib/DPCT/SrcAPI/APINames_ASM.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ ENTRY("sqrt", "sqrt", true, NO_FLAG, P1, "Successful")
123123
ENTRY("st", "st", true, NO_FLAG, P1, "Partial")
124124
ENTRY("stackrestore", "stackrestore", false, NO_FLAG, P1, "Comment")
125125
ENTRY("stacksave", "stacksave", false, NO_FLAG, P1, "Comment")
126-
ENTRY("stmatrix", "stmatrix", false, NO_FLAG, P1, "Comment")
126+
ENTRY("stmatrix", "stmatrix", true, NO_FLAG, P1, "Successful")
127127
ENTRY("sub", "sub", true, NO_FLAG, P1, "Partial")
128128
ENTRY("subc", "subc", false, NO_FLAG, P1, "Comment")
129129
ENTRY("suld", "suld", false, NO_FLAG, P1, "Comment")

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
#ifndef __DPCT_MATH_HPP__
1010
#define __DPCT_MATH_HPP__
1111

12-
#include <limits>
1312
#include <climits>
13+
#include <limits>
1414
#include <sycl/sycl.hpp>
1515
#include <type_traits>
1616

@@ -425,7 +425,7 @@ max(T1 a, T2 b) {
425425
return sycl::fmax(static_cast<common_t>(a), static_cast<common_t>(b));
426426
}
427427

428-
// pow functions overload.
428+
// pow functions overstore.
429429
inline float pow(const float a, const int b) { return sycl::pown(a, b); }
430430
inline double pow(const double a, const int b) { return sycl::pown(a, b); }
431431
inline float pow(const float a, const float b) { return sycl::pow(a, b); }
@@ -2055,6 +2055,102 @@ class joint_matrix {
20552055
matrix_accessor x;
20562056
const size_t num_elements;
20572057
};
2058+
2059+
/// Stores 1 8x8 b16 matrix from local memory to shared memory (32-bits per wi)
2060+
/// Requires the sub-group size of kernel calling this function to be 32
2061+
/// \tparam [in] T The type of matrix elements
2062+
/// \param [in] addr The address of the matrix in shared memory
2063+
/// \param [in] m The local memory containing data of matrix
2064+
/// \param [in] item The sycl::nd_item index space class
2065+
/// \param [in] trans Indicates whether the matrix to be stored transposed
2066+
/// \param [in] mat The matrix index to be stored
2067+
template <typename T, typename ItemT>
2068+
void stmatrix(uintptr_t addr, T m, const ItemT &item, bool trans = false,
2069+
unsigned mat = 0) {
2070+
int lane = item.get_sub_group().get_local_linear_id();
2071+
2072+
int lane_group8_row = lane / 8;
2073+
int lane_group8_col = lane % 8;
2074+
2075+
if (!trans) {
2076+
// calculate the source lane
2077+
int src_lane = 2 * lane_group8_row;
2078+
if (lane_group8_col >= 4)
2079+
src_lane += 1;
2080+
2081+
// Broadcast the address from the source lane
2082+
auto recv_addr_uintp = dpct::select_from_sub_group(
2083+
item.get_sub_group(), addr, mat * 8 + src_lane);
2084+
2085+
// Cast the received address from uintptr_t to the type of 'm'
2086+
auto recv_addr = reinterpret_cast<T *>(recv_addr_uintp);
2087+
2088+
// Non-transposed store
2089+
recv_addr[lane_group8_col % 4] = m;
2090+
} else {
2091+
// calculate the source lane
2092+
int src_lane = (lane % 4) * 2;
2093+
2094+
// Broadcast the address from the source lane
2095+
auto recv_addr_uintp_1 = dpct::select_from_sub_group(
2096+
item.get_sub_group(), addr, mat * 8 + src_lane);
2097+
auto recv_addr_uintp_2 = dpct::select_from_sub_group(
2098+
item.get_sub_group(), addr, mat * 8 + src_lane + 1);
2099+
2100+
// Cast the received address from uintptr_t to 'half *'
2101+
auto recv_addr_1 = reinterpret_cast<sycl::half *>(recv_addr_uintp_1);
2102+
auto recv_addr_2 = reinterpret_cast<sycl::half *>(recv_addr_uintp_2);
2103+
2104+
// Split the 32-bit value of 'm' into two 16-bits
2105+
sycl::half *val = reinterpret_cast<sycl::half *>(&m);
2106+
2107+
// Transposed store
2108+
int index = lane / 4;
2109+
recv_addr_1[index] = val[0];
2110+
recv_addr_2[index] = val[1];
2111+
}
2112+
}
2113+
2114+
/// Stores 2 8x8 b16 matrix from local memory to shared memory (32-bits per wi)
2115+
/// Requires the sub-group size of kernel calling this function to be 32
2116+
/// \tparam [in] T The type of matrix elements
2117+
/// \param [in] addr The address of the matrix in shared memory
2118+
/// \param [in] m1 The local memory containing data of 1st matrix
2119+
/// \param [in] m2 The local memory containing data of 2nd matrix
2120+
/// \param [in] item The sycl::nd_item index space class
2121+
/// \param [in] trans Indicates whether the matrix to be stored transposed
2122+
template <typename T, typename ItemT>
2123+
void stmatrix(uintptr_t addr, T m1, T m2, const ItemT &item,
2124+
bool trans = false) {
2125+
// Store 1st matrix
2126+
stmatrix(addr, m1, item, trans, 0);
2127+
// Store 2nd matrix
2128+
stmatrix(addr, m2, item, trans, 1);
2129+
}
2130+
2131+
/// Stores 4 8x8 b16 matrix from local memory to shared memory (32-bits per wi)
2132+
/// Requires the sub-group size of kernel calling this function to be 32
2133+
/// \tparam [in] T The type of matrix elements
2134+
/// \param [in] addr The address of the matrix in shared memory
2135+
/// \param [in] m1 The local memory containing data of 1st matrix
2136+
/// \param [in] m2 The local memory containing data of 2nd matrix
2137+
/// \param [in] m3 The local memory containing data of 3rd matrix
2138+
/// \param [in] m4 The local memory containing data of 4th matrix
2139+
/// \param [in] item The sycl::nd_item index space class
2140+
/// \param [in] trans Indicates whether the matrix to be stored transposed
2141+
template <typename T, typename ItemT>
2142+
void stmatrix(uintptr_t addr, T m1, T m2, T m3, T m4, const ItemT &item,
2143+
bool trans = false) {
2144+
// Store 1st matrix
2145+
stmatrix(addr, m1, item, trans, 0);
2146+
// Store 2nd matrix
2147+
stmatrix(addr, m2, item, trans, 1);
2148+
// Store 3rd matrix
2149+
stmatrix(addr, m3, item, trans, 2);
2150+
// Store 4th matrix
2151+
stmatrix(addr, m4, item, trans, 3);
2152+
}
2153+
20582154
} // namespace matrix
20592155
} // namespace experimental
20602156

0 commit comments

Comments
 (0)