@@ -514,14 +514,17 @@ bool SYCLGenBase::emitType(const InlineAsmType *T) {
514
514
bool SYCLGenBase::emitBuiltinType (const InlineAsmBuiltinType *T) {
515
515
switch (T->getKind ()) {
516
516
// clang-format off
517
+ case InlineAsmBuiltinType::b1: OS () << " uint8_t" ; break ;
517
518
case InlineAsmBuiltinType::b8: OS () << " uint8_t" ; break ;
518
519
case InlineAsmBuiltinType::b16: OS () << " uint16_t" ; break ;
519
520
case InlineAsmBuiltinType::b32: OS () << " uint32_t" ; break ;
520
521
case InlineAsmBuiltinType::b64: OS () << " uint64_t" ; break ;
522
+ case InlineAsmBuiltinType::u4: OS () << " uint8_t" ; break ;
521
523
case InlineAsmBuiltinType::u8: OS () << " uint8_t" ; break ;
522
524
case InlineAsmBuiltinType::u16: OS () << " uint16_t" ; break ;
523
525
case InlineAsmBuiltinType::u32: OS () << " uint32_t" ; break ;
524
526
case InlineAsmBuiltinType::u64: OS () << " uint64_t" ; break ;
527
+ case InlineAsmBuiltinType::s4: OS () << " int8_t" ; break ;
525
528
case InlineAsmBuiltinType::s8: OS () << " int8_t" ; break ;
526
529
case InlineAsmBuiltinType::s16: OS () << " int16_t" ; break ;
527
530
case InlineAsmBuiltinType::s32: OS () << " int32_t" ; break ;
@@ -1347,44 +1350,276 @@ class SYCLGen : public SYCLGenBase {
1347
1350
// Register sizes for vector elements of A, B, C & D matrices
1348
1351
unsigned NumVecElements[4 ] = {0 };
1349
1352
1353
+ // Sizes of A & B matrices
1354
+ std::string M, N, K;
1355
+
1356
+ // Operator for m8n8k128/m16n8k128/m16n8k256
1357
+ std::string MatrixOp;
1358
+
1350
1359
// Data type used to multiply A & B matrices
1351
1360
std::string MulType;
1352
- if (Inst->hasAttr (InstAttr::m16n8k16)) {
1353
- // Only f16 type is supported for A and B matrix data for m16n8k16
1361
+ if (Inst->hasAttr (InstAttr::m8n8k4)) {
1362
+ M = " 8" ;
1363
+ N = " 8" ;
1364
+ K = " 4" ;
1365
+ // f16 & f64 types are supported for A and B matrices of m8n8k4
1354
1366
if (AType->getKind () == InlineAsmBuiltinType::f16) {
1355
- // If A matrix type is f16, then C&D matrix types can only be f16
1367
+ // If A matrix type is f16, then C&D matrix types can only be f16/f32
1356
1368
if (CType->getKind () == AType->getKind ()) {
1357
1369
NumVecElements[0 ] = 2 ; // A
1358
- NumVecElements[1 ] = 4 ; // B
1370
+ NumVecElements[1 ] = 2 ; // B
1359
1371
NumVecElements[2 ] = 4 ; // C
1360
1372
NumVecElements[3 ] = 4 ; // D
1373
+ } else if (CType->getKind () == InlineAsmBuiltinType::f32) {
1374
+ NumVecElements[0 ] = 2 ; // A
1375
+ NumVecElements[1 ] = 2 ; // B
1376
+ NumVecElements[2 ] = 8 ; // C
1377
+ NumVecElements[3 ] = 8 ; // D
1378
+ } else
1379
+ return SYCLGenError ();
1380
+ } else if (AType->getKind () == InlineAsmBuiltinType::f64) {
1381
+ // If A matrix type is f64, then C&D matrix types can only be f64
1382
+ if (CType->getKind () == AType->getKind ()) {
1383
+ NumVecElements[0 ] = 1 ; // A
1384
+ NumVecElements[1 ] = 1 ; // B
1385
+ NumVecElements[2 ] = 2 ; // C
1386
+ NumVecElements[3 ] = 2 ; // D
1387
+ } else
1388
+ return SYCLGenError ();
1389
+ } else
1390
+ return SYCLGenError ();
1391
+ } else if (Inst->hasAttr (InstAttr::m8n8k16)) {
1392
+ M = " 8" ;
1393
+ N = " 8" ;
1394
+ K = " 16" ;
1395
+ // Only s8/u8 types are supported for A and B matrices of m8n8k16
1396
+ if (AType->getKind () == InlineAsmBuiltinType::s8 ||
1397
+ AType->getKind () == InlineAsmBuiltinType::u8) {
1398
+ // If A matrix type is s8/u8, then C&D matrix types can only be s32
1399
+ if (CType->getKind () == InlineAsmBuiltinType::s32) {
1400
+ NumVecElements[0 ] = 1 ; // A
1401
+ NumVecElements[1 ] = 1 ; // B
1402
+ NumVecElements[2 ] = 2 ; // C
1403
+ NumVecElements[3 ] = 2 ; // D
1361
1404
} else
1362
1405
return SYCLGenError ();
1363
1406
} else
1364
1407
return SYCLGenError ();
1365
- } else if (Inst->hasAttr (InstAttr::m8n8k4)) {
1366
- // f16 & f64 types are supported for A and B matrix data for m8n8k4
1408
+ } else if (Inst->hasAttr (InstAttr::m8n8k32)) {
1409
+ M = " 8" ;
1410
+ N = " 8" ;
1411
+ K = " 32" ;
1412
+ // Only s4/u4 types are supported for A and B matrices of m16n8k32
1413
+ if (AType->getKind () == InlineAsmBuiltinType::s4 ||
1414
+ AType->getKind () == InlineAsmBuiltinType::u4) {
1415
+ // If A matrix type is s4/u4, then C&D matrix types can only be s32
1416
+ if (CType->getKind () == InlineAsmBuiltinType::s32) {
1417
+ NumVecElements[0 ] = 1 ; // A
1418
+ NumVecElements[1 ] = 1 ; // B
1419
+ NumVecElements[2 ] = 2 ; // C
1420
+ NumVecElements[3 ] = 2 ; // D
1421
+ } else
1422
+ return SYCLGenError ();
1423
+ } else
1424
+ return SYCLGenError ();
1425
+ } else if (Inst->hasAttr (InstAttr::m8n8k128)) {
1426
+ M = " 8" ;
1427
+ N = " 8" ;
1428
+ K = " 128" ;
1429
+ // Only b1 type is supported for A and B matrices of m16n8k128
1430
+ if (AType->getKind () == InlineAsmBuiltinType::b1) {
1431
+ // If A matrix type is b1, then C&D matrix types can only be s32
1432
+ if (CType->getKind () == InlineAsmBuiltinType::s32) {
1433
+ NumVecElements[0 ] = 1 ; // A
1434
+ NumVecElements[1 ] = 1 ; // B
1435
+ NumVecElements[2 ] = 2 ; // C
1436
+ NumVecElements[3 ] = 2 ; // D
1437
+
1438
+ // Only and/xor bitwise operations are supported for m8n8k128
1439
+ if (Inst->hasAttr (InstAttr::op_and))
1440
+ MatrixOp = " and" ;
1441
+ else if (Inst->hasAttr (InstAttr::op_xor))
1442
+ MatrixOp = " xor" ;
1443
+ else
1444
+ return SYCLGenError ();
1445
+ } else
1446
+ return SYCLGenError ();
1447
+ } else
1448
+ return SYCLGenError ();
1449
+ } else if (Inst->hasAttr (InstAttr::m16n8k4)) {
1450
+ M = " 16" ;
1451
+ N = " 8" ;
1452
+ K = " 4" ;
1453
+ // Only f64 type is supported for A and B matrices of m16n8k4
1454
+ if (AType->getKind () == InlineAsmBuiltinType::f64) {
1455
+ // If A matrix type is f64, then C&D matrix types can only be f64
1456
+ if (CType->getKind () == InlineAsmBuiltinType::f64) {
1457
+ NumVecElements[0 ] = 2 ; // A
1458
+ NumVecElements[1 ] = 1 ; // B
1459
+ NumVecElements[2 ] = 4 ; // C
1460
+ NumVecElements[3 ] = 4 ; // D
1461
+ } else
1462
+ return SYCLGenError ();
1463
+ } else
1464
+ return SYCLGenError ();
1465
+ } else if (Inst->hasAttr (InstAttr::m16n8k8)) {
1466
+ M = " 16" ;
1467
+ N = " 8" ;
1468
+ K = " 8" ;
1469
+ // Only f16/f64 types are supported for A and B matrices of m16n8k8
1367
1470
if (AType->getKind () == InlineAsmBuiltinType::f16) {
1368
1471
// If A matrix type is f16, then C&D matrix types can only be f16/f32
1369
- if (CType->getKind () == AType-> getKind () ) {
1472
+ if (CType->getKind () == InlineAsmBuiltinType::f16 ) {
1370
1473
NumVecElements[0 ] = 2 ; // A
1474
+ NumVecElements[1 ] = 1 ; // B
1475
+ NumVecElements[2 ] = 2 ; // C
1476
+ NumVecElements[3 ] = 2 ; // D
1477
+ } else if (CType->getKind () == InlineAsmBuiltinType::f32) {
1478
+ NumVecElements[0 ] = 2 ; // A
1479
+ NumVecElements[1 ] = 1 ; // B
1480
+ NumVecElements[2 ] = 4 ; // C
1481
+ NumVecElements[3 ] = 4 ; // D
1482
+ } else
1483
+ return SYCLGenError ();
1484
+ } else if (AType->getKind () == InlineAsmBuiltinType::f64) {
1485
+ // If A matrix type is f64, then C&D matrix types can only be f64
1486
+ if (CType->getKind () == InlineAsmBuiltinType::f64) {
1487
+ NumVecElements[0 ] = 4 ; // A
1371
1488
NumVecElements[1 ] = 2 ; // B
1372
1489
NumVecElements[2 ] = 4 ; // C
1373
1490
NumVecElements[3 ] = 4 ; // D
1491
+ } else
1492
+ return SYCLGenError ();
1493
+ } else
1494
+ return SYCLGenError ();
1495
+ } else if (Inst->hasAttr (InstAttr::m16n8k16)) {
1496
+ M = " 16" ;
1497
+ N = " 8" ;
1498
+ K = " 16" ;
1499
+ // Only f16/f64/s8/u8 type is supported for A and B matrices of m16n8k16
1500
+ if (AType->getKind () == InlineAsmBuiltinType::f16) {
1501
+ // If A matrix type is f16, then C&D matrix types can only be f16/f32
1502
+ if (CType->getKind () == AType->getKind ()) {
1503
+ NumVecElements[0 ] = 4 ; // A
1504
+ NumVecElements[1 ] = 2 ; // B
1505
+ NumVecElements[2 ] = 2 ; // C
1506
+ NumVecElements[3 ] = 2 ; // D
1374
1507
} else if (CType->getKind () == InlineAsmBuiltinType::f32) {
1375
- NumVecElements[0 ] = 2 ; // A
1508
+ NumVecElements[0 ] = 4 ; // A
1376
1509
NumVecElements[1 ] = 2 ; // B
1377
- NumVecElements[2 ] = 8 ; // C
1378
- NumVecElements[3 ] = 8 ; // D
1510
+ NumVecElements[2 ] = 4 ; // C
1511
+ NumVecElements[3 ] = 4 ; // D
1379
1512
} else
1380
1513
return SYCLGenError ();
1381
1514
} else if (AType->getKind () == InlineAsmBuiltinType::f64) {
1382
1515
// If A matrix type is f64, then C&D matrix types can only be f64
1383
1516
if (CType->getKind () == AType->getKind ()) {
1384
- NumVecElements[0 ] = 1 ; // A
1517
+ NumVecElements[0 ] = 8 ; // A
1518
+ NumVecElements[1 ] = 4 ; // B
1519
+ NumVecElements[2 ] = 4 ; // C
1520
+ NumVecElements[3 ] = 4 ; // D
1521
+ } else
1522
+ return SYCLGenError ();
1523
+ } else if (AType->getKind () == InlineAsmBuiltinType::s8 ||
1524
+ AType->getKind () == InlineAsmBuiltinType::u8) {
1525
+ // If A matrix type is s8/u8, then C&D matrix types can only be s32
1526
+ if (CType->getKind () == InlineAsmBuiltinType::s32) {
1527
+ NumVecElements[0 ] = 2 ; // A
1385
1528
NumVecElements[1 ] = 1 ; // B
1386
- NumVecElements[2 ] = 2 ; // C
1387
- NumVecElements[3 ] = 2 ; // D
1529
+ NumVecElements[2 ] = 4 ; // C
1530
+ NumVecElements[3 ] = 4 ; // D
1531
+ } else
1532
+ return SYCLGenError ();
1533
+ } else
1534
+ return SYCLGenError ();
1535
+ } else if (Inst->hasAttr (InstAttr::m16n8k32)) {
1536
+ M = " 16" ;
1537
+ N = " 8" ;
1538
+ K = " 32" ;
1539
+ // Only s4/s8/u4/u8 types are supported for A and B matrices of m16n8k32
1540
+ if (AType->getKind () == InlineAsmBuiltinType::s4 ||
1541
+ AType->getKind () == InlineAsmBuiltinType::u4) {
1542
+ // If A matrix type is s4/u4, then C&D matrix types can only be s32
1543
+ if (CType->getKind () == InlineAsmBuiltinType::s32) {
1544
+ NumVecElements[0 ] = 2 ; // A
1545
+ NumVecElements[1 ] = 1 ; // B
1546
+ NumVecElements[2 ] = 4 ; // C
1547
+ NumVecElements[3 ] = 4 ; // D
1548
+ } else
1549
+ return SYCLGenError ();
1550
+ } else if (AType->getKind () == InlineAsmBuiltinType::s8 ||
1551
+ AType->getKind () == InlineAsmBuiltinType::u8) {
1552
+ // If A matrix type is s8/u8, then C&D matrix types can only be s32
1553
+ if (CType->getKind () == InlineAsmBuiltinType::s32) {
1554
+ NumVecElements[0 ] = 4 ; // A
1555
+ NumVecElements[1 ] = 2 ; // B
1556
+ NumVecElements[2 ] = 4 ; // C
1557
+ NumVecElements[3 ] = 4 ; // D
1558
+ } else
1559
+ return SYCLGenError ();
1560
+ } else
1561
+ return SYCLGenError ();
1562
+ } else if (Inst->hasAttr (InstAttr::m16n8k64)) {
1563
+ M = " 16" ;
1564
+ N = " 8" ;
1565
+ K = " 64" ;
1566
+ // Only s4/u4 types are supported for A and B matrices of m16n8k64
1567
+ if (AType->getKind () == InlineAsmBuiltinType::s4 ||
1568
+ AType->getKind () == InlineAsmBuiltinType::u4) {
1569
+ // If A matrix type is s4/u4, then C&D matrix types can only be s32
1570
+ if (CType->getKind () == InlineAsmBuiltinType::s32) {
1571
+ NumVecElements[0 ] = 4 ; // A
1572
+ NumVecElements[1 ] = 2 ; // B
1573
+ NumVecElements[2 ] = 4 ; // C
1574
+ NumVecElements[3 ] = 4 ; // D
1575
+ } else
1576
+ return SYCLGenError ();
1577
+ } else
1578
+ return SYCLGenError ();
1579
+ } else if (Inst->hasAttr (InstAttr::m16n8k128)) {
1580
+ M = " 16" ;
1581
+ N = " 8" ;
1582
+ K = " 128" ;
1583
+ // Only b1 type is supported for A and B matrices of m16n8k128
1584
+ if (AType->getKind () == InlineAsmBuiltinType::b1) {
1585
+ // If A matrix type is b1, then C&D matrix types can only be s32
1586
+ if (CType->getKind () == InlineAsmBuiltinType::s32) {
1587
+ NumVecElements[0 ] = 2 ; // A
1588
+ NumVecElements[1 ] = 1 ; // B
1589
+ NumVecElements[2 ] = 4 ; // C
1590
+ NumVecElements[3 ] = 4 ; // D
1591
+
1592
+ // Only and/xor bitwise operations are supported for m16n8k128
1593
+ if (Inst->hasAttr (InstAttr::op_and))
1594
+ MatrixOp = " and" ;
1595
+ else if (Inst->hasAttr (InstAttr::op_xor))
1596
+ MatrixOp = " xor" ;
1597
+ else
1598
+ return SYCLGenError ();
1599
+ } else
1600
+ return SYCLGenError ();
1601
+ } else
1602
+ return SYCLGenError ();
1603
+ } else if (Inst->hasAttr (InstAttr::m16n8k256)) {
1604
+ M = " 16" ;
1605
+ N = " 8" ;
1606
+ K = " 256" ;
1607
+ // Only b1 type is supported for A and B matrices of m16n8k256
1608
+ if (AType->getKind () == InlineAsmBuiltinType::b1) {
1609
+ // If A matrix type is b1, then C&D matrix types can only be s32
1610
+ if (CType->getKind () == InlineAsmBuiltinType::s32) {
1611
+ NumVecElements[0 ] = 4 ; // A
1612
+ NumVecElements[1 ] = 2 ; // B
1613
+ NumVecElements[2 ] = 4 ; // C
1614
+ NumVecElements[3 ] = 4 ; // D
1615
+
1616
+ // Only and/xor bitwise operations are supported for m16n8k256
1617
+ if (Inst->hasAttr (InstAttr::op_and))
1618
+ MatrixOp = " and" ;
1619
+ else if (Inst->hasAttr (InstAttr::op_xor))
1620
+ MatrixOp = " xor" ;
1621
+ else
1622
+ return SYCLGenError ();
1388
1623
} else
1389
1624
return SYCLGenError ();
1390
1625
} else
@@ -1407,7 +1642,12 @@ class SYCLGen : public SYCLGenBase {
1407
1642
1408
1643
MulType = ABType;
1409
1644
OS () << MapNames::getDpctNamespace () << " experimental::matrix::mma" ;
1410
- OS () << " <" << MulType << " >(" ;
1645
+ if (!MatrixOp.empty ()) {
1646
+ OS () << " _" << MatrixOp;
1647
+ }
1648
+ OS () << " <" ;
1649
+ OS () << M << " , " << N << " , " << K << " , " ;
1650
+ OS () << MulType << " >(" ;
1411
1651
1412
1652
// Add D matrix address values to store the MAD result
1413
1653
for (unsigned Inst = 0 ; Inst != DMatVE->getNumElements (); ++Inst) {
@@ -1416,7 +1656,8 @@ class SYCLGen : public SYCLGenBase {
1416
1656
OS () << " &" ;
1417
1657
if (emitStmt (DMatVE->getElement (Inst)))
1418
1658
return SYCLGenError ();
1419
- OS () << " , " ;
1659
+ if ((Inst + 1 ) != DMatVE->getNumElements ())
1660
+ OS () << " , " ;
1420
1661
}
1421
1662
1422
1663
// Add A, B & C matrix values to compute MAD
@@ -1427,16 +1668,15 @@ class SYCLGen : public SYCLGenBase {
1427
1668
for (unsigned Inst = 0 ; Inst != VE->getNumElements (); ++Inst) {
1428
1669
if (isa<InlineAsmDiscardExpr>(VE->getElement (Inst)))
1429
1670
continue ;
1671
+ OS () << " , " ;
1430
1672
if (emitStmt (VE->getElement (Inst)))
1431
1673
return SYCLGenError ();
1432
- OS () << " , " ;
1433
1674
}
1434
1675
} else {
1435
1676
return SYCLGenError ();
1436
1677
}
1437
1678
}
1438
1679
1439
- OS () << DpctGlobalInfo::getItem (GAS);
1440
1680
OS () << " );" ;
1441
1681
1442
1682
const auto *KernelDecl = getImmediateOuterFuncDecl (GAS);
0 commit comments