Skip to content

Commit 6719399

Browse files
Added support for mma migration
1 parent c1cb3cc commit 6719399

File tree

6 files changed

+212
-5
lines changed

6 files changed

+212
-5
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,6 +1301,83 @@ class SYCLGen : public SYCLGenBase {
13011301
return SYCLGenSuccess();
13021302
}
13031303

1304+
bool handle_mma(const InlineAsmInstruction *Inst) override {
1305+
if (Inst->getNumInputOperands() != 3)
1306+
return SYCLGenError();
1307+
1308+
if (!Inst->hasAttr(InstAttr::m16n8k16))
1309+
return SYCLGenError();
1310+
1311+
// Only row Layout is supported for of A matrix and
1312+
// only col Layout is supported for of B matrix
1313+
if (Inst->getAttr(3) != InstAttr::row ||
1314+
Inst->getAttr(4) != InstAttr::col) {
1315+
return SYCLGenError();
1316+
}
1317+
1318+
// Only f16 type is supported for A and B matrix data
1319+
const auto *AType = dyn_cast<InlineAsmBuiltinType>(Inst->getType(1));
1320+
const auto *BType = dyn_cast<InlineAsmBuiltinType>(Inst->getType(2));
1321+
1322+
std::string TypeStr;
1323+
if (!AType || !BType ||
1324+
(AType->getKind() != InlineAsmBuiltinType::f16 ||
1325+
BType->getKind() != InlineAsmBuiltinType::f16)) {
1326+
return SYCLGenError();
1327+
} else {
1328+
if (tryEmitType(TypeStr, AType))
1329+
return SYCLGenError();
1330+
}
1331+
1332+
const InlineAsmVectorExpr *VE =
1333+
dyn_cast<InlineAsmVectorExpr>(Inst->getOutputOperand());
1334+
if (VE && VE->getNumElements() != 4) {
1335+
return SYCLGenError();
1336+
}
1337+
1338+
OS() << MapNames::getDpctNamespace() << "experimental::matrix::mma";
1339+
OS() << "<" << TypeStr << ">(";
1340+
1341+
// Add D matrix address values to store the MAD result
1342+
for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) {
1343+
if (isa<InlineAsmDiscardExpr>(VE->getElement(Inst)))
1344+
continue;
1345+
OS() << "&";
1346+
if (emitStmt(VE->getElement(Inst)))
1347+
return SYCLGenError();
1348+
OS() << ", ";
1349+
}
1350+
1351+
// Add A, B & C matrix values to compute MAD
1352+
for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands();
1353+
InputOp++) {
1354+
if (VE = dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand(InputOp))) {
1355+
for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) {
1356+
if (isa<InlineAsmDiscardExpr>(VE->getElement(Inst)))
1357+
continue;
1358+
if (emitStmt(VE->getElement(Inst)))
1359+
return SYCLGenError();
1360+
OS() << ", ";
1361+
}
1362+
} else {
1363+
return SYCLGenError();
1364+
}
1365+
}
1366+
1367+
OS() << DpctGlobalInfo::getItem(GAS);
1368+
OS() << ");";
1369+
1370+
const auto *KernelDecl = getImmediateOuterFuncDecl(GAS);
1371+
if (KernelDecl) {
1372+
auto FuncInfo = DeviceFunctionDecl::LinkRedecls(KernelDecl);
1373+
if (FuncInfo)
1374+
FuncInfo->addSubGroupSizeRequest(32, GAS->getBeginLoc(),
1375+
DpctGlobalInfo::getSubGroup(GAS));
1376+
}
1377+
1378+
return SYCLGenSuccess();
1379+
}
1380+
13041381
bool handle_prefetch(const InlineAsmInstruction *Inst) override {
13051382
if (!DpctGlobalInfo::useExtPrefetch() || Inst->getNumInputOperands() != 1)
13061383
return SYCLGenError();

clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ class InlineAsmInstruction : public InlineAsmStmt {
322322

323323
/// This represents arrtibutes like: comparsion operator, rounding modifiers,
324324
/// ... e.g. instruction setp.eq.s32 has a comparsion operator 'eq'.
325-
SmallSet<InstAttr, 4> Attributes;
325+
SmallVector<InstAttr, 4> Attributes;
326326

327327
/// This represents types in instruction, e.g. mov.u32.
328328
SmallVector<InlineAsmType *, 4> Types;
@@ -350,11 +350,11 @@ class InlineAsmInstruction : public InlineAsmStmt {
350350
OutputOp(Out), PredOutputOp(Pred), InputOps(InOps) {
351351
StateSpaces.insert(StateSpaces.begin(), AsmStateSpaces.begin(),
352352
AsmStateSpaces.end());
353-
Attributes.insert(Attrs.begin(), Attrs.end());
353+
Attributes.insert(Attributes.begin(), Attrs.begin(), Attrs.end());
354354
}
355355

356356
using attr_range =
357-
llvm::iterator_range<SmallSet<InstAttr, 4>::const_iterator>;
357+
llvm::iterator_range<SmallVector<InstAttr, 4>::const_iterator>;
358358
using type_range =
359359
llvm::iterator_range<SmallVector<InlineAsmType *, 4>::const_iterator>;
360360
using op_range =
@@ -369,12 +369,16 @@ class InlineAsmInstruction : public InlineAsmStmt {
369369
}
370370

371371
template <typename... Ts> bool hasAttr(Ts... Attrs) const {
372-
return (Attributes.contains(Attrs) || ...);
372+
return (llvm::is_contained(Attributes, Attrs) || ...);
373373
}
374374
const InlineAsmIdentifierInfo *getOpcodeID() const { return Opcode; }
375375
asmtok::TokenKind getOpcode() const { return Opcode->getTokenID(); }
376376
ArrayRef<InlineAsmType *> getTypes() const { return Types; }
377377
const InlineAsmType *getType(unsigned I) const { return Types[I]; }
378+
InstAttr getAttr(unsigned I) const {
379+
assert(I < Attributes.size() && "Attributes index out of range");
380+
return Attributes[I];
381+
}
378382
unsigned getNumTypes() const { return Types.size(); }
379383
const InlineAsmExpr *getOutputOperand() const { return OutputOp; }
380384
const InlineAsmExpr *getPredOutputOperand() const { return PredOutputOp; }

clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,13 @@ MODIFIER(v2, ".v2")
274274
MODIFIER(v4, ".v4")
275275
MODIFIER(v8, ".v8")
276276

277+
// Matrix modifiers
278+
MODIFIER(row, ".row")
279+
MODIFIER(col, ".col")
280+
281+
// Matrix shape
282+
MODIFIER(m16n8k16, ".m16n8k16")
283+
277284
STATE_SPACE(reg, ".reg")
278285
STATE_SPACE(sreg, ".sreg")
279286
STATE_SPACE(const, ".const")
@@ -418,6 +425,7 @@ MODIFIER(rc8, ".rc8")
418425
MODIFIER(ecl, ".ecl")
419426
MODIFIER(ecr, ".ecr")
420427
MODIFIER(rc16, ".rc16")
428+
MODIFIER(aligned, ".aligned")
421429

422430
#undef LINKAGE
423431
#undef TARGET

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2055,6 +2055,79 @@ class joint_matrix {
20552055
matrix_accessor x;
20562056
const size_t num_elements;
20572057
};
2058+
2059+
/// Multiplies 2 16x16 & 16x8 matrices and accumulates the result to a 16x8 b32
2060+
/// matrix
2061+
/// \tparam [in] MulType The type of the multiplication result
2062+
/// \tparam [in] ABType The type of the input matrices
2063+
/// \tparam [in] CDType The type of the output matrix
2064+
/// \tparam [in] ItemT The type of the sycl::nd_item index space class
2065+
/// \param [in] d0 The 1st element to be written to the output D matrix
2066+
/// \param [in] d1 The 2nd element to be written to the output D matrix
2067+
/// \param [in] d2 The 3rd element to be written to the output D matrix
2068+
/// \param [in] d3 The 4th element to be written to the output D matrix
2069+
/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix
2070+
/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix
2071+
/// \param [in] a2 The 3rd element from A matrix to be multiplied with B matrix
2072+
/// \param [in] a3 The 4th element from A matrix to be multiplied with B matrix
2073+
/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix
2074+
/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix
2075+
/// \param [in] c0 The 1st element from C matrix to be added with d0
2076+
/// \param [in] c1 The 2nd element from C matrix to be added with d1
2077+
/// \param [in] c2 The 3rd element from C matrix to be added with d2
2078+
/// \param [in] c3 The 4th element from C matrix to be added with d3
2079+
/// \param [in] item The sycl::nd_item index space class
2080+
template <typename MulType, typename ABType, typename CDType, typename ItemT>
2081+
__attribute__((optnone)) void
2082+
mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1,
2083+
ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, CDType c1, CDType c2,
2084+
CDType c3, const ItemT &item) {
2085+
int lane = item.get_sub_group().get_local_linear_id();
2086+
2087+
short ROW_LOAD_OFFSET = 4 * (lane / 4);
2088+
short COL_LOAD_OFFSET = 8 * (lane % 4);
2089+
2090+
ABType recv_a[4 * 4], recv_b[4 * 4];
2091+
for (int i = 0; i < 4; i++) {
2092+
recv_a[0 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), a0,
2093+
ROW_LOAD_OFFSET + i);
2094+
recv_a[1 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), a2,
2095+
ROW_LOAD_OFFSET + i);
2096+
recv_a[2 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), a1,
2097+
ROW_LOAD_OFFSET + i);
2098+
recv_a[3 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), a3,
2099+
ROW_LOAD_OFFSET + i);
2100+
2101+
recv_b[0 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), b0,
2102+
COL_LOAD_OFFSET + i);
2103+
recv_b[1 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), b1,
2104+
COL_LOAD_OFFSET + i);
2105+
recv_b[2 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), b0,
2106+
COL_LOAD_OFFSET + 4 + i);
2107+
recv_b[3 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), b1,
2108+
COL_LOAD_OFFSET + 4 + i);
2109+
}
2110+
2111+
auto *a = reinterpret_cast<MulType *>(recv_a);
2112+
auto *b = reinterpret_cast<MulType *>(recv_b);
2113+
for (int i = 0; i < 16; i++) {
2114+
auto a0 = static_cast<CDType>(a[i]);
2115+
auto a1 = static_cast<CDType>(a[i + 16]);
2116+
auto b0 = static_cast<CDType>(b[i]);
2117+
auto b1 = static_cast<CDType>(b[i + 16]);
2118+
2119+
c0 += a0 * b0;
2120+
c1 += a0 * b1;
2121+
c2 += a1 * b0;
2122+
c3 += a1 * b1;
2123+
}
2124+
2125+
*d0 = c0;
2126+
*d1 = c1;
2127+
*d2 = c2;
2128+
*d3 = c3;
2129+
}
2130+
20582131
} // namespace matrix
20592132
} // namespace experimental
20602133

