Skip to content

Commit df19375

Browse files
committed
Almost final code for MMA.
1 parent 05aa63e commit df19375

File tree

2 files changed

+80
-54
lines changed

2 files changed

+80
-54
lines changed

kernel/power/sbgemv_common_power10.c

Lines changed: 66 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,14 @@ FORCEINLINE void vec_reduce84_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_a
152152
vec_reduce44_mma(&out[0], &temp[0], v_alpha, vy0 + 0);
153153
vec_reduce44_mma(&out[1], &temp[4], v_alpha, vy0 + 1);
154154
}
155+
156+
FORCEINLINE void vec_reduce88_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0)
157+
{
158+
vec_reduce44_mma(&out[0], &temp[ 0], v_alpha, vy0 + 0);
159+
vec_reduce44_mma(&out[1], &temp[ 4], v_alpha, vy0 + 1);
160+
vec_reduce44_mma(&out[2], &temp[ 8], v_alpha, vy0 + 8);
161+
vec_reduce44_mma(&out[3], &temp[12], v_alpha, vy0 + 9);
162+
}
155163
#endif
156164

157165
FORCEINLINE void vec_mult11a_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp)
@@ -341,6 +349,32 @@ FORCEINLINE void vec_load_mult284b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf
341349
vec_mult44b_mma(out, in0 + 0, in1 + 0, inp + 0);
342350
vec_mult44b_mma(out, in0 + 2, in1 + 2, inp + 2);
343351
}
352+
353+
FORCEINLINE void vec_load_mult288a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 *inp)
354+
{
355+
vec_bf16 in0[8], in1[8];
356+
357+
vec_load4_mma(in0 + 0, in1 + 0, ina + 0, inb + 0);
358+
vec_load4_mma(in0 + 4, in1 + 4, ina + 4, inb + 4);
359+
360+
vec_mult44a_mma(out + 0, in0 + 0, in1 + 0, inp + 0);
361+
vec_mult44a_mma(out + 2, in0 + 4, in1 + 4, inp + 0);
362+
vec_mult44b_mma(out + 0, in0 + 2, in1 + 2, inp + 2);
363+
vec_mult44b_mma(out + 2, in0 + 6, in1 + 6, inp + 2);
364+
}
365+
366+
FORCEINLINE void vec_load_mult288b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 *inp)
367+
{
368+
vec_bf16 in0[8], in1[8];
369+
370+
vec_load4_mma(in0 + 0, in1 + 0, ina + 0, inb + 0);
371+
vec_load4_mma(in0 + 4, in1 + 4, ina + 4, inb + 4);
372+
373+
vec_mult44b_mma(out + 0, in0 + 0, in1 + 0, inp + 0);
374+
vec_mult44b_mma(out + 2, in0 + 4, in1 + 4, inp + 0);
375+
vec_mult44b_mma(out + 0, in0 + 2, in1 + 2, inp + 2);
376+
vec_mult44b_mma(out + 2, in0 + 6, in1 + 6, inp + 2);
377+
}
344378
#endif
345379

346380
FORCEINLINE void vec_loadN_mult22b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n)
@@ -381,49 +415,54 @@ FORCEINLINE void vec_store8_pair(vec_f32 *v_y, vec_f32 *vy0)
381415
}
382416

383417
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
384-
#define VEC_SHIFT(data, shift) vec_sld(data, data, 16 - shift)
385-
#else
386-
#define VEC_SHIFT(data, shift) vec_sld(data, data, shift)
387-
#endif
418+
#define VEC_SHIFT(data, shift) vec_sldw(data, data, 4 - shift)
388419

389-
typedef __vector unsigned int vec_ui32;
420+
#define MASK_0 0xf000
421+
#define MASK_1 0x0f00
422+
#define MASK_2 0x00f0
423+
#define MASK_3 0x000f
424+
#else
425+
#define VEC_SHIFT(data, shift) vec_sldw(data, data, shift)
390426

391-
static vec_ui32 mask_0 = { 0xffffffff, 0x00000000, 0x00000000, 0x00000000 };
392-
static vec_ui32 mask_1 = { 0x00000000, 0xffffffff, 0x00000000, 0x00000000 };
393-
static vec_ui32 mask_2 = { 0x00000000, 0x00000000, 0xffffffff, 0x00000000 };
394-
static vec_ui32 mask_3 = { 0x00000000, 0x00000000, 0x00000000, 0xffffffff };
427+
#define MASK_0 0x000f
428+
#define MASK_1 0x00f0
429+
#define MASK_2 0x0f00
430+
#define MASK_3 0xf000
431+
#endif
395432

