Skip to content

Commit 9984c5c

Browse files
committed
Clean up k2 removal more and unroll SGEMM more
1 parent b1c9faf commit 9984c5c

File tree

5 files changed

+458
-32
lines changed

5 files changed

+458
-32
lines changed

kernel/arm64/dgemm_small_kernel_tn_sve.c

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -80,25 +80,12 @@ THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
8080
float64x2_t a##m##_k##offset_k = vld1q_dup_f64(&A_ELEMENT_K(m, offset_k));
8181
#define LOAD_A1(m, offset_k) \
8282
float64_t a##m##_k##offset_k = A_ELEMENT_K(m, offset_k);
83-
#define VECTOR_LOAD_B_K2(n, offset_k) \
84-
float64x2_t b##k##n##_k##offset_k = vld1q_f64(&B_ELEMENT_K(n, offset_k));
85-
#define TRANSPOSE_B2_K2(n0, n1, offset_k0, offset_k1) \
86-
float64x2_t b##n0##_k##offset_k0 = \
87-
vzip1q_f64(b##k##n0##_k##offset_k0, b##k##n1##_k##offset_k0); \
88-
float64x2_t b##n0##_k##offset_k1 = \
89-
vzip2q_f64(b##k##n0##_k##offset_k0, b##k##n1##_k##offset_k0);
90-
91-
#define SCALE_B2_K2(n0, offset_k0, offset_k1) \
92-
svfloat64_t b##s##n0##_k##offset_k0 = svdup_neonq_f64(b##n0##_k##offset_k0); \
93-
svfloat64_t b##s##n0##_k##offset_k1 = svdup_neonq_f64(b##n0##_k##offset_k1);
9483
#define GATHER_LOAD_B2(n, offset_k) \
9584
float64x2_t b##n##_k##offset_k = vdupq_n_f64(B_ELEMENT_K(n, offset_k)); \
9685
b##n##_k##offset_k = \
9786
vsetq_lane_f64(B_ELEMENT_K(n + 1, offset_k), b##n##_k##offset_k, 1);
9887
#define VECTOR_UNPACK_B2(n, offset_k) \
9988
float64x2_t b##n##_k##offset_k = vld1q_f64(&PACK_ELEMENT_K(n, offset_k));
100-
#define VECTOR_PACK_B2(n, offset_k) \
101-
vst1q_f64(&PACK_ELEMENT_K(n, offset_k), b##n##_k##offset_k);
10289
#define PACK_B0(n, offset_k) \
10390
PACK_ELEMENT_K(n, offset_k) = vget_lane_f64(b##n##_k##offset_k, 0);
10491
#define UPDATE_RESULT_VECTOR2(m, n, offset_k) \
@@ -128,9 +115,6 @@ THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
128115
svfloat64_t b##s##n##_k##offset_k = svdup_f64(B_ELEMENT_K(n, offset_k));
129116
#define VECTOR_LOAD_A(pg, m, offset_k) \
130117
svfloat64_t a##s##m##_k##offset_k = svld1(pg, &A_ELEMENT_K(m, offset_k));
131-
#define QUADWORD_LOAD_B(n, offset_k) \
132-
svfloat64_t b##s##n##_k##offset_k = \
133-
svld1rq(pg_true, &B_ELEMENT_K(n, offset_k));
134118
#define GATHER_LOAD_A(pg, m, offset_k) \
135119
svfloat64_t a##s##m##_k##offset_k = \
136120
svld1_gather_index(pg, &A_ELEMENT_K(m, offset_k), lda_vec);
@@ -226,7 +210,6 @@ CNAME(BLASLONG M,
226210
const BLASLONG v_m1 = M & -v_size;
227211
const BLASLONG n4 = N & -4;
228212
const BLASLONG n2 = N & -2;
229-
const BLASLONG k2 = K & -2;
230213

231214
const int pack_a = M >= v_size2 && N >= 8 && K >= 8 ? 1 : 0;
232215
FLOAT* packed_a =
@@ -266,6 +249,7 @@ CNAME(BLASLONG M,
266249
if (LIKELY(packed_a != NULL)) {
267250
if (j == 0) {
268251
for (; k < K; k++) {
252+
269253
BROADCAST_LOAD_B(0, 0);
270254
GATHER_LOAD_A(pg_true, 0, 0);
271255
VECTOR_PACK_A(0, 0);
@@ -285,6 +269,7 @@ CNAME(BLASLONG M,
285269
}
286270
} else {
287271
for (; k < K; k++) {
272+
288273
BROADCAST_LOAD_B(0, 0);
289274
UNPACK_VECTOR_A(0, 0);
290275
UPDATE_RESULT_VECTOR(pg_true, 0, 0, 0);
@@ -345,6 +330,7 @@ CNAME(BLASLONG M,
345330

346331
if (LIKELY(packed_a != NULL)) {
347332
for (; k < K; k++) {
333+
348334
BROADCAST_LOAD_B(0, 0);
349335
UNPACK_VECTOR_A(0, 0);
350336
UPDATE_RESULT_VECTOR(pg_true, 0, 0, 0);
@@ -356,6 +342,7 @@ CNAME(BLASLONG M,
356342
}
357343
} else {
358344
for (; k < K; k++) {
345+
359346
BROADCAST_LOAD_B(0, 0);
360347
GATHER_LOAD_A(pg_true, 0, 0);
361348
UPDATE_RESULT_VECTOR(pg_true, 0, 0, 0);
@@ -580,4 +567,4 @@ CNAME(BLASLONG M,
580567
free(packed_a);
581568

582569
return 0;
583-
}
570+
}

kernel/arm64/sgemm_small_kernel_nn_sve.c

Lines changed: 164 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ CNAME(BLASLONG M,
237237
#endif
238238
{
239239
const uint64_t v_size = svcntw();
240+
const uint64_t v_size2 = v_size * 2;
240241
const svbool_t pg_true = svptrue_b32();
241242
const svbool_t pg_quad = svwhilelt_b32(0, 4);
242243
const svbool_t pg_first = svwhilelt_b32(0, 1);
@@ -245,10 +246,11 @@ CNAME(BLASLONG M,
245246
const svfloat32_t beta_vec = svdup_f32(beta);
246247
#endif
247248
const BLASLONG n4 = N & -4;
249+
const BLASLONG v_m2 = M & -v_size2;
248250
const BLASLONG v_m1 = M & -v_size;
249251
const BLASLONG k4 = K & -4;
250252

251-
const int pack_b = M >= v_size && N >= 8 && K >= 8 ? 1 : 0;
253+
const int pack_b = M >= v_size2 && N >= 8 && K >= 8 ? 1 : 0;
252254
FLOAT* packed_b =
253255
(pack_b) ? packed_b = (FLOAT*)malloc(K * 4 * sizeof(FLOAT)) : NULL;
254256

@@ -269,16 +271,21 @@ CNAME(BLASLONG M,
269271
CREATE_B_POINTER(3, 3);
270272

271273
BLASLONG i = 0;
272-
for (; i < v_m1; i += v_size) {
274+
for (; i < v_m2; i += v_size2) {
273275

274276
CREATE_A_POINTER(0, 0);
275-
UPDATE_A_POINTER(v_size);
277+
CREATE_A_POINTER(1, v_size);
278+
UPDATE_A_POINTER(v_size2);
276279

277280
BLASLONG k = 0;
278281
DECLARE_RESULT_VECTOR(0, 0);
279282
DECLARE_RESULT_VECTOR(0, 1);
280283
DECLARE_RESULT_VECTOR(0, 2);
281284
DECLARE_RESULT_VECTOR(0, 3);
285+
DECLARE_RESULT_VECTOR(1, 0);
286+
DECLARE_RESULT_VECTOR(1, 1);
287+
DECLARE_RESULT_VECTOR(1, 2);
288+
DECLARE_RESULT_VECTOR(1, 3);
282289

283290
if (LIKELY(packed_b != NULL)) {
284291
if (i == 0) {
@@ -314,6 +321,26 @@ CNAME(BLASLONG M,
314321
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 3);
315322
UPDATE_RESULT_VECTOR_QUADWORD(0, 2, 0, 2, 3);
316323
UPDATE_RESULT_VECTOR_QUADWORD(0, 3, 0, 3, 3);
324+
VECTOR_LOAD_A(pg_true, 1, 0);
325+
UPDATE_RESULT_VECTOR_QUADWORD(1, 0, 0, 0, 0);
326+
UPDATE_RESULT_VECTOR_QUADWORD(1, 1, 0, 1, 0);
327+
UPDATE_RESULT_VECTOR_QUADWORD(1, 2, 0, 2, 0);
328+
UPDATE_RESULT_VECTOR_QUADWORD(1, 3, 0, 3, 0);
329+
VECTOR_LOAD_A(pg_true, 1, 1);
330+
UPDATE_RESULT_VECTOR_QUADWORD(1, 0, 0, 0, 1);
331+
UPDATE_RESULT_VECTOR_QUADWORD(1, 1, 0, 1, 1);
332+
UPDATE_RESULT_VECTOR_QUADWORD(1, 2, 0, 2, 1);
333+
UPDATE_RESULT_VECTOR_QUADWORD(1, 3, 0, 3, 1);
334+
VECTOR_LOAD_A(pg_true, 1, 2);
335+
UPDATE_RESULT_VECTOR_QUADWORD(1, 0, 0, 0, 2);
336+
UPDATE_RESULT_VECTOR_QUADWORD(1, 1, 0, 1, 2);
337+
UPDATE_RESULT_VECTOR_QUADWORD(1, 2, 0, 2, 2);
338+
UPDATE_RESULT_VECTOR_QUADWORD(1, 3, 0, 3, 2);
339+
VECTOR_LOAD_A(pg_true, 1, 3);
340+
UPDATE_RESULT_VECTOR_QUADWORD(1, 0, 0, 0, 3);
341+
UPDATE_RESULT_VECTOR_QUADWORD(1, 1, 0, 1, 3);
342+
UPDATE_RESULT_VECTOR_QUADWORD(1, 2, 0, 2, 3);
343+
UPDATE_RESULT_VECTOR_QUADWORD(1, 3, 0, 3, 3);
317344
}
318345
for (; k < K; k++) {
319346

@@ -324,12 +351,17 @@ CNAME(BLASLONG M,
324351
BROADCAST_LOAD_B(1, 0);
325352
PACK_B(1, 0);
326353
UPDATE_RESULT_VECTOR(pg_true, 0, 1, 0);
354+
VECTOR_LOAD_A(pg_true, 1, 0);
355+
UPDATE_RESULT_VECTOR(pg_true, 1, 0, 0);
356+
UPDATE_RESULT_VECTOR(pg_true, 1, 1, 0);
327357
BROADCAST_LOAD_B(2, 0);
328358
PACK_B(2, 0);
329359
UPDATE_RESULT_VECTOR(pg_true, 0, 2, 0);
360+
UPDATE_RESULT_VECTOR(pg_true, 1, 2, 0);
330361
BROADCAST_LOAD_B(3, 0);
331362
PACK_B(3, 0);
332363
UPDATE_RESULT_VECTOR(pg_true, 0, 3, 0);
364+
UPDATE_RESULT_VECTOR(pg_true, 1, 3, 0);
333365
}
334366
} else {
335367
for (; k < K; k++) {
@@ -340,11 +372,118 @@ CNAME(BLASLONG M,
340372
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 0);
341373
UPDATE_RESULT_VECTOR_QUADWORD(0, 2, 0, 2, 0);
342374
UPDATE_RESULT_VECTOR_QUADWORD(0, 3, 0, 3, 0);
375+
VECTOR_LOAD_A(pg_true, 1, 0);
376+
UPDATE_RESULT_VECTOR_QUADWORD(1, 0, 0, 0, 0);
377+
UPDATE_RESULT_VECTOR_QUADWORD(1, 1, 0, 1, 0);
378+
UPDATE_RESULT_VECTOR_QUADWORD(1, 2, 0, 2, 0);
379+
UPDATE_RESULT_VECTOR_QUADWORD(1, 3, 0, 3, 0);
343380
}
344381
}
345382
} else {
346383
for (; k < k4; k += 4) {
347384

385+
VECTOR_LOAD_B_K4(0, 0);
386+
VECTOR_LOAD_B_K4(1, 0);
387+
VECTOR_LOAD_B_K4(2, 0);
388+
VECTOR_LOAD_B_K4(3, 0);
389+
TRANSPOSE_B4_K4(0, 1, 2, 3, 0, 1, 2, 3);
390+
SCALE_B4_K4(0, 0, 1, 2, 3);
391+
VECTOR_LOAD_A(pg_true, 0, 0);
392+
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 0);
393+
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 0);
394+
UPDATE_RESULT_VECTOR_QUADWORD(0, 2, 0, 2, 0);
395+
UPDATE_RESULT_VECTOR_QUADWORD(0, 3, 0, 3, 0);
396+
VECTOR_LOAD_A(pg_true, 0, 1);
397+
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 1);
398+
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 1);
399+
UPDATE_RESULT_VECTOR_QUADWORD(0, 2, 0, 2, 1);
400+
UPDATE_RESULT_VECTOR_QUADWORD(0, 3, 0, 3, 1);
401+
VECTOR_LOAD_A(pg_true, 0, 2);
402+
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 2);
403+
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 2);
404+
UPDATE_RESULT_VECTOR_QUADWORD(0, 2, 0, 2, 2);
405+
UPDATE_RESULT_VECTOR_QUADWORD(0, 3, 0, 3, 2);
406+
VECTOR_LOAD_A(pg_true, 0, 3);
407+
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 3);
408+
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 3);
409+
UPDATE_RESULT_VECTOR_QUADWORD(0, 2, 0, 2, 3);
410+
UPDATE_RESULT_VECTOR_QUADWORD(0, 3, 0, 3, 3);
411+
VECTOR_LOAD_A(pg_true, 1, 0);
412+
UPDATE_RESULT_VECTOR_QUADWORD(1, 0, 0, 0, 0);
413+
UPDATE_RESULT_VECTOR_QUADWORD(1, 1, 0, 1, 0);
414+
UPDATE_RESULT_VECTOR_QUADWORD(1, 2, 0, 2, 0);
415+
UPDATE_RESULT_VECTOR_QUADWORD(1, 3, 0, 3, 0);
416+
VECTOR_LOAD_A(pg_true, 1, 1);
417+
UPDATE_RESULT_VECTOR_QUADWORD(1, 0, 0, 0, 1);
418+
UPDATE_RESULT_VECTOR_QUADWORD(1, 1, 0, 1, 1);
419+
UPDATE_RESULT_VECTOR_QUADWORD(1, 2, 0, 2, 1);
420+
UPDATE_RESULT_VECTOR_QUADWORD(1, 3, 0, 3, 1);
421+
VECTOR_LOAD_A(pg_true, 1, 2);
422+
UPDATE_RESULT_VECTOR_QUADWORD(1, 0, 0, 0, 2);
423+
UPDATE_RESULT_VECTOR_QUADWORD(1, 1, 0, 1, 2);
424+
UPDATE_RESULT_VECTOR_QUADWORD(1, 2, 0, 2, 2);
425+
UPDATE_RESULT_VECTOR_QUADWORD(1, 3, 0, 3, 2);
426+
VECTOR_LOAD_A(pg_true, 1, 3);
427+
UPDATE_RESULT_VECTOR_QUADWORD(1, 0, 0, 0, 3);
428+
UPDATE_RESULT_VECTOR_QUADWORD(1, 1, 0, 1, 3);
429+
UPDATE_RESULT_VECTOR_QUADWORD(1, 2, 0, 2, 3);
430+
UPDATE_RESULT_VECTOR_QUADWORD(1, 3, 0, 3, 3);
431+
}
432+
for (; k < K; k++) {
433+
434+
BROADCAST_LOAD_B(0, 0);
435+
VECTOR_LOAD_A(pg_true, 0, 0);
436+
UPDATE_RESULT_VECTOR(pg_true, 0, 0, 0);
437+
BROADCAST_LOAD_B(1, 0);
438+
UPDATE_RESULT_VECTOR(pg_true, 0, 1, 0);
439+
VECTOR_LOAD_A(pg_true, 1, 0);
440+
UPDATE_RESULT_VECTOR(pg_true, 1, 0, 0);
441+
UPDATE_RESULT_VECTOR(pg_true, 1, 1, 0);
442+
BROADCAST_LOAD_B(2, 0);
443+
UPDATE_RESULT_VECTOR(pg_true, 0, 2, 0);
444+
UPDATE_RESULT_VECTOR(pg_true, 1, 2, 0);
445+
BROADCAST_LOAD_B(3, 0);
446+
UPDATE_RESULT_VECTOR(pg_true, 0, 3, 0);
447+
UPDATE_RESULT_VECTOR(pg_true, 1, 3, 0);
448+
}
449+
}
450+
VECTOR_STORE(pg_true, 0, 0);
451+
VECTOR_STORE(pg_true, 0, 1);
452+
VECTOR_STORE(pg_true, 0, 2);
453+
VECTOR_STORE(pg_true, 0, 3);
454+
VECTOR_STORE(pg_true, 1, 0);
455+
VECTOR_STORE(pg_true, 1, 1);
456+
VECTOR_STORE(pg_true, 1, 2);
457+
VECTOR_STORE(pg_true, 1, 3);
458+
INCR_C_POINTER(0, v_size2);
459+
INCR_C_POINTER(1, v_size2);
460+
INCR_C_POINTER(2, v_size2);
461+
INCR_C_POINTER(3, v_size2);
462+
}
463+
for (; i < v_m1; i += v_size) {
464+
465+
CREATE_A_POINTER(0, 0);
466+
UPDATE_A_POINTER(v_size);
467+
468+
BLASLONG k = 0;
469+
DECLARE_RESULT_VECTOR(0, 0);
470+
DECLARE_RESULT_VECTOR(0, 1);
471+
DECLARE_RESULT_VECTOR(0, 2);
472+
DECLARE_RESULT_VECTOR(0, 3);
473+
474+
if (LIKELY(packed_b != NULL)) {
475+
for (; k < K; k++) {
476+
477+
UNPACK_QUADWORD_B(0, 0);
478+
VECTOR_LOAD_A(pg_true, 0, 0);
479+
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 0);
480+
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 0);
481+
UPDATE_RESULT_VECTOR_QUADWORD(0, 2, 0, 2, 0);
482+
UPDATE_RESULT_VECTOR_QUADWORD(0, 3, 0, 3, 0);
483+
}
484+
} else {
485+
for (; k < k4; k += 4) {
486+
348487
VECTOR_LOAD_B_K4(0, 0);
349488
VECTOR_LOAD_B_K4(1, 0);
350489
VECTOR_LOAD_B_K4(2, 0);
@@ -478,6 +617,28 @@ CNAME(BLASLONG M,
478617
CREATE_B_POINTER(0, 0);
479618

480619
BLASLONG i = 0;
620+
for (; i < v_m2; i += v_size2) {
621+
622+
CREATE_A_POINTER(0, 0);
623+
CREATE_A_POINTER(1, v_size);
624+
UPDATE_A_POINTER(v_size2);
625+
626+
BLASLONG k = 0;
627+
DECLARE_RESULT_VECTOR(0, 0);
628+
DECLARE_RESULT_VECTOR(1, 0);
629+
630+
for (; k < K; k++) {
631+
632+
BROADCAST_LOAD_B(0, 0);
633+
VECTOR_LOAD_A(pg_true, 0, 0);
634+
UPDATE_RESULT_VECTOR(pg_true, 0, 0, 0);
635+
VECTOR_LOAD_A(pg_true, 1, 0);
636+
UPDATE_RESULT_VECTOR(pg_true, 1, 0, 0);
637+
}
638+
VECTOR_STORE(pg_true, 0, 0);
639+
VECTOR_STORE(pg_true, 1, 0);
640+
INCR_C_POINTER(0, v_size2);
641+
}
481642
for (; i < v_m1; i += v_size) {
482643

483644
CREATE_A_POINTER(0, 0);

0 commit comments

Comments
 (0)