Skip to content

Commit bc37284

Browse files
committed
format code
1 parent 55d686d commit bc37284

File tree

3 files changed

+60
-50
lines changed

3 files changed

+60
-50
lines changed

kernel/arm64/sbgemm_kernel_8x4_neoversen2.c

Lines changed: 56 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,16 @@
2727
* *****************************************************************************/
2828

2929
#include <arm_sve.h>
30+
3031
#include "common.h"
3132

32-
int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
33-
FLOAT *C, BLASLONG ldc) {
33+
int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C,
34+
BLASLONG ldc) {
3435
// printf("m: %d, n: %d, k: %d\n", m, n, k);
3536
BLASLONG padk = (k + 3) & ~3;
3637
BLASLONG padm = (m + 1) & ~1;
3738
BLASLONG padn = (n + 1) & ~1;
38-
FLOAT *RC = (FLOAT *) calloc(padm * padn, sizeof(float));
39+
FLOAT *RC = (FLOAT *)calloc(padm * padn, sizeof(float));
3940
BLASLONG nldc = padm;
4041

4142
IFLOAT *ptr_a = A;
@@ -52,10 +53,10 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
5253
svbool_t pg32 = svptrue_b32();
5354
svfloat32_t svalpha = svdup_f32(alpha);
5455

55-
uint32_t off_c[] = {0, (uint32_t) nldc, 1, (uint32_t) nldc + 1}; // 00 01 10 11
56+
uint32_t off_c[] = {0, (uint32_t)nldc, 1, (uint32_t)nldc + 1}; // 00 01 10 11
5657
svuint32_t off_vc = svld1_u32(pg32, off_c);
5758

58-
for (BLASLONG j = 0; j < padn/4; j++) {
59+
for (BLASLONG j = 0; j < padn / 4; j++) {
5960
ptr_c00 = ptr_c;
6061
ptr_c10 = ptr_c00 + 2;
6162
ptr_c20 = ptr_c10 + 2;
@@ -68,7 +69,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
6869

6970
ptr_a = A;
7071

71-
for (BLASLONG i = 0; i < padm/8; i++) {
72+
for (BLASLONG i = 0; i < padm / 8; i++) {
7273
ptr_a0 = ptr_a;
7374
ptr_a1 = ptr_a0 + 2 * padk;
7475
ptr_a2 = ptr_a1 + 2 * padk;
@@ -78,18 +79,22 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
7879
ptr_b0 = ptr_b;
7980
ptr_b1 = ptr_b0 + 2 * padk;
8081

81-
mc00 = svdup_f32(0); mc01 = svdup_f32(0);
82-
mc10 = svdup_f32(0); mc11 = svdup_f32(0);
83-
mc20 = svdup_f32(0); mc21 = svdup_f32(0);
84-
mc30 = svdup_f32(0); mc31 = svdup_f32(0);
85-
86-
for (BLASLONG p = 0; p < padk/4; p++) {
87-
ma0 = svld1_bf16(pg16, (bfloat16_t *) ptr_a0);
88-
ma1 = svld1_bf16(pg16, (bfloat16_t *) ptr_a1);
89-
ma2 = svld1_bf16(pg16, (bfloat16_t *) ptr_a2);
90-
ma3 = svld1_bf16(pg16, (bfloat16_t *) ptr_a3);
91-
mb0 = svld1_bf16(pg16, (bfloat16_t *) ptr_b0);
92-
mb1 = svld1_bf16(pg16, (bfloat16_t *) ptr_b1);
82+
mc00 = svdup_f32(0);
83+
mc01 = svdup_f32(0);
84+
mc10 = svdup_f32(0);
85+
mc11 = svdup_f32(0);
86+
mc20 = svdup_f32(0);
87+
mc21 = svdup_f32(0);
88+
mc30 = svdup_f32(0);
89+
mc31 = svdup_f32(0);
90+
91+
for (BLASLONG p = 0; p < padk / 4; p++) {
92+
ma0 = svld1_bf16(pg16, (bfloat16_t *)ptr_a0);
93+
ma1 = svld1_bf16(pg16, (bfloat16_t *)ptr_a1);
94+
ma2 = svld1_bf16(pg16, (bfloat16_t *)ptr_a2);
95+
ma3 = svld1_bf16(pg16, (bfloat16_t *)ptr_a3);
96+
mb0 = svld1_bf16(pg16, (bfloat16_t *)ptr_b0);
97+
mb1 = svld1_bf16(pg16, (bfloat16_t *)ptr_b1);
9398

9499
mc00 = svbfmmla(mc00, ma0, mb0);
95100
mc10 = svbfmmla(mc10, ma1, mb0);
@@ -135,13 +140,15 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
135140
ptr_b0 = ptr_b;
136141
ptr_b1 = ptr_b0 + 2 * padk;
137142

138-
mc00 = svdup_f32(0); mc01 = svdup_f32(0);
139-
mc10 = svdup_f32(0); mc11 = svdup_f32(0);
140-
for (BLASLONG p = 0; p < padk/4; p++) {
141-
ma0 = svld1_bf16(pg16, (bfloat16_t *) ptr_a0);
142-
ma1 = svld1_bf16(pg16, (bfloat16_t *) ptr_a1);
143-
mb0 = svld1_bf16(pg16, (bfloat16_t *) ptr_b0);
144-
mb1 = svld1_bf16(pg16, (bfloat16_t *) ptr_b1);
143+
mc00 = svdup_f32(0);
144+
mc01 = svdup_f32(0);
145+
mc10 = svdup_f32(0);
146+
mc11 = svdup_f32(0);
147+
for (BLASLONG p = 0; p < padk / 4; p++) {
148+
ma0 = svld1_bf16(pg16, (bfloat16_t *)ptr_a0);
149+
ma1 = svld1_bf16(pg16, (bfloat16_t *)ptr_a1);
150+
mb0 = svld1_bf16(pg16, (bfloat16_t *)ptr_b0);
151+
mb1 = svld1_bf16(pg16, (bfloat16_t *)ptr_b1);
145152

146153
mc00 = svbfmmla(mc00, ma0, mb0);
147154
mc10 = svbfmmla(mc10, ma1, mb0);
@@ -171,11 +178,12 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
171178
ptr_b0 = ptr_b;
172179
ptr_b1 = ptr_b0 + 2 * padk;
173180

174-
mc00 = svdup_f32(0); mc01 = svdup_f32(0);
175-
for (BLASLONG p = 0; p < padk/4; p++) {
176-
ma0 = svld1_bf16(pg16, (bfloat16_t *) ptr_a0);
177-
mb0 = svld1_bf16(pg16, (bfloat16_t *) ptr_b0);
178-
mb1 = svld1_bf16(pg16, (bfloat16_t *) ptr_b1);
181+
mc00 = svdup_f32(0);
182+
mc01 = svdup_f32(0);
183+
for (BLASLONG p = 0; p < padk / 4; p++) {
184+
ma0 = svld1_bf16(pg16, (bfloat16_t *)ptr_a0);
185+
mb0 = svld1_bf16(pg16, (bfloat16_t *)ptr_b0);
186+
mb1 = svld1_bf16(pg16, (bfloat16_t *)ptr_b1);
179187
mc00 = svbfmmla(mc00, ma0, mb0);
180188
mc01 = svbfmmla(mc01, ma0, mb1);
181189
ptr_a0 += 8;
@@ -189,7 +197,6 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
189197
}
190198

191199
ptr_b += 4 * padk;
192-
193200
}
194201

195202
if (padn & 2) {
@@ -202,7 +209,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
202209

203210
ptr_a = A;
204211

205-
for (BLASLONG i = 0; i < padm/8; i++) {
212+
for (BLASLONG i = 0; i < padm / 8; i++) {
206213
ptr_a0 = ptr_a;
207214
ptr_a1 = ptr_a0 + 2 * padk;
208215
ptr_a2 = ptr_a1 + 2 * padk;
@@ -216,12 +223,12 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
216223
mc20 = svdup_f32(0);
217224
mc30 = svdup_f32(0);
218225

219-
for (BLASLONG p = 0; p < padk/4; p++) {
220-
ma0 = svld1_bf16(pg16, (bfloat16_t *) ptr_a0);
221-
ma1 = svld1_bf16(pg16, (bfloat16_t *) ptr_a1);
222-
ma2 = svld1_bf16(pg16, (bfloat16_t *) ptr_a2);
223-
ma3 = svld1_bf16(pg16, (bfloat16_t *) ptr_a3);
224-
mb0 = svld1_bf16(pg16, (bfloat16_t *) ptr_b0);
226+
for (BLASLONG p = 0; p < padk / 4; p++) {
227+
ma0 = svld1_bf16(pg16, (bfloat16_t *)ptr_a0);
228+
ma1 = svld1_bf16(pg16, (bfloat16_t *)ptr_a1);
229+
ma2 = svld1_bf16(pg16, (bfloat16_t *)ptr_a2);
230+
ma3 = svld1_bf16(pg16, (bfloat16_t *)ptr_a3);
231+
mb0 = svld1_bf16(pg16, (bfloat16_t *)ptr_b0);
225232
mc00 = svbfmmla(mc00, ma0, mb0);
226233
mc10 = svbfmmla(mc10, ma1, mb0);
227234
mc20 = svbfmmla(mc20, ma2, mb0);
@@ -251,10 +258,10 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
251258

252259
mc00 = svdup_f32(0);
253260
mc10 = svdup_f32(0);
254-
for (BLASLONG p = 0; p < padk/4; p++) {
255-
ma0 = svld1_bf16(pg16, (bfloat16_t *) ptr_a0);
256-
ma1 = svld1_bf16(pg16, (bfloat16_t *) ptr_a1);
257-
mb0 = svld1_bf16(pg16, (bfloat16_t *) ptr_b0);
261+
for (BLASLONG p = 0; p < padk / 4; p++) {
262+
ma0 = svld1_bf16(pg16, (bfloat16_t *)ptr_a0);
263+
ma1 = svld1_bf16(pg16, (bfloat16_t *)ptr_a1);
264+
mb0 = svld1_bf16(pg16, (bfloat16_t *)ptr_b0);
258265
mc00 = svbfmmla(mc00, ma0, mb0);
259266
mc10 = svbfmmla(mc10, ma1, mb0);
260267
ptr_a0 += 8;
@@ -272,9 +279,9 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
272279
ptr_a += 2 * padk;
273280
ptr_b0 = ptr_b;
274281
mc00 = svdup_f32(0);
275-
for (BLASLONG p = 0; p < padk/4; p++) {
276-
ma0 = svld1_bf16(pg16, (bfloat16_t *) ptr_a0);
277-
mb0 = svld1_bf16(pg16, (bfloat16_t *) ptr_b0);
282+
for (BLASLONG p = 0; p < padk / 4; p++) {
283+
ma0 = svld1_bf16(pg16, (bfloat16_t *)ptr_a0);
284+
mb0 = svld1_bf16(pg16, (bfloat16_t *)ptr_b0);
278285
mc00 = svbfmmla(mc00, ma0, mb0);
279286
ptr_a0 += 8;
280287
ptr_b0 += 8;
@@ -296,10 +303,11 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
296303
org_c += ldc;
297304
raw_c += nldc;
298305
BLASLONG i;
299-
for (i = 0; i < m/4; i++) {
306+
for (i = 0; i < m / 4; i++) {
300307
org_vc0 = svld1_f32(pg32, org_c0);
301308
raw_vc0 = svld1_f32(pg32, raw_c0);
302-
org_vc0 = svmad_z(pg32, svalpha, raw_vc0, org_vc0); // alpha * raw + org, raw -> a * b
309+
org_vc0 = svmad_z(pg32, svalpha, raw_vc0,
310+
org_vc0); // alpha * raw + org, raw -> a * b
303311
svst1_f32(pg32, org_c0, org_vc0);
304312
org_c0 += 4;
305313
raw_c0 += 4;
@@ -310,5 +318,6 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
310318
raw_c0++;
311319
}
312320
}
321+
313322
return 0;
314323
}

kernel/arm64/sbgemm_ncopy_neoversen2.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) {
9797
*(b_offset + 6) = 0;
9898
*(b_offset + 7) = 0;
9999
a_offset += 4;
100-
b_offset += 8;
100+
b_offset += 4;
101101
}
102102
if (i < m) {
103103
*(b_offset + 4) = 0;

kernel/arm64/sbgemm_tcopy_neoversen2.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) {
8787
b_offset += 8;
8888
}
8989
}
90-
if (j < n) { // padding 2
90+
if (j < n) { // rest 1
9191
BLASLONG i = 0;
9292
for (; i < m4; i += 4) {
9393
*(b_offset + 0) = *(a_offset + 0);
@@ -98,7 +98,7 @@ int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) {
9898
*(b_offset + 5) = 0;
9999
*(b_offset + 6) = 0;
100100
*(b_offset + 7) = 0;
101-
b_offset += 8;
101+
b_offset += 4;
102102
a_offset += 4 * lda;
103103
}
104104
if (i < m) {
@@ -113,5 +113,6 @@ int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) {
113113
*(b_offset + 3) = 0;
114114
}
115115
}
116+
116117
return 0;
117118
}

0 commit comments

Comments
 (0)