Skip to content

Commit 8f31872

Browse files
[SYCLomatic][PTX][MMA] Added support for migrating m16n8k16 of MMA asm (#2821)
* Added support for mma m16n8k16 migration, supported types: * f32.f16.f16.f32 * s32.s8.s8.s32
1 parent 4820d71 commit 8f31872

File tree

8 files changed

+435
-17
lines changed

8 files changed

+435
-17
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 171 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -514,14 +514,17 @@ bool SYCLGenBase::emitType(const InlineAsmType *T) {
514514
bool SYCLGenBase::emitBuiltinType(const InlineAsmBuiltinType *T) {
515515
switch (T->getKind()) {
516516
// clang-format off
517+
case InlineAsmBuiltinType::b1: OS() << "uint8_t"; break;
517518
case InlineAsmBuiltinType::b8: OS() << "uint8_t"; break;
518519
case InlineAsmBuiltinType::b16: OS() << "uint16_t"; break;
519520
case InlineAsmBuiltinType::b32: OS() << "uint32_t"; break;
520521
case InlineAsmBuiltinType::b64: OS() << "uint64_t"; break;
522+
case InlineAsmBuiltinType::u4: OS() << "uint8_t"; break;
521523
case InlineAsmBuiltinType::u8: OS() << "uint8_t"; break;
522524
case InlineAsmBuiltinType::u16: OS() << "uint16_t"; break;
523525
case InlineAsmBuiltinType::u32: OS() << "uint32_t"; break;
524526
case InlineAsmBuiltinType::u64: OS() << "uint64_t"; break;
527+
case InlineAsmBuiltinType::s4: OS() << "int8_t"; break;
525528
case InlineAsmBuiltinType::s8: OS() << "int8_t"; break;
526529
case InlineAsmBuiltinType::s16: OS() << "int16_t"; break;
527530
case InlineAsmBuiltinType::s32: OS() << "int32_t"; break;
@@ -559,6 +562,9 @@ bool SYCLGenBase::emitVectorType(const InlineAsmVectorType *T) {
559562
case InlineAsmVectorType::x1:
560563
OS() << 1;
561564
break;
565+
case InlineAsmVectorType::v1:
566+
OS() << 1;
567+
break;
562568
case InlineAsmVectorType::v2:
563569
case InlineAsmVectorType::x2:
564570
OS() << 2;
@@ -1370,6 +1376,167 @@ class SYCLGen : public SYCLGenBase {
13701376
return SYCLGenSuccess();
13711377
}
13721378

1379+
bool handle_mma(const InlineAsmInstruction *Inst) override {
1380+
if (Inst->getNumInputOperands() != 3)
1381+
return SYCLGenError();
1382+
1383+
const InlineAsmVectorExpr *DMatVE =
1384+
dyn_cast<InlineAsmVectorExpr>(Inst->getOutputOperand());
1385+
if (!DMatVE)
1386+
return SYCLGenError();
1387+
1388+
// Only row Layout is supported for of A matrix and
1389+
// only col Layout is supported for of B matrix
1390+
if (Inst->getAttr(3) != InstAttr::row || Inst->getAttr(4) != InstAttr::col)
1391+
return SYCLGenError();
1392+
1393+
// Data types of D, A, B & C matrices respectively in the PTX instruction
1394+
const auto *DType = dyn_cast<InlineAsmBuiltinType>(Inst->getType(0));
1395+
const auto *AType = dyn_cast<InlineAsmBuiltinType>(Inst->getType(1));
1396+
const auto *BType = dyn_cast<InlineAsmBuiltinType>(Inst->getType(2));
1397+
const auto *CType = dyn_cast<InlineAsmBuiltinType>(Inst->getType(3));
1398+
1399+
if (!(AType && BType && CType && DType))
1400+
return SYCLGenError();
1401+
1402+
// Data types of matrix elements for A&B and C&D matrices should be same
1403+
if ((AType->getKind() != BType->getKind()) ||
1404+
(CType->getKind() != DType->getKind()))
1405+
return SYCLGenError();
1406+
1407+
// Check the validity of AB & CD types
1408+
std::string ABType, CDType;
1409+
if (tryEmitType(ABType, AType))
1410+
return SYCLGenError();
1411+
1412+
if (tryEmitType(CDType, CType))
1413+
return SYCLGenError();
1414+
1415+
// Register sizes for vector elements of A, B, C & D matrices
1416+
unsigned NumVecElements[4] = {0};
1417+
1418+
// Sizes of A & B matrices
1419+
std::string M, N, K;
1420+
1421+
// Data types of A, B & C matrices respectively in the PTX arguments
1422+
std::string InMatrixType[3];
1423+
1424+
if (Inst->hasAttr(InstAttr::m16n8k16)) {
1425+
M = "16";
1426+
N = "8";
1427+
K = "16";
1428+
1429+
// Only f16/s8 types are supported for A and B matrices of m16n8k16
1430+
if (AType->getKind() == InlineAsmBuiltinType::f16) {
1431+
InMatrixType[0] = "uint32_t"; // A type is .f16x2
1432+
InMatrixType[1] = "uint32_t"; // B type is .f16x2
1433+
1434+
// If A matrix type is f16, then C&D matrix types can only be f32
1435+
if (CType->getKind() == InlineAsmBuiltinType::f32) {
1436+
NumVecElements[0] = 4; // A
1437+
NumVecElements[1] = 2; // B
1438+
NumVecElements[2] = 4; // C
1439+
NumVecElements[3] = 4; // D
1440+
} else
1441+
return SYCLGenError();
1442+
} else if (AType->getKind() == InlineAsmBuiltinType::s8) {
1443+
InMatrixType[0] = "uint32_t"; // A type is .f16x2
1444+
InMatrixType[1] = "uint32_t"; // B type is .f16x2
1445+
1446+
// If A matrix type is s8, then C&D matrix types can only be s32
1447+
if (CType->getKind() == InlineAsmBuiltinType::s32) {
1448+
NumVecElements[0] = 2; // A
1449+
NumVecElements[1] = 1; // B
1450+
NumVecElements[2] = 4; // C
1451+
NumVecElements[3] = 4; // D
1452+
} else
1453+
return SYCLGenError();
1454+
} else
1455+
return SYCLGenError();
1456+
} else
1457+
return SYCLGenError();
1458+
1459+
InMatrixType[2] = CDType;
1460+
1461+
// Check the register sizes for vector elements of A, B, C & D matrices
1462+
for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands();
1463+
InputOp++) {
1464+
if (auto VE =
1465+
dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand(InputOp))) {
1466+
if (VE->getNumElements() != NumVecElements[InputOp])
1467+
return SYCLGenError();
1468+
} else
1469+
return SYCLGenError();
1470+
}
1471+
if (DMatVE->getNumElements() != NumVecElements[3])
1472+
return SYCLGenError();
1473+
1474+
// Declare and init an array for storing the addresses of D matrix elements
1475+
OS() << "{\n";
1476+
OS() << "volatile " << CDType << " *d_mat_frag_ct1["
1477+
<< DMatVE->getNumElements() << "] = { ";
1478+
for (unsigned Inst = 0; Inst != DMatVE->getNumElements(); ++Inst) {
1479+
if (isa<InlineAsmDiscardExpr>(DMatVE->getElement(Inst)))
1480+
continue;
1481+
OS() << "&";
1482+
if (emitStmt(DMatVE->getElement(Inst)))
1483+
return SYCLGenError();
1484+
if ((Inst + 1) != DMatVE->getNumElements())
1485+
OS() << ", ";
1486+
}
1487+
OS() << " }";
1488+
endstmt();
1489+
1490+
// Declare and init vectors for storing the values of A, B & C matrix
1491+
// elements
1492+
std::string InMatrixName[3] = {"a", "b", "c"};
1493+
for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands();
1494+
InputOp++) {
1495+
if (auto VE =
1496+
dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand(InputOp))) {
1497+
OS() << "sycl::vec<" << InMatrixType[InputOp] << ", "
1498+
<< VE->getNumElements() << "> " << InMatrixName[InputOp]
1499+
<< "_mat_frag_ct1(";
1500+
for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) {
1501+
if (isa<InlineAsmDiscardExpr>(VE->getElement(Inst)))
1502+
continue;
1503+
if (emitStmt(VE->getElement(Inst)))
1504+
return SYCLGenError();
1505+
if ((Inst + 1) != VE->getNumElements())
1506+
OS() << ", ";
1507+
}
1508+
OS() << ")";
1509+
endstmt();
1510+
} else {
1511+
return SYCLGenError();
1512+
}
1513+
}
1514+
1515+
OS() << MapNames::getDpctNamespace() << "experimental::matrix::mma";
1516+
OS() << "<";
1517+
OS() << M << ", " << N << ", " << K << ", ";
1518+
OS() << ABType << ", " << CDType;
1519+
OS() << ">(";
1520+
1521+
OS() << "reinterpret_cast<volatile void **>(d_mat_frag_ct1)";
1522+
for (int i = 0; i < 3; i++)
1523+
OS() << ", &" << InMatrixName[i] << "_mat_frag_ct1";
1524+
OS() << ")";
1525+
endstmt();
1526+
OS() << "}";
1527+
endstmt();
1528+
1529+
const auto *KernelDecl = getImmediateOuterFuncDecl(GAS);
1530+
if (KernelDecl) {
1531+
auto FuncInfo = DeviceFunctionDecl::LinkRedecls(KernelDecl);
1532+
if (FuncInfo)
1533+
FuncInfo->addSubGroupSizeRequest(32, GAS->getBeginLoc(),
1534+
DpctGlobalInfo::getSubGroup(GAS));
1535+
}
1536+
1537+
return SYCLGenSuccess();
1538+
}
1539+
13731540
bool handle_prefetch(const InlineAsmInstruction *Inst) override {
13741541
if (!DpctGlobalInfo::useExtPrefetch() || Inst->getNumInputOperands() != 1)
13751542
return SYCLGenError();
@@ -2595,11 +2762,10 @@ class SYCLGen : public SYCLGenBase {
25952762
Op = std::move(NewOp);
25962763
}
25972764

2598-
bool HasHalfOrBfloat16 =
2599-
SrcType->getKind() == InlineAsmBuiltinType::f16 ||
2600-
DesType->getKind() == InlineAsmBuiltinType::f16 ||
2601-
SrcType->getKind() == InlineAsmBuiltinType::bf16 ||
2602-
DesType->getKind() == InlineAsmBuiltinType::bf16;
2765+
bool HasHalfOrBfloat16 = SrcType->getKind() == InlineAsmBuiltinType::f16 ||
2766+
DesType->getKind() == InlineAsmBuiltinType::f16 ||
2767+
SrcType->getKind() == InlineAsmBuiltinType::bf16 ||
2768+
DesType->getKind() == InlineAsmBuiltinType::bf16;
26032769
if (DpctGlobalInfo::useIntelDeviceMath() && HasHalfOrBfloat16) {
26042770
insertHeader(HeaderType::HT_SYCL_Math);
26052771
if (SrcNeedBitCast)

clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,9 @@ class InlineAsmBuiltinType : public InlineAsmType {
9999
return ((K == Kind) || ...);
100100
}
101101
template <class... Ks> bool isNot(Ks... K) { return ((K != Kind) && ...); }
102-
bool isBit() const { return isOneOf(b8, b16, b32, b64); }
103-
bool isSigned() const { return isOneOf(s8, s16, s32, s64); }
104-
bool isUnsigned() const { return isOneOf(u8, u16, u32, u64); }
102+
bool isBit() const { return isOneOf(b1, b8, b16, b32, b64); }
103+
bool isSigned() const { return isOneOf(s4, s8, s16, s32, s64); }
104+
bool isUnsigned() const { return isOneOf(u4, u8, u16, u32, u64); }
105105
bool isInt() const { return isSigned() || isUnsigned(); }
106106
bool isFloat() const { return isOneOf(f16, f32, f64); }
107107
bool isScalar() const { return isInt() || isFloat(); }
@@ -116,7 +116,7 @@ class InlineAsmBuiltinType : public InlineAsmType {
116116
// This class is used for device asm vector types.
117117
class InlineAsmVectorType : public InlineAsmType {
118118
public:
119-
enum VecKind { v2, v4, v8, x1, x2, x4 };
119+
enum VecKind { v1, v2, v4, v8, x1, x2, x4 };
120120

121121
private:
122122
VecKind Kind;
@@ -322,7 +322,7 @@ class InlineAsmInstruction : public InlineAsmStmt {
322322

323323
/// This represents arrtibutes like: comparsion operator, rounding modifiers,
324324
/// ... e.g. instruction setp.eq.s32 has a comparsion operator 'eq'.
325-
SmallSet<InstAttr, 4> Attributes;
325+
SmallVector<InstAttr, 4> Attributes;
326326

327327
/// This represents types in instruction, e.g. mov.u32.
328328
SmallVector<InlineAsmType *, 4> Types;
@@ -350,11 +350,11 @@ class InlineAsmInstruction : public InlineAsmStmt {
350350
OutputOp(Out), PredOutputOp(Pred), InputOps(InOps) {
351351
StateSpaces.insert(StateSpaces.begin(), AsmStateSpaces.begin(),
352352
AsmStateSpaces.end());
353-
Attributes.insert(Attrs.begin(), Attrs.end());
353+
Attributes.insert(Attributes.begin(), Attrs.begin(), Attrs.end());
354354
}
355355

356356
using attr_range =
357-
llvm::iterator_range<SmallSet<InstAttr, 4>::const_iterator>;
357+
llvm::iterator_range<SmallVector<InstAttr, 4>::const_iterator>;
358358
using type_range =
359359
llvm::iterator_range<SmallVector<InlineAsmType *, 4>::const_iterator>;
360360
using op_range =
@@ -369,12 +369,16 @@ class InlineAsmInstruction : public InlineAsmStmt {
369369
}
370370

371371
template <typename... Ts> bool hasAttr(Ts... Attrs) const {
372-
return (Attributes.contains(Attrs) || ...);
372+
return (llvm::is_contained(Attributes, Attrs) || ...);
373373
}
374374
const InlineAsmIdentifierInfo *getOpcodeID() const { return Opcode; }
375375
asmtok::TokenKind getOpcode() const { return Opcode->getTokenID(); }
376376
ArrayRef<InlineAsmType *> getTypes() const { return Types; }
377377
const InlineAsmType *getType(unsigned I) const { return Types[I]; }
378+
InstAttr getAttr(unsigned I) const {
379+
assert(I < Attributes.size() && "Attributes index out of range");
380+
return Attributes[I];
381+
}
378382
unsigned getNumTypes() const { return Types.size(); }
379383
const InlineAsmExpr *getOutputOperand() const { return OutputOp; }
380384
const InlineAsmExpr *getPredOutputOperand() const { return PredOutputOp; }

clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,9 @@ InlineAsmParser::ActOnVectorExpr(ArrayRef<InlineAsmExpr *> Vec) {
756756
} else {
757757
// Vector size must be 2, 4, or 8.
758758
switch (Vec.size()) {
759+
case 1:
760+
Kind = InlineAsmVectorType::v1;
761+
break;
759762
case 2:
760763
Kind = InlineAsmVectorType::v2;
761764
break;

clang/lib/DPCT/RulesAsm/Parser/AsmParser.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ class InlineAsmParser {
498498
/// .reg .sreg .const .local .param .shared .tex
499499
///
500500
/// vector-specifier: one of
501-
/// .v2 .v4 .v8
501+
/// .v1 .v2 .v4 .v8
502502
///
503503
/// type-specifier: one of
504504
/// .b8 .b16 .b32 .b64 .s8 .s16 .s32 .s64

clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,14 +240,17 @@ SPECIAL_REG(warpid, "%warpid", s64)
240240
SPECIAL_REG(WARP_SZ, "WARP_SZ", s64)
241241

242242
// Built-in type names
243+
BUILTIN_TYPE(b1, ".b1")
243244
BUILTIN_TYPE(b8, ".b8")
244245
BUILTIN_TYPE(b16, ".b16")
245246
BUILTIN_TYPE(b32, ".b32")
246247
BUILTIN_TYPE(b64, ".b64")
248+
BUILTIN_TYPE(u4, ".u4")
247249
BUILTIN_TYPE(u8, ".u8")
248250
BUILTIN_TYPE(u16, ".u16")
249251
BUILTIN_TYPE(u32, ".u32")
250252
BUILTIN_TYPE(u64, ".u64")
253+
BUILTIN_TYPE(s4, ".s4")
251254
BUILTIN_TYPE(s8, ".s8")
252255
BUILTIN_TYPE(s16, ".s16")
253256
BUILTIN_TYPE(s32, ".s32")
@@ -270,6 +273,7 @@ BUILTIN_TYPE(s16x2, ".s16x2")
270273
BUILTIN_TYPE(u16x2, ".u16x2")
271274

272275
// Vector modifiers
276+
MODIFIER(v1, ".v1")
273277
MODIFIER(v2, ".v2")
274278
MODIFIER(v4, ".v4")
275279
MODIFIER(v8, ".v8")
@@ -279,8 +283,23 @@ MODIFIER(x1, ".x1")
279283
MODIFIER(x2, ".x2")
280284
MODIFIER(x4, ".x4")
281285

286+
// Matrix modifiers
287+
MODIFIER(row, ".row")
288+
MODIFIER(col, ".col")
289+
282290
// Matrix shape
283291
MODIFIER(m8n8, ".m8n8")
292+
MODIFIER(m8n8k4, ".m8n8k4")
293+
MODIFIER(m8n8k16, ".m8n8k16")
294+
MODIFIER(m8n8k32, ".m8n8k32")
295+
MODIFIER(m8n8k128, ".m8n8k128")
296+
MODIFIER(m16n8k4, ".m16n8k4")
297+
MODIFIER(m16n8k8, ".m16n8k8")
298+
MODIFIER(m16n8k16, ".m16n8k16")
299+
MODIFIER(m16n8k32, ".m16n8k32")
300+
MODIFIER(m16n8k64, ".m16n8k64")
301+
MODIFIER(m16n8k128, ".m16n8k128")
302+
MODIFIER(m16n8k256, ".m16n8k256")
284303

285304
STATE_SPACE(reg, ".reg")
286305
STATE_SPACE(sreg, ".sreg")
@@ -376,6 +395,7 @@ MODIFIER(max, ".max")
376395
MODIFIER(op_or, ".or")
377396
MODIFIER(op_xor, ".xor")
378397
MODIFIER(op_and, ".and")
398+
MODIFIER(op_popc, ".popc")
379399
MODIFIER(cas, ".cas")
380400
MODIFIER(exch, ".exch")
381401
MODIFIER(inc, ".inc")

0 commit comments

Comments
 (0)