Skip to content

Commit fb287d1

Browse files
committed
Common code.
1 parent 8ab6245 commit fb287d1

File tree

3 files changed

+263
-122
lines changed

3 files changed

+263
-122
lines changed

kernel/power/sbgemv_common_power10.c

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,41 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3333
#define USE_MERGE_MMA
3434
#endif
3535

36+
FORCEINLINE void vec_load_pair2(vec_bf16 *in0, vec_bf16 *in)
37+
{
38+
vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(in + 0));
39+
vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(in + 2));
40+
}
41+
3642
FORCEINLINE void vec_load_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp)
3743
{
3844
vec_bf16 in0 = (vec_bf16)vec_load_vec(in);
3945

4046
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0, (vec_uc8)inp);
4147
}
4248

49+
FORCEINLINE void vec_load_mult12a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp)
50+
{
51+
vec_bf16 in01 = (vec_bf16)vec_load_vec(in0);
52+
vec_bf16 in11 = (vec_bf16)vec_load_vec(in1);
53+
54+
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp);
55+
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp);
56+
}
57+
58+
FORCEINLINE void vec_load_mult14_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 inp)
59+
{
60+
vec_bf16 in01 = (vec_bf16)vec_load_vec(in0);
61+
vec_bf16 in11 = (vec_bf16)vec_load_vec(in1);
62+
vec_bf16 in21 = (vec_bf16)vec_load_vec(in2);
63+
vec_bf16 in31 = (vec_bf16)vec_load_vec(in3);
64+
65+
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp);
66+
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp);
67+
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21, (vec_uc8)inp);
68+
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31, (vec_uc8)inp);
69+
}
70+
4371
FORCEINLINE void vec_load_mult2_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp)
4472
{
4573
vec_bf16 in0[2];
@@ -50,13 +78,123 @@ FORCEINLINE void vec_load_mult2_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *
5078
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[1]);
5179
}
5280

81+
FORCEINLINE void vec_load_mult22_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp)
82+
{
83+
vec_bf16 in01[2], in11[2];
84+
85+
vec_load_pair((vec_f32 *)in01, (vec_f32 *)in0);
86+
vec_load_pair((vec_f32 *)in11, (vec_f32 *)in1);
87+
88+
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]);
89+
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]);
90+
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]);
91+
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]);
92+
}
93+
94+
FORCEINLINE void vec_load_mult24_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 *inp)
95+
{
96+
vec_bf16 in01[2], in11[2], in21[2], in31[2];
97+
98+
vec_load_pair((vec_f32 *)in01, (vec_f32 *)in0);
99+
vec_load_pair((vec_f32 *)in11, (vec_f32 *)in1);
100+
vec_load_pair((vec_f32 *)in21, (vec_f32 *)in2);
101+
vec_load_pair((vec_f32 *)in31, (vec_f32 *)in3);
102+
103+
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]);
104+
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]);
105+
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[0], (vec_uc8)inp[0]);
106+
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[0], (vec_uc8)inp[0]);
107+
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]);
108+
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]);
109+
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[1], (vec_uc8)inp[1]);
110+
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[1], (vec_uc8)inp[1]);
111+
}
112+
113+
FORCEINLINE void vec_load_mult4_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp)
114+
{
115+
vec_bf16 in0[4];
116+
117+
vec_load_pair2(in0, in);
118+
119+
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[0], (vec_uc8)inp[0]);
120+
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[1]);
121+
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[2], (vec_uc8)inp[2]);
122+
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[3], (vec_uc8)inp[3]);
123+
}
124+
125+
FORCEINLINE void vec_load_mult42_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp)
126+
{
127+
vec_bf16 in01[4], in11[4];
128+
129+
vec_load_pair2(in01, in0);
130+
vec_load_pair2(in11, in1);
131+
132+
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]);
133+
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]);
134+
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]);
135+
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]);
136+
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[2], (vec_uc8)inp[2]);
137+
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[2], (vec_uc8)inp[2]);
138+
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[3], (vec_uc8)inp[3]);
139+
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[3], (vec_uc8)inp[3]);
140+
}
141+
142+
FORCEINLINE void vec_load_mult44_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 *inp)
143+
{
144+
vec_bf16 in01[4], in11[4], in21[4], in31[4];
145+
146+
vec_load_pair2(in01, in0);
147+
vec_load_pair2(in11, in1);
148+
vec_load_pair2(in21, in2);
149+
vec_load_pair2(in31, in3);
150+
151+
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]);
152+
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]);
153+
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[0], (vec_uc8)inp[0]);
154+
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[0], (vec_uc8)inp[0]);
155+
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]);
156+
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]);
157+
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[1], (vec_uc8)inp[1]);
158+
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[1], (vec_uc8)inp[1]);
159+
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[2], (vec_uc8)inp[2]);
160+
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[2], (vec_uc8)inp[2]);
161+
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[2], (vec_uc8)inp[2]);
162+
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[2], (vec_uc8)inp[2]);
163+
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[3], (vec_uc8)inp[3]);
164+
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[3], (vec_uc8)inp[3]);
165+
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[3], (vec_uc8)inp[3]);
166+
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[3], (vec_uc8)inp[3]);
167+
}
168+
53169
FORCEINLINE void vec_loadN_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n)
54170
{
55171
vec_bf16 in0 = vec_loadN(in, n);
56172

57173
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0, (vec_uc8)inp);
58174
}
59175

