@@ -2218,8 +2218,24 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) {
2218
2218
ldmatrix (addr, m4, trans, 3 );
2219
2219
}
2220
2220
2221
- // / Stores 1 8x8 b16 matrix from private memory to local memory (32-bits per wi)
2222
- // / Requires the sub-group size of kernel calling this function to be 32
2221
+ // / Stores 1 8x8 b16 matrix from private memory to local memory per sub-group.
2222
+ // / Requires the sub-group size of kernel calling this function to be 32.
2223
+ // / Each of the first 8 work items contain the starting address of their
2224
+ // / respective matrix row.
2225
+ // / Each of the 32 work items store 32-bits (2 packed 16-bit data) for a total
2226
+ // / of 128 bytes.
2227
+ // / Row Major: Each row of the matrix is stored by a group of 4 work items
2228
+ // / r0: t0 t1 t2 t3
2229
+ // / r1: t4 t5 t6 t7
2230
+ // / ...
2231
+ // / r7: t24 t25 t26 t27
2232
+ // / r7: t28 t29 t30 t31
2233
+ // / Col Major: Each col of the matrix is stored by a group of 4 work items
2234
+ // / r0: t0 t4 t8 ... t28
2235
+ // / r1: t0 t4 t8 ... t28
2236
+ // / ...
2237
+ // / r6: t3 t7 t11 ... t31
2238
+ // / r7: t3 t7 t11 ... t31
2223
2239
// / \tparam [in] T The type of matrix elements
2224
2240
// / \param [in] addr The address of the matrix in local memory
2225
2241
// / \param [in] m The private memory containing data of matrix
@@ -2272,8 +2288,12 @@ void stmatrix(uintptr_t addr, T m, bool trans = false, unsigned mat = 0) {
2272
2288
}
2273
2289
}
2274
2290
2275
- // / Stores 2 8x8 b16 matrix from private memory to local memory (32-bits per wi)
2276
- // / Requires the sub-group size of kernel calling this function to be 32
2291
+ // / Stores 2 8x8 b16 matrix from private memory to local memory per sub-group.
2292
+ // / Requires the sub-group size of kernel calling this function to be 32.
2293
+ // / Each of the first 16 work items contain the starting address of their
2294
+ // / respective matrix row.
2295
+ // / Each of the 32 work items store 64-bits (32-bit per matrix) for a total
2296
+ // / of 256 bytes.
2277
2297
// / \tparam [in] T The type of matrix elements
2278
2298
// / \param [in] addr The address of the matrix in local memory
2279
2299
// / \param [in] m1 The private memory containing data of 1st matrix
@@ -2287,8 +2307,12 @@ void stmatrix(uintptr_t addr, T m1, T m2, bool trans = false) {
2287
2307
stmatrix (addr, m2, trans, 1 );
2288
2308
}
2289
2309
2290
- // / Stores 4 8x8 b16 matrix from private memory to local memory (32-bits per wi)
2291
- // / Requires the sub-group size of kernel calling this function to be 32
2310
+ // / Stores 4 8x8 b16 matrix from private memory to local memory per sub-group.
2311
+ // / Requires the sub-group size of kernel calling this function to be 32.
2312
+ // / Each of the 32 work items contain the starting address of their
2313
+ // / respective matrix row.
2314
+ // / Each of the 32 work items store 128-bits (32-bit per matrix) for a total
2315
+ // / of 512 bytes.
2292
2316
// / \tparam [in] T The type of matrix elements
2293
2317
// / \param [in] addr The address of the matrix in local memory
2294
2318
// / \param [in] m1 The private memory containing data of 1st matrix
0 commit comments