@@ -2058,6 +2058,7 @@ class joint_matrix {
2058
2058
2059
2059
// / Multiplies 2 16x16 & 16x8 matrices and accumulates the result to a 16x8 b32
2060
2060
// / matrix
2061
+ // / Requires the sub-group size of kernel calling this function to be 32
2061
2062
// / \tparam [in] MulType The type of the multiplication result
2062
2063
// / \tparam [in] ABType The type of the input matrices
2063
2064
// / \tparam [in] CDType The type of the output matrix
@@ -2078,48 +2079,51 @@ class joint_matrix {
2078
2079
// / \param [in] c3 The 4th element from C matrix to be added with d3
2079
2080
// / \param [in] item The sycl::nd_item index space class
2080
2081
template <typename MulType, typename ABType, typename CDType, typename ItemT>
2081
- __attribute__ ((optnone)) void
2082
- mma (CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1,
2083
- ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, CDType c1, CDType c2,
2084
- CDType c3, const ItemT &item) {
2082
+ void mma (CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1,
2083
+ ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, CDType c1,
2084
+ CDType c2, CDType c3, const ItemT &item) {
2085
2085
int lane = item.get_sub_group ().get_local_linear_id ();
2086
2086
2087
2087
short ROW_LOAD_OFFSET = 4 * (lane / 4 );
2088
2088
short COL_LOAD_OFFSET = 8 * (lane % 4 );
2089
2089
2090
- ABType recv_a[4 * 4 ], recv_b[4 * 4 ];
2091
2090
for (int i = 0 ; i < 4 ; i++) {
2092
- recv_a[0 * 4 + i] = dpct::select_from_sub_group (item.get_sub_group (), a0,
2093
- ROW_LOAD_OFFSET + i);
2094
- recv_a[1 * 4 + i] = dpct::select_from_sub_group (item.get_sub_group (), a2,
2095
- ROW_LOAD_OFFSET + i);
2096
- recv_a[2 * 4 + i] = dpct::select_from_sub_group (item.get_sub_group (), a1,
2097
- ROW_LOAD_OFFSET + i);
2098
- recv_a[3 * 4 + i] = dpct::select_from_sub_group (item.get_sub_group (), a3,
2099
- ROW_LOAD_OFFSET + i);
2100
-
2101
- recv_b[0 * 4 + i] = dpct::select_from_sub_group (item.get_sub_group (), b0,
2102
- COL_LOAD_OFFSET + i);
2103
- recv_b[1 * 4 + i] = dpct::select_from_sub_group (item.get_sub_group (), b1,
2104
- COL_LOAD_OFFSET + i);
2105
- recv_b[2 * 4 + i] = dpct::select_from_sub_group (item.get_sub_group (), b0,
2106
- COL_LOAD_OFFSET + 4 + i);
2107
- recv_b[3 * 4 + i] = dpct::select_from_sub_group (item.get_sub_group (), b1,
2108
- COL_LOAD_OFFSET + 4 + i);
2109
- }
2110
-
2111
- auto *a = reinterpret_cast <MulType *>(recv_a);
2112
- auto *b = reinterpret_cast <MulType *>(recv_b);
2113
- for (int i = 0 ; i < 16 ; i++) {
2114
- auto a0 = static_cast <CDType>(a[i]);
2115
- auto a1 = static_cast <CDType>(a[i + 16 ]);
2116
- auto b0 = static_cast <CDType>(b[i]);
2117
- auto b1 = static_cast <CDType>(b[i + 16 ]);
2118
-
2119
- c0 += a0 * b0;
2120
- c1 += a0 * b1;
2121
- c2 += a1 * b0;
2122
- c3 += a1 * b1;
2091
+ ABType recv_a[4 ], recv_b[4 ];
2092
+
2093
+ recv_a[0 ] = dpct::select_from_sub_group (item.get_sub_group (), a0,
2094
+ ROW_LOAD_OFFSET + i);
2095
+ recv_a[1 ] = dpct::select_from_sub_group (item.get_sub_group (), a2,
2096
+ ROW_LOAD_OFFSET + i);
2097
+ recv_a[2 ] = dpct::select_from_sub_group (item.get_sub_group (), a1,
2098
+ ROW_LOAD_OFFSET + i);
2099
+ recv_a[3 ] = dpct::select_from_sub_group (item.get_sub_group (), a3,
2100
+ ROW_LOAD_OFFSET + i);
2101
+
2102
+ recv_b[0 ] = dpct::select_from_sub_group (item.get_sub_group (), b0,
2103
+ COL_LOAD_OFFSET + i);
2104
+ recv_b[1 ] = dpct::select_from_sub_group (item.get_sub_group (), b1,
2105
+ COL_LOAD_OFFSET + i);
2106
+ recv_b[2 ] = dpct::select_from_sub_group (item.get_sub_group (), b0,
2107
+ COL_LOAD_OFFSET + 4 + i);
2108
+ recv_b[3 ] = dpct::select_from_sub_group (item.get_sub_group (), b1,
2109
+ COL_LOAD_OFFSET + 4 + i);
2110
+
2111
+ auto *ra0 = reinterpret_cast <MulType *>(recv_a);
2112
+ auto *ra1 = reinterpret_cast <MulType *>(recv_a + 2 );
2113
+ auto *rb0 = reinterpret_cast <MulType *>(recv_b);
2114
+ auto *rb1 = reinterpret_cast <MulType *>(recv_b + 2 );
2115
+
2116
+ for (int j = 0 ; j < 2 * 2 ; j++) {
2117
+ auto a0 = static_cast <CDType>(ra0[j]);
2118
+ auto a1 = static_cast <CDType>(ra1[j]);
2119
+ auto b0 = static_cast <CDType>(rb0[j]);
2120
+ auto b1 = static_cast <CDType>(rb1[j]);
2121
+
2122
+ c0 += a0 * b0;
2123
+ c1 += a0 * b1;
2124
+ c2 += a1 * b0;
2125
+ c3 += a1 * b1;
2126
+ }
2123
2127
}
2124
2128
2125
2129
*d0 = c0;
0 commit comments