@@ -2218,27 +2218,33 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) {
2218
2218
ldmatrix (addr, m4, trans, 3 );
2219
2219
}
2220
2220
2221
- // / Stores 1 8x8 b16 matrix from private memory to local memory per sub-group.
2221
+ // / Collectively stores 1 8x8 b16 (128 bytes) matrix from private memory to
2222
+ // / local memory per sub-group.
2222
2223
// / Requires the sub-group size of kernel calling this function to be 32.
2223
- // / Each of the first 8 work items contain the starting address of their
2224
- // / respective matrix row.
2225
- // / Each of the 32 work items store 32-bits (2 packed 16-bit data) for a total
2226
- // / of 128 bytes.
2227
- // / Row Major: Each row of the matrix is stored by a group of 4 work items
2228
- // / r0: t0 t1 t2 t3
2229
- // / r1: t4 t5 t6 t7
2224
+ // / 'mat' specifies the matrix index to be stored. The first '(mat + 1) * 8'
2225
+ // / work items of sub-group contain the starting address of their respective
2226
+ // / matrix row in 'addr'.
2227
+ // / After distributing addresses to other work items, each of the 32 work items
2228
+ // / store 32-bits (2 packed 16-bit data) into 'm' for a total of 128 bytes.
2229
+ // / 'trans' specifies to perform a transposed/non-transposed store by each work
2230
+ // / item like below
2231
+ // / Row Major: Each row of the matrix is stored by a group of 4 work items(wi)
2232
+ // / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2233
+ // / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2230
2234
// / ...
2231
- // / r7: t24 t25 t26 t27
2232
- // / r7: t28 t29 t30 t31
2233
- // / Col Major: Each col of the matrix is stored by a group of 4 work items
2234
- // / r0: t0 t4 t8 ... t28
2235
- // / r1: t0 t4 t8 ... t28
2235
+ // / row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2236
+ // / row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2237
+ // / Col Major: Each col of the matrix is stored by a group of 4 work items(wi)
2238
+ // / row-0: wi0 wi4 wi8 ... wi28
2239
+ // / row-1: wi0 wi4 wi8 ... wi28
2236
2240
// / ...
2237
- // / r6: t3 t7 t11 ... t31
2238
- // / r7: t3 t7 t11 ... t31
2239
- // / \tparam [in] T The type of matrix elements
2240
- // / \param [in] addr The address of the matrix in local memory
2241
- // / \param [in] m The private memory containing data of matrix
2241
+ // / row-6: wi3 wi7 wi11 ... wi31
2242
+ // / row-7: wi3 wi7 wi11 ... wi31
2243
+ // / \tparam [in] T Type of result variable (currently only supports 16-bit type)
2244
+ // / \param [in] addr The starting address of corresponding matrix row for a work
2245
+ // / item in local memory
2246
+ // / \param [in] m The private memory to store the matrix. It points to 2 b16
2247
+ // / type elements.
2242
2248
// / \param [in] trans Indicates whether the matrix to be stored transposed
2243
2249
// / \param [in] mat The matrix index to be stored
2244
2250
template <typename T>
@@ -2288,16 +2294,35 @@ void stmatrix(uintptr_t addr, T m, bool trans = false, unsigned mat = 0) {
2288
2294
}
2289
2295
}
2290
2296
2291
- // / Stores 2 8x8 b16 matrix from private memory to local memory per sub-group.
2297
+ // / Collectively stores 2 8x8 b16 (256 bytes) matrix from private memory to
2298
+ // / local memory per sub-group.
2292
2299
// / Requires the sub-group size of kernel calling this function to be 32.
2293
- // / Each of the first 16 work items contain the starting address of their
2294
- // / respective matrix row.
2295
- // / Each of the 32 work items store 64-bits (32-bit per matrix) for a total
2296
- // / of 256 bytes.
2297
- // / \tparam [in] T The type of matrix elements
2298
- // / \param [in] addr The address of the matrix in local memory
2299
- // / \param [in] m1 The private memory containing data of 1st matrix
2300
- // / \param [in] m2 The private memory containing data of 2nd matrix
2300
+ // / The first 16 work items of sub-group contain the starting address of their
2301
+ // / respective matrix row in 'addr'.
2302
+ // / After distributing addresses to other work items, each of the 32 work items
2303
+ // / store 64-bits (32-bits per matrix) into 'm1' & 'm2' for a total of 256
2304
+ // / bytes.
2305
+ // / 'trans' specifies to perform a transposed/non-transposed store by each work
2306
+ // / item like below
2307
+ // / Row Major: Each row of the matrices is stored by a group of 4 work items(wi)
2308
+ // / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2309
+ // / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2310
+ // / ...
2311
+ // / row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2312
+ // / row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2313
+ // / Col Major: Each col of the matrices is stored by a group of 4 work items(wi)
2314
+ // / row-0: wi0 wi4 wi8 ... wi28
2315
+ // / row-1: wi0 wi4 wi8 ... wi28
2316
+ // / ...
2317
+ // / row-6: wi3 wi7 wi11 ... wi31
2318
+ // / row-7: wi3 wi7 wi11 ... wi31
2319
+ // / \tparam [in] T Type of result variable (currently only supports 16-bit type)
2320
+ // / \param [in] addr The starting address of corresponding matrix row for a work
2321
+ // / item in local memory
2322
+ // / \param [in] m1 The private memory to store the data of 1st matrix. It points
2323
+ // / to 2 b16 type elements.
2324
+ // / \param [in] m2 The private memory to store the data of 2nd matrix. It points
2325
+ // / to 2 b16 type elements.
2301
2326
// / \param [in] trans Indicates whether the matrix to be stored transposed
2302
2327
template <typename T>
2303
2328
void stmatrix (uintptr_t addr, T m1, T m2, bool trans = false ) {
@@ -2307,18 +2332,39 @@ void stmatrix(uintptr_t addr, T m1, T m2, bool trans = false) {
2307
2332
stmatrix (addr, m2, trans, 1 );
2308
2333
}
2309
2334
2310
- // / Stores 4 8x8 b16 matrix from private memory to local memory per sub-group.
2335
+ // / Collectively stores 4 8x8 b16 (512 bytes) matrix from private memory to
2336
+ // / local memory per sub-group.
2311
2337
// / Requires the sub-group size of kernel calling this function to be 32.
2312
- // / Each of the 32 work items contain the starting address of their
2313
- // / respective matrix row.
2314
- // / Each of the 32 work items store 128-bits (32-bit per matrix) for a total
2338
+ // / Each work item of sub-group contains the starting address of their
2339
+ // / respective matrix row in 'addr'.
2340
+ // / After distributing addresses to other work items, each of the 32 work items
2341
+ // / store 128-bits (32-bits per matrix) into 'm1', 'm2', 'm3' & 'm4' for a total
2315
2342
// / of 512 bytes.
2316
- // / \tparam [in] T The type of matrix elements
2317
- // / \param [in] addr The address of the matrix in local memory
2318
- // / \param [in] m1 The private memory containing data of 1st matrix
2319
- // / \param [in] m2 The private memory containing data of 2nd matrix
2320
- // / \param [in] m3 The private memory containing data of 3rd matrix
2321
- // / \param [in] m4 The private memory containing data of 4th matrix
2343
+ // / 'trans' specifies to perform a transposed/non-transposed store by each work
2344
+ // / item like below
2345
+ // / Row Major: Each row of the matrices is stored by a group of 4 work items(wi)
2346
+ // / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2347
+ // / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2348
+ // / ...
2349
+ // / row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2350
+ // / row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2351
+ // / Col Major: Each col of the matrices is stored by a group of 4 work items(wi)
2352
+ // / row-0: wi0 wi4 wi8 ... wi28
2353
+ // / row-1: wi0 wi4 wi8 ... wi28
2354
+ // / ...
2355
+ // / row-6: wi3 wi7 wi11 ... wi31
2356
+ // / row-7: wi3 wi7 wi11 ... wi31
2357
+ // / \tparam [in] T Type of result variable (currently only supports 16-bit type)
2358
+ // / \param [in] addr The starting address of corresponding matrix row for a work
2359
+ // / item in local memory
2360
+ // / \param [in] m1 The private memory to store the data of 1st matrix. It points
2361
+ // / to 2 b16 type elements.
2362
+ // / \param [in] m2 The private memory to store the data of 2nd matrix. It points
2363
+ // / to 2 b16 type elements.
2364
+ // / \param [in] m3 The private memory to store the data of 3rd matrix. It points
2365
+ // / to 2 b16 type elements.
2366
+ // / \param [in] m4 The private memory to store the data of 4th matrix. It points
2367
+ // / to 2 b16 type elements.
2322
2368
// / \param [in] trans Indicates whether the matrix to be stored transposed
2323
2369
template <typename T>
2324
2370
void stmatrix (uintptr_t addr, T m1, T m2, T m3, T m4, bool trans = false ) {
0 commit comments