Skip to content

Commit 3ff6428

Browse files
authored
[SYCL][E2E][Joint Matrix] Add half colmajor A colmajor B test (#16683)
- Add test for half colmajor A, colmajor B load for 8x16x16 - Refactor to get rid of unnecessary SG32-specific file
1 parent ab25aa3 commit 3ff6428

File tree

3 files changed

+66
-65
lines changed

3 files changed

+66
-65
lines changed

sycl/test-e2e/Matrix/SG32/joint_matrix_bfloat16_colmajorA_colmajorB.cpp

Lines changed: 0 additions & 22 deletions
This file was deleted.

sycl/test-e2e/Matrix/Inputs/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp renamed to sycl/test-e2e/Matrix/joint_matrix_16bit_colmajorA_colmajorB.cpp

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,51 @@
1-
//==-joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp- DPC++ joint_matrix-==//
1+
//==-joint_matrix_16bit_colmajorA_colmajorB.cpp- DPC++ joint_matrix-==//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
88

9+
// This tests support of col major layout for matrix B which does transpose and
10+
// then VNNI transform. This is currently only available on AMX
11+
12+
// REQUIRES: aspect-ext_intel_matrix
13+
14+
// RUN: %{build} -o %t.out
15+
// RUN: %{run} %t.out
16+
// RUN: %{build} -o %t32.out -DSG_SZ=32
17+
// RUN: %{run} %t32.out
18+
19+
// XFAIL: gpu
20+
// XFAIL-TRACKER: GSD-5768
21+
22+
#include "common.hpp"
23+
924
constexpr size_t TM = 8;
1025
constexpr size_t TN = 16;
1126
constexpr size_t TK = 16;
1227

28+
template <typename T> class imatrix;
29+
1330
template <typename T1, typename T2, size_t M, size_t N, size_t K>
1431
void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
1532
big_matrix<T2, K, N> &B) {
1633
size_t NDRangeM = M / TM;
1734
size_t NDRangeN = N / TN;
18-
buffer<bfloat16, 2> bufA(A.get_data(), range<2>(M, K));
19-
buffer<bfloat16, 2> bufB(B.get_data(), range<2>(K, N));
35+
buffer<T2, 2> bufA(A.get_data(), range<2>(M, K));
36+
buffer<T2, 2> bufB(B.get_data(), range<2>(K, N));
2037
buffer<float, 2> bufC((float *)C.get_data(), range<2>(M, N));
2138

2239
queue q;
23-
size_t sg_size = get_sg_size<class imatrix>(q);
40+
size_t sg_size = get_sg_size<class imatrix<T2>>(q);
41+
std::cout << "subgroup size " << sg_size << " ";
42+
2443
q.submit([&](handler &cgh) {
2544
auto accC = bufC.get_access<access::mode::read_write>(cgh);
26-
auto accA = bufA.get_access<access::mode::read_write>(cgh);
27-
auto accB = bufB.get_access<access::mode::read_write>(cgh);
45+
auto accA = bufA.template get_access<access::mode::read_write>(cgh);
46+
auto accB = bufB.template get_access<access::mode::read_write>(cgh);
2847

29-
cgh.parallel_for<class imatrix>(
48+
cgh.parallel_for<class imatrix<T2>>(
3049
nd_range<2>({NDRangeM, NDRangeN * sg_size}, {1, 1 * sg_size}),
3150
[=](nd_item<2> spmd_item)
3251
#ifdef SG_SZ
@@ -42,10 +61,8 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
4261
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
4362

4463
sub_group sg = spmd_item.get_sub_group();
45-
joint_matrix<sub_group, bfloat16, use::a, TM, TK, layout::col_major>
46-
sub_a;
47-
joint_matrix<sub_group, bfloat16, use::b, TK, TN, layout::col_major>
48-
sub_b;
64+
joint_matrix<sub_group, T2, use::a, TM, TK, layout::col_major> sub_a;
65+
joint_matrix<sub_group, T2, use::b, TK, TN, layout::col_major> sub_b;
4966
joint_matrix<sub_group, float, use::accumulator, TM, TN> sub_c;
5067

5168
joint_matrix_load(
@@ -75,31 +92,57 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
7592
}).wait();
7693
}
7794

78-
int main() {
95+
template <typename T> void test() {
7996
static constexpr size_t MATRIX_M = TM * 2;
8097
static constexpr size_t MATRIX_N = TN * 2;
8198
static constexpr size_t MATRIX_K = TK * 2;
82-
bfloat16 A[MATRIX_K][MATRIX_M];
83-
bfloat16 B[MATRIX_N][MATRIX_K];
99+
T A[MATRIX_K][MATRIX_M];
100+
T B[MATRIX_N][MATRIX_K];
84101
float C[MATRIX_M][MATRIX_N];
85102
float D[MATRIX_M][MATRIX_N];
86103

87-
matrix_fill(MATRIX_K, MATRIX_M, (bfloat16 *)A,
104+
matrix_fill(MATRIX_K, MATRIX_M, (T *)A,
88105
[](int i, int j) { return 1.0f * (i + j); });
89-
matrix_fill(MATRIX_N, MATRIX_K, (bfloat16 *)B,
106+
matrix_fill(MATRIX_N, MATRIX_K, (T *)B,
90107
[](int i, int j) { return 2.0f * i + 3.0f * j; });
91108
matrix_fill(MATRIX_M, MATRIX_N, (float *)C, 1.0f);
92109
matrix_fill(MATRIX_M, MATRIX_N, (float *)D, 1.0f);
93110

94111
big_matrix<float, MATRIX_M, MATRIX_N> MC((float *)&C);
95112
big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
96-
big_matrix<bfloat16, MATRIX_M, MATRIX_K> MA((bfloat16 *)&A);
97-
big_matrix<bfloat16, MATRIX_K, MATRIX_N> MB((bfloat16 *)&B);
113+
big_matrix<T, MATRIX_M, MATRIX_K> MA((T *)&A);
114+
big_matrix<T, MATRIX_K, MATRIX_N> MB((T *)&B);
98115
matrix_multiply(MC, MA, MB);
99-
matrix_multiply_ref((bfloat16 *)A, (bfloat16 *)B, (float *)D, MATRIX_M,
100-
MATRIX_N, MATRIX_K, false, true, true);
116+
matrix_multiply_ref((T *)A, (T *)B, (float *)D, MATRIX_M, MATRIX_N, MATRIX_K,
117+
false, true, true);
118+
119+
assert(matrix_compare(MATRIX_M, MATRIX_N, (float *)C, (float *)D));
120+
std::cout << "passed" << std::endl;
121+
}
122+
123+
int main() {
124+
queue q;
125+
std::vector<combination> combinations =
126+
q.get_device().get_info<syclex::info::device::matrix_combinations>();
127+
bool bf16_run = false;
128+
bool half_run = false;
129+
130+
for (auto &combination : combinations) {
131+
if (!bf16_run && combination.atype == matrix_type::bf16) {
132+
std::cout << "bf16 ";
133+
test<bfloat16>();
134+
bf16_run = true;
135+
}
136+
137+
if (!half_run && combination.atype == matrix_type::fp16) {
138+
std::cout << "half ";
139+
test<half>();
140+
half_run = true;
141+
}
142+
143+
if (bf16_run && half_run)
144+
break;
145+
}
101146

102-
bool res = matrix_compare(MATRIX_M, MATRIX_N, (float *)C, (float *)D);
103-
std::cout << (res ? "passed" : "failed") << std::endl;
104-
return !res;
147+
return 0;
105148
}

sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB.cpp

Lines changed: 0 additions & 20 deletions
This file was deleted.

0 commit comments

Comments
 (0)