@@ -514,14 +514,17 @@ bool SYCLGenBase::emitType(const InlineAsmType *T) {
514
514
bool SYCLGenBase::emitBuiltinType (const InlineAsmBuiltinType *T) {
515
515
switch (T->getKind ()) {
516
516
// clang-format off
517
+ case InlineAsmBuiltinType::b1: OS () << " uint8_t" ; break ;
517
518
case InlineAsmBuiltinType::b8: OS () << " uint8_t" ; break ;
518
519
case InlineAsmBuiltinType::b16: OS () << " uint16_t" ; break ;
519
520
case InlineAsmBuiltinType::b32: OS () << " uint32_t" ; break ;
520
521
case InlineAsmBuiltinType::b64: OS () << " uint64_t" ; break ;
522
+ case InlineAsmBuiltinType::u4: OS () << " uint8_t" ; break ;
521
523
case InlineAsmBuiltinType::u8 : OS () << " uint8_t" ; break ;
522
524
case InlineAsmBuiltinType::u16 : OS () << " uint16_t" ; break ;
523
525
case InlineAsmBuiltinType::u32 : OS () << " uint32_t" ; break ;
524
526
case InlineAsmBuiltinType::u64 : OS () << " uint64_t" ; break ;
527
+ case InlineAsmBuiltinType::s4: OS () << " int8_t" ; break ;
525
528
case InlineAsmBuiltinType::s8: OS () << " int8_t" ; break ;
526
529
case InlineAsmBuiltinType::s16: OS () << " int16_t" ; break ;
527
530
case InlineAsmBuiltinType::s32: OS () << " int32_t" ; break ;
@@ -559,6 +562,9 @@ bool SYCLGenBase::emitVectorType(const InlineAsmVectorType *T) {
559
562
case InlineAsmVectorType::x1:
560
563
OS () << 1 ;
561
564
break ;
565
+ case InlineAsmVectorType::v1:
566
+ OS () << 1 ;
567
+ break ;
562
568
case InlineAsmVectorType::v2:
563
569
case InlineAsmVectorType::x2:
564
570
OS () << 2 ;
@@ -1370,6 +1376,167 @@ class SYCLGen : public SYCLGenBase {
1370
1376
return SYCLGenSuccess ();
1371
1377
}
1372
1378
1379
+ bool handle_mma (const InlineAsmInstruction *Inst) override {
1380
+ if (Inst->getNumInputOperands () != 3 )
1381
+ return SYCLGenError ();
1382
+
1383
+ const InlineAsmVectorExpr *DMatVE =
1384
+ dyn_cast<InlineAsmVectorExpr>(Inst->getOutputOperand ());
1385
+ if (!DMatVE)
1386
+ return SYCLGenError ();
1387
+
1388
+ // Only row Layout is supported for of A matrix and
1389
+ // only col Layout is supported for of B matrix
1390
+ if (Inst->getAttr (3 ) != InstAttr::row || Inst->getAttr (4 ) != InstAttr::col)
1391
+ return SYCLGenError ();
1392
+
1393
+ // Data types of D, A, B & C matrices respectively in the PTX instruction
1394
+ const auto *DType = dyn_cast<InlineAsmBuiltinType>(Inst->getType (0 ));
1395
+ const auto *AType = dyn_cast<InlineAsmBuiltinType>(Inst->getType (1 ));
1396
+ const auto *BType = dyn_cast<InlineAsmBuiltinType>(Inst->getType (2 ));
1397
+ const auto *CType = dyn_cast<InlineAsmBuiltinType>(Inst->getType (3 ));
1398
+
1399
+ if (!(AType && BType && CType && DType))
1400
+ return SYCLGenError ();
1401
+
1402
+ // Data types of matrix elements for A&B and C&D matrices should be same
1403
+ if ((AType->getKind () != BType->getKind ()) ||
1404
+ (CType->getKind () != DType->getKind ()))
1405
+ return SYCLGenError ();
1406
+
1407
+ // Check the validity of AB & CD types
1408
+ std::string ABType, CDType;
1409
+ if (tryEmitType (ABType, AType))
1410
+ return SYCLGenError ();
1411
+
1412
+ if (tryEmitType (CDType, CType))
1413
+ return SYCLGenError ();
1414
+
1415
+ // Register sizes for vector elements of A, B, C & D matrices
1416
+ unsigned NumVecElements[4 ] = {0 };
1417
+
1418
+ // Sizes of A & B matrices
1419
+ std::string M, N, K;
1420
+
1421
+ // Data types of A, B & C matrices respectively in the PTX arguments
1422
+ std::string InMatrixType[3 ];
1423
+
1424
+ if (Inst->hasAttr (InstAttr::m16n8k16)) {
1425
+ M = " 16" ;
1426
+ N = " 8" ;
1427
+ K = " 16" ;
1428
+
1429
+ // Only f16/s8 types are supported for A and B matrices of m16n8k16
1430
+ if (AType->getKind () == InlineAsmBuiltinType::f16 ) {
1431
+ InMatrixType[0 ] = " uint32_t" ; // A type is .f16x2
1432
+ InMatrixType[1 ] = " uint32_t" ; // B type is .f16x2
1433
+
1434
+ // If A matrix type is f16, then C&D matrix types can only be f32
1435
+ if (CType->getKind () == InlineAsmBuiltinType::f32 ) {
1436
+ NumVecElements[0 ] = 4 ; // A
1437
+ NumVecElements[1 ] = 2 ; // B
1438
+ NumVecElements[2 ] = 4 ; // C
1439
+ NumVecElements[3 ] = 4 ; // D
1440
+ } else
1441
+ return SYCLGenError ();
1442
+ } else if (AType->getKind () == InlineAsmBuiltinType::s8) {
1443
+ InMatrixType[0 ] = " uint32_t" ; // A type is .f16x2
1444
+ InMatrixType[1 ] = " uint32_t" ; // B type is .f16x2
1445
+
1446
+ // If A matrix type is s8, then C&D matrix types can only be s32
1447
+ if (CType->getKind () == InlineAsmBuiltinType::s32) {
1448
+ NumVecElements[0 ] = 2 ; // A
1449
+ NumVecElements[1 ] = 1 ; // B
1450
+ NumVecElements[2 ] = 4 ; // C
1451
+ NumVecElements[3 ] = 4 ; // D
1452
+ } else
1453
+ return SYCLGenError ();
1454
+ } else
1455
+ return SYCLGenError ();
1456
+ } else
1457
+ return SYCLGenError ();
1458
+
1459
+ InMatrixType[2 ] = CDType;
1460
+
1461
+ // Check the register sizes for vector elements of A, B, C & D matrices
1462
+ for (unsigned InputOp = 0 ; InputOp < Inst->getNumInputOperands ();
1463
+ InputOp++) {
1464
+ if (auto VE =
1465
+ dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand (InputOp))) {
1466
+ if (VE->getNumElements () != NumVecElements[InputOp])
1467
+ return SYCLGenError ();
1468
+ } else
1469
+ return SYCLGenError ();
1470
+ }
1471
+ if (DMatVE->getNumElements () != NumVecElements[3 ])
1472
+ return SYCLGenError ();
1473
+
1474
+ // Declare and init an array for storing the addresses of D matrix elements
1475
+ OS () << " {\n " ;
1476
+ OS () << " volatile " << CDType << " *d_mat_frag_ct1["
1477
+ << DMatVE->getNumElements () << " ] = { " ;
1478
+ for (unsigned Inst = 0 ; Inst != DMatVE->getNumElements (); ++Inst) {
1479
+ if (isa<InlineAsmDiscardExpr>(DMatVE->getElement (Inst)))
1480
+ continue ;
1481
+ OS () << " &" ;
1482
+ if (emitStmt (DMatVE->getElement (Inst)))
1483
+ return SYCLGenError ();
1484
+ if ((Inst + 1 ) != DMatVE->getNumElements ())
1485
+ OS () << " , " ;
1486
+ }
1487
+ OS () << " }" ;
1488
+ endstmt ();
1489
+
1490
+ // Declare and init vectors for storing the values of A, B & C matrix
1491
+ // elements
1492
+ std::string InMatrixName[3 ] = {" a" , " b" , " c" };
1493
+ for (unsigned InputOp = 0 ; InputOp < Inst->getNumInputOperands ();
1494
+ InputOp++) {
1495
+ if (auto VE =
1496
+ dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand (InputOp))) {
1497
+ OS () << " sycl::vec<" << InMatrixType[InputOp] << " , "
1498
+ << VE->getNumElements () << " > " << InMatrixName[InputOp]
1499
+ << " _mat_frag_ct1(" ;
1500
+ for (unsigned Inst = 0 ; Inst != VE->getNumElements (); ++Inst) {
1501
+ if (isa<InlineAsmDiscardExpr>(VE->getElement (Inst)))
1502
+ continue ;
1503
+ if (emitStmt (VE->getElement (Inst)))
1504
+ return SYCLGenError ();
1505
+ if ((Inst + 1 ) != VE->getNumElements ())
1506
+ OS () << " , " ;
1507
+ }
1508
+ OS () << " )" ;
1509
+ endstmt ();
1510
+ } else {
1511
+ return SYCLGenError ();
1512
+ }
1513
+ }
1514
+
1515
+ OS () << MapNames::getDpctNamespace () << " experimental::matrix::mma" ;
1516
+ OS () << " <" ;
1517
+ OS () << M << " , " << N << " , " << K << " , " ;
1518
+ OS () << ABType << " , " << CDType;
1519
+ OS () << " >(" ;
1520
+
1521
+ OS () << " reinterpret_cast<volatile void **>(d_mat_frag_ct1)" ;
1522
+ for (int i = 0 ; i < 3 ; i++)
1523
+ OS () << " , &" << InMatrixName[i] << " _mat_frag_ct1" ;
1524
+ OS () << " )" ;
1525
+ endstmt ();
1526
+ OS () << " }" ;
1527
+ endstmt ();
1528
+
1529
+ const auto *KernelDecl = getImmediateOuterFuncDecl (GAS);
1530
+ if (KernelDecl) {
1531
+ auto FuncInfo = DeviceFunctionDecl::LinkRedecls (KernelDecl);
1532
+ if (FuncInfo)
1533
+ FuncInfo->addSubGroupSizeRequest (32 , GAS->getBeginLoc (),
1534
+ DpctGlobalInfo::getSubGroup (GAS));
1535
+ }
1536
+
1537
+ return SYCLGenSuccess ();
1538
+ }
1539
+
1373
1540
bool handle_prefetch (const InlineAsmInstruction *Inst) override {
1374
1541
if (!DpctGlobalInfo::useExtPrefetch () || Inst->getNumInputOperands () != 1 )
1375
1542
return SYCLGenError ();
@@ -2595,11 +2762,10 @@ class SYCLGen : public SYCLGenBase {
2595
2762
Op = std::move (NewOp);
2596
2763
}
2597
2764
2598
- bool HasHalfOrBfloat16 =
2599
- SrcType->getKind () == InlineAsmBuiltinType::f16 ||
2600
- DesType->getKind () == InlineAsmBuiltinType::f16 ||
2601
- SrcType->getKind () == InlineAsmBuiltinType::bf16 ||
2602
- DesType->getKind () == InlineAsmBuiltinType::bf16 ;
2765
+ bool HasHalfOrBfloat16 = SrcType->getKind () == InlineAsmBuiltinType::f16 ||
2766
+ DesType->getKind () == InlineAsmBuiltinType::f16 ||
2767
+ SrcType->getKind () == InlineAsmBuiltinType::bf16 ||
2768
+ DesType->getKind () == InlineAsmBuiltinType::bf16 ;
2603
2769
if (DpctGlobalInfo::useIntelDeviceMath () && HasHalfOrBfloat16) {
2604
2770
insertHeader (HeaderType::HT_SYCL_Math);
2605
2771
if (SrcNeedBitCast)
0 commit comments