Skip to content

Commit 9789835

Browse files
committed
Modify the uArch definition.
This version focuses on the utilities to be the pivot. It also saves info directly in C++ files as part of get functions. Don't use the yamls anymore. Adds support for DPAS instruction.
1 parent 2986356 commit 9789835

File tree

3 files changed

+198
-27
lines changed

3 files changed

+198
-27
lines changed

mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,25 @@ struct Xe2Plus : public uArch {
3737
};
3838

3939
// struct to represent DPAS instruction
40-
struct DPASInstruction : public Instruction {
41-
Range systolic_depth;
42-
Range repreat_count;
43-
Range execution_size;
44-
std::map<std::string, uint> ops_per_channel;
45-
std::vector<std::vector<std::string>> supported_types;
46-
std::map<std::string, std::map<std::string, std::vector<std::string>>>
47-
matrix_size;
40+
struct DPASInstruction : public Instruction, public MatrixOpInterface {
41+
// Range systolic_depth;
42+
// Range repreat_count;
43+
// Range execution_size;
44+
// std::map<std::string, uint> ops_per_channel;
45+
// std::vector<std::vector<std::string>> supported_types;
46+
// std::map<std::string, std::map<std::string, std::vector<std::string>>>
47+
// matrix_size;
4848

49-
bool checkSupportedDPASTypes(mlir::Type dstType, mlir::Type src0Type,
50-
mlir::Type src1Type, mlir::Type src2Type);
49+
// bool checkSupportedDPASTypes(mlir::Type dstType, mlir::Type src0Type,
50+
// mlir::Type src1Type, mlir::Type src2Type);
51+
virtual bool checkSupportedMMATypes(mlir::Type AType, mlir::Type BType,
52+
mlir::Type CType,
53+
mlir::Type DType) override;
54+
virtual std::vector<uint> getSupportedM(mlir::Type type) override;
55+
virtual std::vector<uint> getSupportedK(mlir::Type type) override;
56+
virtual std::vector<uint> getSupportedN(mlir::Type type) override;
57+
virtual std::vector<std::pair<unsigned, unsigned>>
58+
getSupportedMatrix(mlir::Type type, MatrixType matrixType) override;
5159
};
5260

5361
struct LoadStore2DTileInfo : public RangeTile {

mlir/include/mlir/Dialect/XeGPU/Utils/uArch.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,35 @@ struct SharedMemory {
277277
// uint num_matrix_units;
278278
// SharedMemory shared_memory;
279279
// };
280+
281+
// Create a TileLikeOp Interface
282+
struct TileOpInterface {
283+
// Get the supported tiles for the specific data type.
284+
// Can provide load/store/prefetch ops supported tile sizes for a specific
285+
// uarch
286+
virtual DiscreteTile getSupportedTiles(mlir::Type type) = 0;
287+
288+
// Validate the tile ops restrictions
289+
// @param tile, tile to load/store/prefetch
290+
// @param surface, surface to load/store/prefetch data from
291+
// @param dataType, data type of the data
292+
// @param surface_pitch, suface pitch
293+
// @param array_len, array length
294+
virtual bool validate(Tile tile, Tile surface, mlir::Type dataType,
295+
uint surface_pitch, uint array_len = 1) = 0;
296+
};
297+
298+
enum class MatrixType { MatrixA, MatrixB, MatrixC, MatrixD };
299+
struct MatrixOpInterface {
300+
virtual bool checkSupportedMMATypes(mlir::Type AType, mlir::Type BType,
301+
mlir::Type CType, mlir::Type DType) = 0;
302+
virtual std::vector<uint> getSupportedM(mlir::Type type) = 0;
303+
virtual std::vector<uint> getSupportedK(mlir::Type type) = 0;
304+
virtual std::vector<uint> getSupportedN(mlir::Type type) = 0;
305+
virtual std::vector<std::pair<unsigned, unsigned>>
306+
getSupportedMatrix(mlir::Type type, MatrixType matrixType) = 0;
307+
};
308+
280309
} // namespace uArch
281310
} // namespace xegpu
282311
} // namespace mlir

mlir/lib/Dialect/XeGPU/Utils/IntelGpuPVC.cpp

Lines changed: 151 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,160 @@
77
using namespace mlir::xegpu::uArch;
88
using namespace mlir::xegpu::uArch::PVCuArch;
99

