@@ -556,6 +556,9 @@ bool SYCLGenBase::emitVectorType(const InlineAsmVectorType *T) {
556
556
return SYCLGenError ();
557
557
OS () << " , " ;
558
558
switch (T->getKind ()) {
559
+ case InlineAsmVectorType::v1:
560
+ OS () << 1 ;
561
+ break ;
559
562
case InlineAsmVectorType::v2:
560
563
OS () << 2 ;
561
564
break ;
@@ -1309,53 +1312,118 @@ class SYCLGen : public SYCLGenBase {
1309
1312
if (Inst->getNumInputOperands () != 3 )
1310
1313
return SYCLGenError ();
1311
1314
1312
- if (!Inst->hasAttr (InstAttr::m16n8k16))
1315
+ const InlineAsmVectorExpr *DMatVE =
1316
+ dyn_cast<InlineAsmVectorExpr>(Inst->getOutputOperand ());
1317
+ if (!DMatVE)
1313
1318
return SYCLGenError ();
1314
1319
1315
1320
// Only row Layout is supported for of A matrix and
1316
1321
// only col Layout is supported for of B matrix
1317
- if (Inst->getAttr (3 ) != InstAttr::row ||
1318
- Inst->getAttr (4 ) != InstAttr::col) {
1322
+ if (Inst->getAttr (3 ) != InstAttr::row || Inst->getAttr (4 ) != InstAttr::col)
1319
1323
return SYCLGenError ();
1320
- }
1321
1324
1322
1325
// Only f16 type is supported for A and B matrix data
1326
+ const auto *DType = dyn_cast<InlineAsmBuiltinType>(Inst->getType (0 ));
1323
1327
const auto *AType = dyn_cast<InlineAsmBuiltinType>(Inst->getType (1 ));
1324
1328
const auto *BType = dyn_cast<InlineAsmBuiltinType>(Inst->getType (2 ));
1329
+ const auto *CType = dyn_cast<InlineAsmBuiltinType>(Inst->getType (3 ));
1325
1330
1326
- std::string TypeStr;
1327
- if (!AType || !BType ||
1328
- (AType->getKind () != InlineAsmBuiltinType::f16 ||
1329
- BType->getKind () != InlineAsmBuiltinType::f16 )) {
1331
+ if (!(AType && BType && CType && DType))
1330
1332
return SYCLGenError ();
1331
- } else {
1332
- if (tryEmitType (TypeStr, AType))
1333
- return SYCLGenError ();
1334
- }
1335
1333
1336
- const InlineAsmVectorExpr *VE =
1337
- dyn_cast<InlineAsmVectorExpr>(Inst-> getOutputOperand ());
1338
- if (VE && VE-> getNumElements () != 4 ) {
1334
+ // Data types of matrix elements for A&B and C&D matrices should be same
1335
+ if ((AType-> getKind () != BType-> getKind ()) ||
1336
+ (CType-> getKind () != DType-> getKind ()))
1339
1337
return SYCLGenError ();
1338
+
1339
+ // Check the validity of AB & CD types
1340
+ std::string ABType, CDType;
1341
+ if (tryEmitType (ABType, AType))
1342
+ return SYCLGenError ();
1343
+
1344
+ if (tryEmitType (CDType, CType))
1345
+ return SYCLGenError ();
1346
+
1347
+ // Register sizes for vector elements of A, B, C & D matrices
1348
+ unsigned NumVecElements[4 ] = {0 };
1349
+
1350
+ // Data type used to multiply A & B matrices
1351
+ std::string MulType;
1352
+ if (Inst->hasAttr (InstAttr::m16n8k16)) {
1353
+ // Only f16 type is supported for A and B matrix data for m16n8k16
1354
+ if (AType->getKind () == InlineAsmBuiltinType::f16 ) {
1355
+ // If A matrix type is f16, then C&D matrix types can only be f16
1356
+ if (CType->getKind () == AType->getKind ()) {
1357
+ NumVecElements[0 ] = 2 ; // A
1358
+ NumVecElements[1 ] = 4 ; // B
1359
+ NumVecElements[2 ] = 4 ; // C
1360
+ NumVecElements[3 ] = 4 ; // D
1361
+ } else
1362
+ return SYCLGenError ();
1363
+ } else
1364
+ return SYCLGenError ();
1365
+ } else if (Inst->hasAttr (InstAttr::m8n8k4)) {
1366
+ // f16 & f64 types are supported for A and B matrix data for m8n8k4
1367
+ if (AType->getKind () == InlineAsmBuiltinType::f16 ) {
1368
+ // If A matrix type is f16, then C&D matrix types can only be f16/f32
1369
+ if (CType->getKind () == AType->getKind ()) {
1370
+ NumVecElements[0 ] = 2 ; // A
1371
+ NumVecElements[1 ] = 2 ; // B
1372
+ NumVecElements[2 ] = 4 ; // C
1373
+ NumVecElements[3 ] = 4 ; // D
1374
+ } else if (CType->getKind () == InlineAsmBuiltinType::f32 ) {
1375
+ NumVecElements[0 ] = 2 ; // A
1376
+ NumVecElements[1 ] = 2 ; // B
1377
+ NumVecElements[2 ] = 8 ; // C
1378
+ NumVecElements[3 ] = 8 ; // D
1379
+ } else
1380
+ return SYCLGenError ();
1381
+ } else if (AType->getKind () == InlineAsmBuiltinType::f64 ) {
1382
+ // If A matrix type is f64, then C&D matrix types can only be f64
1383
+ if (CType->getKind () == AType->getKind ()) {
1384
+ NumVecElements[0 ] = 1 ; // A
1385
+ NumVecElements[1 ] = 1 ; // B
1386
+ NumVecElements[2 ] = 2 ; // C
1387
+ NumVecElements[3 ] = 2 ; // D
1388
+ } else
1389
+ return SYCLGenError ();
1390
+ } else
1391
+ return SYCLGenError ();
1392
+ } else
1393
+ return SYCLGenError ();
1394
+
1395
+ // Check the register sizes for vector elements of A, B, C & D matrices
1396
+ for (unsigned InputOp = 0 ; InputOp < Inst->getNumInputOperands ();
1397
+ InputOp++) {
1398
+ if (auto VE =
1399
+ dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand (InputOp))) {
1400
+ if (VE->getNumElements () != NumVecElements[InputOp])
1401
+ return SYCLGenError ();
1402
+ } else
1403
+ return SYCLGenError ();
1340
1404
}
1405
+ if (DMatVE->getNumElements () != NumVecElements[3 ])
1406
+ return SYCLGenError ();
1341
1407
1408
+ MulType = ABType;
1342
1409
OS () << MapNames::getDpctNamespace () << " experimental::matrix::mma" ;
1343
- OS () << " <" << TypeStr << " >(" ;
1410
+ OS () << " <" << MulType << " >(" ;
1344
1411
1345
1412
// Add D matrix address values to store the MAD result
1346
- for (unsigned Inst = 0 ; Inst != VE ->getNumElements (); ++Inst) {
1347
- if (isa<InlineAsmDiscardExpr>(VE ->getElement (Inst)))
1413
+ for (unsigned Inst = 0 ; Inst != DMatVE ->getNumElements (); ++Inst) {
1414
+ if (isa<InlineAsmDiscardExpr>(DMatVE ->getElement (Inst)))
1348
1415
continue ;
1349
1416
OS () << " &" ;
1350
- if (emitStmt (VE ->getElement (Inst)))
1417
+ if (emitStmt (DMatVE ->getElement (Inst)))
1351
1418
return SYCLGenError ();
1352
1419
OS () << " , " ;
1353
1420
}
1354
1421
1355
1422
// Add A, B & C matrix values to compute MAD
1356
1423
for (unsigned InputOp = 0 ; InputOp < Inst->getNumInputOperands ();
1357
1424
InputOp++) {
1358
- if (VE = dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand (InputOp))) {
1425
+ if (auto VE =
1426
+ dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand (InputOp))) {
1359
1427
for (unsigned Inst = 0 ; Inst != VE->getNumElements (); ++Inst) {
1360
1428
if (isa<InlineAsmDiscardExpr>(VE->getElement (Inst)))
1361
1429
continue ;
@@ -2607,11 +2675,10 @@ class SYCLGen : public SYCLGenBase {
2607
2675
Op = std::move (NewOp);
2608
2676
}
2609
2677
2610
- bool HasHalfOrBfloat16 =
2611
- SrcType->getKind () == InlineAsmBuiltinType::f16 ||
2612
- DesType->getKind () == InlineAsmBuiltinType::f16 ||
2613
- SrcType->getKind () == InlineAsmBuiltinType::bf16 ||
2614
- DesType->getKind () == InlineAsmBuiltinType::bf16 ;
2678
+ bool HasHalfOrBfloat16 = SrcType->getKind () == InlineAsmBuiltinType::f16 ||
2679
+ DesType->getKind () == InlineAsmBuiltinType::f16 ||
2680
+ SrcType->getKind () == InlineAsmBuiltinType::bf16 ||
2681
+ DesType->getKind () == InlineAsmBuiltinType::bf16 ;
2615
2682
if (DpctGlobalInfo::useIntelDeviceMath () && HasHalfOrBfloat16) {
2616
2683
insertHeader (HeaderType::HT_SYCL_Math);
2617
2684
if (SrcNeedBitCast)
0 commit comments