Skip to content

Commit a028eed

Browse files
[Joint Matrix] Enable different accumulator and output types in spirv. Add tests to cover bfloat16 and half floating-point sizes. (#17502)
Updated matrix_compare and matrix_multiply_ref functions to better match bfloat16 calculations on device. --------- Co-authored-by: Nikita Kornev <nikita.kornev@intel.com>
1 parent d3faf36 commit a028eed

19 files changed

+511
-499
lines changed

sycl/include/sycl/__spirv/spirv_ops.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,15 @@ extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreCheckedINTEL(
8484
std::size_t Stride, size_t Height, size_t Width, size_t CoordX,
8585
size_t CoordY, __spv::MatrixLayout Layout = L, int MemOperand = 0);
8686

87-
template <typename TA, typename TB, typename TC, std::size_t M, std::size_t K,
88-
std::size_t N, __spv::MatrixUse UA, __spv::MatrixUse UB,
89-
__spv::MatrixUse UC,
87+
template <typename TA, typename TB, typename TC, typename TD, std::size_t M,
88+
std::size_t K, std::size_t N, __spv::MatrixUse UA,
89+
__spv::MatrixUse UB, __spv::MatrixUse UC,
9090
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
9191
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
9292
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
9393
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
9494
extern __DPCPP_SYCL_EXTERNAL
95-
__spv::__spirv_CooperativeMatrixKHR<TC, S, M, N, UC> *
95+
__spv::__spirv_CooperativeMatrixKHR<TD, S, M, N, UC> *
9696
__spirv_CooperativeMatrixMulAddKHR(
9797
__spv::__spirv_CooperativeMatrixKHR<TA, S, M, K, UA> *A,
9898
__spv::__spirv_CooperativeMatrixKHR<TB, S, K, N, UB> *B,

sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,25 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
263263
#endif // __SYCL_DEVICE_ONLY__
264264
}
265265

266+
operator float() {
267+
#ifdef __SYCL_DEVICE_ONLY__
268+
sycl::ext::oneapi::bfloat16 *ExtractP =
269+
__spirv_AccessChain<sycl::ext::oneapi::bfloat16,
270+
sycl::ext::oneapi::bfloat16, NumRows, NumCols,
271+
spv_matrix_use_traits<Use>::value,
272+
spv_scope_traits<Group>::value>(&M.spvm, idx);
273+
union {
274+
uint16_t intStorage;
275+
sycl::ext::oneapi::bfloat16 floatValue;
276+
};
277+
floatValue = *ExtractP;
278+
return __spirv_ConvertBF16ToFINTEL(intStorage);
279+
#else
280+
throw exception(make_error_code(errc::runtime),
281+
"joint matrix is not supported on host.");
282+
#endif // __SYCL_DEVICE_ONLY__
283+
}
284+
266285
explicit operator bool() {
267286
#ifdef __SYCL_DEVICE_ONLY__
268287
sycl::ext::oneapi::bfloat16 *ExtractP =
@@ -295,6 +314,21 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
295314
#endif // __SYCL_DEVICE_ONLY__
296315
}
297316

317+
wi_element &operator=(const float &rhs) {
318+
#ifdef __SYCL_DEVICE_ONLY__
319+
float *InsertP =
320+
__spirv_AccessChain<float, float, NumRows, NumCols,
321+
spv_matrix_use_traits<Use>::value,
322+
spv_scope_traits<Group>::value>(&M.spvm, idx);
323+
*InsertP = rhs;
324+
return *this;
325+
#else
326+
(void)rhs;
327+
throw exception(make_error_code(errc::runtime),
328+
"joint matrix is not supported on host.");
329+
#endif // __SYCL_DEVICE_ONLY__
330+
}
331+
298332
wi_element &operator=(const wi_element<sycl::ext::oneapi::bfloat16, NumRows,
299333
NumCols, Use, Layout, Group> &rhs) {
300334
#ifdef __SYCL_DEVICE_ONLY__

sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -85,26 +85,26 @@ extern "C" constexpr __spv::MatrixLayout joint_matrix_layout_to_spv(
8585
}
8686
}
8787

88-
template<typename Ta, typename Tb, typename Tc>
88+
template <typename Ta, typename Tb, typename Tc, typename Td>
8989
constexpr uint32_t CalculateMatrixOperand() {
90+
uint32_t returnValue = 0x00;
9091
if constexpr (std::is_same<Ta, sycl::ext::oneapi::bfloat16>::value &&
91-
std::is_same<Tb, sycl::ext::oneapi::bfloat16>::value &&
92-
std::is_same<Tc, float>::value)
93-
return static_cast<uint32_t>(
92+
std::is_same<Tb, sycl::ext::oneapi::bfloat16>::value)
93+
returnValue += static_cast<uint32_t>(
9494
__spv::MatrixOperands::MatrixAAndBBFloat16ComponentsINTEL);
95-
if constexpr (std::is_signed<Ta>::value && std::is_unsigned<Tb>::value)
96-
return static_cast<uint32_t>(
95+
if constexpr (std::is_same<Tc, sycl::ext::oneapi::bfloat16>::value)
96+
returnValue += static_cast<uint32_t>(
97+
__spv::MatrixOperands::MatrixCBFloat16ComponentsINTEL);
98+
if constexpr (std::is_same<Td, sycl::ext::oneapi::bfloat16>::value)
99+
returnValue += static_cast<uint32_t>(
100+
__spv::MatrixOperands::MatrixResultBFloat16ComponentsINTEL);
101+
if constexpr (std::is_signed<Ta>::value)
102+
returnValue += static_cast<uint32_t>(
97103
__spv::MatrixOperands::MatrixASignedComponentsKHR);
98-
if constexpr (std::is_unsigned<Ta>::value && std::is_signed<Tb>::value)
99-
return static_cast<uint32_t>(
104+
if constexpr (std::is_signed<Tb>::value)
105+
returnValue += static_cast<uint32_t>(
100106
__spv::MatrixOperands::MatrixBSignedComponentsKHR);
101-
if constexpr (std::is_signed<Ta>::value && std::is_signed<Tb>::value) {
102-
return static_cast<uint32_t>(
103-
__spv::MatrixOperands::MatrixASignedComponentsKHR) +
104-
static_cast<uint32_t>(
105-
__spv::MatrixOperands::MatrixBSignedComponentsKHR);
106-
}
107-
return 0;
107+
return returnValue;
108108
}
109109

110110
} // namespace detail

sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,7 @@ template <typename Group, typename Ta, typename Tb, typename Tc, typename Td,
431431
sycl::detail::convertTypeToMatrixTypeString<Tc>(),
432432
sycl::detail::convertTypeToMatrixTypeString<Td>(), M, K, N)]]
433433
#endif // defined(__SYCL_DEVICE_ONLY__)
434-
inline __SYCL_ALWAYS_INLINE void
435-
joint_matrix_mad(
434+
inline __SYCL_ALWAYS_INLINE void joint_matrix_mad(
436435
Group,
437436
joint_matrix<Group, Td, use::accumulator, M, N,
438437
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D,
@@ -462,9 +461,9 @@ joint_matrix_mad(
462461
}
463462
#else
464463
constexpr uint32_t MatrixOperand =
465-
sycl::detail::CalculateMatrixOperand<Ta, Tb, Tc>();
466-
D.spvm =
467-
__spirv_CooperativeMatrixMulAddKHR(A.spvm, B.spvm, C.spvm, MatrixOperand);
464+
sycl::detail::CalculateMatrixOperand<Ta, Tb, Tc, Td>();
465+
D.spvm = __spirv_CooperativeMatrixMulAddKHR<Ta, Tb, Tc, Td>(
466+
A.spvm, B.spvm, C.spvm, MatrixOperand);
468467
#endif // defined(__NVPTX__)
469468
#else
470469
std::ignore = A;
@@ -489,10 +488,23 @@ void joint_matrix_copy(
489488
using storage_element_type =
490489
typename oneapi::detail::jm_type_interpretation_helper_trait<
491490
T2>::storage_element_type;
491+
using src_storage_element_type =
492+
typename oneapi::detail::jm_type_interpretation_helper_trait<
493+
T1>::storage_element_type;
494+
492495
auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, src);
493496
auto wi_data_dst = sycl::ext::oneapi::detail::get_wi_data(sg, dst);
494497
for (int i = 0; i < wi_data_c.length(); i++) {
495-
wi_data_dst[i] = static_cast<storage_element_type>(wi_data_c[i]);
498+
if constexpr (std::is_same_v<T1, sycl::half>) {
499+
// Special case for SRC type sycl:half since we can't
500+
// cast directly from wi_element(typed half) to other type.
501+
// first cast is from wi_element to half (T1).
502+
// second cast is from half to dst type (T2).
503+
wi_data_dst[i] = static_cast<storage_element_type>(
504+
static_cast<src_storage_element_type>(wi_data_c[i]));
505+
} else {
506+
wi_data_dst[i] = static_cast<storage_element_type>(wi_data_c[i]);
507+
}
496508
}
497509
#endif // defined(__NVPTX__)
498510
#else

sycl/test-e2e/Matrix/Inputs/common.hpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
6767
for (unsigned int n = 0; n < N; n++) {
6868
int c_ind = transpose_c ? (n * M + m) : m * N + n;
6969
Tc acc = *(C + c_ind);
70-
70+
float tmp = 0.f;
7171
for (unsigned int k = 0; k < K; k++) {
7272
int a_ind = colmajor_a ? (k * M + m) : m * K + k;
7373
int b_ind = colmajor_b ? (n * K + k) : k * N + n;
@@ -80,6 +80,9 @@ void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
8080
acc += make_fp32(va[i]) * make_fp32(vb[i]);
8181
else if constexpr (std::is_same_v<Ta, sycl::half>)
8282
acc += (float)va[i] * (float)vb[i];
83+
else if constexpr (std::is_same_v<Ta, bfloat16> &&
84+
std::is_same_v<Tc, bfloat16>)
85+
tmp += (float)va[i] * (float)vb[i];
8386
else if constexpr (std::is_same_v<Ta, float> &&
8487
std::is_same_v<Tc, float> ||
8588
std::is_integral_v<Ta> && std::is_integral_v<Tc> ||
@@ -92,6 +95,9 @@ void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
9295
assert(false && "Unsupported type in matrix_multiply_ref.");
9396
}
9497
}
98+
if constexpr (std::is_same_v<Ta, bfloat16> &&
99+
std::is_same_v<Tc, bfloat16>)
100+
acc += (bfloat16)tmp;
95101

96102
if constexpr (!std::is_same_v<F, std::nullptr_t>) {
97103
lambda(acc);
@@ -182,10 +188,11 @@ template <typename T1, typename T2, bool exact = false>
182188
bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) {
183189
for (int i = 0; i < rows; i++) {
184190
for (int j = 0; j < cols; j++) {
185-
if constexpr (!exact && (std::is_same_v<T1, float> ||
186-
std::is_same_v<T1, bfloat16> ||
187-
(std::is_same_v<T1, double> &&
188-
std::is_same_v<T2, double>))) {
191+
if constexpr (!exact &&
192+
(std::is_same_v<T1, float> ||
193+
std::is_same_v<T1, bfloat16> || std::is_same_v<T1, half> ||
194+
(std::is_same_v<T1, double> &&
195+
std::is_same_v<T2, double>))) {
189196
float diff = std::fabs(src[i * cols + j] - (T1)ref[i * cols + j]);
190197
if (diff > FLOAT_EPSILON || std::isnan(src[i * cols + j])) {
191198
std::cerr << "Incorrect result in matrix. "
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
//===---joint_matrix_16bit_impl.hpp - DPC++ joint_matrix----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
template <typename Tab, typename TAcc, typename TResult, size_t TM, size_t TN,
10+
size_t TK, layout B_layout>
11+
class imatrix;
12+
13+
template <typename Tab, typename TAcc, typename TResult, size_t M, size_t N,
14+
size_t K, size_t TM, size_t TN, size_t TK, layout B_layout, size_t VF>
15+
void matrix_multiply(big_matrix<TResult, M, N> &D, big_matrix<TAcc, M, N> &C,
16+
big_matrix<Tab, M, K> &A,
17+
big_matrix<Tab, K / VF, N * VF> &B) {
18+
size_t NDRangeM = M / TM;
19+
size_t NDRangeN = N / TN;
20+
buffer<Tab, 2> bufA(A.get_data(), range<2>(M, K));
21+
buffer<Tab, 2> bufB(B.get_data(), range<2>(K, N));
22+
buffer<TAcc, 2> bufC((TAcc *)C.get_data(), range<2>(M, N));
23+
buffer<TResult, 2> bufD((TResult *)D.get_data(), range<2>(M, N));
24+
queue q;
25+
size_t sg_size =
26+
get_sg_size<imatrix<Tab, TAcc, TResult, TM, TN, TK, B_layout>>(q);
27+
28+
q.submit([&](handler &cgh) {
29+
accessor accA{bufA, cgh};
30+
accessor accB{bufB, cgh};
31+
accessor accC{bufC, cgh};
32+
accessor accD{bufD, cgh};
33+
34+
cgh.parallel_for<imatrix<Tab, TAcc, TResult, TM, TN, TK, B_layout>>(
35+
nd_range<2>({NDRangeM, NDRangeN * sg_size}, {1, 1 * sg_size}),
36+
[=](nd_item<2> spmd_item)
37+
#ifdef SG_SZ
38+
[[sycl::reqd_sub_group_size(SG_SZ)]]
39+
#endif
40+
{
41+
// The submatrix API has to be accessed by all the workitems in a
42+
// subgroup these functions will be called once by the subgroup no
43+
// code divergence between the workitems
44+
const auto global_idx = spmd_item.get_global_id(0);
45+
const auto global_idy = spmd_item.get_global_id(1);
46+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
47+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
48+
49+
sub_group sg = spmd_item.get_sub_group();
50+
joint_matrix<sub_group, Tab, use::a, TM, TK, layout::row_major>
51+
sub_a;
52+
joint_matrix<sub_group, Tab, use::b, TK, TN, B_layout> sub_b;
53+
joint_matrix<sub_group, TAcc, use::accumulator, TM, TN> sub_c;
54+
joint_matrix<sub_group, TResult, use::accumulator, TM, TN> sub_d;
55+
56+
joint_matrix_load(
57+
sg, sub_c,
58+
accC.template get_multi_ptr<access::decorated::no>() +
59+
(sg_startx * TM) * N + sg_starty / sg_size * TN,
60+
N, layout::row_major);
61+
62+
for (int k = 0; k < K / TK; k += 1) {
63+
joint_matrix_load(
64+
sg, sub_a,
65+
accA.template get_multi_ptr<access::decorated::no>() +
66+
(sg_startx * TM) * K + k * TK,
67+
K);
68+
joint_matrix_load(
69+
sg, sub_b,
70+
accB.template get_multi_ptr<access::decorated::no>() +
71+
(k * TK / VF) * (N * VF) + sg_starty / sg_size * TN * VF,
72+
N * VF);
73+
74+
joint_matrix_mad(sg, sub_d, sub_a, sub_b, sub_c);
75+
joint_matrix_copy(sg, sub_d, sub_c);
76+
}
77+
78+
joint_matrix_store(
79+
sg, sub_d,
80+
accD.template get_multi_ptr<access::decorated::no>() +
81+
(sg_startx * TM) * N + sg_starty / sg_size * TN,
82+
N, layout::row_major);
83+
}); // parallel for
84+
}).wait();
85+
}
86+
87+
template <typename Tab, typename TAcc, typename TResult, size_t TM, size_t TN,
88+
size_t TK, layout B_layout, size_t VF>
89+
void test() {
90+
std::cout << "Testing: " << TM << " x " << TN << " x " << TK
91+
<< " [TM x TN x TK]" << std::endl;
92+
93+
static constexpr size_t MATRIX_M = TM * 2;
94+
static constexpr size_t MATRIX_N = TN * 2;
95+
static constexpr size_t MATRIX_K = TK * 2;
96+
Tab A[MATRIX_M][MATRIX_K];
97+
Tab B[MATRIX_K / VF][MATRIX_N * VF];
98+
TAcc C[MATRIX_M][MATRIX_N];
99+
TResult D[MATRIX_M][MATRIX_N];
100+
TResult DRef[MATRIX_M][MATRIX_N];
101+
102+
matrix_rand<Tab>(MATRIX_M, MATRIX_K, (Tab *)A, Tab(1));
103+
matrix_rand<Tab>(MATRIX_K / VF, MATRIX_N * VF, (Tab *)B, Tab(1));
104+
105+
matrix_fill(MATRIX_M, MATRIX_N, (TAcc *)C, TAcc(1));
106+
matrix_fill(MATRIX_M, MATRIX_N, (TResult *)D, TResult(1));
107+
matrix_fill(MATRIX_M, MATRIX_N, (TResult *)DRef, TResult(1));
108+
109+
big_matrix<TAcc, MATRIX_M, MATRIX_N> MC((TAcc *)&C);
110+
big_matrix<TResult, MATRIX_M, MATRIX_N> MD((TResult *)&D);
111+
big_matrix<Tab, MATRIX_M, MATRIX_K> MA((Tab *)&A);
112+
big_matrix<Tab, MATRIX_K / VF, MATRIX_N * VF> MB((Tab *)&B);
113+
114+
matrix_multiply<Tab, TAcc, TResult, MATRIX_M, MATRIX_N, MATRIX_K, TM, TN, TK,
115+
B_layout, VF>(MD, MC, MA, MB);
116+
matrix_multiply_ref<Tab, Tab, TResult, VF>(
117+
(Tab *)A, (Tab *)B, (TResult *)DRef, MATRIX_M, MATRIX_N, MATRIX_K / VF);
118+
assert(matrix_compare(MATRIX_M, MATRIX_N, (TResult *)D, (TResult *)DRef));
119+
}
120+
121+
template <typename TLow, typename THigh, size_t TM, size_t TN, size_t TK,
122+
layout B_layout, size_t VF>
123+
void test_combo() {
124+
test<TLow, TLow, THigh, TM, TN, TK, B_layout, VF>();
125+
test<TLow, THigh, TLow, TM, TN, TK, B_layout, VF>();
126+
test<TLow, TLow, TLow, TM, TN, TK, B_layout, VF>();
127+
test<TLow, THigh, THigh, TM, TN, TK, B_layout, VF>();
128+
}
129+
130+
template <typename TLow, typename THigh, layout B_layout, size_t VF>
131+
void test_all() {
132+
test_combo<TLow, THigh, /*TM*/ 8, /*TN*/ 16, /*TK*/ 16, B_layout, VF>();
133+
test_combo<TLow, THigh, /*TM*/ 16, /*TN*/ 16, /*TK*/ 16, B_layout, VF>();
134+
test_combo<TLow, THigh, /*TM*/ 1, /*TN*/ 64, /*TK*/ 16, B_layout, VF>();
135+
test_combo<TLow, THigh, /*TM*/ 1, /*TN*/ 64, /*TK*/ 32, B_layout, VF>();
136+
test_combo<TLow, THigh, /*TM*/ 32, /*TN*/ 64, /*TK*/ 16, B_layout, VF>();
137+
test_combo<TLow, THigh, /*TM*/ 32, /*TN*/ 64, /*TK*/ 32, B_layout, VF>();
138+
}

0 commit comments

Comments
 (0)