Skip to content

Commit 7ec3c16

Browse files
committed
Remove beta from optimized functions.
1 parent 7cc00f6 commit 7ec3c16

File tree

6 files changed

+101
-96
lines changed

6 files changed

+101
-96
lines changed

Makefile.system

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ GEMM_GEMV_FORWARD = 1
282282
endif
283283
ifeq ($(ARCH), power)
284284
GEMM_GEMV_FORWARD = 1
285+
GEMM_GEMV_FORWARD_BF16 = 1
285286
endif
286287

287288
ifeq ($(SMALL_MATRIX_OPT), 1)

kernel/power/sbgemv_common.c

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,10 @@ FORCEINLINE void copy_x(BLASLONG n, IFLOAT *src, IFLOAT *dest, BLASLONG inc_src)
122122
FORCEINLINE void copy_y_beta(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, FLOAT beta)
123123
{
124124
if (beta == 0) {
125-
memset(dest, 0, sizeof(FLOAT) * n);
125+
for (BLASLONG i = 0; i < n; i++) {
126+
*dest++ = (FLOAT)0;
127+
src += inc_src;
128+
}
126129
} else if (beta == 1) {
127130
for (BLASLONG i = 0; i < n; i++) {
128131
*dest++ = *src;
@@ -163,4 +166,64 @@ FORCEINLINE void move_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_dest)
163166
dest += inc_dest;
164167
}
165168
}
169+
170+
static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vector, FLOAT beta)
171+
{
172+
if (beta == 0) {
173+
memset(output_vector, 0, sizeof(FLOAT) * n);
174+
} else if (beta == 1) {
175+
if (output_vector != input_vector) {
176+
memcpy(output_vector, input_vector, sizeof(FLOAT) * n);
177+
}
178+
} else {
179+
vec_f32 b = { beta, beta, beta, beta };
180+
181+
vec_f32 *in = (vec_f32 *)input_vector;
182+
vec_f32 *out = (vec_f32 *)output_vector;
183+
184+
BLASLONG n8 = n / 8;
185+
BLASLONG i = 0;
186+
vec_f32 v_inp0[2];
187+
188+
for (; i + 4 <= n8; i += 4) {
189+
vec_f32 v_inp1[2], v_inp2[2], v_inp3[2];
190+
vec_load_pair(v_inp0, &in[(i * 2) + 0]);
191+
vec_load_pair(v_inp1, &in[(i * 2) + 2]);
192+
vec_load_pair(v_inp2, &in[(i * 2) + 4]);
193+
vec_load_pair(v_inp3, &in[(i * 2) + 6]);
194+
v_inp0[0] *= b;
195+
v_inp0[1] *= b;
196+
v_inp1[0] *= b;
197+
v_inp1[1] *= b;
198+
v_inp2[0] *= b;
199+
v_inp2[1] *= b;
200+
v_inp3[0] *= b;
201+
v_inp3[1] *= b;
202+
vec_store_pair(&out[(i * 2) + 0], v_inp0);
203+
vec_store_pair(&out[(i * 2) + 2], v_inp1);
204+
vec_store_pair(&out[(i * 2) + 4], v_inp2);
205+
vec_store_pair(&out[(i * 2) + 6], v_inp3);
206+
}
207+
208+
for (; i < n8; i++) {
209+
vec_load_pair(v_inp0, &in[(i * 2) + 0]);
210+
v_inp0[0] *= b;
211+
v_inp0[1] *= b;
212+
vec_store_pair(&out[(i * 2) + 0], v_inp0);
213+
}
214+
215+
n &= 7;
216+
if (n > 4) {
217+
BLASLONG n3 = n & 3;
218+
vec_loadN2_f32(v_inp0, &in[(i * 2) + 0], n3);
219+
v_inp0[0] *= b;
220+
v_inp0[1] *= b;
221+
vec_storeN2_f32(v_inp0, &out[(i * 2) + 0], n3);
222+
} else if (n) {
223+
v_inp0[0] = vec_loadN_f32(&in[(i * 2) + 0], n);
224+
v_inp0[0] *= b;
225+
vec_storeN_f32(v_inp0[0], &out[(i * 2) + 0], n);
226+
}
227+
}
228+
}
166229
#endif

kernel/power/sbgemv_n.c

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -27,65 +27,6 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2727

