Skip to content

Commit c94ff56

Browse files
Added support for 10 shapes
1 parent 891e459 commit c94ff56

File tree

5 files changed

+1898
-176
lines changed

5 files changed

+1898
-176
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 257 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -514,14 +514,17 @@ bool SYCLGenBase::emitType(const InlineAsmType *T) {
514514
bool SYCLGenBase::emitBuiltinType(const InlineAsmBuiltinType *T) {
515515
switch (T->getKind()) {
516516
// clang-format off
517+
case InlineAsmBuiltinType::b1: OS() << "uint8_t"; break;
517518
case InlineAsmBuiltinType::b8: OS() << "uint8_t"; break;
518519
case InlineAsmBuiltinType::b16: OS() << "uint16_t"; break;
519520
case InlineAsmBuiltinType::b32: OS() << "uint32_t"; break;
520521
case InlineAsmBuiltinType::b64: OS() << "uint64_t"; break;
522+
case InlineAsmBuiltinType::u4: OS() << "uint8_t"; break;
521523
case InlineAsmBuiltinType::u8: OS() << "uint8_t"; break;
522524
case InlineAsmBuiltinType::u16: OS() << "uint16_t"; break;
523525
case InlineAsmBuiltinType::u32: OS() << "uint32_t"; break;
524526
case InlineAsmBuiltinType::u64: OS() << "uint64_t"; break;
527+
case InlineAsmBuiltinType::s4: OS() << "int8_t"; break;
525528
case InlineAsmBuiltinType::s8: OS() << "int8_t"; break;
526529
case InlineAsmBuiltinType::s16: OS() << "int16_t"; break;
527530
case InlineAsmBuiltinType::s32: OS() << "int32_t"; break;
@@ -1347,44 +1350,276 @@ class SYCLGen : public SYCLGenBase {
13471350
// Register sizes for vector elements of A, B, C & D matrices
13481351
unsigned NumVecElements[4] = {0};
13491352

1353+
// Sizes of A & B matrices
1354+
std::string M, N, K;
1355+
1356+
// Operator for m8n8k128/m16n8k128/m16n8k256
1357+
std::string MatrixOp;
1358+
13501359
// Data type used to multiply A & B matrices
13511360
std::string MulType;
1352-
if (Inst->hasAttr(InstAttr::m16n8k16)) {
1353-
// Only f16 type is supported for A and B matrix data for m16n8k16
1361+
if (Inst->hasAttr(InstAttr::m8n8k4)) {
1362+
M = "8";
1363+
N = "8";
1364+
K = "4";
1365+
// f16 & f64 types are supported for A and B matrices of m8n8k4
13541366
if (AType->getKind() == InlineAsmBuiltinType::f16) {
1355-
// If A matrix type is f16, then C&D matrix types can only be f16
1367+
// If A matrix type is f16, then C&D matrix types can only be f16/f32
13561368
if (CType->getKind() == AType->getKind()) {
13571369
NumVecElements[0] = 2; // A
1358-
NumVecElements[1] = 4; // B
1370+
NumVecElements[1] = 2; // B
13591371
NumVecElements[2] = 4; // C
13601372
NumVecElements[3] = 4; // D
1373+
} else if (CType->getKind() == InlineAsmBuiltinType::f32) {
1374+
NumVecElements[0] = 2; // A
1375+
NumVecElements[1] = 2; // B
1376+
NumVecElements[2] = 8; // C
1377+
NumVecElements[3] = 8; // D
1378+
} else
1379+
return SYCLGenError();
1380+
} else if (AType->getKind() == InlineAsmBuiltinType::f64) {
1381+
// If A matrix type is f64, then C&D matrix types can only be f64
1382+
if (CType->getKind() == AType->getKind()) {
1383+
NumVecElements[0] = 1; // A
1384+
NumVecElements[1] = 1; // B
1385+
NumVecElements[2] = 2; // C
1386+
NumVecElements[3] = 2; // D
1387+
} else
1388+
return SYCLGenError();
1389+
} else
1390+
return SYCLGenError();
1391+
} else if (Inst->hasAttr(InstAttr::m8n8k16)) {
1392+
M = "8";
1393+
N = "8";
1394+
K = "16";
1395+
// Only s8/u8 types are supported for A and B matrices of m8n8k16
1396+
if (AType->getKind() == InlineAsmBuiltinType::s8 ||
1397+
AType->getKind() == InlineAsmBuiltinType::u8) {
1398+
// If A matrix type is s8/u8, then C&D matrix types can only be s32
1399+
if (CType->getKind() == InlineAsmBuiltinType::s32) {
1400+
NumVecElements[0] = 1; // A
1401+
NumVecElements[1] = 1; // B
1402+
NumVecElements[2] = 2; // C
1403+
NumVecElements[3] = 2; // D
13611404
} else
13621405
return SYCLGenError();
13631406
} else
13641407
return SYCLGenError();
1365-
} else if (Inst->hasAttr(InstAttr::m8n8k4)) {
1366-
// f16 & f64 types are supported for A and B matrix data for m8n8k4
1408+
} else if (Inst->hasAttr(InstAttr::m8n8k32)) {
1409+
M = "8";
1410+
N = "8";
1411+
K = "32";
1412+
// Only s4/u4 types are supported for A and B matrices of m16n8k32
1413+
if (AType->getKind() == InlineAsmBuiltinType::s4 ||
1414+
AType->getKind() == InlineAsmBuiltinType::u4) {
1415+
// If A matrix type is s4/u4, then C&D matrix types can only be s32
1416+
if (CType->getKind() == InlineAsmBuiltinType::s32) {
1417+
NumVecElements[0] = 1; // A
1418+
NumVecElements[1] = 1; // B
1419+
NumVecElements[2] = 2; // C
1420+
NumVecElements[3] = 2; // D
1421+
} else
1422+
return SYCLGenError();
1423+
} else
1424+
return SYCLGenError();
1425+
} else if (Inst->hasAttr(InstAttr::m8n8k128)) {
1426+
M = "8";
1427+
N = "8";
1428+
K = "128";
1429+
// Only b1 type is supported for A and B matrices of m16n8k128
1430+
if (AType->getKind() == InlineAsmBuiltinType::b1) {
1431+
// If A matrix type is b1, then C&D matrix types can only be s32
1432+
if (CType->getKind() == InlineAsmBuiltinType::s32) {
1433+
NumVecElements[0] = 1; // A
1434+
NumVecElements[1] = 1; // B
1435+
NumVecElements[2] = 2; // C
1436+
NumVecElements[3] = 2; // D
1437+
1438+
// Only and/xor bitwise operations are supported for m8n8k128
1439+
if (Inst->hasAttr(InstAttr::op_and))
1440+
MatrixOp = "and";
1441+
else if (Inst->hasAttr(InstAttr::op_xor))
1442+
MatrixOp = "xor";
1443+
else
1444+
return SYCLGenError();
1445+
} else
1446+
return SYCLGenError();
1447+
} else
1448+
return SYCLGenError();
1449+
} else if (Inst->hasAttr(InstAttr::m16n8k4)) {
1450+
M = "16";
1451+
N = "8";
1452+
K = "4";
1453+
// Only f64 type is supported for A and B matrices of m16n8k4
1454+
if (AType->getKind() == InlineAsmBuiltinType::f64) {
1455+
// If A matrix type is f64, then C&D matrix types can only be f64
1456+
if (CType->getKind() == InlineAsmBuiltinType::f64) {
1457+
NumVecElements[0] = 2; // A
1458+
NumVecElements[1] = 1; // B
1459+
NumVecElements[2] = 4; // C
1460+
NumVecElements[3] = 4; // D
1461+
} else
1462+
return SYCLGenError();
1463+
} else
1464+
return SYCLGenError();
1465+
} else if (Inst->hasAttr(InstAttr::m16n8k8)) {
1466+
M = "16";
1467+
N = "8";
1468+
K = "8";
1469+
// Only f16/f64 types are supported for A and B matrices of m16n8k8
13671470
if (AType->getKind() == InlineAsmBuiltinType::f16) {
13681471
// If A matrix type is f16, then C&D matrix types can only be f16/f32
1369-
if (CType->getKind() == AType->getKind()) {
1472+
if (CType->getKind() == InlineAsmBuiltinType::f16) {
13701473
NumVecElements[0] = 2; // A
1474+
NumVecElements[1] = 1; // B
1475+
NumVecElements[2] = 2; // C
1476+
NumVecElements[3] = 2; // D
1477+
} else if (CType->getKind() == InlineAsmBuiltinType::f32) {
1478+
NumVecElements[0] = 2; // A
1479+
NumVecElements[1] = 1; // B
1480+
NumVecElements[2] = 4; // C
1481+
NumVecElements[3] = 4; // D
1482+
} else
1483+
return SYCLGenError();
1484+
} else if (AType->getKind() == InlineAsmBuiltinType::f64) {
1485+
// If A matrix type is f64, then C&D matrix types can only be f64
1486+
if (CType->getKind() == InlineAsmBuiltinType::f64) {
1487+
NumVecElements[0] = 4; // A
13711488
NumVecElements[1] = 2; // B
13721489
NumVecElements[2] = 4; // C
13731490
NumVecElements[3] = 4; // D
1491+
} else
1492+
return SYCLGenError();
1493+
} else
1494+
return SYCLGenError();
1495+
} else if (Inst->hasAttr(InstAttr::m16n8k16)) {
1496+
M = "16";
1497+
N = "8";
1498+
K = "16";
1499+
// Only f16/f64/s8/u8 type is supported for A and B matrices of m16n8k16
1500+
if (AType->getKind() == InlineAsmBuiltinType::f16) {
1501+
// If A matrix type is f16, then C&D matrix types can only be f16/f32
1502+
if (CType->getKind() == AType->getKind()) {
1503+
NumVecElements[0] = 4; // A
1504+
NumVecElements[1] = 2; // B
1505+
NumVecElements[2] = 2; // C
1506+
NumVecElements[3] = 2; // D
13741507
} else if (CType->getKind() == InlineAsmBuiltinType::f32) {
1375-
NumVecElements[0] = 2; // A
1508+
NumVecElements[0] = 4; // A
13761509
NumVecElements[1] = 2; // B
1377-
NumVecElements[2] = 8; // C
1378-
NumVecElements[3] = 8; // D
1510+
NumVecElements[2] = 4; // C
1511+
NumVecElements[3] = 4; // D
13791512
} else
13801513
return SYCLGenError();
13811514
} else if (AType->getKind() == InlineAsmBuiltinType::f64) {
13821515
// If A matrix type is f64, then C&D matrix types can only be f64
13831516
if (CType->getKind() == AType->getKind()) {
1384-
NumVecElements[0] = 1; // A
1517+
NumVecElements[0] = 8; // A
1518+
NumVecElements[1] = 4; // B
1519+
NumVecElements[2] = 4; // C
1520+
NumVecElements[3] = 4; // D
1521+
} else
1522+
return SYCLGenError();
1523+
} else if (AType->getKind() == InlineAsmBuiltinType::s8 ||
1524+
AType->getKind() == InlineAsmBuiltinType::u8) {
1525+
// If A matrix type is s8/u8, then C&D matrix types can only be s32
1526+
if (CType->getKind() == InlineAsmBuiltinType::s32) {
1527+
NumVecElements[0] = 2; // A
13851528
NumVecElements[1] = 1; // B
1386-
NumVecElements[2] = 2; // C
1387-
NumVecElements[3] = 2; // D
1529+
NumVecElements[2] = 4; // C
1530+
NumVecElements[3] = 4; // D
1531+
} else
1532+
return SYCLGenError();
1533+
} else
1534+
return SYCLGenError();
1535+
} else if (Inst->hasAttr(InstAttr::m16n8k32)) {
1536+
M = "16";
1537+
N = "8";
1538+
K = "32";
1539+
// Only s4/s8/u4/u8 types are supported for A and B matrices of m16n8k32
1540+
if (AType->getKind() == InlineAsmBuiltinType::s4 ||
1541+
AType->getKind() == InlineAsmBuiltinType::u4) {
1542+
// If A matrix type is s4/u4, then C&D matrix types can only be s32
1543+
if (CType->getKind() == InlineAsmBuiltinType::s32) {
1544+
NumVecElements[0] = 2; // A
1545+
NumVecElements[1] = 1; // B
1546+
NumVecElements[2] = 4; // C
1547+
NumVecElements[3] = 4; // D
1548+
} else
1549+
return SYCLGenError();
1550+
} else if (AType->getKind() == InlineAsmBuiltinType::s8 ||
1551+
AType->getKind() == InlineAsmBuiltinType::u8) {
1552+
// If A matrix type is s8/u8, then C&D matrix types can only be s32
1553+
if (CType->getKind() == InlineAsmBuiltinType::s32) {
1554+
NumVecElements[0] = 4; // A
1555+
NumVecElements[1] = 2; // B
1556+
NumVecElements[2] = 4; // C
1557+
NumVecElements[3] = 4; // D
1558+
} else
1559+
return SYCLGenError();
1560+
} else
1561+
return SYCLGenError();
1562+
} else if (Inst->hasAttr(InstAttr::m16n8k64)) {
1563+
M = "16";
1564+
N = "8";
1565+
K = "64";
1566+
// Only s4/u4 types are supported for A and B matrices of m16n8k64
1567+
if (AType->getKind() == InlineAsmBuiltinType::s4 ||
1568+
AType->getKind() == InlineAsmBuiltinType::u4) {
1569+
// If A matrix type is s4/u4, then C&D matrix types can only be s32
1570+
if (CType->getKind() == InlineAsmBuiltinType::s32) {
1571+
NumVecElements[0] = 4; // A
1572+
NumVecElements[1] = 2; // B
1573+
NumVecElements[2] = 4; // C
1574+
NumVecElements[3] = 4; // D
1575+
} else
1576+
return SYCLGenError();
1577+
} else
1578+
return SYCLGenError();
1579+
} else if (Inst->hasAttr(InstAttr::m16n8k128)) {
1580+
M = "16";
1581+
N = "8";
1582+
K = "128";
1583+
// Only b1 type is supported for A and B matrices of m16n8k128
1584+
if (AType->getKind() == InlineAsmBuiltinType::b1) {
1585+
// If A matrix type is b1, then C&D matrix types can only be s32
1586+
if (CType->getKind() == InlineAsmBuiltinType::s32) {
1587+
NumVecElements[0] = 2; // A
1588+
NumVecElements[1] = 1; // B
1589+
NumVecElements[2] = 4; // C
1590+
NumVecElements[3] = 4; // D
1591+
1592+
// Only and/xor bitwise operations are supported for m16n8k128
1593+
if (Inst->hasAttr(InstAttr::op_and))
1594+
MatrixOp = "and";
1595+
else if (Inst->hasAttr(InstAttr::op_xor))
1596+
MatrixOp = "xor";
1597+
else
1598+
return SYCLGenError();
1599+
} else
1600+
return SYCLGenError();
1601+
} else
1602+
return SYCLGenError();
1603+
} else if (Inst->hasAttr(InstAttr::m16n8k256)) {
1604+
M = "16";
1605+
N = "8";
1606+
K = "256";
1607+
// Only b1 type is supported for A and B matrices of m16n8k256
1608+
if (AType->getKind() == InlineAsmBuiltinType::b1) {
1609+
// If A matrix type is b1, then C&D matrix types can only be s32
1610+
if (CType->getKind() == InlineAsmBuiltinType::s32) {
1611+
NumVecElements[0] = 4; // A
1612+
NumVecElements[1] = 2; // B
1613+
NumVecElements[2] = 4; // C
1614+
NumVecElements[3] = 4; // D
1615+
1616+
// Only and/xor bitwise operations are supported for m16n8k256
1617+
if (Inst->hasAttr(InstAttr::op_and))
1618+
MatrixOp = "and";
1619+
else if (Inst->hasAttr(InstAttr::op_xor))
1620+
MatrixOp = "xor";
1621+
else
1622+
return SYCLGenError();
13881623
} else
13891624
return SYCLGenError();
13901625
} else
@@ -1407,7 +1642,12 @@ class SYCLGen : public SYCLGenBase {
14071642

14081643
MulType = ABType;
14091644
OS() << MapNames::getDpctNamespace() << "experimental::matrix::mma";
1410-
OS() << "<" << MulType << ">(";
1645+
if (!MatrixOp.empty()) {
1646+
OS() << "_" << MatrixOp;
1647+
}
1648+
OS() << "<";
1649+
OS() << M << ", " << N << ", " << K << ", ";
1650+
OS() << MulType << ">(";
14111651

14121652
// Add D matrix address values to store the MAD result
14131653
for (unsigned Inst = 0; Inst != DMatVE->getNumElements(); ++Inst) {
@@ -1416,7 +1656,8 @@ class SYCLGen : public SYCLGenBase {
14161656
OS() << "&";
14171657
if (emitStmt(DMatVE->getElement(Inst)))
14181658
return SYCLGenError();
1419-
OS() << ", ";
1659+
if ((Inst + 1) != DMatVE->getNumElements())
1660+
OS() << ", ";
14201661
}
14211662

14221663
// Add A, B & C matrix values to compute MAD
@@ -1427,16 +1668,15 @@ class SYCLGen : public SYCLGenBase {
14271668
for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) {
14281669
if (isa<InlineAsmDiscardExpr>(VE->getElement(Inst)))
14291670
continue;
1671+
OS() << ", ";
14301672
if (emitStmt(VE->getElement(Inst)))
14311673
return SYCLGenError();
1432-
OS() << ", ";
14331674
}
14341675
} else {
14351676
return SYCLGenError();
14361677
}
14371678
}
14381679

1439-
OS() << DpctGlobalInfo::getItem(GAS);
14401680
OS() << ");";
14411681

14421682
const auto *KernelDecl = getImmediateOuterFuncDecl(GAS);

clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,9 @@ class InlineAsmBuiltinType : public InlineAsmType {
9999
return ((K == Kind) || ...);
100100
}
101101
template <class... Ks> bool isNot(Ks... K) { return ((K != Kind) && ...); }
102-
bool isBit() const { return isOneOf(b8, b16, b32, b64); }
103-
bool isSigned() const { return isOneOf(s8, s16, s32, s64); }
104-
bool isUnsigned() const { return isOneOf(u8, u16, u32, u64); }
102+
bool isBit() const { return isOneOf(b1, b8, b16, b32, b64); }
103+
bool isSigned() const { return isOneOf(s4, s8, s16, s32, s64); }
104+
bool isUnsigned() const { return isOneOf(u4, u8, u16, u32, u64); }
105105
bool isInt() const { return isSigned() || isUnsigned(); }
106106
bool isFloat() const { return isOneOf(f16, f32, f64); }
107107
bool isScalar() const { return isInt() || isFloat(); }

0 commit comments

Comments
 (0)