clang/test/dpct/asm/mma.cu

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// UNSUPPORTED: cuda-8.0, cuda-9.0, cuda-9.1, cuda-9.2, cuda-10.0, cuda-10.1, cuda-10.2
2+
// UNSUPPORTED: v8.0, v9.0, v9.1, v9.2, v10.0, v10.1, v10.2
3+
// RUN: dpct --format-range=none -out-root %T/mma %s --cuda-include-path="%cuda-path/include" -- -std=c++14 -x cuda --cuda-host-only
4+
// RUN: FileCheck %s --match-full-lines --input-file %T/mma/mma.dp.cpp
5+
// RUN: %if build_lit %{icpx -c -DNO_BUILD_TEST -fsycl %T/mma/mma.dp.cpp -o %T/mma/mma.dp.o %}
6+
7+
// clang-format off
8+
#include <cuda_runtime.h>
9+
#include <cuda_fp16.h>
10+
11+
/*
12+
mma.sync.aligned.m16n8k16.alayout.blayout.dtype.f16.f16.ctype d, a, b, c;
13+
14+
Below are the currenly supported configurations:
15+
16+
.alayout = {.row};
17+
.blayout = {.col};
18+
.ctype = {.f32};
19+
.dtype = {.f32};
20+
*/
21+
22+
__global__ void mma_kernel() {
23+
int a[4];
24+
int b[2];
25+
float c[4];
26+
27+
// CHECK: dpct::experimental::matrix::mma<sycl::half>(&c[0], &c[1], &c[2], &c[3], a[0], a[1], a[2], a[3], b[0], b[1], c[0], c[1], c[2], c[3], item_ct1);
28+
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
29+
" { %0, %1, %2, %3 }, "
30+
" { %4, %5, %6, %7 }, "
31+
" { %8, %9 }, "
32+
" { %0, %1, %2, %3 };"
33+
: "+f"(c[0]), "+f"(c[1]), "+f"(c[2]), "+f"(c[3])
34+
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]),
35+
"r"(b[0]), "r"(b[1]));
36+
}
37+
38+
39+
int main () {
40+
// CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} {
41+
mma_kernel<<<1, 32>>>();
42+
43+
return 0;
44+
}
45+
// clang-format on

docs/dev_guide/api-mapping-status/ASM_API_migration_status.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ max,YES,
5454
mbarrier,NO,
5555
membar,YES, Partial
5656
min,YES,
57-
mma,NO,
57+
mma,YES, Partial
5858
mov,YES,
5959
movmatrix,NO,
6060
mul,YES, Partial

0 commit comments

Comments
 (0)