2828
#ifndef SBGEMV_N_COMMON_C
2929
#define SBGEMV_N_COMMON_C
30-
static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vector, FLOAT beta)
31-
{
32-
if (beta == 0) {
33-
memset(output_vector, 0, sizeof(FLOAT) * n);
34-
} else if (beta == 1) {
35-
if (output_vector != input_vector) {
36-
memcpy(output_vector, input_vector, sizeof(FLOAT) * n);
37-
}
38-
} else {
39-
vec_f32 b = { beta, beta, beta, beta };
40-
41-
vec_f32 *in = (vec_f32 *)input_vector;
42-
vec_f32 *out = (vec_f32 *)output_vector;
43-
44-
BLASLONG n8 = n / 8;
45-
BLASLONG i = 0;
46-
vec_f32 v_inp0[2];
47-
48-
for (; i + 4 <= n8; i += 4) {
49-
vec_f32 v_inp1[2], v_inp2[2], v_inp3[2];
50-
vec_load_pair(v_inp0, &in[(i * 2) + 0]);
51-
vec_load_pair(v_inp1, &in[(i * 2) + 2]);
52-
vec_load_pair(v_inp2, &in[(i * 2) + 4]);
53-
vec_load_pair(v_inp3, &in[(i * 2) + 6]);
54-
v_inp0[0] *= b;
55-
v_inp0[1] *= b;
56-
v_inp1[0] *= b;
57-
v_inp1[1] *= b;
58-
v_inp2[0] *= b;
59-
v_inp2[1] *= b;
60-
v_inp3[0] *= b;
61-
v_inp3[1] *= b;
62-
vec_store_pair(&out[(i * 2) + 0], v_inp0);
63-
vec_store_pair(&out[(i * 2) + 2], v_inp1);
64-
vec_store_pair(&out[(i * 2) + 4], v_inp2);
65-
vec_store_pair(&out[(i * 2) + 6], v_inp3);
66-
}
67-
68-
for (; i < n8; i++) {
69-
vec_load_pair(v_inp0, &in[(i * 2) + 0]);
70-
v_inp0[0] *= b;
71-
v_inp0[1] *= b;
72-
vec_store_pair(&out[(i * 2) + 0], v_inp0);
73-
}
74-
75-
n &= 7;
76-
if (n > 4) {
77-
BLASLONG n3 = n & 3;
78-
vec_loadN2_f32(v_inp0, &in[(i * 2) + 0], n3);
79-
v_inp0[0] *= b;
80-
v_inp0[1] *= b;
81-
vec_storeN2_f32(v_inp0, &out[(i * 2) + 0], n3);
82-
} else if (n) {
83-
v_inp0[0] = vec_loadN_f32(&in[(i * 2) + 0], n);
84-
v_inp0[0] *= b;
85-
vec_storeN_f32(v_inp0[0], &out[(i * 2) + 0], n);
86-
}
87-
}
88-
}
8930

9031
#if (defined(_ARCH_PWR10) && (defined(USE_BFGEMV_8_N_MMA) || (!defined(USE_BFGEMV_N_MMA) && defined(USE_BFGEMV_8_N_VSX)))) || (!defined(_ARCH_PWR10) && defined(USE_BFGEMV_8_N_VSX))
9132
#define USE_N_8

kernel/power/sbgemv_t.c

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
4141

4242
if ((m < 1) || (n < 1)) return 0;
4343

44+
if (inc_y == 1) {
45+
BF16GEMV_N_beta(n, y, y, beta);
46+
}
47+
4448
xbuffer = buffer;
4549

4650
BLASLONG lda4 = lda << 2;
@@ -58,42 +62,45 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
5862
}
5963

6064
a_ptr = a;
65+
a += NB;
6166
y_ptr = y;
6267

6368
if (inc_x != 1) {
6469
copy_x(NB, x, xbuffer, inc_x);
70+
x += NB * inc_x;
6571
} else {
6672
xbuffer = x;
73+
x += NB;
6774
}
6875

