Skip to content

Commit 851a90a

Browse files
authored
[SYCL][Matrix] Add support for missing matrix combinations for half and bfloat16 types (#15540)
Spec added in #15547 The new combinations are now added as comments. We will uncomment these once IGC support becomes available
1 parent ca5cc18 commit 851a90a

File tree

4 files changed

+122
-16
lines changed

4 files changed

+122
-16
lines changed

sycl/source/detail/device_info.hpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,14 +850,96 @@ struct get_device_info_impl<
850850
matrix_type::sint32, matrix_type::sint32},
851851
{8, 0, 0, 0, 16, 16, matrix_type::fp16, matrix_type::fp16,
852852
matrix_type::fp32, matrix_type::fp32},
853+
{8, 0, 0, 0, 16, 16, matrix_type::fp16, matrix_type::fp16,
854+
matrix_type::fp16, matrix_type::fp32},
855+
{8, 0, 0, 0, 16, 16, matrix_type::fp16, matrix_type::fp16,
856+
matrix_type::fp32, matrix_type::fp16},
857+
{8, 0, 0, 0, 16, 16, matrix_type::fp16, matrix_type::fp16,
858+
matrix_type::fp16, matrix_type::fp16},
859+
{0, 0, 0, 16, 16, 16, matrix_type::fp16, matrix_type::fp16,
860+
matrix_type::fp32, matrix_type::fp16},
861+
{0, 0, 0, 16, 16, 16, matrix_type::fp16, matrix_type::fp16,
862+
matrix_type::fp16, matrix_type::fp16},
863+
{0, 0, 0, 1, 64, 16, matrix_type::fp16, matrix_type::fp16,
864+
matrix_type::fp32, matrix_type::fp32},
865+
{0, 0, 0, 1, 64, 16, matrix_type::fp16, matrix_type::fp16,
866+
matrix_type::fp16, matrix_type::fp32},
867+
{0, 0, 0, 1, 64, 16, matrix_type::fp16, matrix_type::fp16,
868+
matrix_type::fp32, matrix_type::fp16},
869+
{0, 0, 0, 1, 64, 16, matrix_type::fp16, matrix_type::fp16,
870+
matrix_type::fp16, matrix_type::fp16},
871+
{0, 0, 0, 32, 64, 16, matrix_type::fp16, matrix_type::fp16,
872+
matrix_type::fp32, matrix_type::fp32},
873+
{0, 0, 0, 32, 64, 16, matrix_type::fp16, matrix_type::fp16,
874+
matrix_type::fp16, matrix_type::fp32},
875+
{0, 0, 0, 32, 64, 16, matrix_type::fp16, matrix_type::fp16,
876+
matrix_type::fp32, matrix_type::bf16},
877+
{0, 0, 0, 32, 64, 16, matrix_type::fp16, matrix_type::fp16,
878+
matrix_type::fp16, matrix_type::fp16},
879+
{0, 0, 0, 1, 64, 32, matrix_type::fp16, matrix_type::fp16,
880+
matrix_type::fp32, matrix_type::fp32},
881+
{0, 0, 0, 1, 64, 32, matrix_type::fp16, matrix_type::fp16,
882+
matrix_type::fp16, matrix_type::fp32},
883+
{0, 0, 0, 1, 64, 32, matrix_type::fp16, matrix_type::fp16,
884+
matrix_type::fp32, matrix_type::fp16},
885+
{0, 0, 0, 1, 64, 32, matrix_type::fp16, matrix_type::fp16,
886+
matrix_type::fp16, matrix_type::fp16},
887+
{0, 0, 0, 32, 64, 32, matrix_type::fp16, matrix_type::fp16,
888+
matrix_type::fp32, matrix_type::fp32},
889+
{0, 0, 0, 32, 64, 32, matrix_type::fp16, matrix_type::fp16,
890+
matrix_type::fp16, matrix_type::fp32},
891+
{0, 0, 0, 32, 64, 32, matrix_type::fp16, matrix_type::fp16,
892+
matrix_type::fp32, matrix_type::fp16},
893+
{0, 0, 0, 32, 64, 32, matrix_type::fp16, matrix_type::fp16,
894+
matrix_type::fp16, matrix_type::fp16},
895+
{8, 0, 0, 0, 16, 16, matrix_type::bf16, matrix_type::bf16,
896+
matrix_type::bf16, matrix_type::bf16},
897+
{8, 0, 0, 0, 16, 16, matrix_type::bf16, matrix_type::bf16,
898+
matrix_type::fp32, matrix_type::bf16},
899+
{8, 0, 0, 0, 16, 16, matrix_type::bf16, matrix_type::bf16,
900+
matrix_type::bf16, matrix_type::fp32},
853901
{8, 0, 0, 0, 16, 16, matrix_type::bf16, matrix_type::bf16,
854902
matrix_type::fp32, matrix_type::fp32},
855903
{0, 0, 0, 16, 16, 16, matrix_type::bf16, matrix_type::bf16,
856904
matrix_type::fp32, matrix_type::fp32},
905+
{0, 0, 0, 16, 16, 16, matrix_type::bf16, matrix_type::bf16,
906+
matrix_type::bf16, matrix_type::fp32},
907+
{0, 0, 0, 16, 16, 16, matrix_type::bf16, matrix_type::bf16,
908+
matrix_type::fp32, matrix_type::bf16},
909+
{0, 0, 0, 16, 16, 16, matrix_type::bf16, matrix_type::bf16,
910+
matrix_type::bf16, matrix_type::bf16},
857911
{0, 0, 0, 1, 64, 16, matrix_type::bf16, matrix_type::bf16,
858912
matrix_type::fp32, matrix_type::fp32},
913+
{0, 0, 0, 1, 64, 16, matrix_type::bf16, matrix_type::bf16,
914+
matrix_type::bf16, matrix_type::fp32},
915+
{0, 0, 0, 1, 64, 16, matrix_type::bf16, matrix_type::bf16,
916+
matrix_type::fp32, matrix_type::bf16},
917+
{0, 0, 0, 1, 64, 16, matrix_type::bf16, matrix_type::bf16,
918+
matrix_type::bf16, matrix_type::bf16},
859919
{0, 0, 0, 32, 64, 16, matrix_type::bf16, matrix_type::bf16,
860920
matrix_type::fp32, matrix_type::fp32},
921+
{0, 0, 0, 32, 64, 16, matrix_type::bf16, matrix_type::bf16,
922+
matrix_type::bf16, matrix_type::fp32},
923+
{0, 0, 0, 32, 64, 16, matrix_type::bf16, matrix_type::bf16,
924+
matrix_type::fp32, matrix_type::bf16},
925+
{0, 0, 0, 32, 64, 16, matrix_type::bf16, matrix_type::bf16,
926+
matrix_type::bf16, matrix_type::bf16},
927+
{0, 0, 0, 1, 64, 32, matrix_type::bf16, matrix_type::bf16,
928+
matrix_type::fp32, matrix_type::fp32},
929+
{0, 0, 0, 1, 64, 32, matrix_type::bf16, matrix_type::bf16,
930+
matrix_type::bf16, matrix_type::fp32},
931+
{0, 0, 0, 1, 64, 32, matrix_type::bf16, matrix_type::bf16,
932+
matrix_type::fp32, matrix_type::bf16},
933+
{0, 0, 0, 1, 64, 32, matrix_type::bf16, matrix_type::bf16,
934+
matrix_type::bf16, matrix_type::bf16},
935+
{0, 0, 0, 32, 64, 32, matrix_type::bf16, matrix_type::bf16,
936+
matrix_type::fp32, matrix_type::fp32},
937+
{0, 0, 0, 32, 64, 32, matrix_type::bf16, matrix_type::bf16,
938+
matrix_type::bf16, matrix_type::fp32},
939+
{0, 0, 0, 32, 64, 32, matrix_type::bf16, matrix_type::bf16,
940+
matrix_type::fp32, matrix_type::bf16},
941+
{0, 0, 0, 32, 64, 32, matrix_type::bf16, matrix_type::bf16,
942+
matrix_type::bf16, matrix_type::bf16},
861943
{8, 0, 0, 0, 16, 8, matrix_type::tf32, matrix_type::tf32,
862944
matrix_type::fp32, matrix_type::fp32},
863945
};

