Skip to content

Commit 8461560

Browse files
authored
[SYCL][E2E][Joint Matrix] OOB tests to support more shapes, layouts (#16837)
1 parent 7025de8 commit 8461560

File tree

7 files changed

+126
-97
lines changed

7 files changed

+126
-97
lines changed

sycl/test-e2e/Matrix/Inputs/common.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,10 @@ void matrix_print(unsigned int rows, unsigned int cols, T *mat) {
234234
std::cout << "\n";
235235
}
236236
}
237+
238+
template <typename T, layout Layout> constexpr int vnni_factor() {
239+
if constexpr (Layout != layout::ext_intel_packed)
240+
return 1;
241+
static_assert(sizeof(T) <= 4 && "Unsupported type in vnni_factor().");
242+
return 4 / sizeof(T);
243+
}

sycl/test-e2e/Matrix/Inputs/joint_matrix_out_bounds_impl.hpp

Lines changed: 91 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -9,27 +9,20 @@
99
#include <iostream>
1010
#include <sycl/usm.hpp>
1111

12-
constexpr size_t TM = 8;
13-
constexpr size_t TK = 16;
12+
template <typename Tab, size_t K, layout B_layout> class mult;
1413

15-
template <layout B_layout, unsigned int vnniFactor> class mult;
16-
17-
template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A,
18-
size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C,
19-
size_t NUM_COLS_C, layout B_layout, unsigned int vnniFactor>
14+
template <typename T1, typename T2, size_t M, size_t N, size_t K, size_t TM,
15+
size_t TN, size_t TK, layout A_layout, layout B_layout>
2016
void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) {
21-
size_t M = NUM_ROWS_C;
22-
size_t N = NUM_COLS_C;
23-
size_t K = NUM_COLS_A;
2417

25-
assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * vnniFactor);
2618
// Add one iteration for the out of bounds dpas instruction
2719
size_t NDRangeM = M / TM + (((M % TM) != 0) ? 1 : 0);
28-
size_t NDRangeN = N / TN;
29-
size_t sg_size = get_sg_size<mult<B_layout, vnniFactor>>(q);
20+
size_t NDRangeN = N / TN + (((N % TN) != 0) ? 1 : 0);
21+
size_t sg_size = get_sg_size<mult<T2, K, B_layout>>(q);
22+
std::cout << "SG size: " << sg_size << " ";
3023

