Skip to content

Commit 03b6eba

Browse files
committed
Add necessary infrastructures for the uArch to show the full pipeline.
1 parent 9789835 commit 03b6eba

File tree

3 files changed

+216
-21
lines changed

3 files changed

+216
-21
lines changed

mlir/include/mlir/Dialect/XeGPU/Utils/IntelGpuPVC.h

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,33 @@
2424
namespace mlir {
2525
namespace xegpu {
2626
namespace uArch {
27-
namespace PVCuArch {
27+
namespace Xe2Plus {
2828
struct XeCoreInfo {
2929
uint num_threads;
3030
SharedMemory shared_memory;
3131
uint num_vector_units;
3232
uint num_matrix_units;
33+
34+
// Constructor
35+
XeCoreInfo(uint num_threads, const SharedMemory &shared_memory,
36+
uint num_vector_units, uint num_matrix_units)
37+
: num_threads(num_threads), shared_memory(shared_memory),
38+
num_vector_units(num_vector_units), num_matrix_units(num_matrix_units) {
39+
}
3340
};
3441

3542
struct Xe2Plus : public uArch {
3643
XeCoreInfo xe_core;
44+
Xe2Plus(const std::string &archName, const std::string &archDescription,
45+
const XeCoreInfo &xeCore,
46+
const std::vector<uArchHierarchyComponent> &hierarchy = {},
47+
const std::map<std::string, RegisterFileInfo> &regInfo = {},
48+
const std::vector<CacheInfo> &cacheInfo = {},
49+
const std::map<std::string, Instruction *> &instrs = {},
50+
const std::vector<Restriction<> *> &restrs = {})
51+
: uArch(archName, archDescription, hierarchy, regInfo, cacheInfo, instrs,
52+
restrs),
53+
xe_core(xeCore) {}
3754
};
3855

3956
// struct to represent DPAS instruction
@@ -48,6 +65,18 @@ struct DPASInstruction : public Instruction, public MatrixOpInterface {
4865

4966
// bool checkSupportedDPASTypes(mlir::Type dstType, mlir::Type src0Type,
5067
// mlir::Type src1Type, mlir::Type src2Type);
68+
69+
DPASInstruction()
70+
: Instruction("dpas", // name
71+
"Dot Product Accumulate", // description
72+
"0xABCD", // opcode
73+
FunctionalUnit::Matrix, // functional_unit
74+
InstructionType::SIMD, // type
75+
InstructionScope::Subgroup, // scope
76+
UnitOfComputation::Matrix) // unit_of_computation
77+
{}
78+
79+
// Override all virtuals from MatrixOpInterface
5180
virtual bool checkSupportedMMATypes(mlir::Type AType, mlir::Type BType,
5281
mlir::Type CType,
5382
mlir::Type DType) override;
@@ -99,7 +128,51 @@ struct LoadStorePrefetch2DInstruction : public Instruction {
99128
}
100129
};
101130

131+
namespace PVCuArch {
132+
struct PVCuArch : public Xe2Plus {
133+
// Maintaines ownership of the instructions owned by PVUarch
134+
std::vector<std::unique_ptr<Instruction>> owned_instructions;
135+
PVCuArch()
136+
: Xe2Plus("pvc", // archName
137+
"Ponte Vecchio Architecture", // archDescription
138+
XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8), // xeCore
139+
{/* register_file_info */}, // Optional: empty
140+
{/* cache_info */}, // Optional: empty
141+
{/* instructions */}, // Optional: empty
142+
{/* restrictions */} // Optional: empty
143+
) {
144+
// Initialize uArchHierarchy
145+
this->uArch_hierarchy.push_back(uArchHierarchyComponent("thread", 0));
146+
this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeCore", 8));
147+
this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeSlice", 16));
148+
this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeStack", 4));
149+
this->uArch_hierarchy.push_back(uArchHierarchyComponent("gpu", 2));
150+
// Intialize register file info
151+
// GRF
152+
this->register_file_info["GRF"] =
153+
RegisterFileInfo(64 * 1024, // size in bits
154+
{"small", "large"}, // GRF modes
155+
{128, 256}, // registers per thread per mode
156+
0, // number of banks
157+
0 // bank size
158+
);
159+
// Initialize cache info
160+
// L1 cache, XeCore level
161+
this->cache_info.push_back(
162+
CacheInfo(512 * 1024, 64, this->uArch_hierarchy[1]));
163+
// L3 cache, XeStack level
164+
this->cache_info.push_back(
165+
CacheInfo(512 * 1024, 64, this->uArch_hierarchy[3]));
166+
167+
// Add the instructions
168+
auto dpas = std::make_unique<DPASInstruction>();
169+
instructions[dpas->name] = dpas.get();
170+
owned_instructions.push_back(std::move(dpas));
171+
}
172+
};
102173
} // namespace PVCuArch
174+
175+
} // namespace Xe2Plus
103176
} // namespace uArch
104177
} // namespace xegpu
105178
} // namespace mlir

