@@ -2058,7 +2058,7 @@ class joint_matrix {
2058
2058
const size_t num_elements;
2059
2059
};
2060
2060
2061
- // / Collectively loads 1 8x8 b16 (128 bytes) matrix from private memory to local
2061
+ // / Collectively loads 1 8x8 b16 (128 bytes) matrix from local memory to private
2062
2062
// / memory per sub-group. Requires the sub-group size of kernel calling this
2063
2063
// / function to be 32.
2064
2064
// / 'mat' specifies the matrix index to be loaded. The first '(mat + 1) * 8'
@@ -2135,7 +2135,7 @@ void ldmatrix(uintptr_t addr, T *m, bool trans = false, unsigned mat = 0) {
2135
2135
}
2136
2136
}
2137
2137
2138
- // / Collectively loads 2 8x8 b16 (256 bytes) matrix from private memory to local
2138
+ // / Collectively loads 2 8x8 b16 (256 bytes) matrix from local memory to private
2139
2139
// / memory per sub-group. Requires the sub-group size of kernel calling this
2140
2140
// / function to be 32.
2141
2141
// / The first 16 work items of sub-group contain the starting address of their
@@ -2172,7 +2172,7 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, bool trans = false) {
2172
2172
ldmatrix (addr, m2, trans, 1 );
2173
2173
}
2174
2174
2175
- // / Collectively loads 4 8x8 b16 (512 bytes) matrix from private memory to local
2175
+ // / Collectively loads 4 8x8 b16 (512 bytes) matrix from local memory to private
2176
2176
// / memory per sub-group. Requires the sub-group size of kernel calling this
2177
2177
// / function to be 32.
2178
2178
// / Each work item of sub-group contains the starting address of their
@@ -2218,6 +2218,166 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) {
2218
2218
ldmatrix (addr, m4, trans, 3 );
2219
2219
}
2220
2220
2221
+ // / Collectively stores 1 8x8 b16 (128 bytes) matrix from private memory to
2222
+ // / local memory per sub-group.
2223
+ // / Requires the sub-group size of kernel calling this function to be 32.
2224
+ // / 'mat' specifies the matrix index to be stored. The first '(mat + 1) * 8'
2225
+ // / work items of sub-group contain the starting address of their respective
2226
+ // / matrix row in 'addr'.
2227
+ // / After distributing addresses to other work items, each of the 32 work items
2228
+ // / store 32-bits (2 packed 16-bit data) into 'm' for a total of 128 bytes.
2229
+ // / 'trans' specifies to perform a transposed/non-transposed store by each work
2230
+ // / item like below
2231
+ // / Row Major: Each row of the matrix is stored by a group of 4 work items(wi)
2232
+ // / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2233
+ // / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2234
+ // / ...
2235
+ // / row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2236
+ // / row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2237
+ // / Col Major: Each col of the matrix is stored by a group of 4 work items(wi)
2238
+ // / row-0: wi0 wi4 wi8 ... wi28
2239
+ // / row-1: wi0 wi4 wi8 ... wi28
2240
+ // / ...
2241
+ // / row-6: wi3 wi7 wi11 ... wi31
2242
+ // / row-7: wi3 wi7 wi11 ... wi31
2243
+ // / \tparam [in] T Type of result variable (currently only supports 16-bit type)
2244
+ // / \param [in] addr The starting address of corresponding matrix row for a work
2245
+ // / item in local memory
2246
+ // / \param [in] m The local memory to store the matrix. It points to 2 b16
2247
+ // / type elements.
2248
+ // / \param [in] trans Indicates whether the matrix to be stored transposed
2249
+ // / \param [in] mat The matrix index to be stored
2250
+ template <typename T>
2251
+ void stmatrix (uintptr_t addr, T m, bool trans = false , unsigned mat = 0 ) {
2252
+ auto sg = sycl::ext::oneapi::this_work_item::get_sub_group ();
2253
+ int lane = sg.get_local_linear_id ();
2254
+
2255
+ int lane_group8_row = lane / 8 ;
2256
+ int lane_group8_col = lane % 8 ;
2257
+
2258
+ if (!trans) {
2259
+ // calculate the source lane
2260
+ int src_lane = 2 * lane_group8_row;
2261
+ if (lane_group8_col >= 4 )
2262
+ src_lane += 1 ;
2263
+
2264
+ // Broadcast the address from the source lane
2265
+ auto recv_addr_uintp =
2266
+ dpct::select_from_sub_group (sg, addr, mat * 8 + src_lane);
2267
+
2268
+ // Cast the received address from uintptr_t to the type of 'm'
2269
+ auto recv_addr = reinterpret_cast <T *>(recv_addr_uintp);
2270
+
2271
+ // Non-transposed store
2272
+ recv_addr[lane_group8_col % 4 ] = m;
2273
+ } else {
2274
+ // calculate the source lane
2275
+ int src_lane = (lane % 4 ) * 2 ;
2276
+
2277
+ // Broadcast the address from the source lane
2278
+ auto recv_addr_uintp_1 =
2279
+ dpct::select_from_sub_group (sg, addr, mat * 8 + src_lane);
2280
+ auto recv_addr_uintp_2 =
2281
+ dpct::select_from_sub_group (sg, addr, mat * 8 + src_lane + 1 );
2282
+
2283
+ // Cast the received address from uintptr_t to 'half *'
2284
+ auto recv_addr_1 = reinterpret_cast <sycl::half *>(recv_addr_uintp_1);
2285
+ auto recv_addr_2 = reinterpret_cast <sycl::half *>(recv_addr_uintp_2);
2286
+
2287
+ // Split the 32-bit value of 'm' into two 16-bits
2288
+ sycl::half *val = reinterpret_cast <sycl::half *>(&m);
2289
+
2290
+ // Transposed store
2291
+ int index = lane / 4 ;
2292
+ recv_addr_1[index] = val[0 ];
2293
+ recv_addr_2[index] = val[1 ];
2294
+ }
2295
+ }
2296
+
2297
+ // / Collectively stores 2 8x8 b16 (256 bytes) matrix from private memory to
2298
+ // / local memory per sub-group.
2299
+ // / Requires the sub-group size of kernel calling this function to be 32.
2300
+ // / The first 16 work items of sub-group contain the starting address of their
2301
+ // / respective matrix row in 'addr'.
2302
+ // / After distributing addresses to other work items, each of the 32 work items
2303
+ // / store 64-bits (32-bits per matrix) into 'm1' & 'm2' for a total of 256
2304
+ // / bytes.
2305
+ // / 'trans' specifies to perform a transposed/non-transposed store by each work
2306
+ // / item like below
2307
+ // / Row Major: Each row of the matrices is stored by a group of 4 work items(wi)
2308
+ // / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2309
+ // / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2310
+ // / ...
2311
+ // / row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2312
+ // / row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2313
+ // / Col Major: Each col of the matrices is stored by a group of 4 work items(wi)
2314
+ // / row-0: wi0 wi4 wi8 ... wi28
2315
+ // / row-1: wi0 wi4 wi8 ... wi28
2316
+ // / ...
2317
+ // / row-6: wi3 wi7 wi11 ... wi31
2318
+ // / row-7: wi3 wi7 wi11 ... wi31
2319
+ // / \tparam [in] T Type of result variable (currently only supports 16-bit type)
2320
+ // / \param [in] addr The starting address of corresponding matrix row for a work
2321
+ // / item in local memory
2322
+ // / \param [in] m1 The local memory to store the data of 1st matrix. It points
2323
+ // / to 2 b16 type elements.
2324
+ // / \param [in] m2 The local memory to store the data of 2nd matrix. It points
2325
+ // / to 2 b16 type elements.
2326
+ // / \param [in] trans Indicates whether the matrix to be stored transposed
2327
+ template <typename T>
2328
+ void stmatrix (uintptr_t addr, T m1, T m2, bool trans = false ) {
2329
+ // Store 1st matrix
2330
+ stmatrix (addr, m1, trans, 0 );
2331
+ // Store 2nd matrix
2332
+ stmatrix (addr, m2, trans, 1 );
2333
+ }
2334
+
2335
+ // / Collectively stores 4 8x8 b16 (512 bytes) matrix from private memory to
2336
+ // / local memory per sub-group.
2337
+ // / Requires the sub-group size of kernel calling this function to be 32.
2338
+ // / Each work item of sub-group contains the starting address of their
2339
+ // / respective matrix row in 'addr'.
2340
+ // / After distributing addresses to other work items, each of the 32 work items
2341
+ // / store 128-bits (32-bits per matrix) into 'm1', 'm2', 'm3' & 'm4' for a total
2342
+ // / of 512 bytes.
2343
+ // / 'trans' specifies to perform a transposed/non-transposed store by each work
2344
+ // / item like below
2345
+ // / Row Major: Each row of the matrices is stored by a group of 4 work items(wi)
2346
+ // / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2347
+ // / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2348
+ // / ...
2349
+ // / row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2350
+ // / row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2351
+ // / Col Major: Each col of the matrices is stored by a group of 4 work items(wi)
2352
+ // / row-0: wi0 wi4 wi8 ... wi28
2353
+ // / row-1: wi0 wi4 wi8 ... wi28
2354
+ // / ...
2355
+ // / row-6: wi3 wi7 wi11 ... wi31
2356
+ // / row-7: wi3 wi7 wi11 ... wi31
2357
+ // / \tparam [in] T Type of result variable (currently only supports 16-bit type)
2358
+ // / \param [in] addr The starting address of corresponding matrix row for a work
2359
+ // / item in local memory
2360
+ // / \param [in] m1 The local memory to store the data of 1st matrix. It points
2361
+ // / to 2 b16 type elements.
2362
+ // / \param [in] m2 The local memory to store the data of 2nd matrix. It points
2363
+ // / to 2 b16 type elements.
2364
+ // / \param [in] m3 The local memory to store the data of 3rd matrix. It points
2365
+ // / to 2 b16 type elements.
2366
+ // / \param [in] m4 The local memory to store the data of 4th matrix. It points
2367
+ // / to 2 b16 type elements.
2368
+ // / \param [in] trans Indicates whether the matrix to be stored transposed
2369
+ template <typename T>
2370
+ void stmatrix (uintptr_t addr, T m1, T m2, T m3, T m4, bool trans = false ) {
2371
+ // Store 1st matrix
2372
+ stmatrix (addr, m1, trans, 0 );
2373
+ // Store 2nd matrix
2374
+ stmatrix (addr, m2, trans, 1 );
2375
+ // Store 3rd matrix
2376
+ stmatrix (addr, m3, trans, 2 );
2377
+ // Store 4th matrix
2378
+ stmatrix (addr, m4, trans, 3 );
2379
+ }
2380
+
2221
2381
// / A helper struct that defines the pack type for the input matrix fragments
2222
2382
// / of mma() function based on the type of input matrix fragments.
2223
2383
// / The MMAType struct is specialized for different types of input matrices.
0 commit comments