3124
q.submit([&](handler &cgh) {
32-
cgh.parallel_for<mult<B_layout, vnniFactor>>(
25+
cgh.parallel_for<mult<T2, K, B_layout>>(
3326
nd_range<2>({NDRangeM, NDRangeN * sg_size}, {1, 1 * sg_size}),
3427
[=](nd_item<2> spmd_item)
3528
#ifdef SG_SZ
@@ -45,6 +38,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) {
4538
auto pC =
4639
address_space_cast<sycl::access::address_space::global_space,
4740
sycl::access::decorated::no>(C);
41+
4842
// The submatrix API has to be accessed by all the workitems in a
4943
// subgroup these functions will be called once by the subgroup no
5044
// code divergence between the workitems
@@ -54,27 +48,41 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) {
5448
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
5549

5650
sub_group sg = spmd_item.get_sub_group();
57-
joint_matrix<sub_group, bfloat16, use::a, TM, TK, layout::row_major>
58-
sub_a;
59-
60-
// For B, since current implementation does not support non-packed
61-
// layout, users need to specify the packed_b layout.
62-
joint_matrix<sub_group, bfloat16, use::b, TK, TN, B_layout> sub_b;
63-
joint_matrix<sub_group, float, use::accumulator, TM, TN> sub_c;
64-
// bounds-checked load where width and height are added
51+
joint_matrix<sub_group, T2, use::a, TM, TK, A_layout> sub_a;
52+
joint_matrix<sub_group, T2, use::b, TK, TN, B_layout> sub_b;
53+
joint_matrix<sub_group, T1, use::accumulator, TM, TN> sub_c;
54+
55+
// bounds-checked fill where width and height are added
6556
ext::intel::experimental::matrix::joint_matrix_fill_checked(
6657
sg, sub_c, 1, M, N, sg_startx * TM, sg_starty / sg_size * TN);
58+
6759
for (int k = 0; k < K; k += TK) {
6860
// bounds-checked load where width and height are added
69-
ext::intel::experimental::matrix::joint_matrix_load_checked(
70-
sg, sub_a, pA, K, M, K, sg_startx * TM, k);
71-
// Assume we alreay in vnni format.
61+
// params order: Stride, Height, Width, CoordX, CoordY
62+
if constexpr (A_layout == layout::row_major) {
63+
ext::intel::experimental::matrix::joint_matrix_load_checked(
64+
sg, sub_a, pA, K, M, K, sg_startx * TM, k);
65+
} else {
66+
ext::intel::experimental::matrix::joint_matrix_load_checked(
67+
sg, sub_a, pA, M, K, M, k, sg_startx * TM);
68+
}
69+
7270
// bounds-checked load where width and height are added
73-
ext::intel::experimental::matrix::joint_matrix_load_checked(
74-
sg, sub_b, pB, N * vnniFactor, K / vnniFactor, N * vnniFactor,
75-
k / vnniFactor, sg_starty / sg_size * TN * vnniFactor);
71+
// params order: Stride, Height, Width, CoordX, CoordY
72+
if constexpr (B_layout != layout::col_major) {
73+
constexpr unsigned int vnniFactor = vnni_factor<T2, B_layout>();
74+
ext::intel::experimental::matrix::joint_matrix_load_checked(
75+
sg, sub_b, pB, N * vnniFactor, K / vnniFactor,
76+
N * vnniFactor, k / vnniFactor,
77+
sg_starty / sg_size * TN * vnniFactor);
78+
} else {
79+
ext::intel::experimental::matrix::joint_matrix_load_checked(
80+
sg, sub_b, pB, K, N, K, sg_starty / sg_size * TN, k);
81+
}
82+
7683
joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c);
7784
}
85+
7886
// bounds-checked store where width and height are added
7987
ext::intel::experimental::matrix::joint_matrix_store_checked(
8088
sg, sub_c, pC, N, layout::row_major, M, N, sg_startx * TM,
@@ -83,42 +91,69 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) {
8391
}).wait();
8492
}
8593

86-
int main() {
87-
static constexpr size_t MATRIX_M = 1024 + 14;
88-
static constexpr size_t MATRIX_N = 1024;
89-
static constexpr unsigned int vnniFactor = 2;
90-
94+
template <typename Tab, typename Tc, size_t MATRIX_M, size_t MATRIX_N,
95+
size_t MATRIX_K, size_t TM, size_t TN, size_t TK, layout A_layout,
96+
layout B_layout>
97+
void test() {
98+
std::cout << MATRIX_M << "x" << MATRIX_N << "x" << MATRIX_K << ", " << TM
99+
<< "x" << TN << "x" << TK << ": ";
91100
queue q;
92-
bfloat16 *A = malloc_shared<bfloat16>(MATRIX_M * MATRIX_K, q);
93-
bfloat16 *B = malloc_shared<bfloat16>(MATRIX_K * MATRIX_N, q);
94-
bfloat16 *vnniB = malloc_shared<bfloat16>(MATRIX_K * MATRIX_N, q);
95-
float *C = malloc_shared<float>(MATRIX_M * MATRIX_N, q);
96-
float *D = malloc_shared<float>(MATRIX_M * MATRIX_N, q);
97-
98-
matrix_rand(MATRIX_M, MATRIX_K, A, (bfloat16)5);
99-
matrix_rand(MATRIX_K, MATRIX_N, B, (bfloat16)5);
100-
matrix_fill(MATRIX_M, MATRIX_N, C, (float)1);
101-
matrix_fill(MATRIX_M, MATRIX_N, D, (float)1);
102-
103-
matrix_vnni<bfloat16>(MATRIX_K, MATRIX_N, B, vnniB, vnniFactor);
104101

102+
// reference data
103+
Tab *A = malloc_shared<Tab>(MATRIX_M * MATRIX_K, q);
104+
Tab *B = malloc_shared<Tab>(MATRIX_K * MATRIX_N, q);
105+
Tc *C = malloc_shared<Tc>(MATRIX_M * MATRIX_N, q);
106+
Tc *D = malloc_shared<Tc>(MATRIX_M * MATRIX_N, q);
107+
matrix_rand(MATRIX_M, MATRIX_K, A, (Tab)5);
108+
matrix_rand(MATRIX_K, MATRIX_N, B, (Tab)5);
109+
matrix_fill(MATRIX_M, MATRIX_N, D, (Tc)1);
105110
matrix_multiply_ref(A, B, D, MATRIX_M, MATRIX_N, MATRIX_K);
106-
matrix_multiply<float, bfloat16, MATRIX_M, MATRIX_K, MATRIX_K / vnniFactor,
107-
MATRIX_N * vnniFactor, MATRIX_M, MATRIX_N,
108-
layout::ext_intel_packed, vnniFactor>(C, A, vnniB, q);
109-
bool res = matrix_compare(MATRIX_M, MATRIX_N, C, D);
110-
111-
matrix_multiply<float, bfloat16, MATRIX_M, MATRIX_K, MATRIX_K, MATRIX_N,
112-
MATRIX_M, MATRIX_N, layout::row_major, 1>(C, A, B, q);
113-
res = res && matrix_compare(MATRIX_M, MATRIX_N, C, D);
114111

115-
std::cout << (res ? "passed" : "failed") << std::endl;
112+
// test data
113+
if constexpr (A_layout == layout::col_major) {
114+
Tab *colA = malloc_shared<Tab>(MATRIX_K * MATRIX_M, q);
115+
matrix_transpose(MATRIX_M, MATRIX_K, colA, A);
116+
Tab *tmp = A;
117+
A = colA;
118+
free(tmp, q);
119+
}
120+
121+
if constexpr (B_layout == layout::col_major) {
122+
Tab *colB = malloc_shared<Tab>(MATRIX_N * MATRIX_K, q);
123+
matrix_transpose(MATRIX_K, MATRIX_N, colB, B);
124+
Tab *tmp = B;
125+
B = colB;
126+
free(tmp, q);
127+
}
128+
129+
if constexpr (B_layout == layout::ext_intel_packed) {
130+
Tab *vnniB = malloc_shared<Tab>(MATRIX_K * MATRIX_N, q);
131+
matrix_vnni(MATRIX_K, MATRIX_N, B, vnniB, vnni_factor<Tab, B_layout>());
132+
Tab *tmp = B;
133+
B = vnniB;
134+
free(tmp, q);
135+
}
136+
137+
matrix_multiply<Tc, Tab, MATRIX_M, MATRIX_N, MATRIX_K, TM, TN, TK, A_layout,
138+
B_layout>(C, A, B, q);
139+
assert(matrix_compare(MATRIX_M, MATRIX_N, C, D));
140+
std::cout << "passed" << std::endl;
116141

117142
free(A, q);
118143
free(B, q);
119-
free(vnniB, q);
120144
free(C, q);
121145
free(D, q);
146+
}
122147

123-
return !res;
148+
template <layout A_layout, layout B_layout> void test_all() {
149+
std::cout << "bf16: ";
150+
test<bfloat16, float, /*MATRIX_M*/ 1024 + 20, /*MATRIX_N*/ 1024 + 20,
151+
/*MATRIX_K*/ 1024 + 24, /*TM*/ 8, /*TN*/ 16, /*TK*/ 16, A_layout,
152+
B_layout>();
153+
std::cout << "half: ";
154+
test<half, float, 1024 + 20, 1024 + 20, 1024 + 24, 8, 16, 16, A_layout,
155+
B_layout>();
156+
std::cout << "int8: ";
157+
test<int8_t, int32_t, 1024, 1024 + 20, 1024 + 24, 8, 16, 32, A_layout,
158+
B_layout>();
124159
}

sycl/test-e2e/Matrix/SG32/joint_matrix_out_bounds.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
// REQUIRES-INTEL-DRIVER: lin: 27501, win: 101.4943
1010

1111
// UNSUPPORTED: gpu-intel-dg2, cpu
12+
// UNSUPPORTED-INTENDED: Checked load/stores are not supported by DG2 and CPU HW
1213

1314
// RUN: %{build} -o %t.out
1415
// RUN: %{run} %t.out
@@ -17,9 +18,12 @@
1718
// XFAIL-TRACKER: GSD-4181
1819

1920
#include "common.hpp"
20-
2121
#define SG_SZ 32
22-
constexpr size_t TN = 16;
23-
constexpr size_t MATRIX_K = 1024 + 24;
24-
2522
#include "joint_matrix_out_bounds_impl.hpp"
23+
24+
int main() {
25+
std::cout << "A row major, B row major:\n";
26+
test_all<layout::row_major, layout::row_major>();
27+
std::cout << "A row major, B packed:\n";
28+
test_all<layout::row_major, layout::ext_intel_packed>();
29+
}

sycl/test-e2e/Matrix/joint_matrix_out_bounds.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,17 @@
88
// REQUIRES: aspect-ext_intel_matrix
99

1010
// UNSUPPORTED: gpu-intel-dg2, cpu
11+
// UNSUPPORTED-INTENDED: Checked load/stores are not supported by DG2 and CPU HW
1112

1213
// RUN: %{build} -o %t.out
1314
// RUN: %{run} %t.out
1415

1516
#include "common.hpp"
16-
17-
constexpr size_t TN = 16;
18-
constexpr size_t MATRIX_K = 1024 + 24;
19-
2017
#include "joint_matrix_out_bounds_impl.hpp"
18+
19+
int main() {
20+
std::cout << "A row major, B row major:\n";
21+
test_all<layout::row_major, layout::row_major>();
22+
std::cout << "A row major, B packed:\n";
23+
test_all<layout::row_major, layout::ext_intel_packed>();
24+
}
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,28 @@
1-
//==-------- joint_matrix_unaligned_k.cpp - DPC++ joint_matrix-------------==//
1+
//==----joint_matrix_out_bounds_colmajor.cpp - DPC++ joint_matrix---------==//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
88
// REQUIRES: aspect-ext_intel_matrix
9-
// REQUIRES-INTEL-DRIVER: lin: 27501, win: 101.4943
109

1110
// UNSUPPORTED: gpu-intel-dg2, cpu
11+
// UNSUPPORTED-INTENDED: Checked load/stores are not supported by DG2 and CPU HW
1212

1313
// RUN: %{build} -o %t.out
1414
// RUN: %{run} %t.out
1515

16+
// RUN: %{build} -o %t32.out -DSG_SZ=32
17+
// RUN: %{run} %t32.out
18+
1619
// XFAIL:gpu
17-
// XFAIL-TRACKER: GSD-4181
20+
// XFAIL-TRACKER: GSD-5768
1821

1922
#include "common.hpp"
20-
21-
#define SG_SZ 32
22-
constexpr size_t TN = 16;
23-
static constexpr size_t MATRIX_K = 1024 + 14;
24-
2523
#include "joint_matrix_out_bounds_impl.hpp"
24+
25+
int main() {
26+
std::cout << "A col major, B col major:\n";
27+
test_all<layout::col_major, layout::col_major>();
28+
}

sycl/test-e2e/Matrix/joint_matrix_unaligned_k.cpp

Lines changed: 0 additions & 20 deletions
This file was deleted.

sycl/test/e2e_test_requirements/no-unsupported-without-info.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
// tests to match the required format and in that case you should just update
5555
// (i.e. reduce) the number and the list below.
5656
//
57-
// NUMBER-OF-UNSUPPORTED-WITHOUT-INFO: 375
57+
// NUMBER-OF-UNSUPPORTED-WITHOUT-INFO: 371
5858
//
5959
// List of improperly UNSUPPORTED tests.
6060
// Remove the CHECK once the test has been properly UNSUPPORTED.
@@ -273,22 +273,18 @@
273273
// CHECK-NEXT: Matrix/SG32/joint_matrix_down_convert.cpp
274274
// CHECK-NEXT: Matrix/SG32/joint_matrix_half.cpp
275275
// CHECK-NEXT: Matrix/SG32/joint_matrix_int8_rowmajorA_rowmajorB.cpp
276-
// CHECK-NEXT: Matrix/SG32/joint_matrix_out_bounds.cpp
277276
// CHECK-NEXT: Matrix/SG32/joint_matrix_prefetch.cpp
278277
// CHECK-NEXT: Matrix/SG32/joint_matrix_rowmajorA_rowmajorB.cpp
279278
// CHECK-NEXT: Matrix/SG32/joint_matrix_ss_int8.cpp
280279
// CHECK-NEXT: Matrix/SG32/joint_matrix_su_int8.cpp
281280
// CHECK-NEXT: Matrix/SG32/joint_matrix_transposeC.cpp
282-
// CHECK-NEXT: Matrix/SG32/joint_matrix_unaligned_k.cpp
283281
// CHECK-NEXT: Matrix/SG32/joint_matrix_us_int8.cpp
284282
// CHECK-NEXT: Matrix/SG32/joint_matrix_uu_int8.cpp
285283
// CHECK-NEXT: Matrix/joint_matrix_annotated_ptr.cpp
286284
// CHECK-NEXT: Matrix/joint_matrix_bf16_fill_k_cache_OOB.cpp
287285
// CHECK-NEXT: Matrix/joint_matrix_bf16_fill_k_cache_prefetch.cpp
288286
// CHECK-NEXT: Matrix/joint_matrix_down_convert.cpp
289-
// CHECK-NEXT: Matrix/joint_matrix_out_bounds.cpp
290287
// CHECK-NEXT: Matrix/joint_matrix_rowmajorA_rowmajorB.cpp
291-
// CHECK-NEXT: Matrix/joint_matrix_unaligned_k.cpp
292288
// CHECK-NEXT: NewOffloadDriver/aot-gpu.cpp
293289
// CHECK-NEXT: NewOffloadDriver/spirv_device_obj_smoke.cpp
294290
// CHECK-NEXT: NonUniformGroups/ballot_group.cpp

0 commit comments

Comments
 (0)