Skip to content

Commit a9edddb

Browse files
committed
Unroll TN further
1 parent 9984c5c commit a9edddb

File tree

1 file changed

+229
-2
lines changed

1 file changed

+229
-2
lines changed

kernel/arm64/sgemm_small_kernel_tn_sve.c

Lines changed: 229 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ CNAME(BLASLONG M,
218218

219219
const BLASLONG v_m2 = M & -v_size2;
220220
const BLASLONG v_m1 = M & -v_size;
221+
const BLASLONG n8 = N & -8;
221222
const BLASLONG n4 = N & -4;
222223

223224
const int pack_a = M >= v_size2 && N >= 8 && K >= 8 ? 1 : 0;
@@ -237,23 +238,35 @@ CNAME(BLASLONG M,
237238
CREATE_A_POINTER(1, v_size);
238239

239240
BLASLONG j = 0;
240-
for (; j < n4; j += 4) {
241+
for (; j < n8; j += 8) {
241242

242243
CREATE_B_POINTER(0, 0);
243244
CREATE_B_POINTER(1, 1);
244245
CREATE_B_POINTER(2, 2);
245246
CREATE_B_POINTER(3, 3);
246-
UPDATE_B_POINTER(4);
247+
CREATE_B_POINTER(4, 4);
248+
CREATE_B_POINTER(5, 5);
249+
CREATE_B_POINTER(6, 6);
250+
CREATE_B_POINTER(7, 7);
251+
UPDATE_B_POINTER(8);
247252

248253
BLASLONG k = 0;
249254
DECLARE_RESULT_VECTOR(0, 0);
250255
DECLARE_RESULT_VECTOR(0, 1);
251256
DECLARE_RESULT_VECTOR(0, 2);
252257
DECLARE_RESULT_VECTOR(0, 3);
258+
DECLARE_RESULT_VECTOR(0, 4);
259+
DECLARE_RESULT_VECTOR(0, 5);
260+
DECLARE_RESULT_VECTOR(0, 6);
261+
DECLARE_RESULT_VECTOR(0, 7);
253262
DECLARE_RESULT_VECTOR(1, 0);
254263
DECLARE_RESULT_VECTOR(1, 1);
255264
DECLARE_RESULT_VECTOR(1, 2);
256265
DECLARE_RESULT_VECTOR(1, 3);
266+
DECLARE_RESULT_VECTOR(1, 4);
267+
DECLARE_RESULT_VECTOR(1, 5);
268+
DECLARE_RESULT_VECTOR(1, 6);
269+
DECLARE_RESULT_VECTOR(1, 7);
257270

258271
if (LIKELY(packed_a != NULL)) {
259272
if (j == 0) {
@@ -275,6 +288,18 @@ CNAME(BLASLONG M,
275288
BROADCAST_LOAD_B(3, 0);
276289
UPDATE_RESULT_VECTOR(pg_true, 0, 3, 0);
277290
UPDATE_RESULT_VECTOR(pg_true, 1, 3, 0);
291+
BROADCAST_LOAD_B(4, 0);
292+
UPDATE_RESULT_VECTOR(pg_true, 0, 4, 0);
293+
UPDATE_RESULT_VECTOR(pg_true, 1, 4, 0);
294+
BROADCAST_LOAD_B(5, 0);
295+
UPDATE_RESULT_VECTOR(pg_true, 0, 5, 0);
296+
UPDATE_RESULT_VECTOR(pg_true, 1, 5, 0);
297+
BROADCAST_LOAD_B(6, 0);
298+
UPDATE_RESULT_VECTOR(pg_true, 0, 6, 0);
299+
UPDATE_RESULT_VECTOR(pg_true, 1, 6, 0);
300+
BROADCAST_LOAD_B(7, 0);
301+
UPDATE_RESULT_VECTOR(pg_true, 0, 7, 0);
302+
UPDATE_RESULT_VECTOR(pg_true, 1, 7, 0);
278303
}
279304
} else {
280305
for (; k < K; k++) {
@@ -293,11 +318,109 @@ CNAME(BLASLONG M,
293318
BROADCAST_LOAD_B(3, 0);
294319
UPDATE_RESULT_VECTOR(pg_true, 0, 3, 0);
295320
UPDATE_RESULT_VECTOR(pg_true, 1, 3, 0);
321+
BROADCAST_LOAD_B(4, 0);
322+
UPDATE_RESULT_VECTOR(pg_true, 0, 4, 0);
323+
UPDATE_RESULT_VECTOR(pg_true, 1, 4, 0);
324+
BROADCAST_LOAD_B(5, 0);
325+
UPDATE_RESULT_VECTOR(pg_true, 0, 5, 0);
326+
UPDATE_RESULT_VECTOR(pg_true, 1, 5, 0);
327+
BROADCAST_LOAD_B(6, 0);
328+
UPDATE_RESULT_VECTOR(pg_true, 0, 6, 0);
329+
UPDATE_RESULT_VECTOR(pg_true, 1, 6, 0);
330+
BROADCAST_LOAD_B(7, 0);
331+
UPDATE_RESULT_VECTOR(pg_true, 0, 7, 0);
332+
UPDATE_RESULT_VECTOR(pg_true, 1, 7, 0);
296333
}
297334
}
298335
} else {
299336
for (; k < K; k++) {
300337

338+
BROADCAST_LOAD_B(0, 0);
339+
GATHER_LOAD_A(pg_true, 0, 0);
340+
UPDATE_RESULT_VECTOR(pg_true, 0, 0, 0);
341+
BROADCAST_LOAD_B(1, 0);
342+
UPDATE_RESULT_VECTOR(pg_true, 0, 1, 0);
343+
GATHER_LOAD_A(pg_true, 1, 0);
344+
UPDATE_RESULT_VECTOR(pg_true, 1, 0, 0);
345+
UPDATE_RESULT_VECTOR(pg_true, 1, 1, 0);
346+
BROADCAST_LOAD_B(2, 0);
347+
UPDATE_RESULT_VECTOR(pg_true, 0, 2, 0);
348+
UPDATE_RESULT_VECTOR(pg_true, 1, 2, 0);
349+
BROADCAST_LOAD_B(3, 0);
350+
UPDATE_RESULT_VECTOR(pg_true, 0, 3, 0);
351+
UPDATE_RESULT_VECTOR(pg_true, 1, 3, 0);
352+
BROADCAST_LOAD_B(4, 0);
353+
UPDATE_RESULT_VECTOR(pg_true, 0, 4, 0);
354+
UPDATE_RESULT_VECTOR(pg_true, 1, 4, 0);
355+
BROADCAST_LOAD_B(5, 0);
356+
UPDATE_RESULT_VECTOR(pg_true, 0, 5, 0);
357+
UPDATE_RESULT_VECTOR(pg_true, 1, 5, 0);
358+
BROADCAST_LOAD_B(6, 0);
359+
UPDATE_RESULT_VECTOR(pg_true, 0, 6, 0);
360+
UPDATE_RESULT_VECTOR(pg_true, 1, 6, 0);
361+
BROADCAST_LOAD_B(7, 0);
362+
UPDATE_RESULT_VECTOR(pg_true, 0, 7, 0);
363+
UPDATE_RESULT_VECTOR(pg_true, 1, 7, 0);
364+
}
365+
}
366+
VECTOR_STORE(pg_true, 0, 0);
367+
VECTOR_STORE(pg_true, 0, 1);
368+
VECTOR_STORE(pg_true, 0, 2);
369+
VECTOR_STORE(pg_true, 0, 3);
370+
VECTOR_STORE(pg_true, 0, 4);
371+
VECTOR_STORE(pg_true, 0, 5);
372+
VECTOR_STORE(pg_true, 0, 6);
373+
VECTOR_STORE(pg_true, 0, 7);
374+
VECTOR_STORE(pg_true, 1, 0);
375+
VECTOR_STORE(pg_true, 1, 1);
376+
VECTOR_STORE(pg_true, 1, 2);
377+
VECTOR_STORE(pg_true, 1, 3);
378+
VECTOR_STORE(pg_true, 1, 4);
379+
VECTOR_STORE(pg_true, 1, 5);
380+
VECTOR_STORE(pg_true, 1, 6);
381+
VECTOR_STORE(pg_true, 1, 7);
382+
INCR_C_POINTER(0, 8);
383+
INCR_C_POINTER(1, 8);
384+
}
385+
for (; j < n4; j += 4) {
386+
387+
CREATE_B_POINTER(0, 0);
388+
CREATE_B_POINTER(1, 1);
389+
CREATE_B_POINTER(2, 2);
390+
CREATE_B_POINTER(3, 3);
391+
UPDATE_B_POINTER(4);
392+
393+
BLASLONG k = 0;
394+
DECLARE_RESULT_VECTOR(0, 0);
395+
DECLARE_RESULT_VECTOR(0, 1);
396+
DECLARE_RESULT_VECTOR(0, 2);
397+
DECLARE_RESULT_VECTOR(0, 3);
398+
DECLARE_RESULT_VECTOR(1, 0);
399+
DECLARE_RESULT_VECTOR(1, 1);
400+
DECLARE_RESULT_VECTOR(1, 2);
401+
DECLARE_RESULT_VECTOR(1, 3);
402+
403+
if (LIKELY(packed_a != NULL)) {
404+
for (; k < K; k++) {
405+
406+
BROADCAST_LOAD_B(0, 0);
407+
UNPACK_VECTOR_A(0, 0);
408+
UPDATE_RESULT_VECTOR(pg_true, 0, 0, 0);
409+
BROADCAST_LOAD_B(1, 0);
410+
UPDATE_RESULT_VECTOR(pg_true, 0, 1, 0);
411+
UNPACK_VECTOR_A(1, 0);
412+
UPDATE_RESULT_VECTOR(pg_true, 1, 0, 0);
413+
UPDATE_RESULT_VECTOR(pg_true, 1, 1, 0);
414+
BROADCAST_LOAD_B(2, 0);
415+
UPDATE_RESULT_VECTOR(pg_true, 0, 2, 0);
416+
UPDATE_RESULT_VECTOR(pg_true, 1, 2, 0);
417+
BROADCAST_LOAD_B(3, 0);
418+
UPDATE_RESULT_VECTOR(pg_true, 0, 3, 0);
419+
UPDATE_RESULT_VECTOR(pg_true, 1, 3, 0);
420+
}
421+
} else {
422+
for (; k < K; k++) {
423+
301424
BROADCAST_LOAD_B(0, 0);
302425
GATHER_LOAD_A(pg_true, 0, 0);
303426
UPDATE_RESULT_VECTOR(pg_true, 0, 0, 0);
@@ -369,6 +492,58 @@ CNAME(BLASLONG M,
369492
CREATE_A_POINTER(0, 0);
370493

371494
BLASLONG j = 0;
495+
for (; j < n8; j += 8) {
496+
497+
CREATE_B_POINTER(0, 0);
498+
CREATE_B_POINTER(1, 1);
499+
CREATE_B_POINTER(2, 2);
500+
CREATE_B_POINTER(3, 3);
501+
CREATE_B_POINTER(4, 4);
502+
CREATE_B_POINTER(5, 5);
503+
CREATE_B_POINTER(6, 6);
504+
CREATE_B_POINTER(7, 7);
505+
UPDATE_B_POINTER(8);
506+
507+
BLASLONG k = 0;
508+
DECLARE_RESULT_VECTOR(0, 0);
509+
DECLARE_RESULT_VECTOR(0, 1);
510+
DECLARE_RESULT_VECTOR(0, 2);
511+
DECLARE_RESULT_VECTOR(0, 3);
512+
DECLARE_RESULT_VECTOR(0, 4);
513+
DECLARE_RESULT_VECTOR(0, 5);
514+
DECLARE_RESULT_VECTOR(0, 6);
515+
DECLARE_RESULT_VECTOR(0, 7);
516+
517+
for (; k < K; k++) {
518+
519+
BROADCAST_LOAD_B(0, 0);
520+
GATHER_LOAD_A(pg_true, 0, 0);
521+
UPDATE_RESULT_VECTOR(pg_true, 0, 0, 0);
522+
BROADCAST_LOAD_B(1, 0);
523+
UPDATE_RESULT_VECTOR(pg_true, 0, 1, 0);
524+
BROADCAST_LOAD_B(2, 0);
525+
UPDATE_RESULT_VECTOR(pg_true, 0, 2, 0);
526+
BROADCAST_LOAD_B(3, 0);
527+
UPDATE_RESULT_VECTOR(pg_true, 0, 3, 0);
528+
BROADCAST_LOAD_B(4, 0);
529+
UPDATE_RESULT_VECTOR(pg_true, 0, 4, 0);
530+
BROADCAST_LOAD_B(5, 0);
531+
UPDATE_RESULT_VECTOR(pg_true, 0, 5, 0);
532+
BROADCAST_LOAD_B(6, 0);
533+
UPDATE_RESULT_VECTOR(pg_true, 0, 6, 0);
534+
BROADCAST_LOAD_B(7, 0);
535+
UPDATE_RESULT_VECTOR(pg_true, 0, 7, 0);
536+
}
537+
VECTOR_STORE(pg_true, 0, 0);
538+
VECTOR_STORE(pg_true, 0, 1);
539+
VECTOR_STORE(pg_true, 0, 2);
540+
VECTOR_STORE(pg_true, 0, 3);
541+
VECTOR_STORE(pg_true, 0, 4);
542+
VECTOR_STORE(pg_true, 0, 5);
543+
VECTOR_STORE(pg_true, 0, 6);
544+
VECTOR_STORE(pg_true, 0, 7);
545+
INCR_C_POINTER(0, 8);
546+
}
372547
for (; j < n4; j += 4) {
373548

374549
CREATE_B_POINTER(0, 0);
@@ -429,6 +604,58 @@ CNAME(BLASLONG M,
429604
CREATE_A_POINTER(0, 0);
430605

431606
BLASLONG j = 0;
607+
for (; j < n8; j += 8) {
608+
609+
CREATE_B_POINTER(0, 0);
610+
CREATE_B_POINTER(1, 1);
611+
CREATE_B_POINTER(2, 2);
612+
CREATE_B_POINTER(3, 3);
613+
CREATE_B_POINTER(4, 4);
614+
CREATE_B_POINTER(5, 5);
615+
CREATE_B_POINTER(6, 6);
616+
CREATE_B_POINTER(7, 7);
617+
UPDATE_B_POINTER(8);
618+
619+
BLASLONG k = 0;
620+
DECLARE_RESULT_VECTOR(0, 0);
621+
DECLARE_RESULT_VECTOR(0, 1);
622+
DECLARE_RESULT_VECTOR(0, 2);
623+
DECLARE_RESULT_VECTOR(0, 3);
624+
DECLARE_RESULT_VECTOR(0, 4);
625+
DECLARE_RESULT_VECTOR(0, 5);
626+
DECLARE_RESULT_VECTOR(0, 6);
627+
DECLARE_RESULT_VECTOR(0, 7);
628+
629+
for (; k < K; k++) {
630+
631+
BROADCAST_LOAD_B(0, 0);
632+
GATHER_LOAD_A(pg_tail, 0, 0);
633+
UPDATE_RESULT_VECTOR(pg_tail, 0, 0, 0);
634+
BROADCAST_LOAD_B(1, 0);
635+
UPDATE_RESULT_VECTOR(pg_tail, 0, 1, 0);
636+
BROADCAST_LOAD_B(2, 0);
637+
UPDATE_RESULT_VECTOR(pg_tail, 0, 2, 0);
638+
BROADCAST_LOAD_B(3, 0);
639+
UPDATE_RESULT_VECTOR(pg_tail, 0, 3, 0);
640+
BROADCAST_LOAD_B(4, 0);
641+
UPDATE_RESULT_VECTOR(pg_tail, 0, 4, 0);
642+
BROADCAST_LOAD_B(5, 0);
643+
UPDATE_RESULT_VECTOR(pg_tail, 0, 5, 0);
644+
BROADCAST_LOAD_B(6, 0);
645+
UPDATE_RESULT_VECTOR(pg_tail, 0, 6, 0);
646+
BROADCAST_LOAD_B(7, 0);
647+
UPDATE_RESULT_VECTOR(pg_tail, 0, 7, 0);
648+
}
649+
VECTOR_STORE(pg_tail, 0, 0);
650+
VECTOR_STORE(pg_tail, 0, 1);
651+
VECTOR_STORE(pg_tail, 0, 2);
652+
VECTOR_STORE(pg_tail, 0, 3);
653+
VECTOR_STORE(pg_tail, 0, 4);
654+
VECTOR_STORE(pg_tail, 0, 5);
655+
VECTOR_STORE(pg_tail, 0, 6);
656+
VECTOR_STORE(pg_tail, 0, 7);
657+
INCR_C_POINTER(0, 8);
658+
}
432659
for (; j < n4; j += 4) {
433660

434661
CREATE_B_POINTER(0, 0);

0 commit comments

Comments
 (0)