sycl/test-e2e/Matrix/common.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,17 @@ void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
6363
if constexpr (std::is_same_v<Ta, bfloat16> &&
6464
std::is_same_v<Tc, float>)
6565
acc += make_fp32(va[i]) * make_fp32(vb[i]);
66+
else if constexpr (std::is_same_v<Ta, sycl::half> &&
67+
std::is_same_v<Tc, float>)
68+
acc += (float)va[i] * (float)vb[i];
6669
else if constexpr (std::is_same_v<Ta, float> &&
6770
std::is_same_v<Tc, float> ||
6871
std::is_integral_v<Ta> && std::is_integral_v<Tc> ||
72+
(std::is_same_v<Ta, bfloat16> ||
73+
std::is_same_v<Ta, sycl::half>) ||
6974
(std::is_same_v<Ta, double> &&
7075
std::is_same_v<Tc, double>))
7176
acc += va[i] * vb[i];
72-
else if constexpr (std::is_same_v<Ta, sycl::half> &&
73-
std::is_same_v<Tc, float>)
74-
acc += (float)va[i] * (float)vb[i];
7577
else
7678
assert(false && "Unsupported type in matrix_multiply_ref.");
7779
}

sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,16 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
1414
big_matrix<T2, K / 2, N * 2> &B) {
1515
size_t NDRangeM = M / TM;
1616
size_t NDRangeN = N / TN;
17-
buffer<bfloat16, 2> bufA(A.get_data(), range<2>(M, K));
18-
buffer<bfloat16, 2> bufB(B.get_data(), range<2>(K, N));
19-
buffer<float, 2> bufC((float *)C.get_data(), range<2>(M, N));
17+
buffer<T2, 2> bufA(A.get_data(), range<2>(M, K));
18+
buffer<T2, 2> bufB(B.get_data(), range<2>(K, N));
19+
buffer<T1, 2> bufC((T1 *)C.get_data(), range<2>(M, N));
2020

2121
queue q;
2222
size_t sg_size = get_sg_size<imatrix<T1, TM, TN, TK>>(q);
2323
q.submit([&](handler &cgh) {
24-
auto accC = bufC.get_access<access::mode::read_write>(cgh);
25-
auto accA = bufA.get_access<access::mode::read_write>(cgh);
26-
auto accB = bufB.get_access<access::mode::read_write>(cgh);
24+
accessor accA{bufA, cgh};
25+
accessor accB{bufB, cgh};
26+
accessor accC{bufC, cgh};
2727

2828
cgh.parallel_for<imatrix<T1, TM, TN, TK>>(
2929
nd_range<2>({NDRangeM, NDRangeN * sg_size}, {1, 1 * sg_size}),
@@ -41,13 +41,11 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
4141
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
4242

4343
sub_group sg = spmd_item.get_sub_group();
44-
joint_matrix<sub_group, bfloat16, use::a, TM, TK, layout::row_major>
45-
sub_a;
44+
joint_matrix<sub_group, T2, use::a, TM, TK, layout::row_major> sub_a;
4645
// For B, we assume B has been already VNNIed.
47-
joint_matrix<sub_group, bfloat16, use::b, TK, TN,
48-
layout::ext_intel_packed>
46+
joint_matrix<sub_group, T2, use::b, TK, TN, layout::ext_intel_packed>
4947
sub_b;
50-
joint_matrix<sub_group, float, use::accumulator, TM, TN> sub_c;
48+
joint_matrix<sub_group, T1, use::accumulator, TM, TN> sub_c;
5149

5250
joint_matrix_load(
5351
sg, sub_c,
@@ -122,13 +120,21 @@ int main() {
122120

123121
if (combinations[i].nsize == 16) { // architecture::intel_gpu_pvc
124122
test<bfloat16, float, /*TM*/ 8, /*TN*/ 16, /*TK*/ 16>();
123+
// test<bfloat16, bfloat16, /*TM*/ 8, /*TN*/ 16, /*TK*/ 16>();
125124

126125
// This combination is not currently supported for sub group size = 32 in
127126
// IGC
128127
#if (!defined(SG_SZ) || SG_SZ != 32)
129128
test<bfloat16, float, /*TM*/ 16, /*TN*/ 16, /*TK*/ 16>();
129+
// test<bfloat16, bfloat16, /*TM*/ 16, /*TN*/ 16, /*TK*/ 16>();
130130
test<bfloat16, float, /*TM*/ 1, /*TN*/ 64, /*TK*/ 16>();
131+
// test<bfloat16, bfloat16, /*TM*/ 1, /*TN*/ 64, /*TK*/ 16>();
131132
test<bfloat16, float, /*TM*/ 32, /*TN*/ 64, /*TK*/ 16>();
133+
// test<bfloat16, bfloat16, /*TM*/ 32, /*TN*/ 64, /*TK*/ 16>();
134+
// test<bfloat16, float, /*TM*/ 32, /*TN*/ 64, /*TK*/ 32>();
135+
// test<bfloat16, bfloat16, /*TM*/ 32, /*TN*/ 64, /*TK*/ 32>();
136+
// test<bfloat16, float, /*TM*/ 1, /*TN*/ 64, /*TK*/ 32>();
137+
// test<bfloat16, bfloat16, /*TM*/ 1, /*TN*/ 64, /*TK*/ 32>();
132138
#endif
133139
break;
134140
}

sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ void matrix_multiply(big_matrix<TResult, M, N> &C, big_matrix<T, M, K> &A,
1818
buffer<TResult, 2> bufC(C.get_data(), range<2>(M, N));
1919

2020
queue q;
21-
size_t sg_size = get_sg_size<mult<T, TM, TN, TK>>(q);
21+
size_t sg_size = get_sg_size<mult<TResult, TM, TN, TK>>(q);
2222
q.submit([&](handler &cgh) {
2323
accessor accA{bufA, cgh};
2424
accessor accB{bufB, cgh};
2525
accessor accC{bufC, cgh};
2626

27-
cgh.parallel_for<mult<T, TM, TN, TK>>(
27+
cgh.parallel_for<mult<TResult, TM, TN, TK>>(
2828
nd_range<2>({NDRangeM, NDRangeN * sg_size}, {1, sg_size}),
2929
[=](nd_item<2> spmd_item)
3030
#ifdef SG_SZ
@@ -122,6 +122,22 @@ int main() {
122122

123123
if (combinations[i].nsize == 16) { // architecture::intel_gpu_pvc
124124
test<float, half, 2, /*TM*/ 8, /*TN*/ 16, /*TK*/ 16>();
125+
// test<half, half, 2, /*TM*/ 8, /*TN*/ 16, /*TK*/ 16>();
126+
127+
// This combination is not currently supported for sub group size = 32 in
128+
// IGC
129+
#if (!defined(SG_SZ) || SG_SZ != 32)
130+
// test<float, half, /*TM*/ 16, /*TN*/ 16, /*TK*/ 16>();
131+
// test<half, half, /*TM*/ 16, /*TN*/ 16, /*TK*/ 16>();
132+
// test<float, half, /*TM*/ 1, /*TN*/ 64, /*TK*/ 16>();
133+
// test<half, half, /*TM*/ 1, /*TN*/ 64, /*TK*/ 16>();
134+
// test<float, half, /*TM*/ 32, /*TN*/ 64, /*TK*/ 16>();
135+
// test<half, half, /*TM*/ 32, /*TN*/ 64, /*TK*/ 16>();
136+
// test<float, half, /*TM*/ 1, /*TN*/ 64, /*TK*/ 32>();
137+
// test<half, half, /*TM*/ 1, /*TN*/ 64, /*TK*/ 32>();
138+
// test<float, half, /*TM*/ 32, /*TN*/ 64, /*TK*/ 32>();
139+
// test<half, half, /*TM*/ 32, /*TN*/ 64, /*TK*/ 32>();
140+
#endif
125141
break;
126142
}
127143

0 commit comments

Comments
 (0)