10-
namespace llvm {
11-
namespace yaml {
12-
template <>
13-
struct MappingTraits<XeCoreInfo> {
14-
static void mapping(IO &io, XeCoreInfo &xe_core) {
15-
io.mapRequired("num_threads", xe_core.num_threads);
16-
io.mapRequired("shared_memory", xe_core.shared_memory);
17-
io.mapRequired("num_vector_units", xe_core.num_vector_units);
18-
io.mapRequired("num_matrix_units", xe_core.num_matrix_units);
10+
namespace mlir {
11+
namespace xegpu {
12+
namespace uArch {
13+
namespace PVCuArch {
14+
bool DPASInstruction::checkSupportedMMATypes(mlir::Type AType, mlir::Type BType,
15+
mlir::Type CType,
16+
mlir::Type DType) {
17+
if (AType.isF16() || BType.isF16()) {
18+
if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
19+
(!DType.isF32() && !DType.isF16()))
20+
llvm::errs()
21+
<< "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
22+
<< "Supported types are:\n"
23+
<< " Dst | Acc | A | B \n"
24+
<< " f, hf | f, hf | hf | hf \n"
25+
<< "AType: " << AType << " BType: " << BType << " CType: " << CType
26+
<< " DType: " << DType;
27+
return false;
28+
} else if (AType.isBF16() || BType.isBF16()) {
29+
if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) ||
30+
(!DType.isF32() && !DType.isBF16()))
31+
llvm::errs()
32+
<< "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
33+
<< "Supported types are:\n"
34+
<< " Dst | Acc | A | B \n"
35+
<< " f, bf | f, bf | bf | bf \n"
36+
<< "AType: " << AType << " BType: " << BType << " CType: " << CType
37+
<< " DType: " << DType;
38+
return false;
39+
} else if (AType.isTF32() || BType.isTF32()) {
40+
if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) ||
41+
(!DType.isF32()))
42+
llvm::errs()
43+
<< "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
44+
<< "Supported types are:\n"
45+
<< " Dst | Acc | A | B \n"
46+
<< " f | f | tf32 | tf32 \n"
47+
<< "AType: " << AType << " BType: " << BType << " CType: " << CType
48+
<< " DType: " << DType;
49+
return false;
50+
} else if (!(AType.isInteger(2) || AType.isInteger(4) ||
51+
AType.isInteger(8)) &&
52+
!(BType.isInteger(2) || BType.isInteger(4) ||
53+
BType.isInteger(8))) {
54+
llvm::errs()
55+
<< "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
56+
<< "Supported types are:\n"
57+
<< " Dst | Acc | A | B "
58+
" \n"
59+
<< " ud, d | ud,d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2 "
60+
<< "AType: " << AType << " BType: " << BType << " CType: " << CType
61+
<< " DType: " << DType;
62+
return false;
1963
}
20-
};
2164

22-
template <>
23-
struct MappingTraits<Xe2Plus> {
24-
static void mapping(IO &io, Xe2Plus &xe2plus) {
25-
io.mapRequired("xe_core", xe2plus.xe_core);
65+
return true;
66+
}
67+
68+
std::vector<uint> DPASInstruction::getSupportedM(mlir::Type type) {
69+
return {1, 2, 3, 4, 5, 6, 7, 8};
70+
}
71+
72+
std::vector<uint> DPASInstruction::getSupportedK(mlir::Type type) {
73+
// assert if data type is not int or float type
74+
assert(type.isIntOrFloat() && "Matrix type must be int or float");
75+
auto bitWidth = type.getIntOrFloatBitWidth();
76+
uint kSize = -1;
77+
switch (bitWidth) {
78+
case 2:
79+
kSize = 64;
80+
break;
81+
case 4:
82+
kSize = 64;
83+
break;
84+
case 8:
85+
kSize = 32;
86+
break;
87+
case 16:
88+
kSize = 16;
89+
break;
90+
case 32:
91+
kSize = 8;
92+
break;
93+
default:
94+
llvm_unreachable("Invalid int or float");
95+
}
96+
}
97+
98+
std::vector<uint> DPASInstruction::getSupportedN(mlir::Type type) {
99+
return {16};
100+
}
101+
102+
std::vector<std::pair<unsigned, unsigned>>
103+
DPASInstruction::getSupportedMatrix(mlir::Type type, MatrixType matrixType) {
104+
auto combineVectors = [](const std::vector<unsigned> &a,
105+
const std::vector<unsigned> &b)
106+
-> std::vector<std::pair<unsigned, unsigned>> {
107+
std::vector<std::pair<unsigned, unsigned>> result;
108+
for (unsigned x : a) {
109+
for (unsigned y : b) {
110+
result.emplace_back(x, y);
111+
}
112+
}
113+
return result;
114+
};
115+
116+
auto M = getSupportedM(type);
117+
auto K = getSupportedK(type);
118+
auto N = getSupportedN(type);
119+
std::vector<std::pair<unsigned, unsigned>> resultMatrix;
120+
121+
switch (matrixType) {
122+
case MatrixType::MatrixA:
123+
resultMatrix = combineVectors(M, K);
124+
break;
125+
case MatrixType::MatrixB:
126+
resultMatrix = combineVectors(K, N);
127+
break;
128+
case MatrixType::MatrixC:
129+
resultMatrix = combineVectors(M, N);
130+
break;
131+
case MatrixType::MatrixD:
132+
resultMatrix = combineVectors(M, N);
133+
break;
134+
default:
135+
break;
26136
}
27-
};
28-
} // namespace yaml
29-
} // namespace llvm
137+
}
138+
139+
} // namespace PVCuArch
140+
} // namespace uArch
141+
} // namespace xegpu
142+
} // namespace mlir
143+
144+
// namespace llvm {
145+
// namespace yaml {
146+
// template <>
147+
// struct MappingTraits<XeCoreInfo> {
148+
// static void mapping(IO &io, XeCoreInfo &xe_core) {
149+
// io.mapRequired("num_threads", xe_core.num_threads);
150+
// io.mapRequired("shared_memory", xe_core.shared_memory);
151+
// io.mapRequired("num_vector_units", xe_core.num_vector_units);
152+
// io.mapRequired("num_matrix_units", xe_core.num_matrix_units);
153+
// }
154+
// };
155+
156+
// template <>
157+
// struct MappingTraits<Xe2Plus> {
158+
// static void mapping(IO &io, Xe2Plus &xe2plus) {
159+
// io.mapRequired("xe_core", xe2plus.xe_core);
160+
// }
161+
// };
162+
// } // namespace yaml
163+
// } // namespace llvm
30164

31165
// namespace mlir {
32166
// namespace xe_gpu {

0 commit comments

Comments
 (0)