9
9
#include < random>
10
10
#include < sycl/usm.hpp>
11
11
12
+ #ifdef SLM
13
+ #include " slm_utils.hpp"
14
+ #endif
15
+
12
16
// number of test iterations
13
17
constexpr unsigned int testIterations = 100 ;
14
18
// start recording time after X iterations
@@ -51,6 +55,12 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
51
55
std::chrono::high_resolution_clock::now ();
52
56
53
57
q.submit ([&](handler &h) {
58
+ #ifdef SLM
59
+ local_accessor<TOperand, 2 > tileA{{MCache2, KCache2}, h};
60
+ local_accessor<TOperand, 2 > tileB{
61
+ {KCache2 / vnniFactor, NCache2 * vnniFactor}, h};
62
+ #endif
63
+
54
64
h.parallel_for <MatMul<TM, TN, TK>>( // cache layer#1
55
65
nd_range<2 >{global, cachelocal},
56
66
// loop global
@@ -60,15 +70,16 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
60
70
[[intel::reqd_sub_group_size (SG_SZ)]]
61
71
#endif // SG_SZ
62
72
{
73
+ // sg::load and sg::store expect decorations to be ON
63
74
auto pA =
64
75
address_space_cast<sycl::access::address_space::global_space,
65
- sycl::access::decorated::no >(A);
76
+ sycl::access::decorated::yes >(A);
66
77
auto pB =
67
78
address_space_cast<sycl::access::address_space::global_space,
68
- sycl::access::decorated::no >(B);
79
+ sycl::access::decorated::yes >(B);
69
80
auto pC =
70
81
address_space_cast<sycl::access::address_space::global_space,
71
- sycl::access::decorated::no >(C);
82
+ sycl::access::decorated::yes >(C);
72
83
auto m2 = it.get_group (0 );
73
84
auto n2 = it.get_group (1 );
74
85
auto m1 = it.get_local_id (0 );
@@ -112,7 +123,6 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
112
123
colsA, layout::row_major,
113
124
syclex::properties{syclex::prefetch_hint_L1});
114
125
115
- #ifdef VNNI
116
126
for (int p = 0 ; p < prefDistance; p++)
117
127
joint_matrix_prefetch<prefRow, prefCol>(
118
128
sg,
@@ -122,15 +132,6 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
122
132
(n2 * NCache2 * vnniFactor + pn1B * prefCol),
123
133
colsB * vnniFactor, layout::row_major,
124
134
syclex::properties{syclex::prefetch_hint_L1});
125
- #else // VNNI
126
- for (int p = 0 ; p < prefDistance; p++)
127
- joint_matrix_prefetch<prefRow, prefCol>(
128
- sg,
129
- B + (p * KCache2 + pm1B * prefRow) * colsB + n2 * NCache2 +
130
- pn1B * prefCol,
131
- colsB, layout::row_major,
132
- syclex::properties{syclex::prefetch_hint_L1});
133
- #endif // VNNI
134
135
#endif // PREFETCH
135
136
136
137
joint_matrix<sub_group, TResult, use::accumulator, TM, TN>
@@ -157,7 +158,16 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
157
158
}
158
159
#endif // MANUAL_UNROLL
159
160
161
+ #ifdef SLM
162
+ constexpr unsigned int SGs =
163
+ (MCache2 / MCache1) * (NCache2 / NCache1);
164
+ #endif // SLM
160
165
for (unsigned int k2 = 0 ; k2 < colsA / KCache2; k2++) {
166
+ #ifdef SLM
167
+ slm_read_write<colsA, colsB, MCache2, NCache2, KCache2, vnniFactor,
168
+ SGs>(pA, pB, tileA, tileB, sg, k2, m2, n2, sgSize);
169
+ it.barrier (access::fence_space::local_space);
170
+ #endif // SLM
161
171
joint_matrix<sub_group, TOperand, use::a, TM, TK, layout::row_major>
162
172
tA[MCache1 / TM][KCache2 / KCache1]
163
173
#ifdef INIT_LIST
@@ -192,6 +202,14 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
192
202
#else // MANUAL_UNROLL
193
203
for (unsigned int m = 0 ; m < MCache1 / TM; m++) {
194
204
#endif // MANUAL_UNROLL
205
+ #ifdef SLM
206
+ joint_matrix_load (sg, tA[m][k1],
207
+ tileA.template get_multi_ptr <
208
+ sycl::access::decorated::no>() +
209
+ (m1 * MCache1 + m * TM) * KCache2 +
210
+ k1 * TK,
211
+ KCache2);
212
+ #else // SLM
195
213
#ifdef OOB
196
214
ext::intel::experimental::matrix::joint_matrix_load_checked (
197
215
sg, tA[m][k1], pA, colsA, rowsA, colsA,
@@ -203,6 +221,7 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
203
221
k * TK,
204
222
colsA);
205
223
#endif // OOB
224
+ #endif // SLM
206
225
#ifdef MANUAL_UNROLL
207
226
}); // m
208
227
#else // MANUAL_UNROLL
@@ -213,32 +232,28 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
213
232
#else // MANUAL_UNROLL
214
233
for (unsigned int n = 0 ; n < NCache1 / TN; n++) {
215
234
#endif // MANUAL_UNROLL
235
+ #ifdef SLM
236
+ joint_matrix_load (sg, tB[n][k1],
237
+ tileB.template get_multi_ptr <
238
+ sycl::access::decorated::no>() +
239
+ (k1 * TK / vnniFactor) *
240
+ (NCache2 * vnniFactor) +
241
+ (n1 * NCache1 + n * TN) * vnniFactor,
242
+ NCache2 * vnniFactor);
243
+ #else // SLM
216
244
#ifdef OOB
217
- #ifdef VNNI
218
245
ext::intel::experimental::matrix::joint_matrix_load_checked (
219
246
sg, tB[n][k1], pB, colsB * vnniFactor, rowsB / vnniFactor,
220
247
colsB * vnniFactor, k * TK / vnniFactor,
221
248
(n2 * NCache2 + n1 * NCache1 + n * TN) * vnniFactor);
222
- #else // VNNI
223
- ext::intel::experimental::matrix::joint_matrix_load_checked (
224
- sg, tB[n][k1], pB, colsB, rowsB, colsB, k * TK,
225
- n2 * NCache2 + n1 * NCache1 + n * TN);
226
-
227
- #endif // VNNI
228
249
#else // OOB
229
- #ifdef VNNI
230
250
joint_matrix_load (
231
251
sg, tB[n][k1],
232
252
pB + (k * TK / vnniFactor) * (colsB * vnniFactor) +
233
253
(n2 * NCache2 + n1 * NCache1 + n * TN) * vnniFactor,
234
254
colsB * vnniFactor);
235
- #else // VNNI
236
- joint_matrix_load (sg, tB[n][k1],
237
- pB + (k * TK) * (colsB) +
238
- (n2 * NCache2 + n1 * NCache1 + n * TN),
239
- colsB);
240
- #endif // VNNI
241
255
#endif // OOB
256
+ #endif // SLM
242
257
#ifdef MANUAL_UNROLL
243
258
}); // n
244
259
#else // MANUAL_UNROLL
@@ -266,6 +281,9 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
266
281
} // m
267
282
} // k1
268
283
#endif // MANUAL_UNROLL
284
+ #ifdef SLM
285
+ it.barrier (access::fence_space::local_space);
286
+ #endif // SLM
269
287
#ifdef PREFETCH
270
288
auto prefetch_offsetA = (m2 * MCache2 + sgId * prefRow) * colsA +
271
289
(k2 + prefDistance) * prefCol;
@@ -275,7 +293,6 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
275
293
sg, A + prefetch_offsetA, colsA, layout::row_major,
276
294
syclex::properties{syclex::prefetch_hint_L1});
277
295
278
- #ifdef VNNI
279
296
auto prefetch_offsetB =
280
297
((k2 + prefDistance) * (KCache2 / vnniFactor) +
281
298
pm1B * prefRow) *
@@ -287,16 +304,6 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
287
304
sg, B + prefetch_offsetB, colsB * vnniFactor,
288
305
layout::row_major,
289
306
syclex::properties{syclex::prefetch_hint_L1});
290
- #else // VNNI
291
- auto prefetch_offsetB =
292
- ((k2 + prefDistance) * KCache2 + pm1B * prefRow) * (colsB) +
293
- (n2 * NCache2 + pn1B * prefCol);
294
- if ((prefetch_offsetB + (prefRow * MATRIX_SIZE) + prefCol) <
295
- (MATRIX_SIZE * MATRIX_SIZE))
296
- joint_matrix_prefetch<prefRow, prefCol>(
297
- sg, B + prefetch_offsetB, colsB, layout::row_major,
298
- syclex::properties{syclex::prefetch_hint_L1});
299
- #endif // VNNI
300
307
#endif // PREFETCH
301
308
} // for k2
302
309
#ifdef MANUAL_UNROLL
@@ -411,29 +418,33 @@ int main() {
411
418
constexpr size_t NCache2 = 256 ;
412
419
constexpr size_t KCache2 = 32 ;
413
420
421
+ #ifdef VNNI
422
+ constexpr unsigned int VnniFactor = 2 ;
423
+ #else // VNNI
424
+ constexpr unsigned int VnniFactor = 1 ;
425
+ #endif // VNNI
426
+
414
427
for (unsigned int i = 0 ; i < combinations.size (); i++) {
415
428
if (combinations[i].nsize == 0 ) { // Intel AMX
416
429
constexpr size_t NCache1 = 32 ;
417
430
constexpr size_t KCache1 = 32 ;
418
-
419
- test<bfloat16, float , 2 , /* TM*/ 16 , /* TN*/ 16 , /* TK*/ 32 , MCache1,
420
- NCache1, KCache1, MCache2, NCache2, KCache2>();
431
+ test<bfloat16, float , VnniFactor, /* TM*/ 16 , /* TN*/ 16 , /* TK*/ 32 ,
432
+ MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>();
421
433
break ;
422
434
}
423
435
424
436
if (combinations[i].nsize == 16 ) { // architecture::intel_gpu_pvc
425
437
constexpr size_t NCache1 = 4 * /* TN*/ 16 ;
426
438
constexpr size_t KCache1 = 16 ;
427
-
428
- test<bfloat16, float , 2 , /* TM*/ 8 , /* TN*/ 16 , /* TK*/ 16 , MCache1, NCache1,
429
- KCache1, MCache2, NCache2, KCache2>();
439
+ test<bfloat16, float , VnniFactor, /* TM*/ 8 , /* TN*/ 16 , /* TK*/ 16 , MCache1,
440
+ NCache1, KCache1, MCache2, NCache2, KCache2>();
430
441
#if (!defined(SG_SZ) || SG_SZ != 32)
431
442
// These combination are not currently supported for subgroup size = 32 in
432
443
// IGC
433
- test<bfloat16, float , 2 , /* TM*/ 16 , /* TN*/ 16 , /* TK*/ 16 , MCache1 ,
434
- NCache1, KCache1, MCache2, NCache2, KCache2>();
435
- test<bfloat16, float , 2 , /* TM*/ 32 , /* TN*/ 64 , /* TK*/ 16 , MCache1 ,
436
- NCache1, KCache1, MCache2, NCache2, KCache2>();
444
+ test<bfloat16, float , VnniFactor , /* TM*/ 16 , /* TN*/ 16 , /* TK*/ 16 ,
445
+ MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>();
446
+ test<bfloat16, float , VnniFactor , /* TM*/ 32 , /* TN*/ 64 , /* TK*/ 16 ,
447
+ MCache1, NCache1, KCache1, MCache2, NCache2, KCache2>();
437
448
#endif
438
449
break ;
439
450
}
@@ -442,9 +453,10 @@ int main() {
442
453
constexpr size_t NCache1 = 4 * /* TN*/ 8 ;
443
454
constexpr size_t KCache1 = 16 ;
444
455
445
- test<bfloat16, float , 2 , /* TM*/ 8 , /* TN*/ 8 , /* TK*/ 16 , MCache1, NCache1,
446
- KCache1, MCache2, NCache2, KCache2>();
447
- // test<bfloat16, float, 2, /*TM*/ 32, /*TN*/ 32, /*TK*/ 16, MCache1,
456
+ test<bfloat16, float , VnniFactor, /* TM*/ 8 , /* TN*/ 8 , /* TK*/ 16 , MCache1,
457
+ NCache1, KCache1, MCache2, NCache2, KCache2>();
458
+ // test<bfloat16, float, VnniFactor, /*TM*/ 32, /*TN*/ 32, /*TK*/ 16,
459
+ // MCache1,
448
460
// NCache1, KCache1, MCache2, NCache2, KCache2>();
449
461
break ;
450
462
}
0 commit comments