Skip to content

Commit 123e0df

Browse files
committed
Neoverse N2 sbgemm:
1. Modify the algorithm to resolve multithreading failures 2. No memory allocation in sbgemm kernel 3. Optimize when alpha == 1.0f
1 parent bc37284 commit 123e0df

File tree

5 files changed

+753
-387
lines changed

5 files changed

+753
-387
lines changed

kernel/arm64/sbgemm_kernel_8x4_neoversen2.c

Lines changed: 11 additions & 289 deletions
Original file line numberDiff line numberDiff line change
@@ -30,294 +30,16 @@
3030

3131
#include "common.h"
3232

33-
int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C,
34-
BLASLONG ldc) {
35-
// printf("m: %d, n: %d, k: %d\n", m, n, k);
36-
BLASLONG padk = (k + 3) & ~3;
37-
BLASLONG padm = (m + 1) & ~1;
38-
BLASLONG padn = (n + 1) & ~1;
39-
FLOAT *RC = (FLOAT *)calloc(padm * padn, sizeof(float));
40-
BLASLONG nldc = padm;
41-
42-
IFLOAT *ptr_a = A;
43-
IFLOAT *ptr_b = B;
44-
FLOAT *ptr_c = RC;
45-
46-
IFLOAT *ptr_a0, *ptr_a1, *ptr_a2, *ptr_a3;
47-
IFLOAT *ptr_b0, *ptr_b1;
48-
FLOAT *ptr_c00, *ptr_c10, *ptr_c20, *ptr_c30, *ptr_c01, *ptr_c11, *ptr_c21, *ptr_c31;
49-
50-
svbfloat16_t ma0, ma1, ma2, ma3, mb0, mb1;
51-
svfloat32_t mc00, mc01, mc10, mc11, mc20, mc21, mc30, mc31;
52-
svbool_t pg16 = svptrue_b16();
53-
svbool_t pg32 = svptrue_b32();
54-
svfloat32_t svalpha = svdup_f32(alpha);
55-
56-
uint32_t off_c[] = {0, (uint32_t)nldc, 1, (uint32_t)nldc + 1}; // 00 01 10 11
57-
svuint32_t off_vc = svld1_u32(pg32, off_c);
58-
59-
for (BLASLONG j = 0; j < padn / 4; j++) {
60-
ptr_c00 = ptr_c;
61-
ptr_c10 = ptr_c00 + 2;
62-
ptr_c20 = ptr_c10 + 2;
63-
ptr_c30 = ptr_c20 + 2;
64-
ptr_c01 = ptr_c + 2 * nldc;
65-
ptr_c11 = ptr_c01 + 2;
66-
ptr_c21 = ptr_c11 + 2;
67-
ptr_c31 = ptr_c21 + 2;
68-
ptr_c += 4 * nldc;
69-
70-
ptr_a = A;
71-
72-
for (BLASLONG i = 0; i < padm / 8; i++) {
73-
ptr_a0 = ptr_a;
74-
ptr_a1 = ptr_a0 + 2 * padk;
75-
ptr_a2 = ptr_a1 + 2 * padk;
76-
ptr_a3 = ptr_a2 + 2 * padk;
77-
ptr_a += 8 * padk;
78-
79-
ptr_b0 = ptr_b;
80-
ptr_b1 = ptr_b0 + 2 * padk;
81-
82-
mc00 = svdup_f32(0);
83-
mc01 = svdup_f32(0);
84-
mc10 = svdup_f32(0);
85-
mc11 = svdup_f32(0);
86-
mc20 = svdup_f32(0);
87-
mc21 = svdup_f32(0);
88-
mc30 = svdup_f32(0);
89-
mc31 = svdup_f32(0);
90-
91-
for (BLASLONG p = 0; p < padk / 4; p++) {
92-
ma0 = svld1_bf16(pg16, (bfloat16_t *)ptr_a0);
93-
ma1 = svld1_bf16(pg16, (bfloat16_t *)ptr_a1);
94-
ma2 = svld1_bf16(pg16, (bfloat16_t *)ptr_a2);
95-
ma3 = svld1_bf16(pg16, (bfloat16_t *)ptr_a3);
96-
mb0 = svld1_bf16(pg16, (bfloat16_t *)ptr_b0);
97-
mb1 = svld1_bf16(pg16, (bfloat16_t *)ptr_b1);
98-
99-
mc00 = svbfmmla(mc00, ma0, mb0);
100-
mc10 = svbfmmla(mc10, ma1, mb0);
101-
mc20 = svbfmmla(mc20, ma2, mb0);
102-
mc30 = svbfmmla(mc30, ma3, mb0);
103-
mc01 = svbfmmla(mc01, ma0, mb1);
104-
mc11 = svbfmmla(mc11, ma1, mb1);
105-
mc21 = svbfmmla(mc21, ma2, mb1);
106-
mc31 = svbfmmla(mc31, ma3, mb1);
107-
108-
ptr_a0 += 8;
109-
ptr_a1 += 8;
110-
ptr_a2 += 8;
111-
ptr_a3 += 8;
112-
ptr_b0 += 8;
113-
ptr_b1 += 8;
114-
}
115-
svst1_scatter_index(pg32, ptr_c00, off_vc, mc00);
116-
svst1_scatter_index(pg32, ptr_c10, off_vc, mc10);
117-
svst1_scatter_index(pg32, ptr_c20, off_vc, mc20);
118-
svst1_scatter_index(pg32, ptr_c30, off_vc, mc30);
119-
svst1_scatter_index(pg32, ptr_c01, off_vc, mc01);
120-
svst1_scatter_index(pg32, ptr_c11, off_vc, mc11);
121-
svst1_scatter_index(pg32, ptr_c21, off_vc, mc21);
122-
svst1_scatter_index(pg32, ptr_c31, off_vc, mc31);
123-
124-
ptr_c00 += 8;
125-
ptr_c10 += 8;
126-
ptr_c20 += 8;
127-
ptr_c30 += 8;
128-
ptr_c01 += 8;
129-
ptr_c11 += 8;
130-
ptr_c21 += 8;
131-
ptr_c31 += 8;
132-
}
133-
134-
if (padm & 4) {
135-
// rest 4 or 6
136-
ptr_a0 = ptr_a;
137-
ptr_a1 = ptr_a0 + 2 * padk;
138-
ptr_a += 4 * padk;
139-
140-
ptr_b0 = ptr_b;
141-
ptr_b1 = ptr_b0 + 2 * padk;
142-
143-
mc00 = svdup_f32(0);
144-
mc01 = svdup_f32(0);
145-
mc10 = svdup_f32(0);
146-
mc11 = svdup_f32(0);
147-
for (BLASLONG p = 0; p < padk / 4; p++) {
148-
ma0 = svld1_bf16(pg16, (bfloat16_t *)ptr_a0);
149-
ma1 = svld1_bf16(pg16, (bfloat16_t *)ptr_a1);
150-
mb0 = svld1_bf16(pg16, (bfloat16_t *)ptr_b0);
151-
mb1 = svld1_bf16(pg16, (bfloat16_t *)ptr_b1);
152-
153-
mc00 = svbfmmla(mc00, ma0, mb0);
154-
mc10 = svbfmmla(mc10, ma1, mb0);
155-
mc01 = svbfmmla(mc01, ma0, mb1);
156-
mc11 = svbfmmla(mc11, ma1, mb1);
157-
158-
ptr_a0 += 8;
159-
ptr_a1 += 8;
160-
ptr_b0 += 8;
161-
ptr_b1 += 8;
162-
}
163-
svst1_scatter_index(pg32, ptr_c00, off_vc, mc00);
164-
svst1_scatter_index(pg32, ptr_c10, off_vc, mc10);
165-
svst1_scatter_index(pg32, ptr_c01, off_vc, mc01);
166-
svst1_scatter_index(pg32, ptr_c11, off_vc, mc11);
167-
168-
ptr_c00 += 4;
169-
ptr_c10 += 4;
170-
ptr_c01 += 4;
171-
ptr_c11 += 4;
172-
}
173-
174-
if (padm & 2) {
175-
// rest 2
176-
ptr_a0 = ptr_a;
177-
178-
ptr_b0 = ptr_b;
179-
ptr_b1 = ptr_b0 + 2 * padk;
180-
181-
mc00 = svdup_f32(0);
182-
mc01 = svdup_f32(0);
183-
for (BLASLONG p = 0; p < padk / 4; p++) {
184-
ma0 = svld1_bf16(pg16, (bfloat16_t *)ptr_a0);
185-
mb0 = svld1_bf16(pg16, (bfloat16_t *)ptr_b0);
186-
mb1 = svld1_bf16(pg16, (bfloat16_t *)ptr_b1);
187-
mc00 = svbfmmla(mc00, ma0, mb0);
188-
mc01 = svbfmmla(mc01, ma0, mb1);
189-
ptr_a0 += 8;
190-
ptr_b0 += 8;
191-
ptr_b1 += 8;
192-
}
193-
svst1_scatter_index(pg32, ptr_c00, off_vc, mc00);
194-
svst1_scatter_index(pg32, ptr_c01, off_vc, mc01);
195-
ptr_c00 += 2;
196-
ptr_c01 += 2;
197-
}
198-
199-
ptr_b += 4 * padk;
200-
}
201-
202-
if (padn & 2) {
203-
// rest 2
204-
ptr_c00 = ptr_c;
205-
ptr_c10 = ptr_c00 + 2;
206-
ptr_c20 = ptr_c10 + 2;
207-
ptr_c30 = ptr_c20 + 2;
208-
ptr_c += 2 * nldc;
209-
210-
ptr_a = A;
211-
212-
for (BLASLONG i = 0; i < padm / 8; i++) {
213-
ptr_a0 = ptr_a;
214-
ptr_a1 = ptr_a0 + 2 * padk;
215-
ptr_a2 = ptr_a1 + 2 * padk;
216-
ptr_a3 = ptr_a2 + 2 * padk;
217-
ptr_a += 8 * padk;
218-
219-
ptr_b0 = ptr_b;
220-
221-
mc00 = svdup_f32(0);
222-
mc10 = svdup_f32(0);
223-
mc20 = svdup_f32(0);
224-
mc30 = svdup_f32(0);
225-
226-
for (BLASLONG p = 0; p < padk / 4; p++) {
227-
ma0 = svld1_bf16(pg16, (bfloat16_t *)ptr_a0);
228-
ma1 = svld1_bf16(pg16, (bfloat16_t *)ptr_a1);
229-
ma2 = svld1_bf16(pg16, (bfloat16_t *)ptr_a2);
230-
ma3 = svld1_bf16(pg16, (bfloat16_t *)ptr_a3);
231-
mb0 = svld1_bf16(pg16, (bfloat16_t *)ptr_b0);
232-
mc00 = svbfmmla(mc00, ma0, mb0);
233-
mc10 = svbfmmla(mc10, ma1, mb0);
234-
mc20 = svbfmmla(mc20, ma2, mb0);
235-
mc30 = svbfmmla(mc30, ma3, mb0);
236-
ptr_a0 += 8;
237-
ptr_a1 += 8;
238-
ptr_a2 += 8;
239-
ptr_a3 += 8;
240-
ptr_b0 += 8;
241-
}
242-
svst1_scatter_index(pg32, ptr_c00, off_vc, mc00);
243-
svst1_scatter_index(pg32, ptr_c10, off_vc, mc10);
244-
svst1_scatter_index(pg32, ptr_c20, off_vc, mc20);
245-
svst1_scatter_index(pg32, ptr_c30, off_vc, mc30);
246-
ptr_c00 += 8;
247-
ptr_c10 += 8;
248-
ptr_c20 += 8;
249-
ptr_c30 += 8;
250-
}
251-
252-
if (padm & 4) {
253-
ptr_a0 = ptr_a;
254-
ptr_a1 = ptr_a0 + 2 * padk;
255-
ptr_a += 4 * padk;
256-
257-
ptr_b0 = ptr_b;
258-
259-
mc00 = svdup_f32(0);
260-
mc10 = svdup_f32(0);
261-
for (BLASLONG p = 0; p < padk / 4; p++) {
262-
ma0 = svld1_bf16(pg16, (bfloat16_t *)ptr_a0);
263-
ma1 = svld1_bf16(pg16, (bfloat16_t *)ptr_a1);
264-
mb0 = svld1_bf16(pg16, (bfloat16_t *)ptr_b0);
265-
mc00 = svbfmmla(mc00, ma0, mb0);
266-
mc10 = svbfmmla(mc10, ma1, mb0);
267-
ptr_a0 += 8;
268-
ptr_a1 += 8;
269-
ptr_b0 += 8;
270-
}
271-
svst1_scatter_index(pg32, ptr_c00, off_vc, mc00);
272-
svst1_scatter_index(pg32, ptr_c10, off_vc, mc10);
273-
ptr_c00 += 4;
274-
ptr_c10 += 4;
275-
}
276-
277-
if (padm & 2) {
278-
ptr_a0 = ptr_a;
279-
ptr_a += 2 * padk;
280-
ptr_b0 = ptr_b;
281-
mc00 = svdup_f32(0);
282-
for (BLASLONG p = 0; p < padk / 4; p++) {
283-
ma0 = svld1_bf16(pg16, (bfloat16_t *)ptr_a0);
284-
mb0 = svld1_bf16(pg16, (bfloat16_t *)ptr_b0);
285-
mc00 = svbfmmla(mc00, ma0, mb0);
286-
ptr_a0 += 8;
287-
ptr_b0 += 8;
288-
}
289-
svst1_scatter_index(pg32, ptr_c00, off_vc, mc00);
290-
ptr_c00 += 2;
291-
}
292-
293-
ptr_b += 2 * padk;
294-
}
295-
296-
FLOAT *org_c = C;
297-
FLOAT *raw_c = RC;
298-
FLOAT *org_c0, *raw_c0;
299-
svfloat32_t org_vc0, raw_vc0;
300-
for (BLASLONG j = 0; j < n; j++) {
301-
org_c0 = org_c;
302-
raw_c0 = raw_c;
303-
org_c += ldc;
304-
raw_c += nldc;
305-
BLASLONG i;
306-
for (i = 0; i < m / 4; i++) {
307-
org_vc0 = svld1_f32(pg32, org_c0);
308-
raw_vc0 = svld1_f32(pg32, raw_c0);
309-
org_vc0 = svmad_z(pg32, svalpha, raw_vc0,
310-
org_vc0); // alpha * raw + org, raw -> a * b
311-
svst1_f32(pg32, org_c0, org_vc0);
312-
org_c0 += 4;
313-
raw_c0 += 4;
314-
}
315-
for (i = 0; i < (m & 3); i++) {
316-
*org_c0 += alpha * (*raw_c0);
317-
org_c0++;
318-
raw_c0++;
319-
}
320-
}
321-
33+
#define ALPHA_ONE
34+
#include "sbgemm_kernel_8x4_neoversen2_impl.c"
35+
#undef ALPHA_ONE
36+
#include "sbgemm_kernel_8x4_neoversen2_impl.c"
37+
38+
int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
39+
FLOAT *C, BLASLONG ldc) {
40+
if (alpha == 1.0f)
41+
return sbgemm_kernel_neoversen2_alpha_one(m, n, k, alpha, A, B, C, ldc);
42+
else
43+
return sbgemm_kernel_neoversen2_alpha(m, n, k, alpha, A, B, C, ldc);
32244
return 0;
32345
}

0 commit comments

Comments
 (0)