6976
if (inc_y == 1) {
7077
#ifdef USE_T_8
7178
for (BLASLONG j = 0; j + 8 <= n; j += 8) {
72-
BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta);
79+
BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, y_ptr, alpha);
7380
y_ptr += 8;
7481
a_ptr += lda8;
7582
}
7683
if (n & 4) {
7784
#else
7885
for (BLASLONG j = 0; j + 4 <= n; j += 4) {
7986
#endif
80-
BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta);
87+
BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, y_ptr, alpha);
8188
y_ptr += 4;
8289
a_ptr += lda4;
8390
}
8491
if (n & 2) {
85-
BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta);
92+
BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, y_ptr, alpha);
8693
y_ptr += 2;
8794
a_ptr += (lda * 2);
8895
}
8996
if (n & 1) {
90-
BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta);
97+
BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, y_ptr, alpha);
9198
}
9299
} else {
93100
#ifdef USE_T_8
94101
for (BLASLONG j = 0; j + 8 <= n; j += 8) {
95102
memset(ybuffer, 0, sizeof(FLOAT) * 8);
96-
BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta);
103+
BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, ybuffer, alpha);
97104
copy_y(8, ybuffer, y_ptr, inc_y, beta);
98105
y_ptr += 8 * inc_y;
99106
a_ptr += lda8;
@@ -103,28 +110,25 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
103110
for (BLASLONG j = 0; j + 4 <= n; j += 4) {
104111
#endif
105112
memset(ybuffer, 0, sizeof(FLOAT) * 4);
106-
BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta);
113+
BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, ybuffer, alpha);
107114
copy_y(4, ybuffer, y_ptr, inc_y, beta);
108115
y_ptr += 4 * inc_y;
109116
a_ptr += lda4;
110117
}
111118
if (n & 2) {
112119
memset(ybuffer, 0, sizeof(FLOAT) * 4);
113-
BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta);
120+
BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, ybuffer, alpha);
114121
copy_y(2, ybuffer, y_ptr, inc_y, beta);
115122
y_ptr += 2 * inc_y;
116123
a_ptr += (lda * 2);
117124
}
118125
if (n & 1) {
119126
memset(ybuffer, 0, sizeof(FLOAT) * 4);
120-
BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta);
127+
BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, ybuffer, alpha);
121128
copy_y(1, ybuffer, y_ptr, inc_y, beta);
122129
}
130+
beta = (FLOAT)1;
123131
}
124-
125-
a += NB;
126-
x += NB * inc_x;
127-
beta = (FLOAT)1;
128132
}
129133

130134
return 0;

kernel/power/sbgemv_t_power10.c

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
4343

4444
#define USE_BFGEMV_8_T_MMA
4545

46-
static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
46+
static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
4747
{
4848
IFLOAT *a0;
4949
vec_bf16 *va0, *v_x;
@@ -90,10 +90,10 @@ static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
9090

9191
__builtin_mma_disassemble_acc((void*)temp00, &temp0);
9292

93-
y[0] = (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])) + (beta * y[0]);
93+
y[0] += (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3]));
9494
}
9595

96-
static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
96+
static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
9797
{
9898
IFLOAT *a0, *a1;
9999
vec_bf16 *va0, *va1, *v_x;
@@ -142,11 +142,11 @@ static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
142142

143143
vec_reduce_2(temp00, &temp0[0]);
144144

145-
y[0] = (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])) + (beta * y[0]);
146-
y[1] = (alpha * (temp00[4][0] + temp00[5][1] + temp00[6][2] + temp00[7][3])) + (beta * y[1]);
145+
y[0] += (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3]));
146+
y[1] += (alpha * (temp00[4][0] + temp00[5][1] + temp00[6][2] + temp00[7][3]));
147147
}
148148

149-
static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
149+
static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
150150
{
151151
IFLOAT *a0, *a1, *a2, *a3;
152152
vec_bf16 *va0, *va1, *va2, *va3, *v_x;
@@ -201,7 +201,6 @@ static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
201201

202202
vec_f32 t0, t1, t2, t3, t4, t5, t6, t7;
203203
vec_f32 a = { alpha, alpha, alpha, alpha };
204-
vec_f32 b = { beta, beta, beta, beta };
205204
vec_f32 *v_y = (vec_f32 *) y;
206205

207206
t0 = vec_mergeh(temp00[ 0], temp00[ 4]);
@@ -219,11 +218,11 @@ static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
219218

220219
t0 += t2 + t4 + t6;
221220

222-
v_y[0] = (a * t0) + (b * v_y[0]);
221+
v_y[0] += (a * t0);
223222
}
224223

225224
#ifdef USE_BFGEMV_8_T_MMA
226-
static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
225+
static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
227226
{
228227
IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7;
229228
vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x;
@@ -291,7 +290,6 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
291290

292291
vec_f32 t0, t1, t2, t3, t4, t5, t6, t7, t10, t11, t12, t13, t14, t15, t16, t17;
293292
vec_f32 a = { alpha, alpha, alpha, alpha };
294-
vec_f32 b = { beta, beta, beta, beta };
295293
vec_f32 *v_y = (vec_f32 *) y;
296294

297295
t0 = vec_mergeh(temp00[ 0], temp00[ 4]);
@@ -326,8 +324,8 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
326324

327325
vec_f32 inp2[2];
328326
vec_load_pair(inp2, v_y);
329-
inp2[0] = (a * t0) + (b * inp2[0]);
330-
inp2[1] = (a * t10) + (b * inp2[1]);
327+
inp2[0] += (a * t0);
328+
inp2[1] += (a * t10);
331329
vec_store_pair(v_y, inp2);
332330
}
333331
#endif

0 commit comments

Comments
 (0)