Skip to content

Commit 9f8c46d

Browse files
[SYCLomatic][PTX] Added support for 4 more MMA shapes (#2839)
* Added support for 4 more MMA shapes * m8n8k4.col.row.f32.f16.f16.f32 * m8n8k16.row.col.s32.s8.s8.s32 * m16n8k8.row.col.f32.bf16.bf16.f32 * m16n8k8.row.col.f32.f16.f16.f32 * m16n8k16.row.col.s32.s8.s8.s32 * m16n8k16.row.col.f32.f16.f16.f32 * m16n8k32.row.col.s32.s8.s8.s32 * Added comments of MMA algo
1 parent 9a7b1de commit 9f8c46d

File tree

3 files changed

+485
-19
lines changed

3 files changed

+485
-19
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 101 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,9 +1386,16 @@ class SYCLGen : public SYCLGenBase {
13861386
return SYCLGenError();
13871387

13881388
// Only row Layout is supported for of A matrix and
1389-
// only col Layout is supported for of B matrix
1390-
if (Inst->getAttr(3) != InstAttr::row || Inst->getAttr(4) != InstAttr::col)
1391-
return SYCLGenError();
1389+
// only col Layout is supported for of B matrix (except for m8n8k4)
1390+
if (Inst->hasAttr(InstAttr::m8n8k4)) {
1391+
if (Inst->getAttr(3) != InstAttr::col ||
1392+
Inst->getAttr(4) != InstAttr::row)
1393+
return SYCLGenError();
1394+
} else {
1395+
if (Inst->getAttr(3) != InstAttr::row ||
1396+
Inst->getAttr(4) != InstAttr::col)
1397+
return SYCLGenError();
1398+
}
13921399

13931400
// Data types of D, A, B & C matrices respectively in the PTX instruction
13941401
const auto *DType = dyn_cast<InlineAsmBuiltinType>(Inst->getType(0));
@@ -1421,7 +1428,68 @@ class SYCLGen : public SYCLGenBase {
14211428
// Data types of A, B & C matrices respectively in the PTX arguments
14221429
std::string InMatrixType[3];
14231430

1424-
if (Inst->hasAttr(InstAttr::m16n8k16)) {
1431+
if (Inst->hasAttr(InstAttr::m8n8k4)) {
1432+
M = "8";
1433+
N = "8";
1434+
K = "4";
1435+
1436+
// Only f16 type is supported for A and B matrices of m8n8k4
1437+
if (AType->getKind() == InlineAsmBuiltinType::f16) {
1438+
InMatrixType[0] = "uint32_t"; // A type is .f16x2
1439+
InMatrixType[1] = "uint32_t"; // B type is .f16x2
1440+
1441+
// If A matrix type is f16, then C&D matrix types can only be f32
1442+
if (CType->getKind() == InlineAsmBuiltinType::f32) {
1443+
NumVecElements[0] = 2; // A
1444+
NumVecElements[1] = 2; // B
1445+
NumVecElements[2] = 8; // C
1446+
NumVecElements[3] = 8; // D
1447+
} else
1448+
return SYCLGenError();
1449+
} else
1450+
return SYCLGenError();
1451+
} else if (Inst->hasAttr(InstAttr::m8n8k16)) {
1452+
M = "8";
1453+
N = "8";
1454+
K = "16";
1455+
1456+
// Only s8 type is supported for A and B matrices of m8n8k16
1457+
if (AType->getKind() == InlineAsmBuiltinType::s8) {
1458+
InMatrixType[0] = "uint32_t"; // A type is .s8x4
1459+
InMatrixType[1] = "uint32_t"; // B type is .s8x4
1460+
1461+
// If A matrix type is s8, then C&D matrix types can only be s32
1462+
if (CType->getKind() == InlineAsmBuiltinType::s32) {
1463+
NumVecElements[0] = 1; // A
1464+
NumVecElements[1] = 1; // B
1465+
NumVecElements[2] = 2; // C
1466+
NumVecElements[3] = 2; // D
1467+
} else
1468+
return SYCLGenError();
1469+
} else
1470+
return SYCLGenError();
1471+
} else if (Inst->hasAttr(InstAttr::m16n8k8)) {
1472+
M = "16";
1473+
N = "8";
1474+
K = "8";
1475+
1476+
// Only f16/bf16 types are supported for A and B matrices of m16n8k8
1477+
if (AType->getKind() == InlineAsmBuiltinType::f16 ||
1478+
AType->getKind() == InlineAsmBuiltinType::bf16) {
1479+
InMatrixType[0] = "uint32_t"; // A type is .f16x2
1480+
InMatrixType[1] = "uint32_t"; // B type is .f16x2
1481+
1482+
// If A matrix type is f16/bf16, then C&D matrix types can only be f32
1483+
if (CType->getKind() == InlineAsmBuiltinType::f32) {
1484+
NumVecElements[0] = 2; // A
1485+
NumVecElements[1] = 1; // B
1486+
NumVecElements[2] = 4; // C
1487+
NumVecElements[3] = 4; // D
1488+
} else
1489+
return SYCLGenError();
1490+
} else
1491+
return SYCLGenError();
1492+
} else if (Inst->hasAttr(InstAttr::m16n8k16)) {
14251493
M = "16";
14261494
N = "8";
14271495
K = "16";
@@ -1440,8 +1508,8 @@ class SYCLGen : public SYCLGenBase {
14401508
} else
14411509
return SYCLGenError();
14421510
} else if (AType->getKind() == InlineAsmBuiltinType::s8) {
1443-
InMatrixType[0] = "uint32_t"; // A type is .f16x2
1444-
InMatrixType[1] = "uint32_t"; // B type is .f16x2
1511+
InMatrixType[0] = "uint32_t"; // A type is .s8x4
1512+
InMatrixType[1] = "uint32_t"; // B type is .s8x4
14451513

14461514
// If A matrix type is s8, then C&D matrix types can only be s32
14471515
if (CType->getKind() == InlineAsmBuiltinType::s32) {
@@ -1453,6 +1521,26 @@ class SYCLGen : public SYCLGenBase {
14531521
return SYCLGenError();
14541522
} else
14551523
return SYCLGenError();
1524+
} else if (Inst->hasAttr(InstAttr::m16n8k32)) {
1525+
M = "16";
1526+
N = "8";
1527+
K = "32";
1528+
1529+
// Only s8 type is supported for A and B matrices of m16n8k32
1530+
if (AType->getKind() == InlineAsmBuiltinType::s8) {
1531+
InMatrixType[0] = "uint32_t"; // A type is .s8x4
1532+
InMatrixType[1] = "uint32_t"; // B type is .s8x4
1533+
1534+
// If A matrix type is s8, then C&D matrix types can only be s32
1535+
if (CType->getKind() == InlineAsmBuiltinType::s32) {
1536+
NumVecElements[0] = 4; // A
1537+
NumVecElements[1] = 2; // B
1538+
NumVecElements[2] = 4; // C
1539+
NumVecElements[3] = 4; // D
1540+
} else
1541+
return SYCLGenError();
1542+
} else
1543+
return SYCLGenError();
14561544
} else
14571545
return SYCLGenError();
14581546

@@ -1472,7 +1560,9 @@ class SYCLGen : public SYCLGenBase {
14721560
return SYCLGenError();
14731561

14741562
// Declare and init an array for storing the addresses of D matrix elements
1475-
OS() << "{\n";
1563+
OS() << "{" << getNL();
1564+
incIndent();
1565+
indent();
14761566
OS() << "volatile " << CDType << " *d_mat_frag_ct1["
14771567
<< DMatVE->getNumElements() << "] = { ";
14781568
for (unsigned Inst = 0; Inst != DMatVE->getNumElements(); ++Inst) {
@@ -1494,6 +1584,7 @@ class SYCLGen : public SYCLGenBase {
14941584
InputOp++) {
14951585
if (auto VE =
14961586
dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand(InputOp))) {
1587+
indent();
14971588
OS() << "sycl::vec<" << InMatrixType[InputOp] << ", "
14981589
<< VE->getNumElements() << "> " << InMatrixName[InputOp]
14991590
<< "_mat_frag_ct1(";
@@ -1512,6 +1603,7 @@ class SYCLGen : public SYCLGenBase {
15121603
}
15131604
}
15141605

1606+
indent();
15151607
OS() << MapNames::getDpctNamespace() << "experimental::matrix::mma";
15161608
OS() << "<";
15171609
OS() << M << ", " << N << ", " << K << ", ";
@@ -1523,6 +1615,8 @@ class SYCLGen : public SYCLGenBase {
15231615
OS() << ", &" << InMatrixName[i] << "_mat_frag_ct1";
15241616
OS() << ")";
15251617
endstmt();
1618+
decIndent();
1619+
indent();
15261620
OS() << "}";
15271621
endstmt();
15281622

0 commit comments

Comments
 (0)