14
14
#ifndef MLIR_DIALECT_XEGPU_UTILS_UARCH_H
15
15
#define MLIR_DIALECT_XEGPU_UTILS_UARCH_H
16
16
17
+ #include < any>
17
18
#include < functional>
18
19
#include < iostream>
20
+ #include < map>
19
21
#include < mutex>
20
22
#include < shared_mutex>
21
23
#include < tuple>
22
24
25
+ #include " mlir/IR/Types.h"
26
+
23
27
namespace mlir {
24
28
namespace xegpu {
25
29
namespace uArch {
@@ -37,8 +41,8 @@ struct Range {
37
41
// dim: [2, 2]
38
42
// This represents a 2x2 tile
39
43
struct Tile {
40
- uint no_of_dims;
41
- std::vector<uint > dims;
44
+ uint32_t no_of_dims;
45
+ std::vector<uint32_t > dims;
42
46
};
43
47
44
48
// RangeTile represents a range of tiles instead of a single tile
@@ -52,7 +56,7 @@ struct Tile {
52
56
// This represents a 2x2 RangeTile where the first dimension can have values
53
57
// from 1 to 32 and the second dimension can have values from 2 to 16
54
58
struct RangeTile {
55
- uint no_of_dims;
59
+ uint32_t no_of_dims;
56
60
std::vector<Range> dims;
57
61
};
58
62
@@ -68,8 +72,8 @@ struct RangeTile {
68
72
// This represents a 2x2 DiscreteTile where the first dimension can have values
69
73
// 1, 2, 4, 8, 16, 32 and the second dimension can have values 2, 4, 8, 16
70
74
struct DiscreteTile {
71
- uint no_of_dims;
72
- std::vector<std::vector<uint >> dims;
75
+ uint32_t no_of_dims;
76
+ std::vector<std::vector<uint32_t >> dims;
73
77
};
74
78
75
79
// Restriction struct
@@ -93,9 +97,9 @@ struct DiscreteTile {
93
97
template <typename ... Args>
94
98
struct Restriction {
95
99
std::tuple<Args...> data;
96
- std::function<void (Args...)> func;
100
+ std::function<bool (Args...)> func;
97
101
98
- Restriction (Args... args, std::function<void (Args...)> f)
102
+ Restriction (Args... args, std::function<bool (Args...)> f)
99
103
: data(args...), func(f) {}
100
104
101
105
bool validate () { return std::apply (func, data); }
@@ -107,9 +111,9 @@ struct uArchHierarchyComponent {
107
111
std::string name = " " ; // optional name of the hierarchy component
108
112
// no. of lower hierarchy component it contains, e.g., for PVC XeCore it
109
113
// contains 8 threads, so no_of_component=8
110
- uint no_of_component;
114
+ uint32_t no_of_component;
111
115
// Constructor
112
- uArchHierarchyComponent (const std::string &name, uint no_of_component)
116
+ uArchHierarchyComponent (const std::string &name, uint32_t no_of_component)
113
117
: name(name), no_of_component(no_of_component) {}
114
118
};
115
119
@@ -203,35 +207,37 @@ struct Instruction {
203
207
204
208
// A struct to represent register file information
205
209
struct RegisterFileInfo {
206
- uint size; // size per register in bits
210
+ uint32_t size; // size per register in bits
207
211
std::vector<std::string> mode; // e.g., "small", "large" GRF modes
208
- std::vector<uint >
212
+ std::vector<uint32_t >
209
213
num_regs_per_thread_per_mode; // number of registers per thread per mode
210
- uint num_banks;
211
- uint bank_size;
214
+ uint32_t num_banks;
215
+ uint32_t bank_size;
212
216
213
217
// Constructor
214
- RegisterFileInfo (uint size, const std::vector<std::string> &mode,
215
- const std::vector<uint> &numRegs, uint num_banks,
216
- uint bank_size)
218
+ RegisterFileInfo () = default ;
219
+ RegisterFileInfo (uint32_t size, const std::vector<std::string> &mode,
220
+ const std::vector<uint32_t > &numRegs, uint32_t num_banks,
221
+ uint32_t bank_size)
217
222
: size(size), mode(mode), num_regs_per_thread_per_mode(numRegs),
218
223
num_banks (num_banks), bank_size(bank_size) {}
219
224
};
220
225
221
226
// A struct to represent cache information
222
227
struct CacheInfo {
223
- uint size;
224
- uint line_size;
228
+ uint32_t size;
229
+ uint32_t line_size;
225
230
// At which component level the cache is shared
226
231
uArchHierarchyComponent component;
227
- // uint associativity;
228
- // uint num_banks;
229
- // uint bank_size;
230
- // uint num_ports;
231
- // uint port_width;
232
- // uint bank_conflicts;
232
+ // uint32_t associativity;
233
+ // uint32_t num_banks;
234
+ // uint32_t bank_size;
235
+ // uint32_t num_ports;
236
+ // uint32_t port_width;
237
+ // uint32_t bank_conflicts;
233
238
// Constructor
234
- CacheInfo (uint size, uint line_size, const uArchHierarchyComponent &component)
239
+ CacheInfo (uint32_t size, uint32_t line_size,
240
+ const uArchHierarchyComponent &component)
235
241
: size(size), line_size(line_size), component(component) {}
236
242
};
237
243
@@ -274,6 +280,7 @@ struct uArch {
274
280
std::vector<Restriction<> *> restrictions;
275
281
276
282
// Constructor
283
+ uArch () = default ;
277
284
uArch (const std::string &name, const std::string &description,
278
285
const std::vector<uArchHierarchyComponent> &uArch_hierarchy = {},
279
286
const std::map<std::string, RegisterFileInfo> ®ister_file_info = {},
@@ -287,48 +294,49 @@ struct uArch {
287
294
288
295
// A struct to represent shared memory information
289
296
struct SharedMemory {
290
- uint size; // in bytes
291
- uint alignment; // in bytes
297
+ uint32_t size; // in bytes
298
+ uint32_t alignment; // in bytes
292
299
// @TODO: Add more fields as needed
293
- // uint latency;
294
- // uint throughput;
295
- // uint bandwidth;
296
- // uint num_ports;
297
- // uint port_width;
298
- // uint bank_size;
299
- // uint bank_conflicts;
300
- // uint num_banks;
300
+ // uint32_t latency;
301
+ // uint32_t throughput;
302
+ // uint32_t bandwidth;
303
+ // uint32_t num_ports;
304
+ // uint32_t port_width;
305
+ // uint32_t bank_size;
306
+ // uint32_t bank_conflicts;
307
+ // uint32_t num_banks;
301
308
302
309
// Constructor
303
- SharedMemory (uint size, uint alignment) : size(size), alignment(alignment) {}
310
+ SharedMemory (uint32_t size, uint32_t alignment)
311
+ : size(size), alignment(alignment) {}
304
312
};
305
313
306
314
// For future use case in Xe4+
307
315
308
316
// struct EUInfo {
309
- // uint num_eu_threads;
317
+ // uint32_t num_eu_threads;
310
318
// SharedMemory shared_memory;
311
319
// };
312
320
313
- // uint num_simd_units;
314
- // uint num_spus;
315
- // uint num_smt;
316
- // uint num_hardware_threads;
317
- // uint num_threads_per_spu;
318
- // uint num_threads_per_simd_unit;
319
- // uint num_threads_per_hardware_thread;
320
- // uint num_threads_per_smt;
321
+ // uint32_t num_simd_units;
322
+ // uint32_t num_spus;
323
+ // uint32_t num_smt;
324
+ // uint32_t num_hardware_threads;
325
+ // uint32_t num_threads_per_spu;
326
+ // uint32_t num_threads_per_simd_unit;
327
+ // uint32_t num_threads_per_hardware_thread;
328
+ // uint32_t num_threads_per_smt;
321
329
// SharedMemory shared_memory;
322
330
// };
323
331
324
332
// A struct to represent a GPU uArch
325
333
// This struct is used to represent the GPU microarchitecture of a target device
326
334
// struct GPUuArch : public uArch {
327
- // uint num_compute_units;
328
- // uint num_vector_units;
329
- // uint num_scalar_units;
330
- // uint num_tensor_units;
331
- // uint num_matrix_units;
335
+ // uint32_t num_compute_units;
336
+ // uint32_t num_vector_units;
337
+ // uint32_t num_scalar_units;
338
+ // uint32_t num_tensor_units;
339
+ // uint32_t num_matrix_units;
332
340
// SharedMemory shared_memory;
333
341
// };
334
342
@@ -346,17 +354,17 @@ struct TileOpInterface {
346
354
// @param surface_pitch, suface pitch
347
355
// @param array_len, array length
348
356
virtual bool validate (Tile tile, Tile surface, mlir::Type dataType,
349
- uint surface_pitch, uint array_len = 1 ) = 0;
357
+ uint32_t surface_pitch, uint32_t array_len = 1 ) = 0;
350
358
virtual ~TileOpInterface () = default ;
351
359
};
352
360
353
361
enum class MatrixType { MatrixA, MatrixB, MatrixC, MatrixD };
354
362
struct MatrixOpInterface {
355
363
virtual bool checkSupportedMMATypes (mlir::Type AType, mlir::Type BType,
356
364
mlir::Type CType, mlir::Type DType) = 0;
357
- virtual std::vector<uint > getSupportedM (mlir::Type type) = 0;
358
- virtual std::vector<uint > getSupportedK (mlir::Type type) = 0;
359
- virtual std::vector<uint > getSupportedN (mlir::Type type) = 0;
365
+ virtual std::vector<uint32_t > getSupportedM (mlir::Type type) = 0;
366
+ virtual std::vector<uint32_t > getSupportedK (mlir::Type type) = 0;
367
+ virtual std::vector<uint32_t > getSupportedN (mlir::Type type) = 0;
360
368
virtual std::vector<std::pair<unsigned , unsigned >>
361
369
getSupportedMatrix (mlir::Type type, MatrixType matrixType) = 0 ;
362
370
@@ -373,13 +381,14 @@ struct uArchMap {
373
381
374
382
// Insert or update a key-value pair
375
383
void insert (const std::string &key, uArch value) {
376
- std::unique_lock lock (mutex_);
377
- map_[key] = value;
384
+ std::unique_lock<std::shared_mutex> lock (mutex_);
385
+ // map_[key] = value;
386
+ map_.emplace (key, value);
378
387
}
379
388
380
389
// Get a value by key (concurrent safe read)
381
390
std::optional<uArch> get (const std::string &key) const {
382
- std::shared_lock lock (mutex_);
391
+ std::shared_lock<std::shared_mutex> lock (mutex_);
383
392
auto it = map_.find (key);
384
393
if (it != map_.end ())
385
394
return it->second ;
@@ -388,13 +397,13 @@ struct uArchMap {
388
397
389
398
// Check if a key exists
390
399
bool contains (const std::string &key) const {
391
- std::shared_lock lock (mutex_);
400
+ std::shared_lock<std::shared_mutex> lock (mutex_);
392
401
return map_.find (key) != map_.end ();
393
402
}
394
403
395
404
// Remove a key
396
405
bool erase (const std::string &key) {
397
- std::unique_lock lock (mutex_);
406
+ std::unique_lock<std::shared_mutex> lock (mutex_);
398
407
return map_.erase (key) > 0 ;
399
408
}
400
409
0 commit comments