Skip to content

Commit eb6f3a0

Browse files
committed
Common MMA code.
1 parent fb287d1 commit eb6f3a0

File tree

1 file changed

+40
-54
lines changed

1 file changed

+40
-54
lines changed

kernel/power/sbgemv_common_power10.c

Lines changed: 40 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -48,22 +48,20 @@ FORCEINLINE void vec_load_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 in
4848

4949
FORCEINLINE void vec_load_mult12a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp)
5050
{
51-
vec_bf16 in01 = (vec_bf16)vec_load_vec(in0);
5251
vec_bf16 in11 = (vec_bf16)vec_load_vec(in1);
5352

54-
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp);
53+
vec_load_mult_mma(out, in0, inp);
54+
5555
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp);
5656
}
5757

5858
FORCEINLINE void vec_load_mult14_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 inp)
5959
{
60-
vec_bf16 in01 = (vec_bf16)vec_load_vec(in0);
61-
vec_bf16 in11 = (vec_bf16)vec_load_vec(in1);
6260
vec_bf16 in21 = (vec_bf16)vec_load_vec(in2);
6361
vec_bf16 in31 = (vec_bf16)vec_load_vec(in3);
6462

65-
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp);
66-
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp);
63+
vec_load_mult12a_mma(out, in0, in1, inp);
64+
6765
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21, (vec_uc8)inp);
6866
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31, (vec_uc8)inp);
6967
}
@@ -78,17 +76,21 @@ FORCEINLINE void vec_load_mult2_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *
7876
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[1]);
7977
}
8078

79+
FORCEINLINE void vec_mult2d_mma(__vector_quad *out, vec_bf16 *in01, vec_bf16 *in11, vec_bf16 *inp)
80+
{
81+
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]);
82+
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]);
83+
}
84+
8185
FORCEINLINE void vec_load_mult22_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp)
8286
{
8387
vec_bf16 in01[2], in11[2];
8488

8589
vec_load_pair((vec_f32 *)in01, (vec_f32 *)in0);
8690
vec_load_pair((vec_f32 *)in11, (vec_f32 *)in1);
8791

88-
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]);
89-
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]);
90-
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]);
91-
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]);
92+
vec_mult2d_mma(out, in01 + 0, in11 + 0, inp + 0);
93+
vec_mult2d_mma(out, in01 + 1, in11 + 1, inp + 1);
9294
}
9395

9496
FORCEINLINE void vec_load_mult24_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 *inp)
@@ -100,26 +102,22 @@ FORCEINLINE void vec_load_mult24_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16
100102
vec_load_pair((vec_f32 *)in21, (vec_f32 *)in2);
101103
vec_load_pair((vec_f32 *)in31, (vec_f32 *)in3);
102104

103-
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]);
104-
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]);
105-
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[0], (vec_uc8)inp[0]);
106-
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[0], (vec_uc8)inp[0]);
107-
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]);
108-
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]);
109-
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[1], (vec_uc8)inp[1]);
110-
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[1], (vec_uc8)inp[1]);
105+
vec_mult2d_mma(out + 0, in01 + 0, in11 + 0, inp + 0);
106+
vec_mult2d_mma(out + 2, in21 + 0, in31 + 0, inp + 0);
107+
vec_mult2d_mma(out + 0, in01 + 1, in11 + 1, inp + 1);
108+
vec_mult2d_mma(out + 2, in21 + 1, in31 + 1, inp + 1);
111109
}
112110

113111
FORCEINLINE void vec_load_mult4_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp)
114112
{
115-
vec_bf16 in0[4];
113+
vec_bf16 in0[2];
116114

117-
vec_load_pair2(in0, in);
115+
vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(in + 2));
118116

119-
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[0], (vec_uc8)inp[0]);
120-
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[1]);
121-
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[2], (vec_uc8)inp[2]);
122-
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[3], (vec_uc8)inp[3]);
117+
vec_load_mult2_mma(out, in + 0, inp + 0);
118+
119+
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[0], (vec_uc8)inp[2]);
120+
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[3]);
123121
}
124122

