Skip to content

Commit 743f920

Browse files
aadeshps-mcwJaddyen
authored andcommitted
[SPIRV] Addition of matrix multiply accumulate operands (llvm#138665)
--Added Matrix multiply accumulate operands for the extension SPV_INTEL_subgroup_matrix_multiply_accumulate
1 parent ed0fcdc commit 743f920

File tree

4 files changed

+100
-29
lines changed

4 files changed

+100
-29
lines changed

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ namespace MemoryModel {
5757
#include "SPIRVGenTables.inc"
5858
} // namespace MemoryModel
5959

60+
namespace MatrixMultiplyAccumulateOperands {
61+
#define GET_MatrixMultiplyAccumulateOperands_DECL
62+
#include "SPIRVGenTables.inc"
63+
} // namespace MatrixMultiplyAccumulateOperands
64+
6065
namespace ExecutionMode {
6166
#define GET_ExecutionMode_DECL
6267
#include "SPIRVGenTables.inc"

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,34 @@ void SPIRVInstPrinter::printInst(const MCInst *MI, uint64_t Address,
242242
}
243243
break;
244244
}
245+
case SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL: {
246+
const unsigned NumOps = MI->getNumOperands();
247+
if (NumFixedOps >= NumOps)
248+
break;
249+
OS << ' ';
250+
const unsigned Flags = MI->getOperand(NumOps - 1).getImm();
251+
if (Flags == 0) {
252+
printSymbolicOperand<
253+
OperandCategory::MatrixMultiplyAccumulateOperandsOperand>(
254+
MI, NumOps - 1, OS);
255+
} else {
256+
std::string Buffer;
257+
for (unsigned Mask = 0x1;
258+
Mask <= SPIRV::MatrixMultiplyAccumulateOperands::
259+
MatrixBPackedBFloat16INTEL;
260+
Mask <<= 1) {
261+
if (Flags & Mask) {
262+
if (!Buffer.empty())
263+
Buffer += '|';
264+
Buffer += getSymbolicOperandMnemonic(
265+
OperandCategory::MatrixMultiplyAccumulateOperandsOperand,
266+
Mask);
267+
}
268+
}
269+
OS << Buffer;
270+
}
271+
break;
272+
}
245273
default:
246274
printRemainingVariableOps(MI, NumFixedOps, OS);
247275
break;

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def KernelProfilingInfoOperand : OperandCategory;
172172
def OpcodeOperand : OperandCategory;
173173
def CooperativeMatrixLayoutOperand : OperandCategory;
174174
def CooperativeMatrixOperandsOperand : OperandCategory;
175+
def MatrixMultiplyAccumulateOperandsOperand : OperandCategory;
175176

176177
//===----------------------------------------------------------------------===//
177178
// Multiclass used to define Extesions enum values and at the same time
@@ -1750,3 +1751,40 @@ defm MatrixAAndBTF32ComponentsINTEL : CooperativeMatrixOperandsOperand<0x20, [SP
17501751
defm MatrixAAndBBFloat16ComponentsINTEL : CooperativeMatrixOperandsOperand<0x40, [SPV_INTEL_joint_matrix], [CooperativeMatrixBFloat16ComponentTypeINTEL]>;
17511752
defm MatrixCBFloat16ComponentsINTEL : CooperativeMatrixOperandsOperand<0x80, [SPV_INTEL_joint_matrix], [CooperativeMatrixBFloat16ComponentTypeINTEL]>;
17521753
defm MatrixResultBFloat16ComponentsINTEL : CooperativeMatrixOperandsOperand<0x100, [SPV_INTEL_joint_matrix], [CooperativeMatrixBFloat16ComponentTypeINTEL]>;
1754+
1755+
//===----------------------------------------------------------------------===//
1756+
// Multiclass used to define Matrix Multiply Accumulate Operands enum values and at the same time
1757+
// SymbolicOperand entries with string mnemonics and capabilities.
1758+
//===----------------------------------------------------------------------===//
1759+
def MatrixMultiplyAccumulateOperands : GenericEnum, Operand<i32> {
1760+
let FilterClass = "MatrixMultiplyAccumulateOperands";
1761+
let NameField = "Name";
1762+
let ValueField = "Value";
1763+
let PrintMethod = !strconcat("printSymbolicOperand<OperandCategory::", FilterClass, "Operand>");
1764+
}
1765+
1766+
class MatrixMultiplyAccumulateOperands<string name, bits<32> value> {
1767+
string Name = name;
1768+
bits<32> Value = value;
1769+
}
1770+
1771+
multiclass MatrixMultiplyAccumulateOperandsOperand<bits<32> value, list<Extension> reqExtensions> {
1772+
def : MatrixMultiplyAccumulateOperands<NAME, value>;
1773+
defm : SymbolicOperandWithRequirements<MatrixMultiplyAccumulateOperandsOperand, value, NAME, 0, 0, reqExtensions, []>;
1774+
}
1775+
1776+
defm None : MatrixMultiplyAccumulateOperandsOperand<0x0, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1777+
defm MatrixASignedComponentsINTEL : MatrixMultiplyAccumulateOperandsOperand<0x1, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1778+
defm MatrixBSignedComponentsINTEL : MatrixMultiplyAccumulateOperandsOperand<0x2, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1779+
defm MatrixCBFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x4, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1780+
defm MatrixResultBFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x8, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1781+
defm MatrixAPackedInt8INTEL : MatrixMultiplyAccumulateOperandsOperand<0x10, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1782+
defm MatrixBPackedInt8INTEL : MatrixMultiplyAccumulateOperandsOperand<0x20, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1783+
defm MatrixAPackedInt4INTEL : MatrixMultiplyAccumulateOperandsOperand<0x40, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1784+
defm MatrixBPackedInt4INTEL : MatrixMultiplyAccumulateOperandsOperand<0x80, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1785+
defm MatrixATF32INTEL : MatrixMultiplyAccumulateOperandsOperand<0x100, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1786+
defm MatrixBTF32INTEL : MatrixMultiplyAccumulateOperandsOperand<0x200, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1787+
defm MatrixAPackedFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x400, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1788+
defm MatrixBPackedFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x800, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1789+
defm MatrixAPackedBFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x1000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
1790+
defm MatrixBPackedBFloat16INTEL : MatrixMultiplyAccumulateOperandsOperand<0x2000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;

llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_subgroup_matrix_multiply_accumulate/subgroup_matrix_multiply_accumulate_generic.ll

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -131,35 +131,35 @@
131131
; CHECK: %[[#hM2:]] = OpFunctionParameter %[[#Vec2HalfTy]]
132132
; CHECK: %[[#hM4:]] = OpFunctionParameter %[[#Vec4HalfTy]]
133133
; CHECK: %[[#hM8:]] = OpFunctionParameter %[[#Vec8HalfTy]]
134-
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Int32Ty]] %[[#Const42]] %[[#iM]] %[[#iM8]] %[[#iM]] 10
135-
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2Int32Ty]] %[[#Const42]] %[[#iM2]] %[[#iM8]] %[[#iM2]] 10
136-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4Int32Ty]] %[[#Const42]] %[[#iM4]] %[[#iM8]] %[[#iM4]] 10
137-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8Int32Ty]] %[[#Const42]] %[[#iM8]] %[[#iM8]] %[[#iM8]] 10
138-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#FloatTy]] %[[#Const42]] %[[#iM]] %[[#iM8]] %[[#fM]] 10
139-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2FloatTy]] %[[#Const42]] %[[#iM2]] %[[#iM8]] %[[#fM2]] 10
140-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4FloatTy]] %[[#Const42]] %[[#iM4]] %[[#iM8]] %[[#fM4]] 10
141-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8FloatTy]] %[[#Const42]] %[[#iM8]] %[[#iM8]] %[[#fM8]] 10
142-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Int32Ty]] %[[#Const42]] %[[#sM]] %[[#iM8]] %[[#iM]] 10
143-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2Int32Ty]] %[[#Const42]] %[[#sM2]] %[[#iM8]] %[[#iM2]] 10
144-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4Int32Ty]] %[[#Const42]] %[[#sM4]] %[[#iM8]] %[[#iM4]] 10
145-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8Int32Ty]] %[[#Const42]] %[[#sM8]] %[[#iM8]] %[[#iM8]] 10
146-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#FloatTy]] %[[#Const42]] %[[#sM]] %[[#iM8]] %[[#fM]] 10
147-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2FloatTy]] %[[#Const42]] %[[#sM2]] %[[#iM8]] %[[#fM2]] 10
148-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4FloatTy]] %[[#Const42]] %[[#sM4]] %[[#iM8]] %[[#fM4]] 10
149-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8FloatTy]] %[[#Const42]] %[[#sM8]] %[[#iM8]] %[[#fM8]] 10
150-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#HalfTy]] %[[#Const42]] %[[#sM]] %[[#iM8]] %[[#hM]] 10
151-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2HalfTy]] %[[#Const42]] %[[#sM2]] %[[#iM8]] %[[#hM2]] 10
152-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4HalfTy]] %[[#Const42]] %[[#sM4]] %[[#iM8]] %[[#hM4]] 10
153-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8HalfTy]] %[[#Const42]] %[[#sM8]] %[[#iM8]] %[[#hM8]] 10
154-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Int16Ty]] %[[#Const42]] %[[#sM]] %[[#iM8]] %[[#sM]] 10
155-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2Int16Ty]] %[[#Const42]] %[[#sM2]] %[[#iM8]] %[[#sM2]] 10
156-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4Int16Ty]] %[[#Const42]] %[[#sM4]] %[[#iM8]] %[[#sM4]] 10
157-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8Int16Ty]] %[[#Const42]] %[[#sM8]] %[[#iM8]] %[[#sM8]] 10
158-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#FloatTy]] %[[#Const42]] %[[#fM]] %[[#fM8]] %[[#fM]] 10
159-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2FloatTy]] %[[#Const42]] %[[#fM2]] %[[#fM8]] %[[#fM2]] 10
160-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4FloatTy]] %[[#Const42]] %[[#fM4]] %[[#fM8]] %[[#fM4]] 10
161-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8FloatTy]] %[[#Const42]] %[[#fM8]] %[[#fM8]] %[[#fM8]] 10
162-
; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4FloatTy]] %[[#Const42]] %[[#sM4]] %[[#iM8]] %[[#fM4]]
134+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Int32Ty]] %[[#Const42]] %[[#iM]] %[[#iM8]] %[[#iM]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
135+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2Int32Ty]] %[[#Const42]] %[[#iM2]] %[[#iM8]] %[[#iM2]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
136+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4Int32Ty]] %[[#Const42]] %[[#iM4]] %[[#iM8]] %[[#iM4]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
137+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8Int32Ty]] %[[#Const42]] %[[#iM8]] %[[#iM8]] %[[#iM8]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
138+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#FloatTy]] %[[#Const42]] %[[#iM]] %[[#iM8]] %[[#fM]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
139+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2FloatTy]] %[[#Const42]] %[[#iM2]] %[[#iM8]] %[[#fM2]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
140+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4FloatTy]] %[[#Const42]] %[[#iM4]] %[[#iM8]] %[[#fM4]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
141+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8FloatTy]] %[[#Const42]] %[[#iM8]] %[[#iM8]] %[[#fM8]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
142+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Int32Ty]] %[[#Const42]] %[[#sM]] %[[#iM8]] %[[#iM]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
143+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2Int32Ty]] %[[#Const42]] %[[#sM2]] %[[#iM8]] %[[#iM2]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
144+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4Int32Ty]] %[[#Const42]] %[[#sM4]] %[[#iM8]] %[[#iM4]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
145+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8Int32Ty]] %[[#Const42]] %[[#sM8]] %[[#iM8]] %[[#iM8]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
146+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#FloatTy]] %[[#Const42]] %[[#sM]] %[[#iM8]] %[[#fM]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
147+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2FloatTy]] %[[#Const42]] %[[#sM2]] %[[#iM8]] %[[#fM2]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
148+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4FloatTy]] %[[#Const42]] %[[#sM4]] %[[#iM8]] %[[#fM4]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
149+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8FloatTy]] %[[#Const42]] %[[#sM8]] %[[#iM8]] %[[#fM8]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
150+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#HalfTy]] %[[#Const42]] %[[#sM]] %[[#iM8]] %[[#hM]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
151+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2HalfTy]] %[[#Const42]] %[[#sM2]] %[[#iM8]] %[[#hM2]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
152+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4HalfTy]] %[[#Const42]] %[[#sM4]] %[[#iM8]] %[[#hM4]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
153+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8HalfTy]] %[[#Const42]] %[[#sM8]] %[[#iM8]] %[[#hM8]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
154+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Int16Ty]] %[[#Const42]] %[[#sM]] %[[#iM8]] %[[#sM]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
155+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2Int16Ty]] %[[#Const42]] %[[#sM2]] %[[#iM8]] %[[#sM2]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
156+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4Int16Ty]] %[[#Const42]] %[[#sM4]] %[[#iM8]] %[[#sM4]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
157+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8Int16Ty]] %[[#Const42]] %[[#sM8]] %[[#iM8]] %[[#sM8]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
158+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#FloatTy]] %[[#Const42]] %[[#fM]] %[[#fM8]] %[[#fM]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
159+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2FloatTy]] %[[#Const42]] %[[#fM2]] %[[#fM8]] %[[#fM2]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
160+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4FloatTy]] %[[#Const42]] %[[#fM4]] %[[#fM8]] %[[#fM4]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
161+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8FloatTy]] %[[#Const42]] %[[#fM8]] %[[#fM8]] %[[#fM8]] MatrixBSignedComponentsINTEL|MatrixResultBFloat16INTEL
162+
; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4FloatTy]] %[[#Const42]] %[[#sM4]] %[[#iM8]] %[[#fM4]]
163163

164164
define spir_func void @foo(i32 %iM, <2 x i32> %iM2, <4 x i32> %iM4, <8 x i32> %iM8,
165165
i16 signext %sM, <2 x i16> %sM2, <4 x i16> %sM4, <8 x i16> %sM8,

0 commit comments

Comments
 (0)