|
9 | 9 | #ifndef __DPCT_MATH_HPP__
|
10 | 10 | #define __DPCT_MATH_HPP__
|
11 | 11 |
|
12 |
| -#include <limits> |
13 | 12 | #include <climits>
|
| 13 | +#include <limits> |
14 | 14 | #include <sycl/sycl.hpp>
|
15 | 15 | #include <type_traits>
|
16 | 16 |
|
@@ -1636,7 +1636,8 @@ inline constexpr unsigned extend_vcompare2_add(AT a, BT b, unsigned c,
|
1636 | 1636 | /// \returns The extend vectorized average of the two values
|
1637 | 1637 | template <typename RetT, typename AT, typename BT>
|
1638 | 1638 | inline constexpr RetT extend_vavrg2(AT a, BT b, RetT c) {
|
1639 |
| - return detail::extend_vbinary2<RetT, false, false>(a, b, c, detail::average()); |
| 1639 | + return detail::extend_vbinary2<RetT, false, false>(a, b, c, |
| 1640 | + detail::average()); |
1640 | 1641 | }
|
1641 | 1642 |
|
1642 | 1643 | /// Compute vectorized average of \p a and \p b, with each value treated as a 2
|
@@ -1933,7 +1934,8 @@ inline constexpr unsigned extend_vcompare4_add(AT a, BT b, unsigned c,
|
1933 | 1934 | /// \returns The extend vectorized average of the two values
|
1934 | 1935 | template <typename RetT, typename AT, typename BT>
|
1935 | 1936 | inline constexpr RetT extend_vavrg4(AT a, BT b, RetT c) {
|
1936 |
| - return detail::extend_vbinary4<RetT, false, false>(a, b, c, detail::average()); |
| 1937 | + return detail::extend_vbinary4<RetT, false, false>(a, b, c, |
| 1938 | + detail::average()); |
1937 | 1939 | }
|
1938 | 1940 |
|
1939 | 1941 | /// Compute vectorized average of \p a and \p b, with each value treated as a 4
|
@@ -2056,6 +2058,198 @@ class joint_matrix {
|
2056 | 2058 | const size_t num_elements;
|
2057 | 2059 | };
|
2058 | 2060 |
|
| 2061 | +/// Multiplies 2 8x4 & 4x8 matrices and accumulates the result to a 8x8 b16 |
| 2062 | +/// matrix |
| 2063 | +/// Requires the sub-group size of kernel calling this function to be 32 |
| 2064 | +/// \tparam [in] MulType The type of the multiplication result |
| 2065 | +/// \tparam [in] ABType The type of the input matrices |
| 2066 | +/// \tparam [in] CDType The type of the output matrix |
| 2067 | +/// \tparam [in] ItemT The type of the sycl::nd_item index space class |
| 2068 | +/// \param [in] d0 The 1st element to be written to the output D matrix |
| 2069 | +/// \param [in] d1 The 2nd element to be written to the output D matrix |
| 2070 | +/// \param [in] d2 The 3rd element to be written to the output D matrix |
| 2071 | +/// \param [in] d3 The 4th element to be written to the output D matrix |
| 2072 | +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix |
| 2073 | +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix |
| 2074 | +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix |
| 2075 | +/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix |
| 2076 | +/// \param [in] c0 The 1st element from C matrix to be added with d0 |
| 2077 | +/// \param [in] c1 The 2nd element from C matrix to be added with d1 |
| 2078 | +/// \param [in] c2 The 3rd element from C matrix to be added with d2 |
| 2079 | +/// \param [in] c3 The 4th element from C matrix to be added with d3 |
| 2080 | +/// \param [in] item The sycl::nd_item index space class |
| 2081 | +template <typename MulType, typename ABType, typename CDType, typename ItemT> |
| 2082 | +void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, |
| 2083 | + ABType b0, ABType b1, CDType c0, CDType c1, CDType c2, CDType c3, |
| 2084 | + const ItemT &item) { |
| 2085 | + int lane = item.get_sub_group().get_local_linear_id(); |
| 2086 | + |
| 2087 | + short COL_LOAD_OFFSET = 4 * ((lane % 16) / 4); |
| 2088 | + |
| 2089 | + ABType recv_a[2] recv_a[0] = a0; |
| 2090 | + recv_a[1] = a1; |
| 2091 | + |
| 2092 | + MulType *ra = reinterpret_cast<MulType *>(recv_a); |
| 2093 | + MulType *c_h[4]; |
| 2094 | + c_h[0] = reinterpret_cast<MulType *>(&c0); |
| 2095 | + c_h[1] = reinterpret_cast<MulType *>(&c1); |
| 2096 | + c_h[2] = reinterpret_cast<MulType *>(&c2); |
| 2097 | + c_h[3] = reinterpret_cast<MulType *>(&c3); |
| 2098 | + for (int i = 0; i < 4; i++) { |
| 2099 | + ABType recv_b[4]; |
| 2100 | + |
| 2101 | + recv_b[0] = dpct::select_from_sub_group(item.get_sub_group(), b0, |
| 2102 | + COL_LOAD_OFFSET + i); |
| 2103 | + recv_b[1] = dpct::select_from_sub_group(item.get_sub_group(), b1, |
| 2104 | + COL_LOAD_OFFSET + i); |
| 2105 | + recv_b[2] = dpct::select_from_sub_group(item.get_sub_group(), b0, |
| 2106 | + COL_LOAD_OFFSET + 16 + i); |
| 2107 | + recv_b[3] = dpct::select_from_sub_group(item.get_sub_group(), b1, |
| 2108 | + COL_LOAD_OFFSET + 16 + i); |
| 2109 | + |
| 2110 | + MulType *rb = reinterpret_cast<MulType *>(recv_b); |
| 2111 | + // Iterate for k times |
| 2112 | + for (int j = 0; j < 4; j++) { |
| 2113 | + c_h[(i >> 1)][i % 2] += |
| 2114 | + static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j]); |
| 2115 | + c_h[2 + (i >> 1)][i % 2] += |
| 2116 | + static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[4 + j]); |
| 2117 | + } |
| 2118 | + } |
| 2119 | + |
| 2120 | + d0 = c0; |
| 2121 | + d1 = c1; |
| 2122 | + d2 = c2; |
| 2123 | + d3 = c3; |
| 2124 | +} |
| 2125 | + |
| 2126 | +/// Multiplies 2 8x4 & 4x8 matrices and accumulates the result to a 8x8 b32 |
| 2127 | +/// matrix |
| 2128 | +/// Requires the sub-group size of kernel calling this function to be 32 |
| 2129 | +/// \tparam [in] MulType The type of the multiplication result |
| 2130 | +/// \tparam [in] ABType The type of the input matrices |
| 2131 | +/// \tparam [in] CDType The type of the output matrix |
| 2132 | +/// \tparam [in] ItemT The type of the sycl::nd_item index space class |
| 2133 | +/// \param [in] d0 The 1st element to be written to the output D matrix |
| 2134 | +/// \param [in] d1 The 2nd element to be written to the output D matrix |
| 2135 | +/// \param [in] d2 The 3rd element to be written to the output D matrix |
| 2136 | +/// \param [in] d3 The 4th element to be written to the output D matrix |
| 2137 | +/// \param [in] d4 The 5th element to be written to the output D matrix |
| 2138 | +/// \param [in] d5 The 6th element to be written to the output D matrix |
| 2139 | +/// \param [in] d6 The 7th element to be written to the output D matrix |
| 2140 | +/// \param [in] d7 The 8th element to be written to the output D matrix |
| 2141 | +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix |
| 2142 | +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix |
| 2143 | +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix |
| 2144 | +/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix |
| 2145 | +/// \param [in] c0 The 1st element from C matrix to be added with d0 |
| 2146 | +/// \param [in] c1 The 2nd element from C matrix to be added with d1 |
| 2147 | +/// \param [in] c2 The 3rd element from C matrix to be added with d2 |
| 2148 | +/// \param [in] c3 The 4th element from C matrix to be added with d3 |
| 2149 | +/// \param [in] c4 The 5th element from C matrix to be added with d4 |
| 2150 | +/// \param [in] c5 The 6th element from C matrix to be added with d5 |
| 2151 | +/// \param [in] c6 The 7th element from C matrix to be added with d6 |
| 2152 | +/// \param [in] c7 The 8th element from C matrix to be added with d7 |
| 2153 | +/// \param [in] item The sycl::nd_item index space class |
| 2154 | +template <typename MulType, typename ABType, typename CDType, typename ItemT> |
| 2155 | +void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, CDType *d4, CDType *d5, |
| 2156 | + CDType *d6, CDType *d7, ABType a0, ABType a1, ABType b0, ABType b1, |
| 2157 | + CDType c0, CDType c1, CDType c2, CDType c3, CDType c4, CDType c5, |
| 2158 | + CDType c6, CDType c7, const ItemT &item) { |
| 2159 | + int lane = item.get_sub_group().get_local_linear_id(); |
| 2160 | + |
| 2161 | + short ROW_LOAD_OFFSET = 4 * (lane / 4) + (lane % 2); |
| 2162 | + short COL_LOAD_OFFSET = 4 * ((lane % 16) / 4) + 2 * ((lane / 2) % 2); |
| 2163 | + |
| 2164 | + ABType recv_a[2 * 2], recv_b[4 * 2]; |
| 2165 | + |
| 2166 | + recv_a[0] = |
| 2167 | + dpct::select_from_sub_group(item.get_sub_group(), a0, ROW_LOAD_OFFSET); |
| 2168 | + recv_a[1] = |
| 2169 | + dpct::select_from_sub_group(item.get_sub_group(), a1, ROW_LOAD_OFFSET); |
| 2170 | + recv_a[2] = dpct::select_from_sub_group(item.get_sub_group(), a0, |
| 2171 | + ROW_LOAD_OFFSET + 2); |
| 2172 | + recv_a[3] = dpct::select_from_sub_group(item.get_sub_group(), a1, |
| 2173 | + ROW_LOAD_OFFSET + 2); |
| 2174 | + |
| 2175 | + recv_b[0] = |
| 2176 | + dpct::select_from_sub_group(item.get_sub_group(), b0, COL_LOAD_OFFSET); |
| 2177 | + recv_b[1] = |
| 2178 | + dpct::select_from_sub_group(item.get_sub_group(), b1, COL_LOAD_OFFSET); |
| 2179 | + recv_b[2] = dpct::select_from_sub_group(item.get_sub_group(), b0, |
| 2180 | + COL_LOAD_OFFSET + 1); |
| 2181 | + recv_b[3] = dpct::select_from_sub_group(item.get_sub_group(), b1, |
| 2182 | + COL_LOAD_OFFSET + 1); |
| 2183 | + recv_b[4] = dpct::select_from_sub_group(item.get_sub_group(), b0, |
| 2184 | + COL_LOAD_OFFSET + 16); |
| 2185 | + recv_b[5] = dpct::select_from_sub_group(item.get_sub_group(), b1, |
| 2186 | + COL_LOAD_OFFSET + 16); |
| 2187 | + recv_b[6] = dpct::select_from_sub_group(item.get_sub_group(), b0, |
| 2188 | + COL_LOAD_OFFSET + 17); |
| 2189 | + recv_b[7] = dpct::select_from_sub_group(item.get_sub_group(), b1, |
| 2190 | + COL_LOAD_OFFSET + 17); |
| 2191 | + |
| 2192 | + MulType *ra = reinterpret_cast<MulType *>(recv_a); |
| 2193 | + MulType *rb = reinterpret_cast<MulType *>(recv_b); |
| 2194 | + for (int i = 0; i < 4 /*k*/; i++) { |
| 2195 | + c0 += static_cast<CDType>(ra[i]) * static_cast<CDType>(rb[i]); |
| 2196 | + c1 += static_cast<CDType>(ra[i]) * static_cast<CDType>(rb[i + 4]); |
| 2197 | + c2 += static_cast<CDType>(ra[i + 4]) * static_cast<CDType>(rb[i]); |
| 2198 | + c3 += static_cast<CDType>(ra[i + 4]) * static_cast<CDType>(rb[i + 4]); |
| 2199 | + c4 += static_cast<CDType>(ra[i]) * static_cast<CDType>(rb[i + 8]); |
| 2200 | + c5 += static_cast<CDType>(ra[i]) * static_cast<CDType>(rb[i + 12]); |
| 2201 | + c6 += static_cast<CDType>(ra[i + 4]) * static_cast<CDType>(rb[i + 8]); |
| 2202 | + c7 += static_cast<CDType>(ra[i + 4]) * static_cast<CDType>(rb[i + 12]); |
| 2203 | + } |
| 2204 | + |
| 2205 | + d0 = c0; |
| 2206 | + d1 = c1; |
| 2207 | + d2 = c2; |
| 2208 | + d3 = c3; |
| 2209 | + d4 = c4; |
| 2210 | + d5 = c5; |
| 2211 | + d6 = c6; |
| 2212 | + d7 = c7; |
| 2213 | +} |
| 2214 | + |
| 2215 | +/// Multiplies 2 8x4 & 4x8 matrices and accumulates the result to a 8x8 b32 |
| 2216 | +/// matrix |
| 2217 | +/// Requires the sub-group size of kernel calling this function to be 32 |
| 2218 | +/// \tparam [in] MulType The type of the multiplication result |
| 2219 | +/// \tparam [in] ABType The type of the input matrices |
| 2220 | +/// \tparam [in] CDType The type of the output matrix |
| 2221 | +/// \tparam [in] ItemT The type of the sycl::nd_item index space class |
| 2222 | +/// \param [in] d0 The 1st element to be written to the output D matrix |
| 2223 | +/// \param [in] d1 The 2nd element to be written to the output D matrix |
| 2224 | +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix |
| 2225 | +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix |
| 2226 | +/// \param [in] c0 The 1st element from C matrix to be added with d0 |
| 2227 | +/// \param [in] c1 The 2nd element from C matrix to be added with d1 |
| 2228 | +/// \param [in] item The sycl::nd_item index space class |
| 2229 | +template <typename MulType, typename ABType, typename CDType, typename ItemT> |
| 2230 | +void mma(CDType *d0, CDType *d1, ABType a0, ABType b0, CDType c0, CDType c1, |
| 2231 | + const ItemT &item) { |
| 2232 | + int lane = item.get_sub_group().get_local_linear_id(); |
| 2233 | + |
| 2234 | + short ROW_LOAD_OFFSET = 4 * (lane / 4); |
| 2235 | + short COL_LOAD_OFFSET = 8 * (lane % 4); |
| 2236 | + |
| 2237 | + for (int i = 0; i < 4; i++) { |
| 2238 | + ABType recv_a = dpct::select_from_sub_group(item.get_sub_group(), a0, |
| 2239 | + ROW_LOAD_OFFSET + i); |
| 2240 | + ABType recv_b = dpct::select_from_sub_group(item.get_sub_group(), b0, |
| 2241 | + COL_LOAD_OFFSET + i); |
| 2242 | + c0 += recv_a * recv_b; |
| 2243 | + |
| 2244 | + recv_b = dpct::select_from_sub_group(item.get_sub_group(), b0, |
| 2245 | + COL_LOAD_OFFSET + i + 4); |
| 2246 | + c1 += recv_a * recv_b; |
| 2247 | + } |
| 2248 | + |
| 2249 | + d0 = c0; |
| 2250 | + d1 = c1; |
| 2251 | +} |
| 2252 | + |
2059 | 2253 | /// Multiplies 2 16x16 & 16x8 matrices and accumulates the result to a 16x8 b32
|
2060 | 2254 | /// matrix
|
2061 | 2255 | /// Requires the sub-group size of kernel calling this function to be 32
|
@@ -2084,7 +2278,7 @@ void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1,
|
2084 | 2278 | CDType c2, CDType c3, const ItemT &item) {
|
2085 | 2279 | int lane = item.get_sub_group().get_local_linear_id();
|
2086 | 2280 |
|
2087 |
| - short ROW_LOAD_OFFSET = 4 * (lane / 4); |
| 2281 | + short ROW_LOAD_OFFSET = 4 * (lane >> 2); |
2088 | 2282 | short COL_LOAD_OFFSET = 8 * (lane % 4);
|
2089 | 2283 |
|
2090 | 2284 | for (int i = 0; i < 4; i++) {
|
@@ -2113,7 +2307,8 @@ void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1,
|
2113 | 2307 | auto *rb0 = reinterpret_cast<MulType *>(recv_b);
|
2114 | 2308 | auto *rb1 = reinterpret_cast<MulType *>(recv_b + 2);
|
2115 | 2309 |
|
2116 |
| - for (int j = 0; j < 2 * 2; j++) { |
| 2310 | + // Iterate for k (i * j) times |
| 2311 | + for (int j = 0; j < 4; j++) { |
2117 | 2312 | auto a0 = static_cast<CDType>(ra0[j]);
|
2118 | 2313 | auto a1 = static_cast<CDType>(ra1[j]);
|
2119 | 2314 | auto b0 = static_cast<CDType>(rb0[j]);
|
|
0 commit comments