125123
FORCEINLINE void vec_load_mult42_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp)
@@ -129,14 +127,16 @@ FORCEINLINE void vec_load_mult42_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16
129127
vec_load_pair2(in01, in0);
130128
vec_load_pair2(in11, in1);
131129

132-
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]);
133-
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]);
134-
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]);
135-
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]);
136-
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[2], (vec_uc8)inp[2]);
137-
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[2], (vec_uc8)inp[2]);
138-
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[3], (vec_uc8)inp[3]);
139-
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[3], (vec_uc8)inp[3]);
130+
vec_mult2d_mma(out, in01 + 0, in11 + 0, inp + 0);
131+
vec_mult2d_mma(out, in01 + 1, in11 + 1, inp + 1);
132+
vec_mult2d_mma(out, in01 + 2, in11 + 2, inp + 2);
133+
vec_mult2d_mma(out, in01 + 3, in11 + 3, inp + 3);
134+
}
135+
136+
FORCEINLINE void vec_mult4d_mma(__vector_quad *out, vec_bf16 *in01, vec_bf16 *in11, vec_bf16 *in21, vec_bf16 *in31, vec_bf16 *inp)
137+
{
138+
vec_mult2d_mma(out + 0, in01, in11, inp);
139+
vec_mult2d_mma(out + 2, in21, in31, inp);
140140
}
141141

142142
FORCEINLINE void vec_load_mult44_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 *inp)
@@ -148,22 +148,10 @@ FORCEINLINE void vec_load_mult44_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16
148148
vec_load_pair2(in21, in2);
149149
vec_load_pair2(in31, in3);
150150

151-
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]);
152-
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]);
153-
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[0], (vec_uc8)inp[0]);
154-
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[0], (vec_uc8)inp[0]);
155-
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]);
156-
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]);
157-
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[1], (vec_uc8)inp[1]);
158-
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[1], (vec_uc8)inp[1]);
159-
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[2], (vec_uc8)inp[2]);
160-
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[2], (vec_uc8)inp[2]);
161-
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[2], (vec_uc8)inp[2]);
162-
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[2], (vec_uc8)inp[2]);
163-
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[3], (vec_uc8)inp[3]);
164-
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[3], (vec_uc8)inp[3]);
165-
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[3], (vec_uc8)inp[3]);
166-
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[3], (vec_uc8)inp[3]);
151+
vec_mult4d_mma(out, in01 + 0, in11 + 0, in21 + 0, in31 + 0, inp + 0);
152+
vec_mult4d_mma(out, in01 + 1, in11 + 1, in21 + 1, in31 + 1, inp + 1);
153+
vec_mult4d_mma(out, in01 + 2, in11 + 2, in21 + 2, in31 + 2, inp + 2);
154+
vec_mult4d_mma(out, in01 + 3, in11 + 3, in21 + 3, in31 + 3, inp + 3);
167155
}
168156

169157
FORCEINLINE void vec_loadN_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n)
@@ -175,22 +163,20 @@ FORCEINLINE void vec_loadN_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 i
175163

176164
FORCEINLINE void vec_loadN_mult12a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp, BLASLONG n)
177165
{
178-
vec_bf16 in01 = (vec_bf16)vec_loadN(in0, n);
179166
vec_bf16 in11 = (vec_bf16)vec_loadN(in1, n);
180167

181-
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp);
168+
vec_loadN_mult_mma(out, in0, inp, n);
169+
182170
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp);
183171
}
184172

185173
FORCEINLINE void vec_loadN_mult14_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 inp, BLASLONG n)
186174
{
187-
vec_bf16 in01 = (vec_bf16)vec_loadN(in0, n);
188-
vec_bf16 in11 = (vec_bf16)vec_loadN(in1, n);
189175
vec_bf16 in21 = (vec_bf16)vec_loadN(in2, n);
190176
vec_bf16 in31 = (vec_bf16)vec_loadN(in3, n);
191177

192-
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp);
193-
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp);
178+
vec_loadN_mult12a_mma(out, in0, in1, inp, n);
179+
194180
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21, (vec_uc8)inp);
195181
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31, (vec_uc8)inp);
196182
}

0 commit comments

Comments
 (0)