176+
FORCEINLINE void vec_loadN_mult12a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp, BLASLONG n)
177+
{
178+
vec_bf16 in01 = (vec_bf16)vec_loadN(in0, n);
179+
vec_bf16 in11 = (vec_bf16)vec_loadN(in1, n);
180+
181+
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp);
182+
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp);
183+
}
184+
185+
FORCEINLINE void vec_loadN_mult14_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 inp, BLASLONG n)
186+
{
187+
vec_bf16 in01 = (vec_bf16)vec_loadN(in0, n);
188+
vec_bf16 in11 = (vec_bf16)vec_loadN(in1, n);
189+
vec_bf16 in21 = (vec_bf16)vec_loadN(in2, n);
190+
vec_bf16 in31 = (vec_bf16)vec_loadN(in3, n);
191+
192+
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp);
193+
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp);
194+
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21, (vec_uc8)inp);
195+
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31, (vec_uc8)inp);
196+
}
197+
60198
FORCEINLINE void vec_mult1_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp)
61199
{
62200
vec_bf16 in00 = vec_mergeh(in0, in0);

kernel/power/sbgemv_n_power10.c

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,12 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
119119
if (n > 4) {
120120
vec_loadN_mult12_mma(&temp[0], &va0[i], v_x0[ 0], n);
121121

122-
BLASLONG n3 = n & 3;
123-
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3);
122+
n &= 3;
123+
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n);
124124

125125
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
126126

127-
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
127+
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n);
128128
} else if (n) {
129129
vec_loadN_mult11_mma(&temp[0], &va0[i], v_x0[ 0], n);
130130

@@ -213,12 +213,12 @@ static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
213213
if (n > 4) {
214214
vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n);
215215

216-
BLASLONG n3 = n & 3;
217-
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3);
216+
n &= 3;
217+
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n);
218218

219219
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
220220

221-
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
221+
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n);
222222
} else if (n) {
223223
vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n);
224224

@@ -318,12 +318,12 @@ static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
318318
vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n);
319319
vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n);
320320

321-
BLASLONG n3 = n & 3;
322-
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3);
321+
n &= 3;
322+
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n);
323323

324324
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
325325

326-
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
326+
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n);
327327
} else if (n) {
328328
vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n);
329329
vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n);
@@ -445,12 +445,12 @@ static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS
445445
vec_loadN_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x0[ 8], n);
446446
vec_loadN_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x0[12], n);
447447

448-
BLASLONG n3 = n & 3;
449-
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3);
448+
n &= 3;
449+
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n);
450450

451451
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
452452

453-
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
453+
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n);
454454
} else if (n) {
455455
vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n);
456456
vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n);

0 commit comments

Comments
 (0)