Skip to content

Commit 3edd618

Browse files
authored
[SYCL][NFCI] Finalize switch to SPV_KHR_cooperative_matrix (#16045)
Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com>
1 parent 925ff76 commit 3edd618

File tree

95 files changed

+12
-2353
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

95 files changed

+12
-2353
lines changed

clang/lib/CodeGen/CodeGenTypes.cpp

Lines changed: 1 addition & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -350,34 +350,6 @@ llvm::Type *CodeGenTypes::ConvertFunctionTypeInternal(QualType QFT) {
350350
return ResultType;
351351
}
352352

353-
template <bool NeedTypeInterpret = false>
354-
llvm::Type *getJointMatrixINTELExtType(llvm::Type *CompTy,
355-
ArrayRef<TemplateArgument> TemplateArgs,
356-
const unsigned Val = 0) {
357-
// TODO: we should actually have exactly 5 template parameters: 1 for
358-
// type and 4 for type parameters. But in previous version of the SPIR-V
359-
// spec we have Layout matrix type parameter, that was later removed.
360-
// Once we update to the newest version of the spec - this should be updated.
361-
assert((TemplateArgs.size() == 5 || TemplateArgs.size() == 6) &&
362-
"Wrong JointMatrixINTEL template parameters number");
363-
// This is required to represent optional 'Component Type Interpretation'
364-
// parameter
365-
std::vector<unsigned> Params;
366-
for (size_t I = 1; I != TemplateArgs.size(); ++I) {
367-
assert(TemplateArgs[I].getKind() == TemplateArgument::Integral &&
368-
"Wrong JointMatrixINTEL template parameter");
369-
Params.push_back(TemplateArgs[I].getAsIntegral().getExtValue());
370-
}
371-
// Don't add type interpretation for legacy matrices.
372-
// Legacy matrices has 5 template parameters, while new representation
373-
// has 6.
374-
if (NeedTypeInterpret && TemplateArgs.size() != 5)
375-
Params.push_back(Val);
376-
377-
return llvm::TargetExtType::get(CompTy->getContext(),
378-
"spirv.JointMatrixINTEL", {CompTy}, Params);
379-
}
380-
381353
llvm::Type *
382354
getCooperativeMatrixKHRExtType(llvm::Type *CompTy,
383355
ArrayRef<TemplateArgument> TemplateArgs) {
@@ -394,49 +366,6 @@ getCooperativeMatrixKHRExtType(llvm::Type *CompTy,
394366
CompTy->getContext(), "spirv.CooperativeMatrixKHR", {CompTy}, Params);
395367
}
396368

397-
/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
398-
/// which is represented as a pointer to a structure to LLVM extension type
399-
/// with the parameters that follow SPIR-V JointMatrixINTEL type.
400-
/// The expected representation is:
401-
/// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%,
402-
/// %use%, (optional) %element_type_interpretation%)
403-
llvm::Type *CodeGenTypes::ConvertSYCLJointMatrixINTELType(RecordDecl *RD) {
404-
auto *TemplateDecl = cast<ClassTemplateSpecializationDecl>(RD);
405-
ArrayRef<TemplateArgument> TemplateArgs =
406-
TemplateDecl->getTemplateArgs().asArray();
407-
assert(TemplateArgs[0].getKind() == TemplateArgument::Type &&
408-
"1st JointMatrixINTEL template parameter must be type");
409-
llvm::Type *CompTy = ConvertType(TemplateArgs[0].getAsType());
410-
411-
// Per JointMatrixINTEL spec the type can have an optional
412-
// 'Component Type Interpretation' parameter. We should emit it in case
413-
// if on SYCL level joint matrix accepts 'bfloat16' or 'tf32' objects as
414-
// matrix's components. Yet 'bfloat16' should be represented as 'int16' and
415-
// 'tf32' as 'float' types.
416-
if (CompTy->isStructTy()) {
417-
StringRef LlvmTyName = CompTy->getStructName();
418-
// Emit half/int16/float for sycl[::*]::{half,bfloat16,tf32}
419-
if (LlvmTyName.starts_with("class.sycl::") ||
420-
LlvmTyName.starts_with("class.__sycl_internal::"))
421-
LlvmTyName = LlvmTyName.rsplit("::").second;
422-
if (LlvmTyName == "half") {
423-
CompTy = llvm::Type::getHalfTy(getLLVMContext());
424-
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
425-
} else if (LlvmTyName == "tf32") {
426-
CompTy = llvm::Type::getFloatTy(getLLVMContext());
427-
// 'tf32' interpretation is mapped to '0'
428-
return getJointMatrixINTELExtType<true>(CompTy, TemplateArgs, 0);
429-
} else if (LlvmTyName == "bfloat16") {
430-
CompTy = llvm::Type::getInt16Ty(getLLVMContext());
431-
// 'bfloat16' interpretation is mapped to '1'
432-
return getJointMatrixINTELExtType<true>(CompTy, TemplateArgs, 1);
433-
} else {
434-
llvm_unreachable("Wrong matrix base type!");
435-
}
436-
}
437-
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
438-
}
439-
440369
/// ConvertSPVCooperativeMatrixType - Convert SYCL joint_matrix type
441370
/// which is represented as a pointer to a structure to LLVM extension type
442371
/// with the parameters that follow SPIR-V CooperativeMatrixKHR type.
@@ -733,11 +662,7 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
733662
if (ClangETy && ClangETy->isStructureOrClassType()) {
734663
RecordDecl *RD = ClangETy->getAsCXXRecordDecl();
735664
if (RD && RD->getQualifiedNameAsString() ==
736-
"__spv::__spirv_JointMatrixINTEL") {
737-
ResultType = ConvertSYCLJointMatrixINTELType(RD);
738-
break;
739-
} else if (RD && RD->getQualifiedNameAsString() ==
740-
"__spv::__spirv_CooperativeMatrixKHR") {
665+
"__spv::__spirv_CooperativeMatrixKHR") {
741666
ResultType = ConvertSPVCooperativeMatrixType(RD);
742667
break;
743668
} else if (RD && RD->getQualifiedNameAsString() ==

clang/lib/CodeGen/CodeGenTypes.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,14 +145,6 @@ class CodeGenTypes {
145145
/// load/store type are the same.
146146
llvm::Type *convertTypeForLoadStore(QualType T, llvm::Type *LLVMTy = nullptr);
147147

148-
/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
149-
/// which is represented as a pointer to a structure to LLVM extension type
150-
/// with the parameters that follow SPIR-V JointMatrixINTEL type.
151-
/// The expected representation is:
152-
/// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%,
153-
/// %use%, (optional) %element_type_interpretation%)
154-
llvm::Type *ConvertSYCLJointMatrixINTELType(RecordDecl *RD);
155-
156148
/// ConvertSPVCooperativeMatrixType - Convert SYCL joint_matrix type
157149
/// which is represented as a pointer to a structure to LLVM extension type
158150
/// with the parameters that follow SPIR-V CooperativeMatrixKHR type.

clang/test/CodeGenSYCL/joint_matrix.cpp

Lines changed: 0 additions & 41 deletions
This file was deleted.

sycl/include/sycl/__spirv/spirv_ops.hpp

Lines changed: 0 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -27,155 +27,6 @@
2727

2828
extern __DPCPP_SYCL_EXTERNAL float __spirv_RoundFToTF32INTEL(float a);
2929

30-
#ifndef __SPIRV_USE_COOPERATIVE_MATRIX
31-
template <typename T, typename Tp, std::size_t R, std::size_t C,
32-
__spv::MatrixUse U,
33-
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
34-
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
35-
extern __DPCPP_SYCL_EXTERNAL
36-
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
37-
__spirv_JointMatrixLoadINTEL(T *Ptr, std::size_t Stride,
38-
__spv::MatrixLayout Layout = L,
39-
__spv::Scope::Flag Sc = S, int MemOperand = 0);
40-
41-
template <typename T, typename Tp, std::size_t R, std::size_t C,
42-
__spv::MatrixUse U,
43-
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
44-
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
45-
extern __DPCPP_SYCL_EXTERNAL void __spirv_JointMatrixStoreINTEL(
46-
T *Ptr, __spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *Object,
47-
std::size_t Stride, __spv::MatrixLayout Layout = L,
48-
__spv::Scope::Flag Sc = S, int MemOperand = 0);
49-
50-
template <typename T, typename Tp, std::size_t R, std::size_t C,
51-
__spv::MatrixUse U,
52-
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
53-
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
54-
extern __DPCPP_SYCL_EXTERNAL
55-
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
56-
__spirv_CooperativeMatrixConstructCheckedINTEL(int32_t CoordX,
57-
int32_t CoordY,
58-
uint32_t Height,
59-
uint32_t Width,
60-
const T Value);
61-
62-
template <typename T, typename Tp, std::size_t R, std::size_t C,
63-
__spv::MatrixUse U,
64-
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
65-
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
66-
extern __DPCPP_SYCL_EXTERNAL
67-
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
68-
__spirv_CooperativeMatrixLoadCheckedINTEL(
69-
T *Ptr, int32_t CoordX, int32_t CoordY, __spv::MatrixLayout Layout = L,
70-
uint32_t Height = 0, uint32_t Width = 0, std::size_t Stride = 0,
71-
int MemOperand = 0);
72-
73-
template <typename T, typename Tp, std::size_t R, std::size_t C,
74-
__spv::MatrixUse U,
75-
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
76-
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
77-
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreCheckedINTEL(
78-
T *Ptr, int32_t CoordX, int32_t CoordY,
79-
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *Object,
80-
__spv::MatrixLayout Layout = L, uint32_t Height = 0, uint32_t Width = 0,
81-
std::size_t Stride = 0, int MemOperand = 0);
82-
83-
template <typename TA, typename TB, typename TC, std::size_t M, std::size_t K,
84-
std::size_t N, __spv::MatrixUse UA, __spv::MatrixUse UB,
85-
__spv::MatrixUse UC,
86-
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
87-
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
88-
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
89-
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
90-
extern __DPCPP_SYCL_EXTERNAL
91-
__spv::__spirv_JointMatrixINTEL<TC, M, N, LC, S, UC> *
92-
__spirv_JointMatrixMadINTEL(
93-
__spv::__spirv_JointMatrixINTEL<TA, M, K, LA, S, UA> *A,
94-
__spv::__spirv_JointMatrixINTEL<TB, K, N, LB, S, UB> *B,
95-
__spv::__spirv_JointMatrixINTEL<TC, M, N, LC, S, UC> *C,
96-
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
97-
98-
template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
99-
std::size_t N, __spv::MatrixUse UA, __spv::MatrixUse UB,
100-
__spv::MatrixUse UC,
101-
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
102-
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
103-
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
104-
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
105-
extern __DPCPP_SYCL_EXTERNAL
106-
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *
107-
__spirv_JointMatrixUUMadINTEL(
108-
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S, UA> *A,
109-
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S, UB> *B,
110-
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *C,
111-
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
112-
113-
template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
114-
std::size_t N, __spv::MatrixUse UA, __spv::MatrixUse UB,
115-
__spv::MatrixUse UC,
116-
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
117-
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
118-
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
119-
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
120-
extern __DPCPP_SYCL_EXTERNAL
121-
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *
122-
__spirv_JointMatrixUSMadINTEL(
123-
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S, UA> *A,
124-
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S, UB> *B,
125-
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *C,
126-
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
127-
128-
template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
129-
std::size_t N, __spv::MatrixUse UA, __spv::MatrixUse UB,
130-
__spv::MatrixUse UC,
131-
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
132-
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
133-
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
134-
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
135-
extern __DPCPP_SYCL_EXTERNAL
136-
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *
137-
__spirv_JointMatrixSUMadINTEL(
138-
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S, UA> *A,
139-
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S, UB> *B,
140-
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *C,
141-
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
142-
143-
template <typename T, typename Tp, std::size_t R, std::size_t C,
144-
__spv::MatrixUse U,
145-
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
146-
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
147-
extern __DPCPP_SYCL_EXTERNAL
148-
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
149-
__spirv_CompositeConstruct(const T v);
150-
151-
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
152-
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
153-
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
154-
extern __DPCPP_SYCL_EXTERNAL __ocl_vec_t<uint32_t, 2>
155-
__spirv_JointMatrixGetElementCoordINTEL(
156-
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *, size_t i);
157-
158-
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
159-
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
160-
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
161-
extern __DPCPP_SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL(
162-
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *);
163-
164-
template <typename Ts, typename T, std::size_t R, std::size_t C,
165-
__spv::MatrixUse U,
166-
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
167-
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
168-
extern __DPCPP_SYCL_EXTERNAL Ts __spirv_VectorExtractDynamic(
169-
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *, size_t i);
170-
171-
template <typename Ts, typename T, std::size_t R, std::size_t C,
172-
__spv::MatrixUse U,
173-
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
174-
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
175-
extern __DPCPP_SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
176-
__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *,
177-
Ts val, size_t i);
178-
#else // __SPIRV_USE_COOPERATIVE_MATRIX
17930
template <typename T, typename Tp, std::size_t R, std::size_t C,
18031
__spv::MatrixUse U,
18132
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
@@ -304,7 +155,6 @@ extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreCheckedINTEL(
304155
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *Object,
305156
__spv::MatrixLayout Layout = L, uint32_t Height = 0, uint32_t Width = 0,
306157
std::size_t Stride = 0, int MemOperand = 0);
307-
#endif // __SPIRV_USE_COOPERATIVE_MATRIX
308158

309159
template <typename T>
310160
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixPrefetchINTEL(

sycl/include/sycl/__spirv/spirv_types.hpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ enum class MatrixLayout : uint32_t {
118118

119119
enum class MatrixUse : uint32_t { MatrixA = 0, MatrixB = 1, Accumulator = 2 };
120120

121-
#ifdef __SPIRV_USE_COOPERATIVE_MATRIX
122121
enum class MatrixOperands : uint32_t {
123122
// SPV_KHR_cooperative_matrix operands
124123
NoneKHR = 0,
@@ -133,19 +132,10 @@ enum class MatrixOperands : uint32_t {
133132
MatrixCBFloat16ComponentsINTEL = 0x80,
134133
MatrixResultBFloat16ComponentsINTEL = 0x100
135134
};
136-
#endif // __SPIRV_USE_COOPERATIVE_MATRIX
137135

138-
#ifndef __SPIRV_USE_COOPERATIVE_MATRIX
139-
140-
template <typename T, std::size_t R, std::size_t C, MatrixLayout L,
141-
Scope::Flag S = Scope::Flag::Subgroup,
142-
MatrixUse U = MatrixUse::MatrixA>
143-
struct __spirv_JointMatrixINTEL;
144-
#else
145136
template <typename T, Scope::Flag S = Scope::Flag::Subgroup, std::size_t R = 1,
146137
std::size_t C = 1, MatrixUse U = MatrixUse::MatrixA>
147138
struct __spirv_CooperativeMatrixKHR;
148-
#endif // __SPIRV_USE_COOPERATIVE_MATRIX
149139

150140
struct __spirv_TaskSequenceINTEL;
151141

0 commit comments

Comments
 (0)