1
- // ==-joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp - DPC++ joint_matrix-==//
1
+ // ==-joint_matrix_16bit_colmajorA_colmajorB.cpp - DPC++ joint_matrix-==//
2
2
//
3
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
4
// See https://llvm.org/LICENSE.txt for license information.
5
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
8
9
+ // This tests support of col major layout for matrix B which does transpose and
10
+ // then VNNI transform. This is currently only available on AMX
11
+
12
+ // REQUIRES: aspect-ext_intel_matrix
13
+
14
+ // RUN: %{build} -o %t.out
15
+ // RUN: %{run} %t.out
16
+ // RUN: %{build} -o %t32.out -DSG_SZ=32
17
+ // RUN: %{run} %t32.out
18
+
19
+ // XFAIL: gpu
20
+ // XFAIL-TRACKER: GSD-5768
21
+
22
+ #include " common.hpp"
23
+
9
24
constexpr size_t TM = 8 ;
10
25
constexpr size_t TN = 16 ;
11
26
constexpr size_t TK = 16 ;
12
27
28
+ template <typename T> class imatrix ;
29
+
13
30
template <typename T1, typename T2, size_t M, size_t N, size_t K>
14
31
void matrix_multiply (big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
15
32
big_matrix<T2, K, N> &B) {
16
33
size_t NDRangeM = M / TM;
17
34
size_t NDRangeN = N / TN;
18
- buffer<bfloat16 , 2 > bufA (A.get_data (), range<2 >(M, K));
19
- buffer<bfloat16 , 2 > bufB (B.get_data (), range<2 >(K, N));
35
+ buffer<T2 , 2 > bufA (A.get_data (), range<2 >(M, K));
36
+ buffer<T2 , 2 > bufB (B.get_data (), range<2 >(K, N));
20
37
buffer<float , 2 > bufC ((float *)C.get_data (), range<2 >(M, N));
21
38
22
39
queue q;
23
- size_t sg_size = get_sg_size<class imatrix >(q);
40
+ size_t sg_size = get_sg_size<class imatrix <T2>>(q);
41
+ std::cout << " subgroup size " << sg_size << " " ;
42
+
24
43
q.submit ([&](handler &cgh) {
25
44
auto accC = bufC.get_access <access::mode::read_write>(cgh);
26
- auto accA = bufA.get_access <access::mode::read_write>(cgh);
27
- auto accB = bufB.get_access <access::mode::read_write>(cgh);
45
+ auto accA = bufA.template get_access <access::mode::read_write>(cgh);
46
+ auto accB = bufB.template get_access <access::mode::read_write>(cgh);
28
47
29
- cgh.parallel_for <class imatrix >(
48
+ cgh.parallel_for <class imatrix <T2> >(
30
49
nd_range<2 >({NDRangeM, NDRangeN * sg_size}, {1 , 1 * sg_size}),
31
50
[=](nd_item<2 > spmd_item)
32
51
#ifdef SG_SZ
@@ -42,10 +61,8 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
42
61
const auto sg_starty = global_idy - spmd_item.get_local_id (1 );
43
62
44
63
sub_group sg = spmd_item.get_sub_group ();
45
- joint_matrix<sub_group, bfloat16, use::a, TM, TK, layout::col_major>
46
- sub_a;
47
- joint_matrix<sub_group, bfloat16, use::b, TK, TN, layout::col_major>
48
- sub_b;
64
+ joint_matrix<sub_group, T2, use::a, TM, TK, layout::col_major> sub_a;
65
+ joint_matrix<sub_group, T2, use::b, TK, TN, layout::col_major> sub_b;
49
66
joint_matrix<sub_group, float , use::accumulator, TM, TN> sub_c;
50
67
51
68
joint_matrix_load (
@@ -75,31 +92,57 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
75
92
}).wait ();
76
93
}
77
94
78
- int main () {
95
+ template < typename T> void test () {
79
96
static constexpr size_t MATRIX_M = TM * 2 ;
80
97
static constexpr size_t MATRIX_N = TN * 2 ;
81
98
static constexpr size_t MATRIX_K = TK * 2 ;
82
- bfloat16 A[MATRIX_K][MATRIX_M];
83
- bfloat16 B[MATRIX_N][MATRIX_K];
99
+ T A[MATRIX_K][MATRIX_M];
100
+ T B[MATRIX_N][MATRIX_K];
84
101
float C[MATRIX_M][MATRIX_N];
85
102
float D[MATRIX_M][MATRIX_N];
86
103
87
- matrix_fill (MATRIX_K, MATRIX_M, (bfloat16 *)A,
104
+ matrix_fill (MATRIX_K, MATRIX_M, (T *)A,
88
105
[](int i, int j) { return 1 .0f * (i + j); });
89
- matrix_fill (MATRIX_N, MATRIX_K, (bfloat16 *)B,
106
+ matrix_fill (MATRIX_N, MATRIX_K, (T *)B,
90
107
[](int i, int j) { return 2 .0f * i + 3 .0f * j; });
91
108
matrix_fill (MATRIX_M, MATRIX_N, (float *)C, 1 .0f );
92
109
matrix_fill (MATRIX_M, MATRIX_N, (float *)D, 1 .0f );
93
110
94
111
big_matrix<float , MATRIX_M, MATRIX_N> MC ((float *)&C);
95
112
big_matrix<float , MATRIX_M, MATRIX_N> MD ((float *)&D);
96
- big_matrix<bfloat16 , MATRIX_M, MATRIX_K> MA ((bfloat16 *)&A);
97
- big_matrix<bfloat16 , MATRIX_K, MATRIX_N> MB ((bfloat16 *)&B);
113
+ big_matrix<T , MATRIX_M, MATRIX_K> MA ((T *)&A);
114
+ big_matrix<T , MATRIX_K, MATRIX_N> MB ((T *)&B);
98
115
matrix_multiply (MC, MA, MB);
99
- matrix_multiply_ref ((bfloat16 *)A, (bfloat16 *)B, (float *)D, MATRIX_M,
100
- MATRIX_N, MATRIX_K, false , true , true );
116
+ matrix_multiply_ref ((T *)A, (T *)B, (float *)D, MATRIX_M, MATRIX_N, MATRIX_K,
117
+ false , true , true );
118
+
119
+ assert (matrix_compare (MATRIX_M, MATRIX_N, (float *)C, (float *)D));
120
+ std::cout << " passed" << std::endl;
121
+ }
122
+
123
+ int main () {
124
+ queue q;
125
+ std::vector<combination> combinations =
126
+ q.get_device ().get_info <syclex::info::device::matrix_combinations>();
127
+ bool bf16_run = false ;
128
+ bool half_run = false ;
129
+
130
+ for (auto &combination : combinations) {
131
+ if (!bf16_run && combination.atype == matrix_type::bf16 ) {
132
+ std::cout << " bf16 " ;
133
+ test<bfloat16>();
134
+ bf16_run = true ;
135
+ }
136
+
137
+ if (!half_run && combination.atype == matrix_type::fp16) {
138
+ std::cout << " half " ;
139
+ test<half>();
140
+ half_run = true ;
141
+ }
142
+
143
+ if (bf16_run && half_run)
144
+ break ;
145
+ }
101
146
102
- bool res = matrix_compare (MATRIX_M, MATRIX_N, (float *)C, (float *)D);
103
- std::cout << (res ? " passed" : " failed" ) << std::endl;
104
- return !res;
147
+ return 0 ;
105
148
}
0 commit comments