Skip to content

Commit f8742e2

Browse files
MrSidimsvmaksimo
authored andcommitted
Add 'Use' parameter to TypeJointMatrixINTEL
'Use' is an optional parameter that shows where in a math operation the matrix is used. It must be the result of a constant instruction with scalar 'integer type'. Spec: #5944 Signed-off-by: Dmitry Sidorov <dmitry.sidorov@intel.com> Original commit: KhronosGroup/SPIRV-LLVM-Translator@348da24
1 parent b07be69 commit f8742e2

File tree

9 files changed

+72
-46
lines changed

9 files changed

+72
-46
lines changed

llvm-spirv/lib/SPIRV/SPIRVReader.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,9 @@ Type *SPIRVToLLVM::transType(SPIRVType *T, bool IsClassMember) {
471471
SS << kSPIRVTypeName::PostfixDelim << R << kSPIRVTypeName::PostfixDelim << C
472472
<< kSPIRVTypeName::PostfixDelim << L << kSPIRVTypeName::PostfixDelim
473473
<< S;
474+
if (auto *Use = MT->getUse())
475+
SS << kSPIRVTypeName::PostfixDelim
476+
<< static_cast<SPIRVConstant *>(Use)->getZExtIntValue();
474477
std::string Name =
475478
getSPIRVTypeName(kSPIRVTypeName::JointMatrixINTEL, SS.str());
476479
return mapType(T, getOrCreateOpaquePtrType(M, Name));

llvm-spirv/lib/SPIRV/SPIRVRegularizeLLVM.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -473,16 +473,26 @@ void SPIRVRegularizeLLVMBase::adaptStructTypes(StructType *ST) {
473473
auto *PtrTy = dyn_cast<PointerType>(ST->getElementType(0));
474474
assert(PtrTy &&
475475
"Expected a pointer to an array to represent joint matrix type");
476-
size_t TypeLayout[4] = {0, 0, 0, 0};
476+
std::vector<size_t> TypeLayout;
477477
ArrayType *ArrayTy = dyn_cast<ArrayType>(PtrTy->getPointerElementType());
478478
assert(ArrayTy && "Expected a pointer element type of an array type to "
479479
"represent joint matrix type");
480-
TypeLayout[0] = ArrayTy->getNumElements();
480+
TypeLayout.push_back(ArrayTy->getNumElements());
481481
for (size_t I = 1; I != 4; ++I) {
482482
ArrayTy = dyn_cast<ArrayType>(ArrayTy->getElementType());
483483
assert(ArrayTy &&
484484
"Expected a element type to represent joint matrix type");
485-
TypeLayout[I] = ArrayTy->getNumElements();
485+
TypeLayout.push_back(ArrayTy->getNumElements());
486+
}
487+
// JointMatrixINTEL type can have optional 'Use' parameter, which is encoded
488+
// as another array dimention. In case if it has default 'Unnecessary' (4)
489+
// parameter - ignore it.
490+
if (isa<ArrayType>(ArrayTy->getElementType())) {
491+
ArrayTy = cast<ArrayType>(ArrayTy->getElementType());
492+
uint32_t UseInt = ArrayTy->getNumElements();
493+
assert(UseInt <= 4 && "Use parameter encoded in the array must be < 5 ");
494+
if (UseInt != 4)
495+
TypeLayout.push_back(UseInt);
486496
}
487497

488498
auto *ElemTy = ArrayTy->getElementType();
@@ -536,6 +546,9 @@ void SPIRVRegularizeLLVMBase::adaptStructTypes(StructType *ST) {
536546
<< kSPIRVTypeName::PostfixDelim << std::to_string(TypeLayout[2] - 1)
537547
<< kSPIRVTypeName::PostfixDelim
538548
<< std::to_string(TypeLayout[3] - 1);
549+
if (TypeLayout.size() == 5)
550+
SPVName << kSPIRVTypeName::PostfixDelim
551+
<< std::to_string(TypeLayout[4] - 1);
539552
// Note, that this structure is not opaque and there is no way to make it
540553
// opaque but to recreate it entirely and replace it everywhere. Lets
541554
// keep the structure as is, dealing with it during SPIR-V generation.

llvm-spirv/lib/SPIRV/SPIRVWriter.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -582,12 +582,10 @@ SPIRVType *LLVMToSPIRVBase::transSPIRVJointMatrixINTELType(
582582
consumeUnsignedInteger(Postfix, 10, N);
583583
return getUInt32(M, N);
584584
};
585-
SPIRVValue *Rows = transConstant(ParseInteger(Postfixes[1]));
586-
SPIRVValue *Columns = transConstant(ParseInteger(Postfixes[2]));
587-
SPIRVValue *Layout = transConstant(ParseInteger(Postfixes[3]));
588-
SPIRVValue *Scope = transConstant(ParseInteger(Postfixes[4]));
589-
return mapType(T, BM->addJointMatrixINTELType(transType(ElemTy), Rows,
590-
Columns, Layout, Scope));
585+
std::vector<SPIRVValue *> Args;
586+
for (size_t I = 1; I != Postfixes.size(); ++I)
587+
Args.emplace_back(transConstant(ParseInteger(Postfixes[I])));
588+
return mapType(T, BM->addJointMatrixINTELType(transType(ElemTy), Args));
591589
}
592590

593591
SPIRVType *LLVMToSPIRVBase::transSPIRVOpaqueType(Type *T) {

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,8 @@ class SPIRVModuleImpl : public SPIRVModule {
245245
SPIRVEntry *addTypeStructContinuedINTEL(unsigned NumMembers) override;
246246
void closeStructType(SPIRVTypeStruct *T, bool) override;
247247
SPIRVTypeVector *addVectorType(SPIRVType *, SPIRVWord) override;
248-
SPIRVTypeJointMatrixINTEL *addJointMatrixINTELType(SPIRVType *, SPIRVValue *,
249-
SPIRVValue *, SPIRVValue *,
250-
SPIRVValue *) override;
248+
SPIRVTypeJointMatrixINTEL *
249+
addJointMatrixINTELType(SPIRVType *, std::vector<SPIRVValue *>) override;
251250
SPIRVType *addOpaqueGenericType(Op) override;
252251
SPIRVTypeDeviceEvent *addDeviceEventType() override;
253252
SPIRVTypeQueue *addQueueType() override;
@@ -916,11 +915,10 @@ SPIRVTypeVector *SPIRVModuleImpl::addVectorType(SPIRVType *CompType,
916915
return addType(new SPIRVTypeVector(this, getId(), CompType, CompCount));
917916
}
918917

919-
SPIRVTypeJointMatrixINTEL *SPIRVModuleImpl::addJointMatrixINTELType(
920-
SPIRVType *CompType, SPIRVValue *Rows, SPIRVValue *Columns,
921-
SPIRVValue *Layout, SPIRVValue *Scope) {
922-
return addType(new SPIRVTypeJointMatrixINTEL(this, getId(), CompType, Rows,
923-
Columns, Layout, Scope));
918+
SPIRVTypeJointMatrixINTEL *
919+
SPIRVModuleImpl::addJointMatrixINTELType(SPIRVType *CompType,
920+
std::vector<SPIRVValue *> Args) {
921+
return addType(new SPIRVTypeJointMatrixINTEL(this, getId(), CompType, Args));
924922
}
925923

926924
SPIRVType *SPIRVModuleImpl::addOpaqueGenericType(Op TheOpCode) {

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,7 @@ class SPIRVModule {
245245
virtual void closeStructType(SPIRVTypeStruct *, bool) = 0;
246246
virtual SPIRVTypeVector *addVectorType(SPIRVType *, SPIRVWord) = 0;
247247
virtual SPIRVTypeJointMatrixINTEL *
248-
addJointMatrixINTELType(SPIRVType *, SPIRVValue *, SPIRVValue *, SPIRVValue *,
249-
SPIRVValue *) = 0;
248+
addJointMatrixINTELType(SPIRVType *, std::vector<SPIRVValue *>) = 0;
250249
virtual SPIRVTypeVoid *addVoidType() = 0;
251250
virtual SPIRVType *addOpaqueGenericType(Op) = 0;
252251
virtual SPIRVTypeDeviceEvent *addDeviceEventType() = 0;

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVType.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -275,16 +275,23 @@ void SPIRVTypeForwardPointer::decode(std::istream &I) {
275275
}
276276

277277
SPIRVTypeJointMatrixINTEL::SPIRVTypeJointMatrixINTEL(
278-
SPIRVModule *M, SPIRVId TheId, SPIRVType *CompType, SPIRVValue *Rows,
279-
SPIRVValue *Columns, SPIRVValue *Layout, SPIRVValue *Scope)
280-
: SPIRVType(M, FixedWC, OC, TheId), CompType(CompType), Rows(Rows),
281-
Columns(Columns), Layout(Layout), Scope(Scope) {}
278+
SPIRVModule *M, SPIRVId TheId, SPIRVType *CompType,
279+
std::vector<SPIRVValue *> Args)
280+
: SPIRVType(M, FixedWC + Args.size(), OC, TheId), CompType(CompType),
281+
Args(Args) {}
282282

283283
SPIRVTypeJointMatrixINTEL::SPIRVTypeJointMatrixINTEL()
284-
: SPIRVType(OC), CompType(nullptr), Rows(nullptr), Columns(nullptr),
285-
Layout(nullptr), Scope(nullptr) {}
284+
: SPIRVType(OC), CompType(nullptr),
285+
Args({nullptr, nullptr, nullptr, nullptr}) {}
286286

287-
_SPIRV_IMP_ENCDEC6(SPIRVTypeJointMatrixINTEL, Id, CompType, Rows, Columns,
288-
Layout, Scope)
287+
void SPIRVTypeJointMatrixINTEL::encode(spv_ostream &O) const {
288+
auto Encoder = getEncoder(O);
289+
Encoder << Id << CompType << Args;
290+
}
291+
292+
void SPIRVTypeJointMatrixINTEL::decode(std::istream &I) {
293+
auto Decoder = getDecoder(I);
294+
Decoder >> Id >> CompType >> Args;
295+
}
289296

290297
} // namespace SPIRV

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVType.h

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,18 +1060,14 @@ class SPIRVTypeTokenINTEL : public SPIRVType {
10601060

10611061
class SPIRVTypeJointMatrixINTEL : public SPIRVType {
10621062
SPIRVType *CompType;
1063-
SPIRVValue *Rows;
1064-
SPIRVValue *Columns;
1065-
SPIRVValue *Layout;
1066-
SPIRVValue *Scope;
1063+
std::vector<SPIRVValue *> Args;
10671064

10681065
public:
10691066
const static Op OC = internal::OpTypeJointMatrixINTEL;
1070-
const static SPIRVWord FixedWC = 7;
1067+
const static SPIRVWord FixedWC = 3;
10711068
// Complete constructor
10721069
SPIRVTypeJointMatrixINTEL(SPIRVModule *M, SPIRVId TheId, SPIRVType *CompType,
1073-
SPIRVValue *Rows, SPIRVValue *Columns,
1074-
SPIRVValue *Layout, SPIRVValue *Scope);
1070+
std::vector<SPIRVValue *> Args);
10751071
// Incomplete constructor
10761072
SPIRVTypeJointMatrixINTEL();
10771073
_SPIRV_DCL_ENCDEC
@@ -1081,11 +1077,16 @@ class SPIRVTypeJointMatrixINTEL : public SPIRVType {
10811077
SPIRVCapVec getRequiredCapability() const override {
10821078
return {internal::CapabilityJointMatrixINTEL};
10831079
}
1080+
void setWordCount(SPIRVWord WordCount) override {
1081+
SPIRVType::setWordCount(WordCount);
1082+
Args.resize(WordCount - FixedWC);
1083+
}
10841084
SPIRVType *getCompType() const { return CompType; }
1085-
SPIRVValue *getLayout() const { return Layout; }
1086-
SPIRVValue *getRows() const { return Rows; }
1087-
SPIRVValue *getColumns() const { return Columns; }
1088-
SPIRVValue *getScope() const { return Scope; }
1085+
SPIRVValue *getRows() const { return Args[0]; }
1086+
SPIRVValue *getColumns() const { return Args[1]; }
1087+
SPIRVValue *getLayout() const { return Args[2]; }
1088+
SPIRVValue *getScope() const { return Args[3]; }
1089+
SPIRVValue *getUse() const { return Args.size() > 4 ? Args[4] : nullptr; }
10891090
};
10901091

10911092
} // namespace SPIRV

llvm-spirv/lib/SPIRV/libSPIRV/spirv_internal.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,14 @@ enum InternalLoopControlMask { ILoopControlLoopCountINTELMask = 0x1000000 };
9393
constexpr LinkageType LinkageTypeInternal =
9494
static_cast<LinkageType>(ILTInternal);
9595

96-
enum InternalJointMatrixLayout { RowMajor, ColumnMajor, PackedA, PackedB };
96+
enum InternalJointMatrixLayout {
97+
RowMajor = 0,
98+
ColumnMajor = 1,
99+
PackedA = 2,
100+
PackedB = 3
101+
};
102+
103+
enum InternalJointMatrixUse { MatrixA = 0, MatrixB = 1, Accumulator = 2 };
97104

98105
enum InternalBuiltIn {
99106
IBuiltInSubDeviceIDINTEL = 6135,

llvm-spirv/test/transcoding/SPV_INTEL_joint_matrix/joint_matrix.ll

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s --check-prefix=CHECK-LLVM
99

1010
; CHECK-PRE: %spirv.JointMatrixINTEL._short_2_2_0_3
11-
; CHECK-PRE: %spirv.JointMatrixINTEL._char_2_16_0_3
11+
; CHECK-PRE: %spirv.JointMatrixINTEL._char_2_16_0_3_0
1212
; CHECK-PRE: %spirv.JointMatrixINTEL._char_16_2_3_3
1313

1414
; CHECK-SPIRV: Capability JointMatrixINTEL
@@ -24,7 +24,7 @@
2424
; CHECK-SPIRV-DAG: Constant [[#IntTy]] [[#Sixteen:]] 16
2525
; CHECK-SPIRV-DAG: Constant [[#IntTy]] [[#FortyTwo:]] 42
2626
; CHECK-SPIRV: TypeJointMatrixINTEL [[#CTy:]] [[#ShortTy]] [[#Two]] [[#Two]] [[#Zero]] [[#Three]]
27-
; CHECK-SPIRV: TypeJointMatrixINTEL [[#ATy:]] [[#CharTy]] [[#Two]] [[#Sixteen]] [[#Zero]] [[#Three]]
27+
; CHECK-SPIRV: TypeJointMatrixINTEL [[#ATy:]] [[#CharTy]] [[#Two]] [[#Sixteen]] [[#Zero]] [[#Three]] [[#Zero]]
2828
; CHECK-SPIRV: TypeJointMatrixINTEL [[#BTy:]] [[#CharTy]] [[#Sixteen]] [[#Two]] [[#Three]] [[#Three]]
2929

3030
; CHECK-SPIRV: Function [[#]] [[#Kernel]]
@@ -48,14 +48,14 @@
4848

4949

5050
; CHECK-LLVM: %spirv.JointMatrixINTEL._short_2_2_0_3
51-
; CHECK-LLVM: %spirv.JointMatrixINTEL._char_2_16_0_3
51+
; CHECK-LLVM: %spirv.JointMatrixINTEL._char_2_16_0_3_0
5252
; CHECK-LLVM: %spirv.JointMatrixINTEL._char_16_2_3_3
5353

5454
; CHECK-LLVM: [[CLoaded:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z77__spirv_JointMatrixLoadINTEL_RPU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3PU3AS4sliii(i16 addrspace(4)* [[CPtr:%.*]], i64 [[Stride:%.*]], i32 0, i32 3, i32 0)
5555
; CHECK-LLVM: [[C:%.*]] = phi %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [ [[CLoaded]], %entry ], [ [[CMad:%.*]], %for.body.i ]
56-
; CHECK-LLVM: [[A:%.*]] = call spir_func %spirv.JointMatrixINTEL._char_2_16_0_3 addrspace(1)* @_Z77__spirv_JointMatrixLoadINTEL_RPU3AS139__spirv_JointMatrixINTEL__char_2_16_0_3PU3AS4cliii(i8 addrspace(4)* [[APtr:%.*]], i64 [[Stride]], i32 0, i32 3, i32 0)
56+
; CHECK-LLVM: [[A:%.*]] = call spir_func %spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* @_Z79__spirv_JointMatrixLoadINTEL_RPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS4cliii(i8 addrspace(4)* [[APtr:%.*]], i64 [[Stride]], i32 0, i32 3, i32 0)
5757
; CHECK-LLVM: [[B:%.*]] = call spir_func %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* @_Z77__spirv_JointMatrixLoadINTEL_RPU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS4cliii(i8 addrspace(4)* [[BPtr:%.*]], i64 [[Stride]], i32 0, i32 3, i32 0)
58-
; CHECK-LLVM: [[CMad:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z27__spirv_JointMatrixMadINTELPU3AS139__spirv_JointMatrixINTEL__char_2_16_0_3PU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3i(%spirv.JointMatrixINTEL._char_2_16_0_3 addrspace(1)* [[A]], %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* [[B]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i32 3)
58+
; CHECK-LLVM: [[CMad:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z27__spirv_JointMatrixMadINTELPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3i(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* [[A]], %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* [[B]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i32 3)
5959
; CHECK-LLVM: call spir_func void @_Z29__spirv_JointMatrixStoreINTELPU3AS4sPU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3liii(i16 addrspace(4)* [[CPtr]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i64 [[Stride]], i32 0, i32 3, i32 0)
6060
; CHECK-LLVM: call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z26__spirv_CompositeConstructi(i32 42)
6161
; CHECK-LLVM: store i32 0, i32 addrspace(4)* [[StoredZero:%.*]], align 4
@@ -67,8 +67,8 @@ source_filename = "./joint_matrix_test.cpp"
6767
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
6868
target triple = "spir64-unknown-unknown"
6969

70-
%"struct.__spv::__spirv_JointMatrixINTEL" = type { [2 x [2 x [1 x [4 x i16]]]]* }
71-
%"struct.__spv::__spirv_JointMatrixINTEL.0" = type { [2 x [16 x [1 x [4 x i8]]]]* }
70+
%"struct.__spv::__spirv_JointMatrixINTEL" = type { [2 x [2 x [1 x [4 x [4 x i16]]]]]* }
71+
%"struct.__spv::__spirv_JointMatrixINTEL.0" = type { [2 x [16 x [1 x [4 x [1 x i8]]]]]* }
7272
%"struct.__spv::__spirv_JointMatrixINTEL.2" = type { [16 x [2 x [4 x [4 x i8]]]]* }
7373

7474
$_ZTSZ4mainE11matrix_test = comdat any

0 commit comments

Comments
 (0)