Skip to content

Commit e82504f

Browse files
Added comments and removed optnone attr
1 parent 7d0682f commit e82504f

File tree

1 file changed

+40
-36
lines changed
  • clang/runtime/dpct-rt/include/dpct

1 file changed

+40
-36
lines changed

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 40 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2058,6 +2058,7 @@ class joint_matrix {
20582058

20592059
/// Multiplies 2 16x16 & 16x8 matrices and accumulates the result to a 16x8 b32
20602060
/// matrix
2061+
/// Requires the sub-group size of kernel calling this function to be 32
20612062
/// \tparam [in] MulType The type of the multiplication result
20622063
/// \tparam [in] ABType The type of the input matrices
20632064
/// \tparam [in] CDType The type of the output matrix
@@ -2078,48 +2079,51 @@ class joint_matrix {
20782079
/// \param [in] c3 The 4th element from C matrix to be added with d3
20792080
/// \param [in] item The sycl::nd_item index space class
20802081
template <typename MulType, typename ABType, typename CDType, typename ItemT>
2081-
__attribute__((optnone)) void
2082-
mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1,
2083-
ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, CDType c1, CDType c2,
2084-
CDType c3, const ItemT &item) {
2082+
void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1,
2083+
ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, CDType c1,
2084+
CDType c2, CDType c3, const ItemT &item) {
20852085
int lane = item.get_sub_group().get_local_linear_id();
20862086

20872087
short ROW_LOAD_OFFSET = 4 * (lane / 4);
20882088
short COL_LOAD_OFFSET = 8 * (lane % 4);
20892089

2090-
ABType recv_a[4 * 4], recv_b[4 * 4];
20912090
for (int i = 0; i < 4; i++) {
2092-
recv_a[0 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), a0,
2093-
ROW_LOAD_OFFSET + i);
2094-
recv_a[1 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), a2,
2095-
ROW_LOAD_OFFSET + i);
2096-
recv_a[2 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), a1,
2097-
ROW_LOAD_OFFSET + i);
2098-
recv_a[3 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), a3,
2099-
ROW_LOAD_OFFSET + i);
2100-
2101-
recv_b[0 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), b0,
2102-
COL_LOAD_OFFSET + i);
2103-
recv_b[1 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), b1,
2104-
COL_LOAD_OFFSET + i);
2105-
recv_b[2 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), b0,
2106-
COL_LOAD_OFFSET + 4 + i);
2107-
recv_b[3 * 4 + i] = dpct::select_from_sub_group(item.get_sub_group(), b1,
2108-
COL_LOAD_OFFSET + 4 + i);
2109-
}
2110-
2111-
auto *a = reinterpret_cast<MulType *>(recv_a);
2112-
auto *b = reinterpret_cast<MulType *>(recv_b);
2113-
for (int i = 0; i < 16; i++) {
2114-
auto a0 = static_cast<CDType>(a[i]);
2115-
auto a1 = static_cast<CDType>(a[i + 16]);
2116-
auto b0 = static_cast<CDType>(b[i]);
2117-
auto b1 = static_cast<CDType>(b[i + 16]);
2118-
2119-
c0 += a0 * b0;
2120-
c1 += a0 * b1;
2121-
c2 += a1 * b0;
2122-
c3 += a1 * b1;
2091+
ABType recv_a[4], recv_b[4];
2092+
2093+
recv_a[0] = dpct::select_from_sub_group(item.get_sub_group(), a0,
2094+
ROW_LOAD_OFFSET + i);
2095+
recv_a[1] = dpct::select_from_sub_group(item.get_sub_group(), a2,
2096+
ROW_LOAD_OFFSET + i);
2097+
recv_a[2] = dpct::select_from_sub_group(item.get_sub_group(), a1,
2098+
ROW_LOAD_OFFSET + i);
2099+
recv_a[3] = dpct::select_from_sub_group(item.get_sub_group(), a3,
2100+
ROW_LOAD_OFFSET + i);
2101+
2102+
recv_b[0] = dpct::select_from_sub_group(item.get_sub_group(), b0,
2103+
COL_LOAD_OFFSET + i);
2104+
recv_b[1] = dpct::select_from_sub_group(item.get_sub_group(), b1,
2105+
COL_LOAD_OFFSET + i);
2106+
recv_b[2] = dpct::select_from_sub_group(item.get_sub_group(), b0,
2107+
COL_LOAD_OFFSET + 4 + i);
2108+
recv_b[3] = dpct::select_from_sub_group(item.get_sub_group(), b1,
2109+
COL_LOAD_OFFSET + 4 + i);
2110+
2111+
auto *ra0 = reinterpret_cast<MulType *>(recv_a);
2112+
auto *ra1 = reinterpret_cast<MulType *>(recv_a + 2);
2113+
auto *rb0 = reinterpret_cast<MulType *>(recv_b);
2114+
auto *rb1 = reinterpret_cast<MulType *>(recv_b + 2);
2115+
2116+
for (int j = 0; j < 2 * 2; j++) {
2117+
auto a0 = static_cast<CDType>(ra0[j]);
2118+
auto a1 = static_cast<CDType>(ra1[j]);
2119+
auto b0 = static_cast<CDType>(rb0[j]);
2120+
auto b1 = static_cast<CDType>(rb1[j]);
2121+
2122+
c0 += a0 * b0;
2123+
c1 += a0 * b1;
2124+
c2 += a1 * b0;
2125+
c3 += a1 * b1;
2126+
}
21232127
}
21242128

21252129
*d0 = c0;

0 commit comments

Comments
 (0)