|
30 | 30 |
|
31 | 31 | #include "common.h"
|
32 | 32 |
|
33 |
| -int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, |
34 |
| - BLASLONG ldc) { |
35 |
| - // printf("m: %d, n: %d, k: %d\n", m, n, k); |
36 |
| - BLASLONG padk = (k + 3) & ~3; |
37 |
| - BLASLONG padm = (m + 1) & ~1; |
38 |
| - BLASLONG padn = (n + 1) & ~1; |
39 |
| - FLOAT *RC = (FLOAT *)calloc(padm * padn, sizeof(float)); |
40 |
| - BLASLONG nldc = padm; |
41 |
| - |
42 |
| - IFLOAT *ptr_a = A; |
43 |
| - IFLOAT *ptr_b = B; |
44 |
| - FLOAT *ptr_c = RC; |
45 |
| - |
46 |
| - IFLOAT *ptr_a0, *ptr_a1, *ptr_a2, *ptr_a3; |
47 |
| - IFLOAT *ptr_b0, *ptr_b1; |
48 |
| - FLOAT *ptr_c00, *ptr_c10, *ptr_c20, *ptr_c30, *ptr_c01, *ptr_c11, *ptr_c21, *ptr_c31; |
49 |
| - |
50 |
| - svbfloat16_t ma0, ma1, ma2, ma3, mb0, mb1; |
51 |
| - svfloat32_t mc00, mc01, mc10, mc11, mc20, mc21, mc30, mc31; |
52 |
| - svbool_t pg16 = svptrue_b16(); |
53 |
| - svbool_t pg32 = svptrue_b32(); |
54 |
| - svfloat32_t svalpha = svdup_f32(alpha); |
55 |
| - |
56 |
| - uint32_t off_c[] = {0, (uint32_t)nldc, 1, (uint32_t)nldc + 1}; // 00 01 10 11 |
57 |
| - svuint32_t off_vc = svld1_u32(pg32, off_c); |
58 |
| - |
59 |
| - for (BLASLONG j = 0; j < padn / 4; j++) { |
60 |
| - ptr_c00 = ptr_c; |
61 |
| - ptr_c10 = ptr_c00 + 2; |
62 |
| - ptr_c20 = ptr_c10 + 2; |
63 |
| - ptr_c30 = ptr_c20 + 2; |
64 |
| - ptr_c01 = ptr_c + 2 * nldc; |
65 |
| - ptr_c11 = ptr_c01 + 2; |
66 |
| - ptr_c21 = ptr_c11 + 2; |
67 |
| - ptr_c31 = ptr_c21 + 2; |
68 |
| - ptr_c += 4 * nldc; |
69 |
| - |
70 |
| - ptr_a = A; |
71 |
| - |
72 |
| - for (BLASLONG i = 0; i < padm / 8; i++) { |
73 |
| - ptr_a0 = ptr_a; |
74 |
| - ptr_a1 = ptr_a0 + 2 * padk; |
75 |
| - ptr_a2 = ptr_a1 + 2 * padk; |
76 |
| - ptr_a3 = ptr_a2 + 2 * padk; |
77 |
| - ptr_a += 8 * padk; |
78 |
| - |
79 |
| - ptr_b0 = ptr_b; |
80 |
| - ptr_b1 = ptr_b0 + 2 * padk; |
81 |
| - |
82 |
| - mc00 = svdup_f32(0); |
83 |
| - mc01 = svdup_f32(0); |
84 |
| - mc10 = svdup_f32(0); |
85 |
| - mc11 = svdup_f32(0); |
86 |
| - mc20 = svdup_f32(0); |
87 |
| - mc21 = svdup_f32(0); |
88 |
| - mc30 = svdup_f32(0); |
89 |
| - mc31 = svdup_f32(0); |
90 |
| - |
91 |
| - for (BLASLONG p = 0; p < padk / 4; p++) { |
92 |
| - ma0 = svld1_bf16(pg16, (bfloat16_t *)ptr_a0); |
93 |
| - ma1 = svld1_bf16(pg16, (bfloat16_t *)ptr_a1); |
94 |
| - ma2 = svld1_bf16(pg16, (bfloat16_t *)ptr_a2); |
95 |
| - ma3 = svld1_bf16(pg16, (bfloat16_t *)ptr_a3); |
96 |
| - mb0 = svld1_bf16(pg16, (bfloat16_t *)ptr_b0); |
97 |
| - mb1 = svld1_bf16(pg16, (bfloat16_t *)ptr_b1); |
98 |
| - |
99 |
| - mc00 = svbfmmla(mc00, ma0, mb0); |
100 |
| - mc10 = svbfmmla(mc10, ma1, mb0); |
101 |
| - mc20 = svbfmmla(mc20, ma2, mb0); |
102 |
| - mc30 = svbfmmla(mc30, ma3, mb0); |
103 |
| - mc01 = svbfmmla(mc01, ma0, mb1); |
104 |
| - mc11 = svbfmmla(mc11, ma1, mb1); |
105 |
| - mc21 = svbfmmla(mc21, ma2, mb1); |
106 |
| - mc31 = svbfmmla(mc31, ma3, mb1); |
107 |
| - |
108 |
| - ptr_a0 += 8; |
109 |
| - ptr_a1 += 8; |
110 |
| - ptr_a2 += 8; |
111 |
| - ptr_a3 += 8; |
112 |
| - ptr_b0 += 8; |
113 |
| - ptr_b1 += 8; |
114 |
| - } |
115 |
| - svst1_scatter_index(pg32, ptr_c00, off_vc, mc00); |
116 |
| - svst1_scatter_index(pg32, ptr_c10, off_vc, mc10); |
117 |
| - svst1_scatter_index(pg32, ptr_c20, off_vc, mc20); |
118 |
| - svst1_scatter_index(pg32, ptr_c30, off_vc, mc30); |
119 |
| - svst1_scatter_index(pg32, ptr_c01, off_vc, mc01); |
120 |
| - svst1_scatter_index(pg32, ptr_c11, off_vc, mc11); |
121 |
| - svst1_scatter_index(pg32, ptr_c21, off_vc, mc21); |
122 |
| - svst1_scatter_index(pg32, ptr_c31, off_vc, mc31); |
123 |
| - |
124 |
| - ptr_c00 += 8; |
125 |
| - ptr_c10 += 8; |
126 |
| - ptr_c20 += 8; |
127 |
| - ptr_c30 += 8; |
128 |
| - ptr_c01 += 8; |
129 |
| - ptr_c11 += 8; |
130 |
| - ptr_c21 += 8; |
131 |
| - ptr_c31 += 8; |
132 |
| - } |
133 |
| - |
134 |
| - if (padm & 4) { |
135 |
| - // rest 4 or 6 |
136 |
| - ptr_a0 = ptr_a; |
137 |
| - ptr_a1 = ptr_a0 + 2 * padk; |
138 |
| - ptr_a += 4 * padk; |
139 |
| - |
140 |
| - ptr_b0 = ptr_b; |
141 |
| - ptr_b1 = ptr_b0 + 2 * padk; |
142 |
| - |
143 |
| - mc00 = svdup_f32(0); |
144 |
| - mc01 = svdup_f32(0); |
145 |
| - mc10 = svdup_f32(0); |
146 |
| - mc11 = svdup_f32(0); |
147 |
| - for (BLASLONG p = 0; p < padk / 4; p++) { |
148 |
| - ma0 = svld1_bf16(pg16, (bfloat16_t *)ptr_a0); |
149 |
| - ma1 = svld1_bf16(pg16, (bfloat16_t *)ptr_a1); |
150 |
| - mb0 = svld1_bf16(pg16, (bfloat16_t *)ptr_b0); |
151 |
| - mb1 = svld1_bf16(pg16, (bfloat16_t *)ptr_b1); |
152 |
| - |
153 |
| - mc00 = svbfmmla(mc00, ma0, mb0); |
154 |
| - mc10 = svbfmmla(mc10, ma1, mb0); |
155 |
| - mc01 = svbfmmla(mc01, ma0, mb1); |
156 |
| - mc11 = svbfmmla(mc11, ma1, mb1); |
157 |
| - |
158 |
| - ptr_a0 += 8; |
159 |
| - ptr_a1 += 8; |
160 |
| - ptr_b0 += 8; |
161 |
| - ptr_b1 += 8; |
162 |
| - } |
163 |
| - svst1_scatter_index(pg32, ptr_c00, off_vc, mc00); |
164 |
| - svst1_scatter_index(pg32, ptr_c10, off_vc, mc10); |
165 |
| - svst1_scatter_index(pg32, ptr_c01, off_vc, mc01); |
166 |
| - svst1_scatter_index(pg32, ptr_c11, off_vc, mc11); |
167 |
| - |
168 |
| - ptr_c00 += 4; |
169 |
| - ptr_c10 += 4; |
170 |
| - ptr_c01 += 4; |
171 |
| - ptr_c11 += 4; |
172 |
| - } |
173 |
| - |
174 |
| - if (padm & 2) { |
175 |
| - // rest 2 |
176 |
| - ptr_a0 = ptr_a; |
177 |
| - |
178 |
| - ptr_b0 = ptr_b; |
179 |
| - ptr_b1 = ptr_b0 + 2 * padk; |
180 |
| - |
181 |
| - mc00 = svdup_f32(0); |
182 |
| - mc01 = svdup_f32(0); |
183 |
| - for (BLASLONG p = 0; p < padk / 4; p++) { |
184 |
| - ma0 = svld1_bf16(pg16, (bfloat16_t *)ptr_a0); |
185 |
| - mb0 = svld1_bf16(pg16, (bfloat16_t *)ptr_b0); |
186 |
| - mb1 = svld1_bf16(pg16, (bfloat16_t *)ptr_b1); |
187 |
| - mc00 = svbfmmla(mc00, ma0, mb0); |
188 |
| - mc01 = svbfmmla(mc01, ma0, mb1); |
189 |
| - ptr_a0 += 8; |
190 |
| - ptr_b0 += 8; |
191 |
| - ptr_b1 += 8; |
192 |
| - } |
193 |
| - svst1_scatter_index(pg32, ptr_c00, off_vc, mc00); |
194 |
| - svst1_scatter_index(pg32, ptr_c01, off_vc, mc01); |
195 |
| - ptr_c00 += 2; |
196 |
| - ptr_c01 += 2; |
197 |
| - } |
198 |
| - |
199 |
| - ptr_b += 4 * padk; |
200 |
| - } |
201 |
| - |
202 |
| - if (padn & 2) { |
203 |
| - // rest 2 |
204 |
| - ptr_c00 = ptr_c; |
205 |
| - ptr_c10 = ptr_c00 + 2; |
206 |
| - ptr_c20 = ptr_c10 + 2; |
207 |
| - ptr_c30 = ptr_c20 + 2; |
208 |
| - ptr_c += 2 * nldc; |
209 |
| - |
210 |
| - ptr_a = A; |
211 |
| - |
212 |
| - for (BLASLONG i = 0; i < padm / 8; i++) { |
213 |
| - ptr_a0 = ptr_a; |
214 |
| - ptr_a1 = ptr_a0 + 2 * padk; |
215 |
| - ptr_a2 = ptr_a1 + 2 * padk; |
216 |
| - ptr_a3 = ptr_a2 + 2 * padk; |
217 |
| - ptr_a += 8 * padk; |
218 |
| - |
219 |
| - ptr_b0 = ptr_b; |
220 |
| - |
221 |
| - mc00 = svdup_f32(0); |
222 |
| - mc10 = svdup_f32(0); |
223 |
| - mc20 = svdup_f32(0); |
224 |
| - mc30 = svdup_f32(0); |
225 |
| - |
226 |
| - for (BLASLONG p = 0; p < padk / 4; p++) { |
227 |
| - ma0 = svld1_bf16(pg16, (bfloat16_t *)ptr_a0); |
228 |
| - ma1 = svld1_bf16(pg16, (bfloat16_t *)ptr_a1); |
229 |
| - ma2 = svld1_bf16(pg16, (bfloat16_t *)ptr_a2); |
230 |
| - ma3 = svld1_bf16(pg16, (bfloat16_t *)ptr_a3); |
231 |
| - mb0 = svld1_bf16(pg16, (bfloat16_t *)ptr_b0); |
232 |
| - mc00 = svbfmmla(mc00, ma0, mb0); |
233 |
| - mc10 = svbfmmla(mc10, ma1, mb0); |
234 |
| - mc20 = svbfmmla(mc20, ma2, mb0); |
235 |
| - mc30 = svbfmmla(mc30, ma3, mb0); |
236 |
| - ptr_a0 += 8; |
237 |
| - ptr_a1 += 8; |
238 |
| - ptr_a2 += 8; |
239 |
| - ptr_a3 += 8; |
240 |
| - ptr_b0 += 8; |
241 |
| - } |
242 |
| - svst1_scatter_index(pg32, ptr_c00, off_vc, mc00); |
243 |
| - svst1_scatter_index(pg32, ptr_c10, off_vc, mc10); |
244 |
| - svst1_scatter_index(pg32, ptr_c20, off_vc, mc20); |
245 |
| - svst1_scatter_index(pg32, ptr_c30, off_vc, mc30); |
246 |
| - ptr_c00 += 8; |
247 |
| - ptr_c10 += 8; |
248 |
| - ptr_c20 += 8; |
249 |
| - ptr_c30 += 8; |
250 |
| - } |
251 |
| - |
252 |
| - if (padm & 4) { |
253 |
| - ptr_a0 = ptr_a; |
254 |
| - ptr_a1 = ptr_a0 + 2 * padk; |
255 |
| - ptr_a += 4 * padk; |
256 |
| - |
257 |
| - ptr_b0 = ptr_b; |
258 |
| - |
259 |
| - mc00 = svdup_f32(0); |
260 |
| - mc10 = svdup_f32(0); |
261 |
| - for (BLASLONG p = 0; p < padk / 4; p++) { |
262 |
| - ma0 = svld1_bf16(pg16, (bfloat16_t *)ptr_a0); |
263 |
| - ma1 = svld1_bf16(pg16, (bfloat16_t *)ptr_a1); |
264 |
| - mb0 = svld1_bf16(pg16, (bfloat16_t *)ptr_b0); |
265 |
| - mc00 = svbfmmla(mc00, ma0, mb0); |
266 |
| - mc10 = svbfmmla(mc10, ma1, mb0); |
267 |
| - ptr_a0 += 8; |
268 |
| - ptr_a1 += 8; |
269 |
| - ptr_b0 += 8; |
270 |
| - } |
271 |
| - svst1_scatter_index(pg32, ptr_c00, off_vc, mc00); |
272 |
| - svst1_scatter_index(pg32, ptr_c10, off_vc, mc10); |
273 |
| - ptr_c00 += 4; |
274 |
| - ptr_c10 += 4; |
275 |
| - } |
276 |
| - |
277 |
| - if (padm & 2) { |
278 |
| - ptr_a0 = ptr_a; |
279 |
| - ptr_a += 2 * padk; |
280 |
| - ptr_b0 = ptr_b; |
281 |
| - mc00 = svdup_f32(0); |
282 |
| - for (BLASLONG p = 0; p < padk / 4; p++) { |
283 |
| - ma0 = svld1_bf16(pg16, (bfloat16_t *)ptr_a0); |
284 |
| - mb0 = svld1_bf16(pg16, (bfloat16_t *)ptr_b0); |
285 |
| - mc00 = svbfmmla(mc00, ma0, mb0); |
286 |
| - ptr_a0 += 8; |
287 |
| - ptr_b0 += 8; |
288 |
| - } |
289 |
| - svst1_scatter_index(pg32, ptr_c00, off_vc, mc00); |
290 |
| - ptr_c00 += 2; |
291 |
| - } |
292 |
| - |
293 |
| - ptr_b += 2 * padk; |
294 |
| - } |
295 |
| - |
296 |
| - FLOAT *org_c = C; |
297 |
| - FLOAT *raw_c = RC; |
298 |
| - FLOAT *org_c0, *raw_c0; |
299 |
| - svfloat32_t org_vc0, raw_vc0; |
300 |
| - for (BLASLONG j = 0; j < n; j++) { |
301 |
| - org_c0 = org_c; |
302 |
| - raw_c0 = raw_c; |
303 |
| - org_c += ldc; |
304 |
| - raw_c += nldc; |
305 |
| - BLASLONG i; |
306 |
| - for (i = 0; i < m / 4; i++) { |
307 |
| - org_vc0 = svld1_f32(pg32, org_c0); |
308 |
| - raw_vc0 = svld1_f32(pg32, raw_c0); |
309 |
| - org_vc0 = svmad_z(pg32, svalpha, raw_vc0, |
310 |
| - org_vc0); // alpha * raw + org, raw -> a * b |
311 |
| - svst1_f32(pg32, org_c0, org_vc0); |
312 |
| - org_c0 += 4; |
313 |
| - raw_c0 += 4; |
314 |
| - } |
315 |
| - for (i = 0; i < (m & 3); i++) { |
316 |
| - *org_c0 += alpha * (*raw_c0); |
317 |
| - org_c0++; |
318 |
| - raw_c0++; |
319 |
| - } |
320 |
| - } |
321 |
| - |
| 33 | +#define ALPHA_ONE |
| 34 | +#include "sbgemm_kernel_8x4_neoversen2_impl.c" |
| 35 | +#undef ALPHA_ONE |
| 36 | +#include "sbgemm_kernel_8x4_neoversen2_impl.c" |
| 37 | + |
| 38 | +int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B, |
| 39 | + FLOAT *C, BLASLONG ldc) { |
| 40 | + if (alpha == 1.0f) |
| 41 | + return sbgemm_kernel_neoversen2_alpha_one(m, n, k, alpha, A, B, C, ldc); |
| 42 | + else |
| 43 | + return sbgemm_kernel_neoversen2_alpha(m, n, k, alpha, A, B, C, ldc); |
322 | 44 | return 0;
|
323 | 45 | }
|
0 commit comments