27
27
* *****************************************************************************/
28
28
29
29
#include <arm_sve.h>
30
+
30
31
#include "common.h"
31
32
32
- int CNAME (BLASLONG m , BLASLONG n , BLASLONG k , FLOAT alpha , IFLOAT * A , IFLOAT * B ,
33
- FLOAT * C , BLASLONG ldc ) {
33
+ int CNAME (BLASLONG m , BLASLONG n , BLASLONG k , FLOAT alpha , IFLOAT * A , IFLOAT * B , FLOAT * C ,
34
+ BLASLONG ldc ) {
34
35
// printf("m: %d, n: %d, k: %d\n", m, n, k);
35
36
BLASLONG padk = (k + 3 ) & ~3 ;
36
37
BLASLONG padm = (m + 1 ) & ~1 ;
37
38
BLASLONG padn = (n + 1 ) & ~1 ;
38
- FLOAT * RC = (FLOAT * ) calloc (padm * padn , sizeof (float ));
39
+ FLOAT * RC = (FLOAT * )calloc (padm * padn , sizeof (float ));
39
40
BLASLONG nldc = padm ;
40
41
41
42
IFLOAT * ptr_a = A ;
@@ -52,10 +53,10 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
52
53
svbool_t pg32 = svptrue_b32 ();
53
54
svfloat32_t svalpha = svdup_f32 (alpha );
54
55
55
- uint32_t off_c [] = {0 , (uint32_t ) nldc , 1 , (uint32_t ) nldc + 1 }; // 00 01 10 11
56
+ uint32_t off_c [] = {0 , (uint32_t )nldc , 1 , (uint32_t )nldc + 1 }; // 00 01 10 11
56
57
svuint32_t off_vc = svld1_u32 (pg32 , off_c );
57
58
58
- for (BLASLONG j = 0 ; j < padn / 4 ; j ++ ) {
59
+ for (BLASLONG j = 0 ; j < padn / 4 ; j ++ ) {
59
60
ptr_c00 = ptr_c ;
60
61
ptr_c10 = ptr_c00 + 2 ;
61
62
ptr_c20 = ptr_c10 + 2 ;
@@ -68,7 +69,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
68
69
69
70
ptr_a = A ;
70
71
71
- for (BLASLONG i = 0 ; i < padm / 8 ; i ++ ) {
72
+ for (BLASLONG i = 0 ; i < padm / 8 ; i ++ ) {
72
73
ptr_a0 = ptr_a ;
73
74
ptr_a1 = ptr_a0 + 2 * padk ;
74
75
ptr_a2 = ptr_a1 + 2 * padk ;
@@ -78,18 +79,22 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
78
79
ptr_b0 = ptr_b ;
79
80
ptr_b1 = ptr_b0 + 2 * padk ;
80
81
81
- mc00 = svdup_f32 (0 ); mc01 = svdup_f32 (0 );
82
- mc10 = svdup_f32 (0 ); mc11 = svdup_f32 (0 );
83
- mc20 = svdup_f32 (0 ); mc21 = svdup_f32 (0 );
84
- mc30 = svdup_f32 (0 ); mc31 = svdup_f32 (0 );
85
-
86
- for (BLASLONG p = 0 ; p < padk /4 ; p ++ ) {
87
- ma0 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a0 );
88
- ma1 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a1 );
89
- ma2 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a2 );
90
- ma3 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a3 );
91
- mb0 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_b0 );
92
- mb1 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_b1 );
82
+ mc00 = svdup_f32 (0 );
83
+ mc01 = svdup_f32 (0 );
84
+ mc10 = svdup_f32 (0 );
85
+ mc11 = svdup_f32 (0 );
86
+ mc20 = svdup_f32 (0 );
87
+ mc21 = svdup_f32 (0 );
88
+ mc30 = svdup_f32 (0 );
89
+ mc31 = svdup_f32 (0 );
90
+
91
+ for (BLASLONG p = 0 ; p < padk / 4 ; p ++ ) {
92
+ ma0 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a0 );
93
+ ma1 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a1 );
94
+ ma2 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a2 );
95
+ ma3 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a3 );
96
+ mb0 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_b0 );
97
+ mb1 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_b1 );
93
98
94
99
mc00 = svbfmmla (mc00 , ma0 , mb0 );
95
100
mc10 = svbfmmla (mc10 , ma1 , mb0 );
@@ -135,13 +140,15 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
135
140
ptr_b0 = ptr_b ;
136
141
ptr_b1 = ptr_b0 + 2 * padk ;
137
142
138
- mc00 = svdup_f32 (0 ); mc01 = svdup_f32 (0 );
139
- mc10 = svdup_f32 (0 ); mc11 = svdup_f32 (0 );
140
- for (BLASLONG p = 0 ; p < padk /4 ; p ++ ) {
141
- ma0 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a0 );
142
- ma1 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a1 );
143
- mb0 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_b0 );
144
- mb1 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_b1 );
143
+ mc00 = svdup_f32 (0 );
144
+ mc01 = svdup_f32 (0 );
145
+ mc10 = svdup_f32 (0 );
146
+ mc11 = svdup_f32 (0 );
147
+ for (BLASLONG p = 0 ; p < padk / 4 ; p ++ ) {
148
+ ma0 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a0 );
149
+ ma1 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a1 );
150
+ mb0 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_b0 );
151
+ mb1 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_b1 );
145
152
146
153
mc00 = svbfmmla (mc00 , ma0 , mb0 );
147
154
mc10 = svbfmmla (mc10 , ma1 , mb0 );
@@ -171,11 +178,12 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
171
178
ptr_b0 = ptr_b ;
172
179
ptr_b1 = ptr_b0 + 2 * padk ;
173
180
174
- mc00 = svdup_f32 (0 ); mc01 = svdup_f32 (0 );
175
- for (BLASLONG p = 0 ; p < padk /4 ; p ++ ) {
176
- ma0 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a0 );
177
- mb0 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_b0 );
178
- mb1 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_b1 );
181
+ mc00 = svdup_f32 (0 );
182
+ mc01 = svdup_f32 (0 );
183
+ for (BLASLONG p = 0 ; p < padk / 4 ; p ++ ) {
184
+ ma0 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a0 );
185
+ mb0 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_b0 );
186
+ mb1 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_b1 );
179
187
mc00 = svbfmmla (mc00 , ma0 , mb0 );
180
188
mc01 = svbfmmla (mc01 , ma0 , mb1 );
181
189
ptr_a0 += 8 ;
@@ -189,7 +197,6 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
189
197
}
190
198
191
199
ptr_b += 4 * padk ;
192
-
193
200
}
194
201
195
202
if (padn & 2 ) {
@@ -202,7 +209,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
202
209
203
210
ptr_a = A ;
204
211
205
- for (BLASLONG i = 0 ; i < padm / 8 ; i ++ ) {
212
+ for (BLASLONG i = 0 ; i < padm / 8 ; i ++ ) {
206
213
ptr_a0 = ptr_a ;
207
214
ptr_a1 = ptr_a0 + 2 * padk ;
208
215
ptr_a2 = ptr_a1 + 2 * padk ;
@@ -216,12 +223,12 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
216
223
mc20 = svdup_f32 (0 );
217
224
mc30 = svdup_f32 (0 );
218
225
219
- for (BLASLONG p = 0 ; p < padk / 4 ; p ++ ) {
220
- ma0 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a0 );
221
- ma1 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a1 );
222
- ma2 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a2 );
223
- ma3 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a3 );
224
- mb0 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_b0 );
226
+ for (BLASLONG p = 0 ; p < padk / 4 ; p ++ ) {
227
+ ma0 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a0 );
228
+ ma1 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a1 );
229
+ ma2 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a2 );
230
+ ma3 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a3 );
231
+ mb0 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_b0 );
225
232
mc00 = svbfmmla (mc00 , ma0 , mb0 );
226
233
mc10 = svbfmmla (mc10 , ma1 , mb0 );
227
234
mc20 = svbfmmla (mc20 , ma2 , mb0 );
@@ -251,10 +258,10 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
251
258
252
259
mc00 = svdup_f32 (0 );
253
260
mc10 = svdup_f32 (0 );
254
- for (BLASLONG p = 0 ; p < padk / 4 ; p ++ ) {
255
- ma0 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a0 );
256
- ma1 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a1 );
257
- mb0 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_b0 );
261
+ for (BLASLONG p = 0 ; p < padk / 4 ; p ++ ) {
262
+ ma0 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a0 );
263
+ ma1 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a1 );
264
+ mb0 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_b0 );
258
265
mc00 = svbfmmla (mc00 , ma0 , mb0 );
259
266
mc10 = svbfmmla (mc10 , ma1 , mb0 );
260
267
ptr_a0 += 8 ;
@@ -272,9 +279,9 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
272
279
ptr_a += 2 * padk ;
273
280
ptr_b0 = ptr_b ;
274
281
mc00 = svdup_f32 (0 );
275
- for (BLASLONG p = 0 ; p < padk / 4 ; p ++ ) {
276
- ma0 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_a0 );
277
- mb0 = svld1_bf16 (pg16 , (bfloat16_t * ) ptr_b0 );
282
+ for (BLASLONG p = 0 ; p < padk / 4 ; p ++ ) {
283
+ ma0 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_a0 );
284
+ mb0 = svld1_bf16 (pg16 , (bfloat16_t * )ptr_b0 );
278
285
mc00 = svbfmmla (mc00 , ma0 , mb0 );
279
286
ptr_a0 += 8 ;
280
287
ptr_b0 += 8 ;
@@ -296,10 +303,11 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
296
303
org_c += ldc ;
297
304
raw_c += nldc ;
298
305
BLASLONG i ;
299
- for (i = 0 ; i < m / 4 ; i ++ ) {
306
+ for (i = 0 ; i < m / 4 ; i ++ ) {
300
307
org_vc0 = svld1_f32 (pg32 , org_c0 );
301
308
raw_vc0 = svld1_f32 (pg32 , raw_c0 );
302
- org_vc0 = svmad_z (pg32 , svalpha , raw_vc0 , org_vc0 ); // alpha * raw + org, raw -> a * b
309
+ org_vc0 = svmad_z (pg32 , svalpha , raw_vc0 ,
310
+ org_vc0 ); // alpha * raw + org, raw -> a * b
303
311
svst1_f32 (pg32 , org_c0 , org_vc0 );
304
312
org_c0 += 4 ;
305
313
raw_c0 += 4 ;
@@ -310,5 +318,6 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
310
318
raw_c0 ++ ;
311
319
}
312
320
}
321
+
313
322
return 0 ;
314
323
}
0 commit comments