@@ -195,7 +195,7 @@ static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
195
195
vec_f32 temp6 = { 0 , 0 , 0 , 0 };
196
196
vec_f32 temp7 = { 0 , 0 , 0 , 0 };
197
197
vec_bf16 zero = { 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 };
198
- vec_f32 inp [2 ];
198
+ vec_f32 inp [2 ], inp0 [ 2 ], inp1 [ 2 ], inp2 [ 2 ], inp3 [ 2 ], inp4 [ 2 ], inp5 [ 2 ], inp6 [ 2 ], inp7 [ 2 ] ;
199
199
200
200
BLASLONG lda4 = lda << 2 ;
201
201
a0 = ap ;
@@ -220,29 +220,61 @@ static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
220
220
221
221
for (; i < n8 ; i ++ ) {
222
222
vec_load_vec2 (& v_x [i ], inp , zero );
223
-
224
- temp0 += vec_load_mult (& va0 [i ], inp , zero );
225
- temp1 += vec_load_mult (& va1 [i ], inp , zero );
226
- temp2 += vec_load_mult (& va2 [i ], inp , zero );
227
- temp3 += vec_load_mult (& va3 [i ], inp , zero );
228
- temp4 += vec_load_mult (& va4 [i ], inp , zero );
229
- temp5 += vec_load_mult (& va5 [i ], inp , zero );
230
- temp6 += vec_load_mult (& va6 [i ], inp , zero );
231
- temp7 += vec_load_mult (& va7 [i ], inp , zero );
223
+ vec_load_vec2 (& va0 [i ], inp0 , zero );
224
+ vec_load_vec2 (& va1 [i ], inp1 , zero );
225
+ vec_load_vec2 (& va2 [i ], inp2 , zero );
226
+ vec_load_vec2 (& va3 [i ], inp3 , zero );
227
+ vec_load_vec2 (& va4 [i ], inp4 , zero );
228
+ vec_load_vec2 (& va5 [i ], inp5 , zero );
229
+ vec_load_vec2 (& va6 [i ], inp6 , zero );
230
+ vec_load_vec2 (& va7 [i ], inp7 , zero );
231
+
232
+ temp0 += (inp [0 ] * inp0 [0 ]);
233
+ temp1 += (inp [0 ] * inp1 [0 ]);
234
+ temp2 += (inp [0 ] * inp2 [0 ]);
235
+ temp3 += (inp [0 ] * inp3 [0 ]);
236
+ temp4 += (inp [0 ] * inp4 [0 ]);
237
+ temp5 += (inp [0 ] * inp5 [0 ]);
238
+ temp6 += (inp [0 ] * inp6 [0 ]);
239
+ temp7 += (inp [0 ] * inp7 [0 ]);
240
+ temp0 += (inp [1 ] * inp0 [1 ]);
241
+ temp1 += (inp [1 ] * inp1 [1 ]);
242
+ temp2 += (inp [1 ] * inp2 [1 ]);
243
+ temp3 += (inp [1 ] * inp3 [1 ]);
244
+ temp4 += (inp [1 ] * inp4 [1 ]);
245
+ temp5 += (inp [1 ] * inp5 [1 ]);
246
+ temp6 += (inp [1 ] * inp6 [1 ]);
247
+ temp7 += (inp [1 ] * inp7 [1 ]);
232
248
}
233
249
234
250
n &= 7 ;
235
251
if (n > 4 ) {
236
252
vec_loadN_vec2 (& v_x [i ], inp , n , zero );
237
-
238
- temp0 += vec_loadN_mult (& va0 [i ], inp , n , zero );
239
- temp1 += vec_loadN_mult (& va1 [i ], inp , n , zero );
240
- temp2 += vec_loadN_mult (& va2 [i ], inp , n , zero );
241
- temp3 += vec_loadN_mult (& va3 [i ], inp , n , zero );
242
- temp4 += vec_loadN_mult (& va4 [i ], inp , n , zero );
243
- temp5 += vec_loadN_mult (& va5 [i ], inp , n , zero );
244
- temp6 += vec_loadN_mult (& va6 [i ], inp , n , zero );
245
- temp7 += vec_loadN_mult (& va7 [i ], inp , n , zero );
253
+ vec_loadN_vec2 (& va0 [i ], inp0 , n , zero );
254
+ vec_loadN_vec2 (& va1 [i ], inp1 , n , zero );
255
+ vec_loadN_vec2 (& va2 [i ], inp2 , n , zero );
256
+ vec_loadN_vec2 (& va3 [i ], inp3 , n , zero );
257
+ vec_loadN_vec2 (& va4 [i ], inp4 , n , zero );
258
+ vec_loadN_vec2 (& va5 [i ], inp5 , n , zero );
259
+ vec_loadN_vec2 (& va6 [i ], inp6 , n , zero );
260
+ vec_loadN_vec2 (& va7 [i ], inp7 , n , zero );
261
+
262
+ temp0 += (inp [0 ] * inp0 [0 ]);
263
+ temp1 += (inp [0 ] * inp1 [0 ]);
264
+ temp2 += (inp [0 ] * inp2 [0 ]);
265
+ temp3 += (inp [0 ] * inp3 [0 ]);
266
+ temp4 += (inp [0 ] * inp4 [0 ]);
267
+ temp5 += (inp [0 ] * inp5 [0 ]);
268
+ temp6 += (inp [0 ] * inp6 [0 ]);
269
+ temp7 += (inp [0 ] * inp7 [0 ]);
270
+ temp0 += (inp [1 ] * inp0 [1 ]);
271
+ temp1 += (inp [1 ] * inp1 [1 ]);
272
+ temp2 += (inp [1 ] * inp2 [1 ]);
273
+ temp3 += (inp [1 ] * inp3 [1 ]);
274
+ temp4 += (inp [1 ] * inp4 [1 ]);
275
+ temp5 += (inp [1 ] * inp5 [1 ]);
276
+ temp6 += (inp [1 ] * inp6 [1 ]);
277
+ temp7 += (inp [1 ] * inp7 [1 ]);
246
278
} else if (n ) {
247
279
inp [0 ] = vec_loadNHi (& v_x [i ], n , zero );
248
280
0 commit comments