Skip to content

Commit 891e459

Browse files
Added support for m8n8k4 shape
1 parent 44f5434 commit 891e459

File tree

6 files changed

+300
-32
lines changed

6 files changed

+300
-32
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 92 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,9 @@ bool SYCLGenBase::emitVectorType(const InlineAsmVectorType *T) {
556556
return SYCLGenError();
557557
OS() << ", ";
558558
switch (T->getKind()) {
559+
case InlineAsmVectorType::v1:
560+
OS() << 1;
561+
break;
559562
case InlineAsmVectorType::v2:
560563
OS() << 2;
561564
break;
@@ -1309,53 +1312,118 @@ class SYCLGen : public SYCLGenBase {
13091312
if (Inst->getNumInputOperands() != 3)
13101313
return SYCLGenError();
13111314

1312-
if (!Inst->hasAttr(InstAttr::m16n8k16))
1315+
const InlineAsmVectorExpr *DMatVE =
1316+
dyn_cast<InlineAsmVectorExpr>(Inst->getOutputOperand());
1317+
if (!DMatVE)
13131318
return SYCLGenError();
13141319

13151320
// Only row Layout is supported for of A matrix and
13161321
// only col Layout is supported for of B matrix
1317-
if (Inst->getAttr(3) != InstAttr::row ||
1318-
Inst->getAttr(4) != InstAttr::col) {
1322+
if (Inst->getAttr(3) != InstAttr::row || Inst->getAttr(4) != InstAttr::col)
13191323
return SYCLGenError();
1320-
}
13211324

13221325
// Only f16 type is supported for A and B matrix data
1326+
const auto *DType = dyn_cast<InlineAsmBuiltinType>(Inst->getType(0));
13231327
const auto *AType = dyn_cast<InlineAsmBuiltinType>(Inst->getType(1));
13241328
const auto *BType = dyn_cast<InlineAsmBuiltinType>(Inst->getType(2));
1329+
const auto *CType = dyn_cast<InlineAsmBuiltinType>(Inst->getType(3));
13251330

1326-
std::string TypeStr;
1327-
if (!AType || !BType ||
1328-
(AType->getKind() != InlineAsmBuiltinType::f16 ||
1329-
BType->getKind() != InlineAsmBuiltinType::f16)) {
1331+
if (!(AType && BType && CType && DType))
13301332
return SYCLGenError();
1331-
} else {
1332-
if (tryEmitType(TypeStr, AType))
1333-
return SYCLGenError();
1334-
}
13351333

1336-
const InlineAsmVectorExpr *VE =
1337-
dyn_cast<InlineAsmVectorExpr>(Inst->getOutputOperand());
1338-
if (VE && VE->getNumElements() != 4) {
1334+
// Data types of matrix elements for A&B and C&D matrices should be same
1335+
if ((AType->getKind() != BType->getKind()) ||
1336+
(CType->getKind() != DType->getKind()))
13391337
return SYCLGenError();
1338+
1339+
// Check the validity of AB & CD types
1340+
std::string ABType, CDType;
1341+
if (tryEmitType(ABType, AType))
1342+
return SYCLGenError();
1343+
1344+
if (tryEmitType(CDType, CType))
1345+
return SYCLGenError();
1346+
1347+
// Register sizes for vector elements of A, B, C & D matrices
1348+
unsigned NumVecElements[4] = {0};
1349+
1350+
// Data type used to multiply A & B matrices
1351+
std::string MulType;
1352+
if (Inst->hasAttr(InstAttr::m16n8k16)) {
1353+
// Only f16 type is supported for A and B matrix data for m16n8k16
1354+
if (AType->getKind() == InlineAsmBuiltinType::f16) {
1355+
// If A matrix type is f16, then C&D matrix types can only be f16
1356+
if (CType->getKind() == AType->getKind()) {
1357+
NumVecElements[0] = 2; // A
1358+
NumVecElements[1] = 4; // B
1359+
NumVecElements[2] = 4; // C
1360+
NumVecElements[3] = 4; // D
1361+
} else
1362+
return SYCLGenError();
1363+
} else
1364+
return SYCLGenError();
1365+
} else if (Inst->hasAttr(InstAttr::m8n8k4)) {
1366+
// f16 & f64 types are supported for A and B matrix data for m8n8k4
1367+
if (AType->getKind() == InlineAsmBuiltinType::f16) {
1368+
// If A matrix type is f16, then C&D matrix types can only be f16/f32
1369+
if (CType->getKind() == AType->getKind()) {
1370+
NumVecElements[0] = 2; // A
1371+
NumVecElements[1] = 2; // B
1372+
NumVecElements[2] = 4; // C
1373+
NumVecElements[3] = 4; // D
1374+
} else if (CType->getKind() == InlineAsmBuiltinType::f32) {
1375+
NumVecElements[0] = 2; // A
1376+
NumVecElements[1] = 2; // B
1377+
NumVecElements[2] = 8; // C
1378+
NumVecElements[3] = 8; // D
1379+
} else
1380+
return SYCLGenError();
1381+
} else if (AType->getKind() == InlineAsmBuiltinType::f64) {
1382+
// If A matrix type is f64, then C&D matrix types can only be f64
1383+
if (CType->getKind() == AType->getKind()) {
1384+
NumVecElements[0] = 1; // A
1385+
NumVecElements[1] = 1; // B
1386+
NumVecElements[2] = 2; // C
1387+
NumVecElements[3] = 2; // D
1388+
} else
1389+
return SYCLGenError();
1390+
} else
1391+
return SYCLGenError();
1392+
} else
1393+
return SYCLGenError();
1394+
1395+
// Check the register sizes for vector elements of A, B, C & D matrices
1396+
for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands();
1397+
InputOp++) {
1398+
if (auto VE =
1399+
dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand(InputOp))) {
1400+
if (VE->getNumElements() != NumVecElements[InputOp])
1401+
return SYCLGenError();
1402+
} else
1403+
return SYCLGenError();
13401404
}
1405+
if (DMatVE->getNumElements() != NumVecElements[3])
1406+
return SYCLGenError();
13411407

1408+
MulType = ABType;
13421409
OS() << MapNames::getDpctNamespace() << "experimental::matrix::mma";
1343-
OS() << "<" << TypeStr << ">(";
1410+
OS() << "<" << MulType << ">(";
13441411

13451412
// Add D matrix address values to store the MAD result
1346-
for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) {
1347-
if (isa<InlineAsmDiscardExpr>(VE->getElement(Inst)))
1413+
for (unsigned Inst = 0; Inst != DMatVE->getNumElements(); ++Inst) {
1414+
if (isa<InlineAsmDiscardExpr>(DMatVE->getElement(Inst)))
13481415
continue;
13491416
OS() << "&";
1350-
if (emitStmt(VE->getElement(Inst)))
1417+
if (emitStmt(DMatVE->getElement(Inst)))
13511418
return SYCLGenError();
13521419
OS() << ", ";
13531420
}
13541421

13551422
// Add A, B & C matrix values to compute MAD
13561423
for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands();
13571424
InputOp++) {
1358-
if (VE = dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand(InputOp))) {
1425+
if (auto VE =
1426+
dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand(InputOp))) {
13591427
for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) {
13601428
if (isa<InlineAsmDiscardExpr>(VE->getElement(Inst)))
13611429
continue;
@@ -2607,11 +2675,10 @@ class SYCLGen : public SYCLGenBase {
26072675
Op = std::move(NewOp);
26082676
}
26092677

2610-
bool HasHalfOrBfloat16 =
2611-
SrcType->getKind() == InlineAsmBuiltinType::f16 ||
2612-
DesType->getKind() == InlineAsmBuiltinType::f16 ||
2613-
SrcType->getKind() == InlineAsmBuiltinType::bf16 ||
2614-
DesType->getKind() == InlineAsmBuiltinType::bf16;
2678+
bool HasHalfOrBfloat16 = SrcType->getKind() == InlineAsmBuiltinType::f16 ||
2679+
DesType->getKind() == InlineAsmBuiltinType::f16 ||
2680+
SrcType->getKind() == InlineAsmBuiltinType::bf16 ||
2681+
DesType->getKind() == InlineAsmBuiltinType::bf16;
26152682
if (DpctGlobalInfo::useIntelDeviceMath() && HasHalfOrBfloat16) {
26162683
insertHeader(HeaderType::HT_SYCL_Math);
26172684
if (SrcNeedBitCast)

clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class InlineAsmBuiltinType : public InlineAsmType {
116116
// This class is used for device asm vector types.
117117
class InlineAsmVectorType : public InlineAsmType {
118118
public:
119-
enum VecKind { v2, v4, v8 };
119+
enum VecKind { v1, v2, v4, v8 };
120120

121121
private:
122122
VecKind Kind;

clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,9 @@ InlineAsmParser::ActOnVectorExpr(ArrayRef<InlineAsmExpr *> Vec) {
739739
// Vector size must be 2, 4, or 8.
740740
InlineAsmVectorType::VecKind Kind;
741741
switch (Vec.size()) {
742+
case 1:
743+
Kind = InlineAsmVectorType::v1;
744+
break;
742745
case 2:
743746
Kind = InlineAsmVectorType::v2;
744747
break;

clang/lib/DPCT/RulesAsm/Parser/AsmParser.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ class InlineAsmParser {
496496
/// .reg .sreg .const .local .param .shared .tex
497497
///
498498
/// vector-specifier: one of
499-
/// .v2 .v4 .v8
499+
/// .v1 .v2 .v4 .v8
500500
///
501501
/// type-specifier: one of
502502
/// .b8 .b16 .b32 .b64 .s8 .s16 .s32 .s64

clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ BUILTIN_TYPE(s16x2, ".s16x2")
270270
BUILTIN_TYPE(u16x2, ".u16x2")
271271

272272
// Vector modifiers
273+
MODIFIER(v1, ".v1")
273274
MODIFIER(v2, ".v2")
274275
MODIFIER(v4, ".v4")
275276
MODIFIER(v8, ".v8")
@@ -280,6 +281,7 @@ MODIFIER(col, ".col")
280281

281282
// Matrix shape
282283
MODIFIER(m16n8k16, ".m16n8k16")
284+
MODIFIER(m8n8k4, ".m8n8k4")
283285

284286
STATE_SPACE(reg, ".reg")
285287
STATE_SPACE(sreg, ".sreg")

0 commit comments

Comments
 (0)