@@ -1386,9 +1386,16 @@ class SYCLGen : public SYCLGenBase {
1386
1386
return SYCLGenError ();
1387
1387
1388
1388
// Only row Layout is supported for of A matrix and
1389
- // only col Layout is supported for of B matrix
1390
- if (Inst->getAttr (3 ) != InstAttr::row || Inst->getAttr (4 ) != InstAttr::col)
1391
- return SYCLGenError ();
1389
+ // only col Layout is supported for of B matrix (except for m8n8k4)
1390
+ if (Inst->hasAttr (InstAttr::m8n8k4)) {
1391
+ if (Inst->getAttr (3 ) != InstAttr::col ||
1392
+ Inst->getAttr (4 ) != InstAttr::row)
1393
+ return SYCLGenError ();
1394
+ } else {
1395
+ if (Inst->getAttr (3 ) != InstAttr::row ||
1396
+ Inst->getAttr (4 ) != InstAttr::col)
1397
+ return SYCLGenError ();
1398
+ }
1392
1399
1393
1400
// Data types of D, A, B & C matrices respectively in the PTX instruction
1394
1401
const auto *DType = dyn_cast<InlineAsmBuiltinType>(Inst->getType (0 ));
@@ -1421,7 +1428,68 @@ class SYCLGen : public SYCLGenBase {
1421
1428
// Data types of A, B & C matrices respectively in the PTX arguments
1422
1429
std::string InMatrixType[3 ];
1423
1430
1424
- if (Inst->hasAttr (InstAttr::m16n8k16)) {
1431
+ if (Inst->hasAttr (InstAttr::m8n8k4)) {
1432
+ M = " 8" ;
1433
+ N = " 8" ;
1434
+ K = " 4" ;
1435
+
1436
+ // Only f16 type is supported for A and B matrices of m8n8k4
1437
+ if (AType->getKind () == InlineAsmBuiltinType::f16 ) {
1438
+ InMatrixType[0 ] = " uint32_t" ; // A type is .f16x2
1439
+ InMatrixType[1 ] = " uint32_t" ; // B type is .f16x2
1440
+
1441
+ // If A matrix type is f16, then C&D matrix types can only be f32
1442
+ if (CType->getKind () == InlineAsmBuiltinType::f32 ) {
1443
+ NumVecElements[0 ] = 2 ; // A
1444
+ NumVecElements[1 ] = 2 ; // B
1445
+ NumVecElements[2 ] = 8 ; // C
1446
+ NumVecElements[3 ] = 8 ; // D
1447
+ } else
1448
+ return SYCLGenError ();
1449
+ } else
1450
+ return SYCLGenError ();
1451
+ } else if (Inst->hasAttr (InstAttr::m8n8k16)) {
1452
+ M = " 8" ;
1453
+ N = " 8" ;
1454
+ K = " 16" ;
1455
+
1456
+ // Only s8 type is supported for A and B matrices of m8n8k16
1457
+ if (AType->getKind () == InlineAsmBuiltinType::s8) {
1458
+ InMatrixType[0 ] = " uint32_t" ; // A type is .s8x4
1459
+ InMatrixType[1 ] = " uint32_t" ; // B type is .s8x4
1460
+
1461
+ // If A matrix type is s8, then C&D matrix types can only be s32
1462
+ if (CType->getKind () == InlineAsmBuiltinType::s32) {
1463
+ NumVecElements[0 ] = 1 ; // A
1464
+ NumVecElements[1 ] = 1 ; // B
1465
+ NumVecElements[2 ] = 2 ; // C
1466
+ NumVecElements[3 ] = 2 ; // D
1467
+ } else
1468
+ return SYCLGenError ();
1469
+ } else
1470
+ return SYCLGenError ();
1471
+ } else if (Inst->hasAttr (InstAttr::m16n8k8)) {
1472
+ M = " 16" ;
1473
+ N = " 8" ;
1474
+ K = " 8" ;
1475
+
1476
+ // Only f16/bf16 types are supported for A and B matrices of m16n8k8
1477
+ if (AType->getKind () == InlineAsmBuiltinType::f16 ||
1478
+ AType->getKind () == InlineAsmBuiltinType::bf16 ) {
1479
+ InMatrixType[0 ] = " uint32_t" ; // A type is .f16x2
1480
+ InMatrixType[1 ] = " uint32_t" ; // B type is .f16x2
1481
+
1482
+ // If A matrix type is f16/bf16, then C&D matrix types can only be f32
1483
+ if (CType->getKind () == InlineAsmBuiltinType::f32 ) {
1484
+ NumVecElements[0 ] = 2 ; // A
1485
+ NumVecElements[1 ] = 1 ; // B
1486
+ NumVecElements[2 ] = 4 ; // C
1487
+ NumVecElements[3 ] = 4 ; // D
1488
+ } else
1489
+ return SYCLGenError ();
1490
+ } else
1491
+ return SYCLGenError ();
1492
+ } else if (Inst->hasAttr (InstAttr::m16n8k16)) {
1425
1493
M = " 16" ;
1426
1494
N = " 8" ;
1427
1495
K = " 16" ;
@@ -1440,8 +1508,8 @@ class SYCLGen : public SYCLGenBase {
1440
1508
} else
1441
1509
return SYCLGenError ();
1442
1510
} else if (AType->getKind () == InlineAsmBuiltinType::s8) {
1443
- InMatrixType[0 ] = " uint32_t" ; // A type is .f16x2
1444
- InMatrixType[1 ] = " uint32_t" ; // B type is .f16x2
1511
+ InMatrixType[0 ] = " uint32_t" ; // A type is .s8x4
1512
+ InMatrixType[1 ] = " uint32_t" ; // B type is .s8x4
1445
1513
1446
1514
// If A matrix type is s8, then C&D matrix types can only be s32
1447
1515
if (CType->getKind () == InlineAsmBuiltinType::s32) {
@@ -1453,6 +1521,26 @@ class SYCLGen : public SYCLGenBase {
1453
1521
return SYCLGenError ();
1454
1522
} else
1455
1523
return SYCLGenError ();
1524
+ } else if (Inst->hasAttr (InstAttr::m16n8k32)) {
1525
+ M = " 16" ;
1526
+ N = " 8" ;
1527
+ K = " 32" ;
1528
+
1529
+ // Only s8 type is supported for A and B matrices of m16n8k32
1530
+ if (AType->getKind () == InlineAsmBuiltinType::s8) {
1531
+ InMatrixType[0 ] = " uint32_t" ; // A type is .s8x4
1532
+ InMatrixType[1 ] = " uint32_t" ; // B type is .s8x4
1533
+
1534
+ // If A matrix type is s8, then C&D matrix types can only be s32
1535
+ if (CType->getKind () == InlineAsmBuiltinType::s32) {
1536
+ NumVecElements[0 ] = 4 ; // A
1537
+ NumVecElements[1 ] = 2 ; // B
1538
+ NumVecElements[2 ] = 4 ; // C
1539
+ NumVecElements[3 ] = 4 ; // D
1540
+ } else
1541
+ return SYCLGenError ();
1542
+ } else
1543
+ return SYCLGenError ();
1456
1544
} else
1457
1545
return SYCLGenError ();
1458
1546
@@ -1472,7 +1560,9 @@ class SYCLGen : public SYCLGenBase {
1472
1560
return SYCLGenError ();
1473
1561
1474
1562
// Declare and init an array for storing the addresses of D matrix elements
1475
- OS () << " {\n " ;
1563
+ OS () << " {" << getNL ();
1564
+ incIndent ();
1565
+ indent ();
1476
1566
OS () << " volatile " << CDType << " *d_mat_frag_ct1["
1477
1567
<< DMatVE->getNumElements () << " ] = { " ;
1478
1568
for (unsigned Inst = 0 ; Inst != DMatVE->getNumElements (); ++Inst) {
@@ -1494,6 +1584,7 @@ class SYCLGen : public SYCLGenBase {
1494
1584
InputOp++) {
1495
1585
if (auto VE =
1496
1586
dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand (InputOp))) {
1587
+ indent ();
1497
1588
OS () << " sycl::vec<" << InMatrixType[InputOp] << " , "
1498
1589
<< VE->getNumElements () << " > " << InMatrixName[InputOp]
1499
1590
<< " _mat_frag_ct1(" ;
@@ -1512,6 +1603,7 @@ class SYCLGen : public SYCLGenBase {
1512
1603
}
1513
1604
}
1514
1605
1606
+ indent ();
1515
1607
OS () << MapNames::getDpctNamespace () << " experimental::matrix::mma" ;
1516
1608
OS () << " <" ;
1517
1609
OS () << M << " , " << N << " , " << K << " , " ;
@@ -1523,6 +1615,8 @@ class SYCLGen : public SYCLGenBase {
1523
1615
OS () << " , &" << InMatrixName[i] << " _mat_frag_ct1" ;
1524
1616
OS () << " )" ;
1525
1617
endstmt ();
1618
+ decIndent ();
1619
+ indent ();
1526
1620
OS () << " }" ;
1527
1621
endstmt ();
1528
1622
0 commit comments