Skip to content

Commit dae7032

Browse files
authored
[SYCL][E2E] Add tile store test with rowmajor and use.b. (#14698)
1 parent 0b65c98 commit dae7032

6 files changed

+198
-0
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
//==---- joint_matrix_bf16_rowmajorB_load_store.cpp - DPC++ joint_matrix ---==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: aspect-ext_intel_matrix, cpu
9+
10+
// RUN: %{build} -o %t.out
11+
// RUN: %{run} %t.out
12+
13+
#include "../common.hpp"
14+
15+
#define SG_SZ 32
16+
17+
#include "../joint_matrix_bf16_rowmajorB_load_store_impl.hpp"
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
//==- joint_matrix_bf16_rowmajorB_pair_load_store.cpp - DPC++ joint_matrix--==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: aspect-ext_intel_matrix, cpu
9+
10+
// RUN: %{build} -o %t.out
11+
// RUN: %{run} %t.out
12+
13+
#include "../common.hpp"
14+
15+
#define SG_SZ 32
16+
17+
#include "../joint_matrix_bf16_rowmajorB_pair_load_store_impl.hpp"
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//==---- joint_matrix_bf16_rowmajorB_load_store.cpp - DPC++ joint_matrix ---==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: aspect-ext_intel_matrix, cpu
9+
10+
// RUN: %{build} -o %t.out
11+
// RUN: %{run} %t.out
12+
13+
#include "common.hpp"
14+
15+
#include "joint_matrix_bf16_rowmajorB_load_store_impl.hpp"
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
//------------------------------------------------------------------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===-------------------------------------------------------------------------===//
8+
9+
#include <sycl/usm.hpp>
10+
11+
template <typename Tb, unsigned int rows, unsigned int cols>
12+
void joint_B_rowmajor_load_store(Tb *B, Tb *OutB, queue &q) {
13+
14+
range<1> global{1};
15+
range<1> local{1};
16+
17+
q.submit([&](handler &h) {
18+
h.parallel_for<class Load>(
19+
nd_range<1>{global, local}, [=](nd_item<1> it)
20+
#ifdef SG_SZ
21+
[[intel::reqd_sub_group_size(SG_SZ)]]
22+
#endif
23+
{
24+
auto pB =
25+
address_space_cast<sycl::access::address_space::global_space,
26+
sycl::access::decorated::no>(B);
27+
auto pOutB =
28+
address_space_cast<sycl::access::address_space::global_space,
29+
sycl::access::decorated::no>(OutB);
30+
31+
auto sg = it.get_sub_group();
32+
33+
joint_matrix<sub_group, Tb, use::b, rows, cols, layout::row_major> tB;
34+
35+
joint_matrix_load(sg, tB, pB, cols);
36+
ext::intel::experimental::matrix::joint_matrix_store(sg, tB, pOutB,
37+
cols);
38+
}); // parallel_for
39+
}); // queue.submit
40+
41+
q.wait();
42+
}
43+
44+
template <typename Tb, size_t ROW_SIZE, size_t COL_SIZE> void test(queue &q) {
45+
Tb *B = malloc_shared<Tb>(ROW_SIZE * COL_SIZE, q);
46+
Tb *outB = malloc_shared<Tb>(ROW_SIZE * COL_SIZE, q);
47+
48+
matrix_fill(ROW_SIZE, COL_SIZE, B, [](int i, int j) { return i + j; });
49+
50+
joint_B_rowmajor_load_store<Tb, ROW_SIZE, COL_SIZE>(B, outB, q);
51+
52+
assert(matrix_compare(ROW_SIZE, COL_SIZE, outB, B));
53+
54+
free(B, q);
55+
free(outB, q);
56+
}
57+
58+
int main(void) {
59+
queue q;
60+
61+
test<bfloat16, 8, 16>(q);
62+
63+
return 0;
64+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//==- joint_matrix_bf16_rowmajorB_pair_load_store.cpp - DPC++ joint_matrix--==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: aspect-ext_intel_matrix, cpu
9+
10+
// RUN: %{build} -o %t.out
11+
// RUN: %{run} %t.out
12+
13+
#include "common.hpp"
14+
15+
#include "joint_matrix_bf16_rowmajorB_pair_load_store_impl.hpp"
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
//------------------------------------------------------------------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===-------------------------------------------------------------------------===//
8+
9+
#include <sycl/usm.hpp>
10+
11+
template <typename Tb, unsigned rows, unsigned cols, unsigned HW_MAX_COL_SIZE>
12+
void joint_B_rowmajor_pair_load_store(Tb *B, Tb *OutB, queue &q) {
13+
14+
range<1> global{1};
15+
range<1> local{1};
16+
17+
q.submit([&](handler &h) {
18+
h.parallel_for<class Load>(
19+
nd_range<1>{global, local}, [=](nd_item<1> it)
20+
#ifdef SG_SZ
21+
[[intel::reqd_sub_group_size(SG_SZ)]]
22+
#endif
23+
{
24+
auto pB =
25+
address_space_cast<sycl::access::address_space::global_space,
26+
sycl::access::decorated::no>(B);
27+
auto pOutB =
28+
address_space_cast<sycl::access::address_space::global_space,
29+
sycl::access::decorated::no>(OutB);
30+
31+
auto sg = it.get_sub_group();
32+
33+
joint_matrix<sub_group, Tb, use::b, rows, HW_MAX_COL_SIZE,
34+
layout::row_major>
35+
tB[2];
36+
37+
joint_matrix_load(sg, tB[0], pB, cols);
38+
joint_matrix_load(sg, tB[1], pB + HW_MAX_COL_SIZE, cols);
39+
ext::intel::experimental::matrix::joint_matrix_store(sg, tB[0], pOutB,
40+
cols);
41+
ext::intel::experimental::matrix::joint_matrix_store(
42+
sg, tB[1], pOutB + HW_MAX_COL_SIZE, cols);
43+
}); // parallel_for
44+
}); // queue.submit
45+
46+
q.wait();
47+
}
48+
49+
template <typename Tb, size_t ROW_SIZE, size_t COL_SIZE, size_t HW_MAX_COL_SIZE>
50+
void test(queue &q) {
51+
Tb *B = malloc_shared<Tb>(ROW_SIZE * COL_SIZE, q);
52+
Tb *outB = malloc_shared<Tb>(ROW_SIZE * COL_SIZE, q);
53+
54+
matrix_fill(ROW_SIZE, COL_SIZE, B, [](int i, int j) { return i + j; });
55+
56+
joint_B_rowmajor_pair_load_store<Tb, ROW_SIZE, COL_SIZE, HW_MAX_COL_SIZE>(
57+
B, outB, q);
58+
59+
assert(matrix_compare(ROW_SIZE, COL_SIZE, outB, B));
60+
61+
free(B, q);
62+
free(outB, q);
63+
}
64+
65+
int main(void) {
66+
queue q;
67+
68+
test<bfloat16, 8, 32, 16>(q);
69+
return 0;
70+
}

0 commit comments

Comments
 (0)