@@ -101,8 +101,7 @@ static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
101
101
vec_f32 temp00 [4 * 2 ];
102
102
vec_bf16 inp [4 ];
103
103
104
- __builtin_mma_xxsetaccz (& temp0 [0 ]);
105
- __builtin_mma_xxsetaccz (& temp0 [1 ]);
104
+ vec_setzero_2 (& temp0 [0 ]);
106
105
107
106
a0 = ap ;
108
107
a1 = ap + lda ;
@@ -141,8 +140,7 @@ static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
141
140
vec_loadN_mult12a_mma (& temp0 [0 ], & va0 [i ], & va1 [i ], inp [0 ], n );
142
141
}
143
142
144
- __builtin_mma_disassemble_acc ((void * )(temp00 + 0 ), & temp0 [0 ]);
145
- __builtin_mma_disassemble_acc ((void * )(temp00 + 4 ), & temp0 [1 ]);
143
+ vec_reduce_2 (temp00 , & temp0 [0 ]);
146
144
147
145
y [0 ] = (alpha * (temp00 [0 ][0 ] + temp00 [1 ][1 ] + temp00 [2 ][2 ] + temp00 [3 ][3 ])) + (beta * y [0 ]);
148
146
y [1 ] = (alpha * (temp00 [4 ][0 ] + temp00 [5 ][1 ] + temp00 [6 ][2 ] + temp00 [7 ][3 ])) + (beta * y [1 ]);
@@ -156,10 +154,7 @@ static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
156
154
vec_f32 temp00 [4 * 4 ];
157
155
vec_bf16 inp [4 ];
158
156
159
- __builtin_mma_xxsetaccz (& temp0 [0 ]);
160
- __builtin_mma_xxsetaccz (& temp0 [1 ]);
161
- __builtin_mma_xxsetaccz (& temp0 [2 ]);
162
- __builtin_mma_xxsetaccz (& temp0 [3 ]);
157
+ vec_setzero_4 (& temp0 [0 ]);
163
158
164
159
a0 = ap ;
165
160
a1 = ap + lda ;
@@ -202,10 +197,7 @@ static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
202
197
vec_loadN_mult14_mma (& temp0 [0 ], & va0 [i ], & va1 [i ], & va2 [i ], & va3 [i ], inp [0 ], n );
203
198
}
204
199
205
- __builtin_mma_disassemble_acc ((void * )(temp00 + 0 ), & temp0 [0 ]);
206
- __builtin_mma_disassemble_acc ((void * )(temp00 + 4 ), & temp0 [1 ]);
207
- __builtin_mma_disassemble_acc ((void * )(temp00 + 8 ), & temp0 [2 ]);
208
- __builtin_mma_disassemble_acc ((void * )(temp00 + 12 ), & temp0 [3 ]);
200
+ vec_reduce_4 (temp00 , & temp0 [0 ]);
209
201
210
202
vec_f32 t0 , t1 , t2 , t3 , t4 , t5 , t6 , t7 ;
211
203
vec_f32 a = { alpha , alpha , alpha , alpha };
@@ -239,23 +231,17 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
239
231
vec_f32 temp00 [4 * 8 ];
240
232
vec_bf16 inp [4 ];
241
233
242
- __builtin_mma_xxsetaccz (& temp0 [0 ]);
243
- __builtin_mma_xxsetaccz (& temp0 [1 ]);
244
- __builtin_mma_xxsetaccz (& temp0 [2 ]);
245
- __builtin_mma_xxsetaccz (& temp0 [3 ]);
246
- __builtin_mma_xxsetaccz (& temp0 [4 ]);
247
- __builtin_mma_xxsetaccz (& temp0 [5 ]);
248
- __builtin_mma_xxsetaccz (& temp0 [6 ]);
249
- __builtin_mma_xxsetaccz (& temp0 [7 ]);
234
+ vec_setzero_8 (& temp0 [0 ]);
250
235
236
+ BLASLONG lda4 = lda << 2 ;
251
237
a0 = ap ;
252
238
a1 = ap + lda ;
253
239
a2 = a1 + lda ;
254
240
a3 = a2 + lda ;
255
- a4 = a3 + lda ;
256
- a5 = a4 + lda ;
257
- a6 = a5 + lda ;
258
- a7 = a6 + lda ;
241
+ a4 = a0 + lda4 ;
242
+ a5 = a1 + lda4 ;
243
+ a6 = a2 + lda4 ;
244
+ a7 = a3 + lda4 ;
259
245
va0 = (vec_bf16 * )a0 ;
260
246
va1 = (vec_bf16 * )a1 ;
261
247
va2 = (vec_bf16 * )a2 ;
@@ -301,14 +287,7 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
301
287
vec_loadN_mult14_mma (& temp0 [4 ], & va4 [i ], & va5 [i ], & va6 [i ], & va7 [i ], inp [0 ], n );
302
288
}
303
289
304
- __builtin_mma_disassemble_acc ((void * )(temp00 + 0 ), & temp0 [0 ]);
305
- __builtin_mma_disassemble_acc ((void * )(temp00 + 4 ), & temp0 [1 ]);
306
- __builtin_mma_disassemble_acc ((void * )(temp00 + 8 ), & temp0 [2 ]);
307
- __builtin_mma_disassemble_acc ((void * )(temp00 + 12 ), & temp0 [3 ]);
308
- __builtin_mma_disassemble_acc ((void * )(temp00 + 16 ), & temp0 [4 ]);
309
- __builtin_mma_disassemble_acc ((void * )(temp00 + 20 ), & temp0 [5 ]);
310
- __builtin_mma_disassemble_acc ((void * )(temp00 + 24 ), & temp0 [6 ]);
311
- __builtin_mma_disassemble_acc ((void * )(temp00 + 28 ), & temp0 [7 ]);
290
+ vec_reduce_8 (temp00 , & temp0 [0 ]);
312
291
313
292
vec_f32 t0 , t1 , t2 , t3 , t4 , t5 , t6 , t7 , t10 , t11 , t12 , t13 , t14 , t15 , t16 , t17 ;
314
293
vec_f32 a = { alpha , alpha , alpha , alpha };
0 commit comments