396-
FORCEINLINE void vec_make_mult1(vec_bf16 *v_x0)
433+
FORCEINLINE void vec_make_mult1(vec_bf16 *v_x0, const bool mask)
397434
{
398-
v_x0[ 0] = vec_and(v_x0[0], (vec_bf16)mask_0);
435+
if (mask) {
436+
v_x0[ 0] = vec_and(v_x0[0], (vec_bf16)vec_genbm(MASK_0));
437+
}
399438

400-
v_x0[ 1] = VEC_SHIFT(v_x0[ 0], 4);
401-
v_x0[ 2] = VEC_SHIFT(v_x0[ 0], 8);
402-
v_x0[ 3] = VEC_SHIFT(v_x0[ 0], 12);
439+
v_x0[ 1] = VEC_SHIFT(v_x0[ 0], 1);
440+
v_x0[ 2] = VEC_SHIFT(v_x0[ 0], 2);
441+
v_x0[ 3] = VEC_SHIFT(v_x0[ 0], 3);
403442
}
404443

405444
FORCEINLINE void vec_make_mult2(vec_bf16 *v_x0)
406445
{
407-
v_x0[ 5] = vec_and(v_x0[0], (vec_bf16)mask_1);
408-
vec_make_mult1(v_x0);
446+
v_x0[ 5] = vec_and(v_x0[0], (vec_bf16)vec_genbm(MASK_1));
447+
vec_make_mult1(v_x0, true);
409448

410-
v_x0[ 4] = VEC_SHIFT(v_x0[ 5], 12);
411-
v_x0[ 6] = VEC_SHIFT(v_x0[ 5], 4);
412-
v_x0[ 7] = VEC_SHIFT(v_x0[ 5], 8);
449+
v_x0[ 4] = VEC_SHIFT(v_x0[ 5], 3);
450+
v_x0[ 6] = VEC_SHIFT(v_x0[ 5], 1);
451+
v_x0[ 7] = VEC_SHIFT(v_x0[ 5], 2);
413452
}
414453

415454
FORCEINLINE void vec_make_mult4(vec_bf16 *v_x0)
416455
{
417-
v_x0[10] = vec_and(v_x0[0], (vec_bf16)mask_2);
418-
v_x0[15] = vec_and(v_x0[0], (vec_bf16)mask_3);
456+
v_x0[10] = vec_and(v_x0[0], (vec_bf16)vec_genbm(MASK_2));
457+
v_x0[15] = vec_and(v_x0[0], (vec_bf16)vec_genbm(MASK_3));
419458
vec_make_mult2(v_x0);
420459

421-
v_x0[ 8] = VEC_SHIFT(v_x0[10], 8);
422-
v_x0[ 9] = VEC_SHIFT(v_x0[10], 12);
423-
v_x0[11] = VEC_SHIFT(v_x0[10], 4);
424-
v_x0[12] = VEC_SHIFT(v_x0[15], 4);
425-
v_x0[13] = VEC_SHIFT(v_x0[15], 8);
426-
v_x0[14] = VEC_SHIFT(v_x0[15], 12);
460+
v_x0[ 8] = VEC_SHIFT(v_x0[10], 2);
461+
v_x0[ 9] = VEC_SHIFT(v_x0[10], 3);
462+
v_x0[11] = VEC_SHIFT(v_x0[10], 1);
463+
v_x0[12] = VEC_SHIFT(v_x0[15], 1);
464+
v_x0[13] = VEC_SHIFT(v_x0[15], 2);
465+
v_x0[14] = VEC_SHIFT(v_x0[15], 3);
427466
}
428467
#endif
429468

kernel/power/sbgemv_n_power10.c

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2828
#ifndef SBGEMV_N_MMA_C
2929
#define SBGEMV_N_MMA_C
3030

31-
#if !defined(_AIX) || defined(__clang__)
3231
#define USE_BFGEMV_N_MMA
33-
#endif
3432