mlir/include/mlir/Dialect/XeGPU/Utils/uArch.h

Lines changed: 130 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616

1717
#include <functional>
1818
#include <iostream>
19+
#include <mutex>
20+
#include <shared_mutex>
1921
#include <tuple>
22+
2023
namespace mlir {
2124
namespace xegpu {
2225
namespace uArch {
@@ -99,6 +102,17 @@ struct Restriction {
99102
std::any apply() { return std::apply(func, data); }
100103
};
101104

105+
// Architecture HW component hierarchy to present thread, core, socket ...
106+
struct uArchHierarchyComponent {
107+
std::string name = ""; // optional name of the hierarchy component
108+
// no. of lower hierarchy component it contains, e.g., for PVC XeCore it
109+
// contains 8 threads, so no_of_component=8
110+
uint no_of_component;
111+
// Constructor
112+
uArchHierarchyComponent(const std::string &name, uint no_of_component)
113+
: name(name), no_of_component(no_of_component) {}
114+
};
115+
102116
// An enum class to represent the functional unit of an instruction
103117
enum class FunctionalUnit {
104118
ALU,
@@ -179,6 +193,12 @@ struct Instruction {
179193
// std::string pipeline;
180194
// std::string resource;
181195
// std::string comment;
196+
Instruction(std::string name, std::string desc, std::string opcode,
197+
FunctionalUnit fu, InstructionType itype, InstructionScope sc,
198+
UnitOfComputation uoc)
199+
: name(std::move(name)), description(std::move(desc)),
200+
opcode(std::move(opcode)), functional_unit(fu), type(itype), scope(sc),
201+
unit_of_computation(uoc) {}
182202
};
183203

184204
// A struct to represent register file information
@@ -189,18 +209,30 @@ struct RegisterFileInfo {
189209
num_regs_per_thread_per_mode; // number of registers per thread per mode
190210
uint num_banks;
191211
uint bank_size;
212+
213+
// Constructor
214+
RegisterFileInfo(uint size, const std::vector<std::string> &mode,
215+
const std::vector<uint> &numRegs, uint num_banks,
216+
uint bank_size)
217+
: size(size), mode(mode), num_regs_per_thread_per_mode(numRegs),
218+
num_banks(num_banks), bank_size(bank_size) {}
192219
};
193220

194221
// A struct to represent cache information
195222
struct CacheInfo {
196223
uint size;
197-
uint associativity;
198224
uint line_size;
199-
uint num_banks;
200-
uint bank_size;
201-
uint num_ports;
202-
uint port_width;
203-
uint bank_conflicts;
225+
// At which component level the cache is shared
226+
uArchHierarchyComponent component;
227+
// uint associativity;
228+
// uint num_banks;
229+
// uint bank_size;
230+
// uint num_ports;
231+
// uint port_width;
232+
// uint bank_conflicts;
233+
// Constructor
234+
CacheInfo(uint size, uint line_size, const uArchHierarchyComponent &component)
235+
: size(size), line_size(line_size), component(component) {}
204236
};
205237

206238
// A struct to represent the uArch
@@ -225,19 +257,38 @@ struct CacheInfo {
225257
struct uArch {
226258
std::string name; // similar to target triple
227259
std::string description;
260+
// Represent the whole uArch hierarchy
261+
// For 2 stack Intel PVC it would look something like this:
262+
// uArchHierarchy[0] = {thread, 0}
263+
// uArchHierarchy[1] = {XeCore, 8}
264+
// uArchHierarchy[2] = {XeSlice, 16}
265+
// uArchHierarchy[3] = {XeStack, 4}
266+
// uArchHierarchy[4] = {gpu, 2}
267+
std::vector<uArchHierarchyComponent> uArch_hierarchy;
228268
// Different kind of regiger file information (e.g., GRF, ARF, etc.)
229-
std::vector<RegisterFileInfo> register_file_info;
269+
std::map<std::string, RegisterFileInfo> register_file_info;
230270
// Each level of cache is indexed lower to higher in the vector
231271
// (e.g., L1 indexed at 0, L2 at 1 and so on) L1, L2, L3, etc.
232272
std::vector<CacheInfo> cache_info;
233-
std::vector<Instruction *> instructions;
273+
std::map<std::string, Instruction *> instructions;
234274
std::vector<Restriction<> *> restrictions;
275+
276+
// Constructor
277+
uArch(const std::string &name, const std::string &description,
278+
const std::vector<uArchHierarchyComponent> &uArch_hierarchy = {},
279+
const std::map<std::string, RegisterFileInfo> &register_file_info = {},
280+
const std::vector<CacheInfo> &cache_info = {},
281+
const std::map<std::string, Instruction *> &instructions = {},
282+
const std::vector<Restriction<> *> &restrictions = {})
283+
: name(name), description(description), uArch_hierarchy(uArch_hierarchy),
284+
register_file_info(register_file_info), cache_info(cache_info),
285+
instructions(instructions), restrictions(restrictions) {}
235286
};
236287

237288
// A struct to represent shared memory information
238289
struct SharedMemory {
239-
uint size;
240-
uint alignment;
290+
uint size; // in bytes
291+
uint alignment; // in bytes
241292
// @TODO: Add more fields as needed
242293
// uint latency;
243294
// uint throughput;
@@ -247,6 +298,9 @@ struct SharedMemory {
247298
// uint bank_size;
248299
// uint bank_conflicts;
249300
// uint num_banks;
301+
302+
// Constructor
303+
SharedMemory(uint size, uint alignment) : size(size), alignment(alignment) {}
250304
};
251305

252306
// For future use case in Xe4+
@@ -293,6 +347,7 @@ struct TileOpInterface {
293347
// @param array_len, array length
294348
virtual bool validate(Tile tile, Tile surface, mlir::Type dataType,
295349
uint surface_pitch, uint array_len = 1) = 0;
350+
virtual ~TileOpInterface() = default;
296351
};
297352

298353
enum class MatrixType { MatrixA, MatrixB, MatrixC, MatrixD };
@@ -304,11 +359,75 @@ struct MatrixOpInterface {
304359
virtual std::vector<uint> getSupportedN(mlir::Type type) = 0;
305360
virtual std::vector<std::pair<unsigned, unsigned>>
306361
getSupportedMatrix(mlir::Type type, MatrixType matrixType) = 0;
362+
363+
virtual ~MatrixOpInterface() = default;
307364
};
308365

366+
struct uArchMap {
367+
public:
368+
// Singleton instance
369+
static uArchMap &instance() {
370+
static uArchMap instance;
371+
return instance;
372+
}
373+
374+
// Insert or update a key-value pair
375+
void insert(const std::string &key, uArch value) {
376+
std::unique_lock lock(mutex_);
377+
map_[key] = value;
378+
}
379+
380+
// Get a value by key (concurrent safe read)
381+
std::optional<uArch> get(const std::string &key) const {
382+
std::shared_lock lock(mutex_);
383+
auto it = map_.find(key);
384+
if (it != map_.end())
385+
return it->second;
386+
return std::nullopt;
387+
}
388+
389+
// Check if a key exists
390+
bool contains(const std::string &key) const {
391+
std::shared_lock lock(mutex_);
392+
return map_.find(key) != map_.end();
393+
}
394+
395+
// Remove a key
396+
bool erase(const std::string &key) {
397+
std::unique_lock lock(mutex_);
398+
return map_.erase(key) > 0;
399+
}
400+
401+
private:
402+
uArchMap() = default;
403+
uArchMap(const uArchMap &) = delete;
404+
uArchMap &operator=(const uArchMap &) = delete;
405+
406+
mutable std::shared_mutex mutex_;
407+
std::map<std::string, uArch> map_;
408+
};
409+
410+
// std::unordered_map<std::string, uArch> uArchMap;
411+
// std::shared_mutex uArchMapMutex;
412+
413+
// void getuArch(const std::string &key) {
414+
// std::shared_lock<std::shared_mutex> lock(uArchMapMutex);
415+
// auto it = uArchMap.find(key);
416+
// if(it != uArchMap.end())
417+
// return *it;
418+
// else
419+
420+
// // safe concurrent read
421+
// }
422+
423+
// void AdduArch(const std::string &key, uArch &value) {
424+
// std::unique_lock<std::shared_mutex> lock(uArchMapMutex);
425+
426+
// // exclusive write
427+
// }
428+
309429
} // namespace uArch
310430
} // namespace xegpu
311431
} // namespace mlir
312432

313433
#endif // MLIR_DIALECT_XEGPU_UTILS_UARCH_H
314-
//===--- uArch.h ---------------------------------------*- C++ -*-===//

mlir/lib/Dialect/XeGPU/Utils/IntelGpuPVC.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,48 +5,51 @@
55
#include <vector>
66

77
using namespace mlir::xegpu::uArch;
8-
using namespace mlir::xegpu::uArch::PVCuArch;
8+
using namespace mlir::xegpu::uArch::Xe2Plus;
99

1010
namespace mlir {
1111
namespace xegpu {
1212
namespace uArch {
13-
namespace PVCuArch {
13+
namespace Xe2Plus {
1414
bool DPASInstruction::checkSupportedMMATypes(mlir::Type AType, mlir::Type BType,
1515
mlir::Type CType,
1616
mlir::Type DType) {
1717
if (AType.isF16() || BType.isF16()) {
1818
if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
19-
(!DType.isF32() && !DType.isF16()))
19+
(!DType.isF32() && !DType.isF16())) {
2020
llvm::errs()
2121
<< "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
2222
<< "Supported types are:\n"
2323
<< " Dst | Acc | A | B \n"
2424
<< " f, hf | f, hf | hf | hf \n"
2525
<< "AType: " << AType << " BType: " << BType << " CType: " << CType
2626
<< " DType: " << DType;
27-
return false;
27+
return false;
28+
}
2829
} else if (AType.isBF16() || BType.isBF16()) {
2930
if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) ||
30-
(!DType.isF32() && !DType.isBF16()))
31+
(!DType.isF32() && !DType.isBF16())) {
3132
llvm::errs()
3233
<< "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
3334
<< "Supported types are:\n"
3435
<< " Dst | Acc | A | B \n"
3536
<< " f, bf | f, bf | bf | bf \n"
3637
<< "AType: " << AType << " BType: " << BType << " CType: " << CType
3738
<< " DType: " << DType;
38-
return false;
39+
return false;
40+
}
3941
} else if (AType.isTF32() || BType.isTF32()) {
4042
if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) ||
41-
(!DType.isF32()))
43+
(!DType.isF32())) {
4244
llvm::errs()
4345
<< "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
4446
<< "Supported types are:\n"
4547
<< " Dst | Acc | A | B \n"
4648
<< " f | f | tf32 | tf32 \n"
4749
<< "AType: " << AType << " BType: " << BType << " CType: " << CType
4850
<< " DType: " << DType;
49-
return false;
51+
return false;
52+
}
5053
} else if (!(AType.isInteger(2) || AType.isInteger(4) ||
5154
AType.isInteger(8)) &&
5255
!(BType.isInteger(2) || BType.isInteger(4) ||
@@ -136,7 +139,7 @@ DPASInstruction::getSupportedMatrix(mlir::Type type, MatrixType matrixType) {
136139
}
137140
}
138141

139-
} // namespace PVCuArch
142+
} // namespace Xe2Plus
140143
} // namespace uArch
141144
} // namespace xegpu
142145
} // namespace mlir

0 commit comments

Comments
 (0)