9
9
#include < iostream>
10
10
#include < sycl/usm.hpp>
11
11
12
- constexpr size_t TM = 8 ;
13
- constexpr size_t TK = 16 ;
12
+ template <typename Tab, size_t K, layout B_layout> class mult ;
14
13
15
- template <layout B_layout, unsigned int vnniFactor> class mult ;
16
-
17
- template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A,
18
- size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C,
19
- size_t NUM_COLS_C, layout B_layout, unsigned int vnniFactor>
14
+ template <typename T1, typename T2, size_t M, size_t N, size_t K, size_t TM,
15
+ size_t TN, size_t TK, layout A_layout, layout B_layout>
20
16
void matrix_multiply (T1 *C, T2 *A, T2 *B, queue q) {
21
- size_t M = NUM_ROWS_C;
22
- size_t N = NUM_COLS_C;
23
- size_t K = NUM_COLS_A;
24
17
25
- assert (NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * vnniFactor);
26
18
// Add one iteration for the out of bounds dpas instruction
27
19
size_t NDRangeM = M / TM + (((M % TM) != 0 ) ? 1 : 0 );
28
- size_t NDRangeN = N / TN;
29
- size_t sg_size = get_sg_size<mult<B_layout, vnniFactor>>(q);
20
+ size_t NDRangeN = N / TN + (((N % TN) != 0 ) ? 1 : 0 );
21
+ size_t sg_size = get_sg_size<mult<T2, K, B_layout>>(q);
22
+ std::cout << " SG size: " << sg_size << " " ;
30
23
31
24
q.submit ([&](handler &cgh) {
32
- cgh.parallel_for <mult<B_layout, vnniFactor >>(
25
+ cgh.parallel_for <mult<T2, K, B_layout >>(
33
26
nd_range<2 >({NDRangeM, NDRangeN * sg_size}, {1 , 1 * sg_size}),
34
27
[=](nd_item<2 > spmd_item)
35
28
#ifdef SG_SZ
@@ -45,6 +38,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) {
45
38
auto pC =
46
39
address_space_cast<sycl::access::address_space::global_space,
47
40
sycl::access::decorated::no>(C);
41
+
48
42
// The submatrix API has to be accessed by all the workitems in a
49
43
// subgroup these functions will be called once by the subgroup no
50
44
// code divergence between the workitems
@@ -54,27 +48,41 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) {
54
48
const auto sg_starty = global_idy - spmd_item.get_local_id (1 );
55
49
56
50
sub_group sg = spmd_item.get_sub_group ();
57
- joint_matrix<sub_group, bfloat16, use::a, TM, TK, layout::row_major>
58
- sub_a;
59
-
60
- // For B, since current implementation does not support non-packed
61
- // layout, users need to specify the packed_b layout.
62
- joint_matrix<sub_group, bfloat16, use::b, TK, TN, B_layout> sub_b;
63
- joint_matrix<sub_group, float , use::accumulator, TM, TN> sub_c;
64
- // bounds-checked load where width and height are added
51
+ joint_matrix<sub_group, T2, use::a, TM, TK, A_layout> sub_a;
52
+ joint_matrix<sub_group, T2, use::b, TK, TN, B_layout> sub_b;
53
+ joint_matrix<sub_group, T1, use::accumulator, TM, TN> sub_c;
54
+
55
+ // bounds-checked fill where width and height are added
65
56
ext::intel::experimental::matrix::joint_matrix_fill_checked (
66
57
sg, sub_c, 1 , M, N, sg_startx * TM, sg_starty / sg_size * TN);
58
+
67
59
for (int k = 0 ; k < K; k += TK) {
68
60
// bounds-checked load where width and height are added
69
- ext::intel::experimental::matrix::joint_matrix_load_checked (
70
- sg, sub_a, pA, K, M, K, sg_startx * TM, k);
71
- // Assume we alreay in vnni format.
61
+ // params order: Stride, Height, Width, CoordX, CoordY
62
+ if constexpr (A_layout == layout::row_major) {
63
+ ext::intel::experimental::matrix::joint_matrix_load_checked (
64
+ sg, sub_a, pA, K, M, K, sg_startx * TM, k);
65
+ } else {
66
+ ext::intel::experimental::matrix::joint_matrix_load_checked (
67
+ sg, sub_a, pA, M, K, M, k, sg_startx * TM);
68
+ }
69
+
72
70
// bounds-checked load where width and height are added
73
- ext::intel::experimental::matrix::joint_matrix_load_checked (
74
- sg, sub_b, pB, N * vnniFactor, K / vnniFactor, N * vnniFactor,
75
- k / vnniFactor, sg_starty / sg_size * TN * vnniFactor);
71
+ // params order: Stride, Height, Width, CoordX, CoordY
72
+ if constexpr (B_layout != layout::col_major) {
73
+ constexpr unsigned int vnniFactor = vnni_factor<T2, B_layout>();
74
+ ext::intel::experimental::matrix::joint_matrix_load_checked (
75
+ sg, sub_b, pB, N * vnniFactor, K / vnniFactor,
76
+ N * vnniFactor, k / vnniFactor,
77
+ sg_starty / sg_size * TN * vnniFactor);
78
+ } else {
79
+ ext::intel::experimental::matrix::joint_matrix_load_checked (
80
+ sg, sub_b, pB, K, N, K, sg_starty / sg_size * TN, k);
81
+ }
82
+
76
83
joint_matrix_mad (sg, sub_c, sub_a, sub_b, sub_c);
77
84
}
85
+
78
86
// bounds-checked store where width and height are added
79
87
ext::intel::experimental::matrix::joint_matrix_store_checked (
80
88
sg, sub_c, pC, N, layout::row_major, M, N, sg_startx * TM,
@@ -83,42 +91,69 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) {
83
91
}).wait ();
84
92
}
85
93
86
- int main () {
87
- static constexpr size_t MATRIX_M = 1024 + 14 ;
88
- static constexpr size_t MATRIX_N = 1024 ;
89
- static constexpr unsigned int vnniFactor = 2 ;
90
-
94
+ template <typename Tab, typename Tc, size_t MATRIX_M, size_t MATRIX_N,
95
+ size_t MATRIX_K, size_t TM, size_t TN, size_t TK, layout A_layout,
96
+ layout B_layout>
97
+ void test () {
98
+ std::cout << MATRIX_M << " x" << MATRIX_N << " x" << MATRIX_K << " , " << TM
99
+ << " x" << TN << " x" << TK << " : " ;
91
100
queue q;
92
- bfloat16 *A = malloc_shared<bfloat16>(MATRIX_M * MATRIX_K, q);
93
- bfloat16 *B = malloc_shared<bfloat16>(MATRIX_K * MATRIX_N, q);
94
- bfloat16 *vnniB = malloc_shared<bfloat16>(MATRIX_K * MATRIX_N, q);
95
- float *C = malloc_shared<float >(MATRIX_M * MATRIX_N, q);
96
- float *D = malloc_shared<float >(MATRIX_M * MATRIX_N, q);
97
-
98
- matrix_rand (MATRIX_M, MATRIX_K, A, (bfloat16)5 );
99
- matrix_rand (MATRIX_K, MATRIX_N, B, (bfloat16)5 );
100
- matrix_fill (MATRIX_M, MATRIX_N, C, (float )1 );
101
- matrix_fill (MATRIX_M, MATRIX_N, D, (float )1 );
102
-
103
- matrix_vnni<bfloat16>(MATRIX_K, MATRIX_N, B, vnniB, vnniFactor);
104
101
102
+ // reference data
103
+ Tab *A = malloc_shared<Tab>(MATRIX_M * MATRIX_K, q);
104
+ Tab *B = malloc_shared<Tab>(MATRIX_K * MATRIX_N, q);
105
+ Tc *C = malloc_shared<Tc>(MATRIX_M * MATRIX_N, q);
106
+ Tc *D = malloc_shared<Tc>(MATRIX_M * MATRIX_N, q);
107
+ matrix_rand (MATRIX_M, MATRIX_K, A, (Tab)5 );
108
+ matrix_rand (MATRIX_K, MATRIX_N, B, (Tab)5 );
109
+ matrix_fill (MATRIX_M, MATRIX_N, D, (Tc)1 );
105
110
matrix_multiply_ref (A, B, D, MATRIX_M, MATRIX_N, MATRIX_K);
106
- matrix_multiply<float , bfloat16, MATRIX_M, MATRIX_K, MATRIX_K / vnniFactor,
107
- MATRIX_N * vnniFactor, MATRIX_M, MATRIX_N,
108
- layout::ext_intel_packed, vnniFactor>(C, A, vnniB, q);
109
- bool res = matrix_compare (MATRIX_M, MATRIX_N, C, D);
110
-
111
- matrix_multiply<float , bfloat16, MATRIX_M, MATRIX_K, MATRIX_K, MATRIX_N,
112
- MATRIX_M, MATRIX_N, layout::row_major, 1 >(C, A, B, q);
113
- res = res && matrix_compare (MATRIX_M, MATRIX_N, C, D);
114
111
115
- std::cout << (res ? " passed" : " failed" ) << std::endl;
112
+ // test data
113
+ if constexpr (A_layout == layout::col_major) {
114
+ Tab *colA = malloc_shared<Tab>(MATRIX_K * MATRIX_M, q);
115
+ matrix_transpose (MATRIX_M, MATRIX_K, colA, A);
116
+ Tab *tmp = A;
117
+ A = colA;
118
+ free (tmp, q);
119
+ }
120
+
121
+ if constexpr (B_layout == layout::col_major) {
122
+ Tab *colB = malloc_shared<Tab>(MATRIX_N * MATRIX_K, q);
123
+ matrix_transpose (MATRIX_K, MATRIX_N, colB, B);
124
+ Tab *tmp = B;
125
+ B = colB;
126
+ free (tmp, q);
127
+ }
128
+
129
+ if constexpr (B_layout == layout::ext_intel_packed) {
130
+ Tab *vnniB = malloc_shared<Tab>(MATRIX_K * MATRIX_N, q);
131
+ matrix_vnni (MATRIX_K, MATRIX_N, B, vnniB, vnni_factor<Tab, B_layout>());
132
+ Tab *tmp = B;
133
+ B = vnniB;
134
+ free (tmp, q);
135
+ }
136
+
137
+ matrix_multiply<Tc, Tab, MATRIX_M, MATRIX_N, MATRIX_K, TM, TN, TK, A_layout,
138
+ B_layout>(C, A, B, q);
139
+ assert (matrix_compare (MATRIX_M, MATRIX_N, C, D));
140
+ std::cout << " passed" << std::endl;
116
141
117
142
free (A, q);
118
143
free (B, q);
119
- free (vnniB, q);
120
144
free (C, q);
121
145
free (D, q);
146
+ }
122
147
123
- return !res;
148
+ template <layout A_layout, layout B_layout> void test_all () {
149
+ std::cout << " bf16: " ;
150
+ test<bfloat16, float , /* MATRIX_M*/ 1024 + 20 , /* MATRIX_N*/ 1024 + 20 ,
151
+ /* MATRIX_K*/ 1024 + 24 , /* TM*/ 8 , /* TN*/ 16 , /* TK*/ 16 , A_layout,
152
+ B_layout>();
153
+ std::cout << " half: " ;
154
+ test<half, float , 1024 + 20 , 1024 + 20 , 1024 + 24 , 8 , 16 , 16 , A_layout,
155
+ B_layout>();
156
+ std::cout << " int8: " ;
157
+ test<int8_t , int32_t , 1024 , 1024 + 20 , 1024 + 24 , 8 , 16 , 32 , A_layout,
158
+ B_layout>();
124
159
}
0 commit comments