3533
#ifdef USE_BFGEMV_N_MMA
3634
#include "sbgemv_common_power10.c"
@@ -67,16 +65,15 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
6765
v_x0[0] = vec_loadN(x_bf, 1);
6866
vec_f32 vy0[2*4*2];
6967

70-
vec_make_mult1(v_x0);
68+
vec_make_mult1(v_x0, false);
7169

7270
for (; i + 8 <= n8; i += 8) {
7371
vec_load8_pair(vy0, &v_y[(i * 2) + 0]);
7472

7573
vec_load_mult184_mma(&temp[0], &va0[i + 0], &v_x0[ 0]);
7674
vec_load_mult184_mma(&temp[2], &va0[i + 4], &v_x0[ 0]);
7775

78-
vec_reduce84_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0);
79-
vec_reduce84_mma(&temp[2], temp0 + 8, v_alpha, vy0 + 8);
76+
vec_reduce88_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0);
8077

8178
vec_store8_pair(&v_y[(i * 2) + 0], vy0);
8279
}
@@ -163,16 +160,14 @@ static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
163160
vec_f32 vy0[2*4*2];
164161
v_x0[0] = vec_loadN(x_bf, 2);
165162

166-
vec_make_mult1(v_x0);
163+
vec_make_mult1(v_x0, false);
167164

168165
for (; i + 8 <= n8; i += 8) {
169166
vec_load8_pair(vy0, &v_y[(i * 2) + 0]);
170167

171-
vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]);
172-
vec_load_mult284a_mma(&temp[2], &va0[i + 4], &va1[i + 4], &v_x0[ 0]);
168+
vec_load_mult288a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]);
173169

174-
vec_reduce84_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0);
175-
vec_reduce84_mma(&temp[2], temp0 + 8, v_alpha, vy0 + 8);
170+
vec_reduce88_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0);
176171

177172
vec_store8_pair(&v_y[(i * 2) + 0], vy0);
178173
}
@@ -268,13 +263,10 @@ static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
268263
for (; i + 8 <= n8; i += 8) {
269264
vec_load8_pair(vy0, &v_y[(i * 2) + 0]);
270265

271-
vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]);
272-
vec_load_mult284b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]);
273-
vec_load_mult284a_mma(&temp[2], &va0[i + 4], &va1[i + 4], &v_x0[ 0]);
274-
vec_load_mult284b_mma(&temp[2], &va2[i + 4], &va3[i + 4], &v_x0[ 4]);
266+
vec_load_mult288a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]);
267+
vec_load_mult288b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]);
275268

276-
vec_reduce84_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0);
277-
vec_reduce84_mma(&temp[2], temp0 + 8, v_alpha, vy0 + 8);
269+
vec_reduce88_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0);
278270

279271
vec_store8_pair(&v_y[(i * 2) + 0], vy0);
280272
}
@@ -386,17 +378,12 @@ static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS
386378
for (; i + 8 <= n8; i += 8) {
387379
vec_load8_pair(vy0, &v_y[(i * 2) + 0]);
388380

389-
vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]);
390-
vec_load_mult284b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]);
391-
vec_load_mult284b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], &v_x0[ 8]);
392-
vec_load_mult284b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], &v_x0[12]);
393-
vec_load_mult284a_mma(&temp[2], &va0[i + 4], &va1[i + 4], &v_x0[ 0]);
394-
vec_load_mult284b_mma(&temp[2], &va2[i + 4], &va3[i + 4], &v_x0[ 4]);
395-
vec_load_mult284b_mma(&temp[2], &vb0[i + 4], &vb1[i + 4], &v_x0[ 8]);
396-
vec_load_mult284b_mma(&temp[2], &vb2[i + 4], &vb3[i + 4], &v_x0[12]);
397-
398-
vec_reduce84_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0);
399-
vec_reduce84_mma(&temp[2], temp0 + 8, v_alpha, vy0 + 8);
381+
vec_load_mult288a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]);
382+
vec_load_mult288b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]);
383+
vec_load_mult288b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], &v_x0[ 8]);
384+
vec_load_mult288b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], &v_x0[12]);
385+
386+
vec_reduce88_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0);
400387

401388
vec_store8_pair(&v_y[(i * 2) + 0], vy0);
402389
}

0 commit comments

Comments
 (0)