Skip to content

Commit 7311d93

Browse files
committed
Unroll TT further
1 parent a9edddb commit 7311d93

File tree

1 file changed

+202
-2
lines changed

1 file changed

+202
-2
lines changed

kernel/arm64/sgemm_small_kernel_tt_sve.c

Lines changed: 202 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ CNAME(BLASLONG M,
219219

220220
const BLASLONG v_m2 = M & -v_size2;
221221
const BLASLONG v_m1 = M & -v_size;
222+
const BLASLONG n8 = N & -8;
222223
const BLASLONG n4 = N & -4;
223224

224225
const int pack_a = M >= v_size2 && N >= 8 && K >= 8 ? 1 : 0;
@@ -238,23 +239,35 @@ CNAME(BLASLONG M,
238239
CREATE_A_POINTER(1, v_size);
239240

240241
BLASLONG j = 0;
241-
for (; j < n4; j += 4) {
242+
for (; j < n8; j += 8) {
242243

243244
CREATE_B_POINTER(0, 0);
244245
CREATE_B_POINTER(1, 1);
245246
CREATE_B_POINTER(2, 2);
246247
CREATE_B_POINTER(3, 3);
247-
UPDATE_B_POINTER(4);
248+
CREATE_B_POINTER(4, 4);
249+
CREATE_B_POINTER(5, 5);
250+
CREATE_B_POINTER(6, 6);
251+
CREATE_B_POINTER(7, 7);
252+
UPDATE_B_POINTER(8);
248253

249254
BLASLONG k = 0;
250255
DECLARE_RESULT_VECTOR(0, 0);
251256
DECLARE_RESULT_VECTOR(0, 1);
252257
DECLARE_RESULT_VECTOR(0, 2);
253258
DECLARE_RESULT_VECTOR(0, 3);
259+
DECLARE_RESULT_VECTOR(0, 4);
260+
DECLARE_RESULT_VECTOR(0, 5);
261+
DECLARE_RESULT_VECTOR(0, 6);
262+
DECLARE_RESULT_VECTOR(0, 7);
254263
DECLARE_RESULT_VECTOR(1, 0);
255264
DECLARE_RESULT_VECTOR(1, 1);
256265
DECLARE_RESULT_VECTOR(1, 2);
257266
DECLARE_RESULT_VECTOR(1, 3);
267+
DECLARE_RESULT_VECTOR(1, 4);
268+
DECLARE_RESULT_VECTOR(1, 5);
269+
DECLARE_RESULT_VECTOR(1, 6);
270+
DECLARE_RESULT_VECTOR(1, 7);
258271

259272
if (LIKELY(packed_a != NULL)) {
260273
if (j == 0) {
@@ -267,12 +280,21 @@ CNAME(BLASLONG M,
267280
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 0);
268281
UPDATE_RESULT_VECTOR_QUADWORD(0, 2, 0, 2, 0);
269282
UPDATE_RESULT_VECTOR_QUADWORD(0, 3, 0, 3, 0);
283+
QUADWORD_LOAD_B(4, 0);
284+
UPDATE_RESULT_VECTOR_QUADWORD(0, 4, 4, 0, 0);
285+
UPDATE_RESULT_VECTOR_QUADWORD(0, 5, 4, 1, 0);
286+
UPDATE_RESULT_VECTOR_QUADWORD(0, 6, 4, 2, 0);
287+
UPDATE_RESULT_VECTOR_QUADWORD(0, 7, 4, 3, 0);
270288
GATHER_LOAD_A(pg_true, 1, 0);
271289
VECTOR_PACK_A(1, 0);
272290
UPDATE_RESULT_VECTOR_QUADWORD(1, 0, 0, 0, 0);
273291
UPDATE_RESULT_VECTOR_QUADWORD(1, 1, 0, 1, 0);
274292
UPDATE_RESULT_VECTOR_QUADWORD(1, 2, 0, 2, 0);
275293
UPDATE_RESULT_VECTOR_QUADWORD(1, 3, 0, 3, 0);
294+
UPDATE_RESULT_VECTOR_QUADWORD(1, 4, 4, 0, 0);
295+
UPDATE_RESULT_VECTOR_QUADWORD(1, 5, 4, 1, 0);
296+
UPDATE_RESULT_VECTOR_QUADWORD(1, 6, 4, 2, 0);
297+
UPDATE_RESULT_VECTOR_QUADWORD(1, 7, 4, 3, 0);
276298
}
277299
} else {
278300
for (; k < K; k++) {
@@ -283,16 +305,102 @@ CNAME(BLASLONG M,
283305
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 0);
284306
UPDATE_RESULT_VECTOR_QUADWORD(0, 2, 0, 2, 0);
285307
UPDATE_RESULT_VECTOR_QUADWORD(0, 3, 0, 3, 0);
308+
QUADWORD_LOAD_B(4, 0);
309+
UPDATE_RESULT_VECTOR_QUADWORD(0, 4, 4, 0, 0);
310+
UPDATE_RESULT_VECTOR_QUADWORD(0, 5, 4, 1, 0);
311+
UPDATE_RESULT_VECTOR_QUADWORD(0, 6, 4, 2, 0);
312+
UPDATE_RESULT_VECTOR_QUADWORD(0, 7, 4, 3, 0);
286313
UNPACK_VECTOR_A(1, 0);
287314
UPDATE_RESULT_VECTOR_QUADWORD(1, 0, 0, 0, 0);
288315
UPDATE_RESULT_VECTOR_QUADWORD(1, 1, 0, 1, 0);
289316
UPDATE_RESULT_VECTOR_QUADWORD(1, 2, 0, 2, 0);
290317
UPDATE_RESULT_VECTOR_QUADWORD(1, 3, 0, 3, 0);
318+
UPDATE_RESULT_VECTOR_QUADWORD(1, 4, 4, 0, 0);
319+
UPDATE_RESULT_VECTOR_QUADWORD(1, 5, 4, 1, 0);
320+
UPDATE_RESULT_VECTOR_QUADWORD(1, 6, 4, 2, 0);
321+
UPDATE_RESULT_VECTOR_QUADWORD(1, 7, 4, 3, 0);
291322
}
292323
}
293324
} else {
294325
for (; k < K; k++) {
295326

327+
QUADWORD_LOAD_B(0, 0);
328+
GATHER_LOAD_A(pg_true, 0, 0);
329+
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 0);
330+
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 0);
331+
UPDATE_RESULT_VECTOR_QUADWORD(0, 2, 0, 2, 0);
332+
UPDATE_RESULT_VECTOR_QUADWORD(0, 3, 0, 3, 0);
333+
QUADWORD_LOAD_B(4, 0);
334+
UPDATE_RESULT_VECTOR_QUADWORD(0, 4, 4, 0, 0);
335+
UPDATE_RESULT_VECTOR_QUADWORD(0, 5, 4, 1, 0);
336+
UPDATE_RESULT_VECTOR_QUADWORD(0, 6, 4, 2, 0);
337+
UPDATE_RESULT_VECTOR_QUADWORD(0, 7, 4, 3, 0);
338+
GATHER_LOAD_A(pg_true, 1, 0);
339+
UPDATE_RESULT_VECTOR_QUADWORD(1, 0, 0, 0, 0);
340+
UPDATE_RESULT_VECTOR_QUADWORD(1, 1, 0, 1, 0);
341+
UPDATE_RESULT_VECTOR_QUADWORD(1, 2, 0, 2, 0);
342+
UPDATE_RESULT_VECTOR_QUADWORD(1, 3, 0, 3, 0);
343+
UPDATE_RESULT_VECTOR_QUADWORD(1, 4, 4, 0, 0);
344+
UPDATE_RESULT_VECTOR_QUADWORD(1, 5, 4, 1, 0);
345+
UPDATE_RESULT_VECTOR_QUADWORD(1, 6, 4, 2, 0);
346+
UPDATE_RESULT_VECTOR_QUADWORD(1, 7, 4, 3, 0);
347+
}
348+
}
349+
VECTOR_STORE(pg_true, 0, 0);
350+
VECTOR_STORE(pg_true, 0, 1);
351+
VECTOR_STORE(pg_true, 0, 2);
352+
VECTOR_STORE(pg_true, 0, 3);
353+
VECTOR_STORE(pg_true, 0, 4);
354+
VECTOR_STORE(pg_true, 0, 5);
355+
VECTOR_STORE(pg_true, 0, 6);
356+
VECTOR_STORE(pg_true, 0, 7);
357+
VECTOR_STORE(pg_true, 1, 0);
358+
VECTOR_STORE(pg_true, 1, 1);
359+
VECTOR_STORE(pg_true, 1, 2);
360+
VECTOR_STORE(pg_true, 1, 3);
361+
VECTOR_STORE(pg_true, 1, 4);
362+
VECTOR_STORE(pg_true, 1, 5);
363+
VECTOR_STORE(pg_true, 1, 6);
364+
VECTOR_STORE(pg_true, 1, 7);
365+
INCR_C_POINTER(0, 8);
366+
INCR_C_POINTER(1, 8);
367+
}
368+
for (; j < n4; j += 4) {
369+
370+
CREATE_B_POINTER(0, 0);
371+
CREATE_B_POINTER(1, 1);
372+
CREATE_B_POINTER(2, 2);
373+
CREATE_B_POINTER(3, 3);
374+
UPDATE_B_POINTER(4);
375+
376+
BLASLONG k = 0;
377+
DECLARE_RESULT_VECTOR(0, 0);
378+
DECLARE_RESULT_VECTOR(0, 1);
379+
DECLARE_RESULT_VECTOR(0, 2);
380+
DECLARE_RESULT_VECTOR(0, 3);
381+
DECLARE_RESULT_VECTOR(1, 0);
382+
DECLARE_RESULT_VECTOR(1, 1);
383+
DECLARE_RESULT_VECTOR(1, 2);
384+
DECLARE_RESULT_VECTOR(1, 3);
385+
386+
if (LIKELY(packed_a != NULL)) {
387+
for (; k < K; k++) {
388+
389+
QUADWORD_LOAD_B(0, 0);
390+
UNPACK_VECTOR_A(0, 0);
391+
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 0);
392+
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 0);
393+
UPDATE_RESULT_VECTOR_QUADWORD(0, 2, 0, 2, 0);
394+
UPDATE_RESULT_VECTOR_QUADWORD(0, 3, 0, 3, 0);
395+
UNPACK_VECTOR_A(1, 0);
396+
UPDATE_RESULT_VECTOR_QUADWORD(1, 0, 0, 0, 0);
397+
UPDATE_RESULT_VECTOR_QUADWORD(1, 1, 0, 1, 0);
398+
UPDATE_RESULT_VECTOR_QUADWORD(1, 2, 0, 2, 0);
399+
UPDATE_RESULT_VECTOR_QUADWORD(1, 3, 0, 3, 0);
400+
}
401+
} else {
402+
for (; k < K; k++) {
403+
296404
QUADWORD_LOAD_B(0, 0);
297405
GATHER_LOAD_A(pg_true, 0, 0);
298406
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 0);
@@ -361,6 +469,52 @@ CNAME(BLASLONG M,
361469
CREATE_A_POINTER(0, 0);
362470

363471
BLASLONG j = 0;
472+
for (; j < n8; j += 8) {
473+
474+
CREATE_B_POINTER(0, 0);
475+
CREATE_B_POINTER(1, 1);
476+
CREATE_B_POINTER(2, 2);
477+
CREATE_B_POINTER(3, 3);
478+
CREATE_B_POINTER(4, 4);
479+
CREATE_B_POINTER(5, 5);
480+
CREATE_B_POINTER(6, 6);
481+
CREATE_B_POINTER(7, 7);
482+
UPDATE_B_POINTER(8);
483+
484+
BLASLONG k = 0;
485+
DECLARE_RESULT_VECTOR(0, 0);
486+
DECLARE_RESULT_VECTOR(0, 1);
487+
DECLARE_RESULT_VECTOR(0, 2);
488+
DECLARE_RESULT_VECTOR(0, 3);
489+
DECLARE_RESULT_VECTOR(0, 4);
490+
DECLARE_RESULT_VECTOR(0, 5);
491+
DECLARE_RESULT_VECTOR(0, 6);
492+
DECLARE_RESULT_VECTOR(0, 7);
493+
494+
for (; k < K; k++) {
495+
496+
QUADWORD_LOAD_B(0, 0);
497+
GATHER_LOAD_A(pg_true, 0, 0);
498+
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 0);
499+
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 0);
500+
UPDATE_RESULT_VECTOR_QUADWORD(0, 2, 0, 2, 0);
501+
UPDATE_RESULT_VECTOR_QUADWORD(0, 3, 0, 3, 0);
502+
QUADWORD_LOAD_B(4, 0);
503+
UPDATE_RESULT_VECTOR_QUADWORD(0, 4, 4, 0, 0);
504+
UPDATE_RESULT_VECTOR_QUADWORD(0, 5, 4, 1, 0);
505+
UPDATE_RESULT_VECTOR_QUADWORD(0, 6, 4, 2, 0);
506+
UPDATE_RESULT_VECTOR_QUADWORD(0, 7, 4, 3, 0);
507+
}
508+
VECTOR_STORE(pg_true, 0, 0);
509+
VECTOR_STORE(pg_true, 0, 1);
510+
VECTOR_STORE(pg_true, 0, 2);
511+
VECTOR_STORE(pg_true, 0, 3);
512+
VECTOR_STORE(pg_true, 0, 4);
513+
VECTOR_STORE(pg_true, 0, 5);
514+
VECTOR_STORE(pg_true, 0, 6);
515+
VECTOR_STORE(pg_true, 0, 7);
516+
INCR_C_POINTER(0, 8);
517+
}
364518
for (; j < n4; j += 4) {
365519

366520
CREATE_B_POINTER(0, 0);
@@ -418,6 +572,52 @@ CNAME(BLASLONG M,
418572
CREATE_A_POINTER(0, 0);
419573

420574
BLASLONG j = 0;
575+
for (; j < n8; j += 8) {
576+
577+
CREATE_B_POINTER(0, 0);
578+
CREATE_B_POINTER(1, 1);
579+
CREATE_B_POINTER(2, 2);
580+
CREATE_B_POINTER(3, 3);
581+
CREATE_B_POINTER(4, 4);
582+
CREATE_B_POINTER(5, 5);
583+
CREATE_B_POINTER(6, 6);
584+
CREATE_B_POINTER(7, 7);
585+
UPDATE_B_POINTER(8);
586+
587+
BLASLONG k = 0;
588+
DECLARE_RESULT_VECTOR(0, 0);
589+
DECLARE_RESULT_VECTOR(0, 1);
590+
DECLARE_RESULT_VECTOR(0, 2);
591+
DECLARE_RESULT_VECTOR(0, 3);
592+
DECLARE_RESULT_VECTOR(0, 4);
593+
DECLARE_RESULT_VECTOR(0, 5);
594+
DECLARE_RESULT_VECTOR(0, 6);
595+
DECLARE_RESULT_VECTOR(0, 7);
596+
597+
for (; k < K; k++) {
598+
599+
QUADWORD_LOAD_B(0, 0);
600+
GATHER_LOAD_A(pg_tail, 0, 0);
601+
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 0);
602+
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 0);
603+
UPDATE_RESULT_VECTOR_QUADWORD(0, 2, 0, 2, 0);
604+
UPDATE_RESULT_VECTOR_QUADWORD(0, 3, 0, 3, 0);
605+
QUADWORD_LOAD_B(4, 0);
606+
UPDATE_RESULT_VECTOR_QUADWORD(0, 4, 4, 0, 0);
607+
UPDATE_RESULT_VECTOR_QUADWORD(0, 5, 4, 1, 0);
608+
UPDATE_RESULT_VECTOR_QUADWORD(0, 6, 4, 2, 0);
609+
UPDATE_RESULT_VECTOR_QUADWORD(0, 7, 4, 3, 0);
610+
}
611+
VECTOR_STORE(pg_tail, 0, 0);
612+
VECTOR_STORE(pg_tail, 0, 1);
613+
VECTOR_STORE(pg_tail, 0, 2);
614+
VECTOR_STORE(pg_tail, 0, 3);
615+
VECTOR_STORE(pg_tail, 0, 4);
616+
VECTOR_STORE(pg_tail, 0, 5);
617+
VECTOR_STORE(pg_tail, 0, 6);
618+
VECTOR_STORE(pg_tail, 0, 7);
619+
INCR_C_POINTER(0, 8);
620+
}
421621
for (; j < n4; j += 4) {
422622

423623
CREATE_B_POINTER(0, 0);

0 commit comments

Comments
 (0)