Skip to content

Commit 8df7dab

Browse files
authored
[SYCL][Joint Matrix][E2E] Fix several tests (#16548)
with this fix joint_matrix_colA_rowB_colC.cpp is passing on CPU
1 parent e9822b2 commit 8df7dab

8 files changed

+29
-40
lines changed

sycl/test-e2e/Matrix/SG32/joint_matrix_bfloat16_colmajorA_colmajorB.cpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,7 @@
1616

1717
// XFAIL: gpu
1818
// XFAIL-TRACKER: GSD-5768
19-
#include "../common.hpp"
20-
#include <iostream>
21-
#include <sycl/detail/core.hpp>
22-
#include <sycl/ext/oneapi/matrix/matrix.hpp>
23-
24-
using namespace sycl;
25-
using namespace sycl::ext::oneapi::experimental::matrix;
26-
using bfloat16 = sycl::ext::oneapi::bfloat16;
2719

20+
#include "../common.hpp"
2821
#define SG_SZ 32
29-
constexpr size_t TN = 16;
30-
3122
#include "../joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp"

sycl/test-e2e/Matrix/SG32/joint_matrix_colA_rowB_colC.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,9 @@
1111
// RUN: %{build} -o %t.out
1212
// RUN: %{run} %t.out
1313

14-
// XFAIL: run-mode
14+
// XFAIL: gpu && run-mode
1515
// XFAIL-TRACKER: GSD-5768
1616

1717
#include "../common.hpp"
18-
1918
#define SG_SZ 32
20-
constexpr size_t TN = 16;
21-
2219
#include "../joint_matrix_colA_rowB_colC_impl.hpp"

sycl/test-e2e/Matrix/common.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) {
174174
std::is_same_v<T2, double>))) {
175175
float diff = std::fabs(src[i * cols + j] - (T1)ref[i * cols + j]);
176176
if (diff > FLOAT_EPSILON || std::isnan(src[i * cols + j])) {
177-
std::cout << "Incorrect result in matrix. "
177+
std::cerr << "Incorrect result in matrix. "
178178
<< "i: " << i << ", j: " << j
179179
<< ", Ref: " << (T1)ref[i * cols + j]
180180
<< ", Val: " << src[i * cols + j] << ", Diff: " << diff
@@ -183,14 +183,14 @@ bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) {
183183
}
184184
} else if constexpr (exact || std::is_integral_v<T1>) {
185185
if (src[i * cols + j] != ref[i * cols + j]) {
186-
std::cout << "Incorrect result in matrix."
186+
std::cerr << "Incorrect result in matrix."
187187
<< "i: " << i << ", j: " << j
188188
<< ", Ref: " << ref[i * cols + j]
189189
<< ", Val: " << src[i * cols + j] << "\n";
190190
return false;
191191
}
192192
} else {
193-
std::cout << "Unsupported type in matrix_compare\n";
193+
std::cerr << "Unsupported type in matrix_compare\n";
194194
return false;
195195
}
196196
}

sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i
119119
// along the workgroup prefetch for B matrix. For A matrix, sgId is
120120
// enough.
121121
size_t pm1B = sgId / 16; // prefetch m1 (sgId/16)
122-
size_t pn1B = sgId & 0x15; // prefetch n1 (sgId%16)
122+
size_t pn1B = sgId & 0xF; // prefetch n1 (sgId%16)
123123
#else // VNNI
124124
size_t pm1B = sgId / 8; // prefetch m1 (sgId/8)
125125
size_t pn1B = sgId & 0x7; // prefetch n1 (sgId%8)

sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,4 @@
1717
// XFAIL-TRACKER: GSD-5768
1818

1919
#include "common.hpp"
20-
21-
constexpr size_t TN = 16;
22-
2320
#include "joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp"

sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#define TM 8
10-
#define TK 16
9+
constexpr size_t TM = 8;
10+
constexpr size_t TN = 16;
11+
constexpr size_t TK = 16;
1112

1213
template <typename T1, typename T2, size_t M, size_t N, size_t K>
1314
void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
@@ -43,7 +44,6 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
4344
sub_group sg = spmd_item.get_sub_group();
4445
joint_matrix<sub_group, bfloat16, use::a, TM, TK, layout::col_major>
4546
sub_a;
46-
// For B, we assume B has been already VNNIed.
4747
joint_matrix<sub_group, bfloat16, use::b, TK, TN, layout::col_major>
4848
sub_b;
4949
joint_matrix<sub_group, float, use::accumulator, TM, TN> sub_c;

sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,8 @@
1010
// RUN: %{build} -o %t.out
1111
// RUN: %{run} %t.out
1212

13-
// XFAIL: run-mode
13+
// XFAIL: gpu && run-mode
1414
// XFAIL-TRACKER: GSD-5768
1515

1616
#include "common.hpp"
17-
18-
constexpr size_t TN = 16;
19-
2017
#include "joint_matrix_colA_rowB_colC_impl.hpp"

sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <sycl/usm.hpp>
1212

1313
constexpr size_t TM = 8;
14+
constexpr size_t TN = 16;
1415
constexpr size_t TK = 16;
1516

1617
template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A,
@@ -60,14 +61,15 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) {
6061
joint_matrix<sub_group, float, use::accumulator, TM, TN> sub_c;
6162
joint_matrix_fill(sg, sub_c, 1);
6263
for (int k = 0; k < K; k += TK) {
63-
joint_matrix_load(sg, sub_a, pA + (sg_startx * TM) * K + k, K);
64+
joint_matrix_load(sg, sub_a, pA + k * M + sg_startx * TM, M);
6465
joint_matrix_load(sg, sub_b, pB + k * N + sg_starty / sg_size * TN,
6566
N);
6667
joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c);
6768
}
68-
joint_matrix_store(
69-
sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / sg_size * TN,
70-
N, layout::col_major);
69+
joint_matrix_store(sg, sub_c,
70+
pC + (sg_startx * TM) +
71+
(sg_starty / sg_size * TN) * M,
72+
M, layout::col_major);
7173
}); // parallel for
7274
}).wait();
7375
}
@@ -76,23 +78,28 @@ int main() {
7678
static constexpr size_t MATRIX_M = 1024;
7779
static constexpr size_t MATRIX_N = 1024;
7880
static constexpr size_t MATRIX_K = 1024;
81+
7982
queue q;
80-
bfloat16 *A = malloc_shared<bfloat16>(MATRIX_M * MATRIX_K, q);
83+
bfloat16 *A = malloc_shared<bfloat16>(MATRIX_K * MATRIX_M, q);
8184
bfloat16 *B = malloc_shared<bfloat16>(MATRIX_K * MATRIX_N, q);
82-
float *C = malloc_shared<float>(MATRIX_M * MATRIX_N, q);
83-
float *D = malloc_shared<float>(MATRIX_M * MATRIX_N, q);
85+
float *C = malloc_shared<float>(MATRIX_N * MATRIX_M, q);
86+
float *D = malloc_shared<float>(MATRIX_N * MATRIX_M, q);
8487

85-
matrix_rand(MATRIX_M, MATRIX_K, A, (bfloat16)5);
88+
matrix_rand(MATRIX_K, MATRIX_M, A, (bfloat16)5);
8689
matrix_rand(MATRIX_K, MATRIX_N, B, (bfloat16)5);
87-
matrix_fill(MATRIX_M, MATRIX_N, C, (float)1.0);
88-
matrix_fill(MATRIX_M, MATRIX_N, D, (float)1.0);
90+
matrix_fill(MATRIX_N, MATRIX_M, D, (float)1.0);
8991

9092
matrix_multiply<float, bfloat16, MATRIX_M, MATRIX_K, MATRIX_K, MATRIX_N,
9193
MATRIX_M, MATRIX_N>(C, A, B, q);
9294
matrix_multiply_ref(A, B, D, MATRIX_M, MATRIX_N, MATRIX_K,
93-
true /*transposed c*/);
95+
/*transposed c*/ true, /*colmajor a*/ true);
96+
97+
bool res = matrix_compare(MATRIX_N, MATRIX_M, C, D);
9498

95-
bool res = matrix_compare(MATRIX_M, MATRIX_N, C, D);
99+
sycl::free(A, q);
100+
sycl::free(B, q);
101+
sycl::free(C, q);
102+
sycl::free(D, q);
96103

97104
std::cout << (res ? "passed" : "failed") << std::endl;
98105
return !res;

0 commit comments

Comments
 (0)