@@ -2223,13 +2223,12 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) {
2223
2223
// / \tparam [in] T The type of matrix elements
2224
2224
// / \param [in] addr The address of the matrix in local memory
2225
2225
// / \param [in] m The private memory containing data of matrix
2226
- // / \param [in] item The sycl::nd_item index space class
2227
2226
// / \param [in] trans Indicates whether the matrix to be stored transposed
2228
2227
// / \param [in] mat The matrix index to be stored
2229
- template <typename T, typename ItemT >
2230
- void stmatrix (uintptr_t addr, T m, const ItemT &item, bool trans = false ,
2231
- unsigned mat = 0 ) {
2232
- int lane = item. get_sub_group () .get_local_linear_id ();
2228
+ template <typename T>
2229
+ void stmatrix (uintptr_t addr, T m, bool trans = false , unsigned mat = 0 ) {
2230
+ auto sg = sycl::ext::oneapi::this_work_item::get_sub_group ();
2231
+ int lane = sg .get_local_linear_id ();
2233
2232
2234
2233
int lane_group8_row = lane / 8 ;
2235
2234
int lane_group8_col = lane % 8 ;
@@ -2241,8 +2240,8 @@ void stmatrix(uintptr_t addr, T m, const ItemT &item, bool trans = false,
2241
2240
src_lane += 1 ;
2242
2241
2243
2242
// Broadcast the address from the source lane
2244
- auto recv_addr_uintp = dpct::select_from_sub_group (
2245
- item. get_sub_group () , addr, mat * 8 + src_lane);
2243
+ auto recv_addr_uintp =
2244
+ dpct::select_from_sub_group (sg , addr, mat * 8 + src_lane);
2246
2245
2247
2246
// Cast the received address from uintptr_t to the type of 'm'
2248
2247
auto recv_addr = reinterpret_cast <T *>(recv_addr_uintp);
@@ -2254,10 +2253,10 @@ void stmatrix(uintptr_t addr, T m, const ItemT &item, bool trans = false,
2254
2253
int src_lane = (lane % 4 ) * 2 ;
2255
2254
2256
2255
// Broadcast the address from the source lane
2257
- auto recv_addr_uintp_1 = dpct::select_from_sub_group (
2258
- item. get_sub_group () , addr, mat * 8 + src_lane);
2259
- auto recv_addr_uintp_2 = dpct::select_from_sub_group (
2260
- item. get_sub_group () , addr, mat * 8 + src_lane + 1 );
2256
+ auto recv_addr_uintp_1 =
2257
+ dpct::select_from_sub_group (sg , addr, mat * 8 + src_lane);
2258
+ auto recv_addr_uintp_2 =
2259
+ dpct::select_from_sub_group (sg , addr, mat * 8 + src_lane + 1 );
2261
2260
2262
2261
// Cast the received address from uintptr_t to 'half *'
2263
2262
auto recv_addr_1 = reinterpret_cast <sycl::half *>(recv_addr_uintp_1);
@@ -2279,15 +2278,13 @@ void stmatrix(uintptr_t addr, T m, const ItemT &item, bool trans = false,
2279
2278
// / \param [in] addr The address of the matrix in local memory
2280
2279
// / \param [in] m1 The private memory containing data of 1st matrix
2281
2280
// / \param [in] m2 The private memory containing data of 2nd matrix
2282
- // / \param [in] item The sycl::nd_item index space class
2283
2281
// / \param [in] trans Indicates whether the matrix to be stored transposed
2284
- template <typename T, typename ItemT>
2285
- void stmatrix (uintptr_t addr, T m1, T m2, const ItemT &item,
2286
- bool trans = false ) {
2282
+ template <typename T>
2283
+ void stmatrix (uintptr_t addr, T m1, T m2, bool trans = false ) {
2287
2284
// Store 1st matrix
2288
- stmatrix (addr, m1, item, trans, 0 );
2285
+ stmatrix (addr, m1, trans, 0 );
2289
2286
// Store 2nd matrix
2290
- stmatrix (addr, m2, item, trans, 1 );
2287
+ stmatrix (addr, m2, trans, 1 );
2291
2288
}
2292
2289
2293
2290
// / Stores 4 8x8 b16 matrix from private memory to local memory (32-bits per wi)
@@ -2298,19 +2295,17 @@ void stmatrix(uintptr_t addr, T m1, T m2, const ItemT &item,
2298
2295
// / \param [in] m2 The private memory containing data of 2nd matrix
2299
2296
// / \param [in] m3 The private memory containing data of 3rd matrix
2300
2297
// / \param [in] m4 The private memory containing data of 4th matrix
2301
- // / \param [in] item The sycl::nd_item index space class
2302
2298
// / \param [in] trans Indicates whether the matrix to be stored transposed
2303
- template <typename T, typename ItemT>
2304
- void stmatrix (uintptr_t addr, T m1, T m2, T m3, T m4, const ItemT &item,
2305
- bool trans = false ) {
2299
+ template <typename T>
2300
+ void stmatrix (uintptr_t addr, T m1, T m2, T m3, T m4, bool trans = false ) {
2306
2301
// Store 1st matrix
2307
- stmatrix (addr, m1, item, trans, 0 );
2302
+ stmatrix (addr, m1, trans, 0 );
2308
2303
// Store 2nd matrix
2309
- stmatrix (addr, m2, item, trans, 1 );
2304
+ stmatrix (addr, m2, trans, 1 );
2310
2305
// Store 3rd matrix
2311
- stmatrix (addr, m3, item, trans, 2 );
2306
+ stmatrix (addr, m3, trans, 2 );
2312
2307
// Store 4th matrix
2313
- stmatrix (addr, m4, item, trans, 3 );
2308
+ stmatrix (addr, m4, trans, 3 );
2314
2309
}
2315
2310
2316
2311
// / A helper struct that defines the pack type for the input matrix fragments
0 commit comments