@@ -68,21 +68,21 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
68
68
vec_make_mult1 (v_x0 , false);
69
69
70
70
for (; i + 8 <= n8 ; i += 8 ) {
71
- vec_load8_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
72
-
73
71
vec_load_mult184_mma (& temp [0 ], & va0 [i + 0 ], & v_x0 [ 0 ]);
74
72
vec_load_mult184_mma (& temp [2 ], & va0 [i + 4 ], & v_x0 [ 0 ]);
75
73
74
+ vec_load8_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
75
+
76
76
vec_reduce88_mma (& temp [0 ], temp0 + 0 , v_alpha , vy0 + 0 );
77
77
78
78
vec_store8_pair (& v_y [(i * 2 ) + 0 ], vy0 );
79
79
}
80
80
81
81
if (n8 & 4 ) {
82
- vec_load4_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
83
-
84
82
vec_load_mult184_mma (& temp [0 ], & va0 [i + 0 ], & v_x0 [ 0 ]);
85
83
84
+ vec_load4_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
85
+
86
86
vec_reduce84_mma (& temp [0 ], temp0 , v_alpha , vy0 );
87
87
88
88
vec_store4_pair (& v_y [(i * 2 ) + 0 ], vy0 );
@@ -95,41 +95,41 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
95
95
vec_f32 vy0 [2 * 4 ];
96
96
97
97
for (; i + 4 <= n8 ; i += 4 ) {
98
- vec_load4_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
99
-
100
98
vec_load_mult18_mma (& temp [0 ], & va0 [i + 0 ], v_x0 [ 0 ]);
101
99
100
+ vec_load4_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
101
+
102
102
vec_reduce8_mma (& temp [0 ], temp0 , v_alpha , vy0 );
103
103
104
104
vec_store4_pair (& v_y [(i * 2 ) + 0 ], vy0 );
105
105
}
106
106
#endif
107
107
108
108
for (; i < n8 ; i ++ ) {
109
- vec_load_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
110
-
111
109
vec_load_mult12_mma (& temp [0 ], & va0 [i ], v_x0 [ 0 ]);
112
110
111
+ vec_load_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
112
+
113
113
vec_reduce2_mma (& temp [0 ], temp0 , v_alpha , vy0 );
114
114
115
115
vec_store_pair (& v_y [(i * 2 ) + 0 ], vy0 );
116
116
}
117
117
118
118
n &= 7 ;
119
119
if (n > 4 ) {
120
+ vec_loadN_mult12_mma (& temp [0 ], & va0 [i ], v_x0 [ 0 ], n );
121
+
120
122
BLASLONG n3 = n & 3 ;
121
123
vec_loadN2_f32 (vy0 , & v_y [(i * 2 ) + 0 ], n3 );
122
124
123
- vec_loadN_mult12_mma (& temp [0 ], & va0 [i ], v_x0 [ 0 ], n );
124
-
125
125
vec_reduce2_mma (& temp [0 ], temp0 , v_alpha , vy0 );
126
126
127
127
vec_storeN2_f32 (vy0 , & v_y [(i * 2 ) + 0 ], n3 );
128
128
} else if (n ) {
129
- vy0 [0 ] = vec_loadN_f32 (& v_y [(i * 2 ) + 0 ], n );
130
-
131
129
vec_loadN_mult11_mma (& temp [0 ], & va0 [i ], v_x0 [ 0 ], n );
132
130
131
+ vy0 [0 ] = vec_loadN_f32 (& v_y [(i * 2 ) + 0 ], n );
132
+
133
133
vec_reduce1_mma (& temp [0 ], temp0 , v_alpha , vy0 );
134
134
135
135
vec_storeN_f32 (vy0 [0 ], & v_y [(i * 2 ) + 0 ], n );
@@ -163,20 +163,20 @@ static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
163
163
vec_make_mult1 (v_x0 , false);
164
164
165
165
for (; i + 8 <= n8 ; i += 8 ) {
166
- vec_load8_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
167
-
168
166
vec_load_mult288a_mma (& temp [0 ], & va0 [i + 0 ], & va1 [i + 0 ], & v_x0 [ 0 ]);
169
167
168
+ vec_load8_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
169
+
170
170
vec_reduce88_mma (& temp [0 ], temp0 + 0 , v_alpha , vy0 + 0 );
171
171
172
172
vec_store8_pair (& v_y [(i * 2 ) + 0 ], vy0 );
173
173
}
174
174
175
175
if (n8 & 4 ) {
176
- vec_load4_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
177
-
178
176
vec_load_mult284a_mma (& temp [0 ], & va0 [i + 0 ], & va1 [i + 0 ], & v_x0 [ 0 ]);
179
177
178
+ vec_load4_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
179
+
180
180
vec_reduce84_mma (& temp [0 ], temp0 , v_alpha , vy0 );
181
181
182
182
vec_store4_pair (& v_y [(i * 2 ) + 0 ], vy0 );
@@ -189,41 +189,41 @@ static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
189
189
v_x0 [0 ] = vec_loadN (x_bf , 2 );
190
190
191
191
for (; i + 4 <= n8 ; i += 4 ) {
192
- vec_load4_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
193
-
194
192
vec_load_mult28a_mma (& temp [0 ], & va0 [i + 0 ], & va1 [i + 0 ], v_x0 [ 0 ]);
195
193
194
+ vec_load4_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
195
+
196
196
vec_reduce8_mma (& temp [0 ], temp0 , v_alpha , vy0 );
197
197
198
198
vec_store4_pair (& v_y [(i * 2 ) + 0 ], vy0 );
199
199
}
200
200
#endif
201
201
202
202
for (; i < n8 ; i ++ ) {
203
- vec_load_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
204
-
205
203
vec_load_mult22a_mma (& temp [0 ], & va0 [i ], & va1 [i ], v_x0 [ 0 ]);
206
204
205
+ vec_load_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
206
+
207
207
vec_reduce2_mma (& temp [0 ], temp0 , v_alpha , vy0 );
208
208
209
209
vec_store_pair (& v_y [(i * 2 ) + 0 ], vy0 );
210
210
}
211
211
212
212
n &= 7 ;
213
213
if (n > 4 ) {
214
+ vec_loadN_mult22a_mma (& temp [0 ], & va0 [i ], & va1 [i ], v_x0 [ 0 ], n );
215
+
214
216
BLASLONG n3 = n & 3 ;
215
217
vec_loadN2_f32 (vy0 , & v_y [(i * 2 ) + 0 ], n3 );
216
218
217
- vec_loadN_mult22a_mma (& temp [0 ], & va0 [i ], & va1 [i ], v_x0 [ 0 ], n );
218
-
219
219
vec_reduce2_mma (& temp [0 ], temp0 , v_alpha , vy0 );
220
220
221
221
vec_storeN2_f32 (vy0 , & v_y [(i * 2 ) + 0 ], n3 );
222
222
} else if (n ) {
223
- vy0 [0 ] = vec_loadN_f32 (& v_y [(i * 2 ) + 0 ], n );
224
-
225
223
vec_loadN_mult11a_mma (& temp [0 ], & va0 [i ], & va1 [i ], v_x0 [ 0 ], n );
226
224
225
+ vy0 [0 ] = vec_loadN_f32 (& v_y [(i * 2 ) + 0 ], n );
226
+
227
227
vec_reduce1_mma (& temp [0 ], temp0 , v_alpha , vy0 );
228
228
229
229
vec_storeN_f32 (vy0 [0 ], & v_y [(i * 2 ) + 0 ], n );
@@ -261,22 +261,22 @@ static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
261
261
vec_make_mult2 (v_x0 );
262
262
263
263
for (; i + 8 <= n8 ; i += 8 ) {
264
- vec_load8_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
265
-
266
264
vec_load_mult288a_mma (& temp [0 ], & va0 [i + 0 ], & va1 [i + 0 ], & v_x0 [ 0 ]);
267
265
vec_load_mult288b_mma (& temp [0 ], & va2 [i + 0 ], & va3 [i + 0 ], & v_x0 [ 4 ]);
268
266
267
+ vec_load8_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
268
+
269
269
vec_reduce88_mma (& temp [0 ], temp0 + 0 , v_alpha , vy0 + 0 );
270
270
271
271
vec_store8_pair (& v_y [(i * 2 ) + 0 ], vy0 );
272
272
}
273
273
274
274
if (n8 & 4 ) {
275
- vec_load4_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
276
-
277
275
vec_load_mult284a_mma (& temp [0 ], & va0 [i + 0 ], & va1 [i + 0 ], & v_x0 [ 0 ]);
278
276
vec_load_mult284b_mma (& temp [0 ], & va2 [i + 0 ], & va3 [i + 0 ], & v_x0 [ 4 ]);
279
277
278
+ vec_load4_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
279
+
280
280
vec_reduce84_mma (& temp [0 ], temp0 , v_alpha , vy0 );
281
281
282
282
vec_store4_pair (& v_y [(i * 2 ) + 0 ], vy0 );
@@ -291,45 +291,45 @@ static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
291
291
v_x0 [ 4 ] = (vec_bf16 )vec_splat ((vec_f32 )v_x0 [0 ], 1 );
292
292
293
293
for (; i + 4 <= n8 ; i += 4 ) {
294
- vec_load4_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
295
-
296
294
vec_load_mult28a_mma (& temp [0 ], & va0 [i + 0 ], & va1 [i + 0 ], v_x0 [ 0 ]);
297
295
vec_load_mult28b_mma (& temp [0 ], & va2 [i + 0 ], & va3 [i + 0 ], v_x0 [ 4 ]);
298
296
297
+ vec_load4_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
298
+
299
299
vec_reduce8_mma (& temp [0 ], temp0 , v_alpha , vy0 );
300
300
301
301
vec_store4_pair (& v_y [(i * 2 ) + 0 ], vy0 );
302
302
}
303
303
#endif
304
304
305
305
for (; i < n8 ; i ++ ) {
306
- vec_load_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
307
-
308
306
vec_load_mult22a_mma (& temp [0 ], & va0 [i ], & va1 [i ], v_x0 [ 0 ]);
309
307
vec_load_mult22b_mma (& temp [0 ], & va2 [i ], & va3 [i ], v_x0 [ 4 ]);
310
308
309
+ vec_load_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
310
+
311
311
vec_reduce2_mma (& temp [0 ], temp0 , v_alpha , vy0 );
312
312
313
313
vec_store_pair (& v_y [(i * 2 ) + 0 ], vy0 );
314
314
}
315
315
316
316
n &= 7 ;
317
317
if (n > 4 ) {
318
- BLASLONG n3 = n & 3 ;
319
- vec_loadN2_f32 (vy0 , & v_y [(i * 2 ) + 0 ], n3 );
320
-
321
318
vec_loadN_mult22a_mma (& temp [0 ], & va0 [i ], & va1 [i ], v_x0 [ 0 ], n );
322
319
vec_loadN_mult22b_mma (& temp [0 ], & va2 [i ], & va3 [i ], v_x0 [ 4 ], n );
323
320
321
+ BLASLONG n3 = n & 3 ;
322
+ vec_loadN2_f32 (vy0 , & v_y [(i * 2 ) + 0 ], n3 );
323
+
324
324
vec_reduce2_mma (& temp [0 ], temp0 , v_alpha , vy0 );
325
325
326
326
vec_storeN2_f32 (vy0 , & v_y [(i * 2 ) + 0 ], n3 );
327
327
} else if (n ) {
328
- vy0 [0 ] = vec_loadN_f32 (& v_y [(i * 2 ) + 0 ], n );
329
-
330
328
vec_loadN_mult11a_mma (& temp [0 ], & va0 [i ], & va1 [i ], v_x0 [ 0 ], n );
331
329
vec_loadN_mult11b_mma (& temp [0 ], & va2 [i ], & va3 [i ], v_x0 [ 4 ], n );
332
330
331
+ vy0 [0 ] = vec_loadN_f32 (& v_y [(i * 2 ) + 0 ], n );
332
+
333
333
vec_reduce1_mma (& temp [0 ], temp0 , v_alpha , vy0 );
334
334
335
335
vec_storeN_f32 (vy0 [0 ], & v_y [(i * 2 ) + 0 ], n );
@@ -376,26 +376,26 @@ static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS
376
376
vec_make_mult4 (v_x0 );
377
377
378
378
for (; i + 8 <= n8 ; i += 8 ) {
379
- vec_load8_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
380
-
381
379
vec_load_mult288a_mma (& temp [0 ], & va0 [i + 0 ], & va1 [i + 0 ], & v_x0 [ 0 ]);
382
380
vec_load_mult288b_mma (& temp [0 ], & va2 [i + 0 ], & va3 [i + 0 ], & v_x0 [ 4 ]);
383
381
vec_load_mult288b_mma (& temp [0 ], & vb0 [i + 0 ], & vb1 [i + 0 ], & v_x0 [ 8 ]);
384
382
vec_load_mult288b_mma (& temp [0 ], & vb2 [i + 0 ], & vb3 [i + 0 ], & v_x0 [12 ]);
385
383
384
+ vec_load8_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
385
+
386
386
vec_reduce88_mma (& temp [0 ], temp0 + 0 , v_alpha , vy0 + 0 );
387
387
388
388
vec_store8_pair (& v_y [(i * 2 ) + 0 ], vy0 );
389
389
}
390
390
391
391
if (n8 & 4 ) {
392
- vec_load4_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
393
-
394
392
vec_load_mult284a_mma (& temp [0 ], & va0 [i + 0 ], & va1 [i + 0 ], & v_x0 [ 0 ]);
395
393
vec_load_mult284b_mma (& temp [0 ], & va2 [i + 0 ], & va3 [i + 0 ], & v_x0 [ 4 ]);
396
394
vec_load_mult284b_mma (& temp [0 ], & vb0 [i + 0 ], & vb1 [i + 0 ], & v_x0 [ 8 ]);
397
395
vec_load_mult284b_mma (& temp [0 ], & vb2 [i + 0 ], & vb3 [i + 0 ], & v_x0 [12 ]);
398
396
397
+ vec_load4_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
398
+
399
399
vec_reduce84_mma (& temp [0 ], temp0 , v_alpha , vy0 );
400
400
401
401
vec_store4_pair (& v_y [(i * 2 ) + 0 ], vy0 );
@@ -412,53 +412,53 @@ static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS
412
412
v_x0 [12 ] = (vec_bf16 )vec_splat ((vec_f32 )v_x0 [0 ], 3 );
413
413
414
414
for (; i + 4 <= n8 ; i += 4 ) {
415
- vec_load4_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
416
-
417
415
vec_load_mult28a_mma (& temp [0 ], & va0 [i + 0 ], & va1 [i + 0 ], v_x0 [ 0 ]);
418
416
vec_load_mult28b_mma (& temp [0 ], & va2 [i + 0 ], & va3 [i + 0 ], v_x0 [ 4 ]);
419
417
vec_load_mult28b_mma (& temp [0 ], & vb0 [i + 0 ], & vb1 [i + 0 ], v_x0 [ 8 ]);
420
418
vec_load_mult28b_mma (& temp [0 ], & vb2 [i + 0 ], & vb3 [i + 0 ], v_x0 [12 ]);
421
419
420
+ vec_load4_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
421
+
422
422
vec_reduce8_mma (& temp [0 ], temp0 , v_alpha , vy0 );
423
423
424
424
vec_store4_pair (& v_y [(i * 2 ) + 0 ], vy0 );
425
425
}
426
426
#endif
427
427
428
428
for (; i < n8 ; i ++ ) {
429
- vec_load_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
430
-
431
429
vec_load_mult22a_mma (& temp [0 ], & va0 [i ], & va1 [i ], v_x0 [ 0 ]);
432
430
vec_load_mult22b_mma (& temp [0 ], & va2 [i ], & va3 [i ], v_x0 [ 4 ]);
433
431
vec_load_mult22b_mma (& temp [0 ], & vb0 [i ], & vb1 [i ], v_x0 [ 8 ]);
434
432
vec_load_mult22b_mma (& temp [0 ], & vb2 [i ], & vb3 [i ], v_x0 [12 ]);
435
433
434
+ vec_load_pair (vy0 , & v_y [(i * 2 ) + 0 ]);
435
+
436
436
vec_reduce2_mma (& temp [0 ], temp0 , v_alpha , vy0 );
437
437
438
438
vec_store_pair (& v_y [(i * 2 ) + 0 ], vy0 );
439
439
}
440
440
441
441
n &= 7 ;
442
442
if (n > 4 ) {
443
- BLASLONG n3 = n & 3 ;
444
- vec_loadN2_f32 (vy0 , & v_y [(i * 2 ) + 0 ], n3 );
445
-
446
443
vec_loadN_mult22a_mma (& temp [0 ], & va0 [i ], & va1 [i ], v_x0 [ 0 ], n );
447
444
vec_loadN_mult22b_mma (& temp [0 ], & va2 [i ], & va3 [i ], v_x0 [ 4 ], n );
448
445
vec_loadN_mult22b_mma (& temp [0 ], & vb0 [i ], & vb1 [i ], v_x0 [ 8 ], n );
449
446
vec_loadN_mult22b_mma (& temp [0 ], & vb2 [i ], & vb3 [i ], v_x0 [12 ], n );
450
447
448
+ BLASLONG n3 = n & 3 ;
449
+ vec_loadN2_f32 (vy0 , & v_y [(i * 2 ) + 0 ], n3 );
450
+
451
451
vec_reduce2_mma (& temp [0 ], temp0 , v_alpha , vy0 );
452
452
453
453
vec_storeN2_f32 (vy0 , & v_y [(i * 2 ) + 0 ], n3 );
454
454
} else if (n ) {
455
- vy0 [0 ] = vec_loadN_f32 (& v_y [(i * 2 ) + 0 ], n );
456
-
457
455
vec_loadN_mult11a_mma (& temp [0 ], & va0 [i ], & va1 [i ], v_x0 [ 0 ], n );
458
456
vec_loadN_mult11b_mma (& temp [0 ], & va2 [i ], & va3 [i ], v_x0 [ 4 ], n );
459
457
vec_loadN_mult11b_mma (& temp [0 ], & vb0 [i ], & vb1 [i ], v_x0 [ 8 ], n );
460
458
vec_loadN_mult11b_mma (& temp [0 ], & vb2 [i ], & vb3 [i ], v_x0 [12 ], n );
461
459
460
+ vy0 [0 ] = vec_loadN_f32 (& v_y [(i * 2 ) + 0 ], n );
461
+
462
462
vec_reduce1_mma (& temp [0 ], temp0 , v_alpha , vy0 );
463
463
464
464
vec_storeN_f32 (vy0 [0 ], & v_y [(i * 2 ) + 0 ], n );
0 commit comments