@@ -152,6 +152,14 @@ FORCEINLINE void vec_reduce84_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_a
152
152
vec_reduce44_mma (& out [0 ], & temp [0 ], v_alpha , vy0 + 0 );
153
153
vec_reduce44_mma (& out [1 ], & temp [4 ], v_alpha , vy0 + 1 );
154
154
}
155
+
156
+ FORCEINLINE void vec_reduce88_mma (__vector_quad * out , vec_f32 * temp , vec_f32 v_alpha , vec_f32 * vy0 )
157
+ {
158
+ vec_reduce44_mma (& out [0 ], & temp [ 0 ], v_alpha , vy0 + 0 );
159
+ vec_reduce44_mma (& out [1 ], & temp [ 4 ], v_alpha , vy0 + 1 );
160
+ vec_reduce44_mma (& out [2 ], & temp [ 8 ], v_alpha , vy0 + 8 );
161
+ vec_reduce44_mma (& out [3 ], & temp [12 ], v_alpha , vy0 + 9 );
162
+ }
155
163
#endif
156
164
157
165
FORCEINLINE void vec_mult11a_mma (__vector_quad * out , vec_bf16 in0 , vec_bf16 in1 , vec_bf16 inp )
@@ -341,6 +349,32 @@ FORCEINLINE void vec_load_mult284b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf
341
349
vec_mult44b_mma (out , in0 + 0 , in1 + 0 , inp + 0 );
342
350
vec_mult44b_mma (out , in0 + 2 , in1 + 2 , inp + 2 );
343
351
}
352
+
353
+ FORCEINLINE void vec_load_mult288a_mma (__vector_quad * out , vec_bf16 * ina , vec_bf16 * inb , vec_bf16 * inp )
354
+ {
355
+ vec_bf16 in0 [8 ], in1 [8 ];
356
+
357
+ vec_load4_mma (in0 + 0 , in1 + 0 , ina + 0 , inb + 0 );
358
+ vec_load4_mma (in0 + 4 , in1 + 4 , ina + 4 , inb + 4 );
359
+
360
+ vec_mult44a_mma (out + 0 , in0 + 0 , in1 + 0 , inp + 0 );
361
+ vec_mult44a_mma (out + 2 , in0 + 4 , in1 + 4 , inp + 0 );
362
+ vec_mult44b_mma (out + 0 , in0 + 2 , in1 + 2 , inp + 2 );
363
+ vec_mult44b_mma (out + 2 , in0 + 6 , in1 + 6 , inp + 2 );
364
+ }
365
+
366
+ FORCEINLINE void vec_load_mult288b_mma (__vector_quad * out , vec_bf16 * ina , vec_bf16 * inb , vec_bf16 * inp )
367
+ {
368
+ vec_bf16 in0 [8 ], in1 [8 ];
369
+
370
+ vec_load4_mma (in0 + 0 , in1 + 0 , ina + 0 , inb + 0 );
371
+ vec_load4_mma (in0 + 4 , in1 + 4 , ina + 4 , inb + 4 );
372
+
373
+ vec_mult44b_mma (out + 0 , in0 + 0 , in1 + 0 , inp + 0 );
374
+ vec_mult44b_mma (out + 2 , in0 + 4 , in1 + 4 , inp + 0 );
375
+ vec_mult44b_mma (out + 0 , in0 + 2 , in1 + 2 , inp + 2 );
376
+ vec_mult44b_mma (out + 2 , in0 + 6 , in1 + 6 , inp + 2 );
377
+ }
344
378
#endif
345
379
346
380
FORCEINLINE void vec_loadN_mult22b_mma (__vector_quad * out , vec_bf16 * ina , vec_bf16 * inb , vec_bf16 inp , BLASLONG n )
@@ -381,49 +415,54 @@ FORCEINLINE void vec_store8_pair(vec_f32 *v_y, vec_f32 *vy0)
381
415
}
382
416
383
417
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
384
- #define VEC_SHIFT (data , shift ) vec_sld(data, data, 16 - shift)
385
- #else
386
- #define VEC_SHIFT (data , shift ) vec_sld(data, data, shift)
387
- #endif
418
+ #define VEC_SHIFT (data , shift ) vec_sldw(data, data, 4 - shift)
388
419
389
- typedef __vector unsigned int vec_ui32 ;
420
+ #define MASK_0 0xf000
421
+ #define MASK_1 0x0f00
422
+ #define MASK_2 0x00f0
423
+ #define MASK_3 0x000f
424
+ #else
425
+ #define VEC_SHIFT (data , shift ) vec_sldw(data, data, shift)
390
426
391
- static vec_ui32 mask_0 = { 0xffffffff , 0x00000000 , 0x00000000 , 0x00000000 };
392
- static vec_ui32 mask_1 = { 0x00000000 , 0xffffffff , 0x00000000 , 0x00000000 };
393
- static vec_ui32 mask_2 = { 0x00000000 , 0x00000000 , 0xffffffff , 0x00000000 };
394
- static vec_ui32 mask_3 = { 0x00000000 , 0x00000000 , 0x00000000 , 0xffffffff };
427
+ #define MASK_0 0x000f
428
+ #define MASK_1 0x00f0
429
+ #define MASK_2 0x0f00
430
+ #define MASK_3 0xf000
431
+ #endif
395
432
396
- FORCEINLINE void vec_make_mult1 (vec_bf16 * v_x0 )
433
+ FORCEINLINE void vec_make_mult1 (vec_bf16 * v_x0 , const bool mask )
397
434
{
398
- v_x0 [ 0 ] = vec_and (v_x0 [0 ], (vec_bf16 )mask_0 );
435
+ if (mask ) {
436
+ v_x0 [ 0 ] = vec_and (v_x0 [0 ], (vec_bf16 )vec_genbm (MASK_0 ));
437
+ }
399
438
400
- v_x0 [ 1 ] = VEC_SHIFT (v_x0 [ 0 ], 4 );
401
- v_x0 [ 2 ] = VEC_SHIFT (v_x0 [ 0 ], 8 );
402
- v_x0 [ 3 ] = VEC_SHIFT (v_x0 [ 0 ], 12 );
439
+ v_x0 [ 1 ] = VEC_SHIFT (v_x0 [ 0 ], 1 );
440
+ v_x0 [ 2 ] = VEC_SHIFT (v_x0 [ 0 ], 2 );
441
+ v_x0 [ 3 ] = VEC_SHIFT (v_x0 [ 0 ], 3 );
403
442
}
404
443
405
444
FORCEINLINE void vec_make_mult2 (vec_bf16 * v_x0 )
406
445
{
407
- v_x0 [ 5 ] = vec_and (v_x0 [0 ], (vec_bf16 )mask_1 );
408
- vec_make_mult1 (v_x0 );
446
+ v_x0 [ 5 ] = vec_and (v_x0 [0 ], (vec_bf16 )vec_genbm ( MASK_1 ) );
447
+ vec_make_mult1 (v_x0 , true );
409
448
410
- v_x0 [ 4 ] = VEC_SHIFT (v_x0 [ 5 ], 12 );
411
- v_x0 [ 6 ] = VEC_SHIFT (v_x0 [ 5 ], 4 );
412
- v_x0 [ 7 ] = VEC_SHIFT (v_x0 [ 5 ], 8 );
449
+ v_x0 [ 4 ] = VEC_SHIFT (v_x0 [ 5 ], 3 );
450
+ v_x0 [ 6 ] = VEC_SHIFT (v_x0 [ 5 ], 1 );
451
+ v_x0 [ 7 ] = VEC_SHIFT (v_x0 [ 5 ], 2 );
413
452
}
414
453
415
454
FORCEINLINE void vec_make_mult4 (vec_bf16 * v_x0 )
416
455
{
417
- v_x0 [10 ] = vec_and (v_x0 [0 ], (vec_bf16 )mask_2 );
418
- v_x0 [15 ] = vec_and (v_x0 [0 ], (vec_bf16 )mask_3 );
456
+ v_x0 [10 ] = vec_and (v_x0 [0 ], (vec_bf16 )vec_genbm ( MASK_2 ) );
457
+ v_x0 [15 ] = vec_and (v_x0 [0 ], (vec_bf16 )vec_genbm ( MASK_3 ) );
419
458
vec_make_mult2 (v_x0 );
420
459
421
- v_x0 [ 8 ] = VEC_SHIFT (v_x0 [10 ], 8 );
422
- v_x0 [ 9 ] = VEC_SHIFT (v_x0 [10 ], 12 );
423
- v_x0 [11 ] = VEC_SHIFT (v_x0 [10 ], 4 );
424
- v_x0 [12 ] = VEC_SHIFT (v_x0 [15 ], 4 );
425
- v_x0 [13 ] = VEC_SHIFT (v_x0 [15 ], 8 );
426
- v_x0 [14 ] = VEC_SHIFT (v_x0 [15 ], 12 );
460
+ v_x0 [ 8 ] = VEC_SHIFT (v_x0 [10 ], 2 );
461
+ v_x0 [ 9 ] = VEC_SHIFT (v_x0 [10 ], 3 );
462
+ v_x0 [11 ] = VEC_SHIFT (v_x0 [10 ], 1 );
463
+ v_x0 [12 ] = VEC_SHIFT (v_x0 [15 ], 1 );
464
+ v_x0 [13 ] = VEC_SHIFT (v_x0 [15 ], 2 );
465
+ v_x0 [14 ] = VEC_SHIFT (v_x0 [15 ], 3 );
427
466
}
428
467
#endif
429
468
0 commit comments