Skip to content

Commit a3cac9c

Browse files
committed
Update sgemm kernel 1x4 for C910.
1 parent 7834c10 commit a3cac9c

File tree

1 file changed

+97
-40
lines changed

1 file changed

+97
-40
lines changed

kernel/riscv64/sgemm_kernel_16x4_c910v.c

Lines changed: 97 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
382382
{
383383
BLASLONG i,j,k;
384384
FLOAT *C0,*C1,*C2,*C3;
385-
FLOAT *ptrba,*ptrbb;
385+
FLOAT *ptrba,*ptrbb, *tmpc;
386386

387387
FLOAT loadb0,loadb1,loadb2,loadb3;
388388
FLOAT load0,load1,load2,load3,load4,load5,load6,load7;
@@ -392,6 +392,7 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
392392
FLOAT res8,res9,res10,res11;
393393
FLOAT res12,res13,res14,res15;
394394

395+
395396
for (j=0; j<bn/4; j+=1){
396397
C0 = C;
397398
C1 = C0+ldc;
@@ -942,53 +943,109 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
942943
}
943944
if(bm&1){
944945
ptrbb = bb;
945-
946-
res0 = 0;
947-
948-
res4 = 0;
949-
950-
res8 = 0;
951-
952-
res12 = 0;
953-
954-
for(k=0; k<bk; k+=1){
955-
loadb0 = ptrbb[0];
956-
loadb1 = ptrbb[1];
957-
958-
load0 = ptrba[0];
959-
960-
res0 = res0 + load0 * loadb0;
961-
962-
res4 = res4 + load0 * loadb1;
963-
964-
loadb2 = ptrbb[2];
965-
loadb3 = ptrbb[3];
966-
967-
res8 = res8 + load0 * loadb2;
968-
969-
res12 = res12 + load0 * loadb3;
970-
971-
ptrba += 1;
972-
ptrbb += 4;
973-
}
974-
975-
res0 = res0 * alpha;
976-
977-
res4 = res4 * alpha;
946+
//t0 for k
947+
//ft0-ft3,ft4-ft7,v8-v15 for B, t1-t3 for PB1-3
948+
//v0-v3,v4-v7 for A, t4-t6 for PA1-3
949+
//v16-v31 for temp C
978950

979-
res8 = res8 * alpha;
951+
FLOAT tmp[4];
952+
tmpc=tmp;
953+
//t1-t3 for PB
954+
//v0-v4 for A, v8-v11 for B
955+
//v16-v19 for C
956+
asm volatile(
957+
"vsetvli zero, zero, e32,m1 \n\t"
958+
"fmv.w.x ft11, zero \n\t"
959+
960+
"vfmv.v.f v16, ft11 \n\t"
961+
"vfmv.v.f v17, ft11 \n\t"
962+
"vfmv.v.f v18, ft11 \n\t"
963+
"vfmv.v.f v19, ft11 \n\t"
964+
//unloop 4
980965

981-
res12 = res12 * alpha;
966+
"srli t0, %[BK], 2 \n\t"
967+
"blez t0, M1x4_TAIL \n\t"
982968

983-
C0[0] += res0;
984-
C1[0] += res4;
985-
C2[0] += res8;
986-
C3[0] += res12;
969+
"addi t1, %[PB], 4*4 \n\t"
970+
"addi t2, %[PB], 8*4 \n\t"
971+
"addi t3, %[PB], 12*4 \n\t"
972+
973+
".align 4 \n\t"
974+
"M1x4_MAINLOOP: \n\t"
987975

976+
"vle.v v4, (%[PA]) \n\t"
977+
"addi %[PA], %[PA], 4*4 \n\t"
978+
"vrgather.vi v0, v4, 0 \n\t"
979+
980+
"vle.v v8, (%[PB]) \n\t"
981+
"addi %[PB], %[PB], 16*4 \n\t"
982+
"vrgather.vi v1, v4, 1 \n\t"
983+
984+
"vle.v v9, (t1) \n\t"
985+
"addi t1, t1, 16*4 \n\t"
986+
"vrgather.vi v2, v4, 2 \n\t"
987+
988+
"vle.v v10, (t2) \n\t"
989+
"addi t2, t2, 16*4 \n\t"
990+
"vrgather.vi v3, v4, 3 \n\t"
991+
992+
"vle.v v11, (t3) \n\t"
993+
"addi t3, t3, 16*4 \n\t"
994+
995+
"vfmacc.vv v16, v8, v0 \n\t"
996+
"vfmacc.vv v17, v9, v1 \n\t"
997+
"vfmacc.vv v18, v10, v2 \n\t"
998+
"vfmacc.vv v19, v11, v3 \n\t"
999+
1000+
"addi t0, t0, -1 \n\t"
1001+
"bgtz t0, M1x4_MAINLOOP \n\t"
1002+
1003+
"M1x4_TAIL: \n\t"
1004+
"andi t0, %[BK], 3 \n\t"
1005+
"blez t0, M1x4_SAVERESULT \n\t"
1006+
1007+
"M1x4_TAILLOOP: \n\t"
1008+
"flw ft0, (%[PA]) \n\t"
1009+
"addi %[PA], %[PA], 1*4 \n\t"
1010+
"vle.v v8, (%[PB]) \n\t"
1011+
"addi %[PB], %[PB], 4*4 \n\t"
1012+
"vfmv.v.f v0, ft0 \n\t"
1013+
"vfmacc.vv v16, v8, v0 \n\t"
1014+
1015+
"addi t0, t0, -1 \n\t"
1016+
"bgtz t0, M1x4_TAILLOOP \n\t"
1017+
1018+
"M1x4_SAVERESULT: \n\t"
1019+
//merge v16-v19
1020+
"vfadd.vv v16, v16, v17 \n\t"
1021+
"vfadd.vv v18, v18, v19 \n\t"
1022+
"vfadd.vv v16, v16, v18 \n\t"
1023+
1024+
"vfmv.v.f v8, %[ALPHA] \n\t"
1025+
"vfmul.vv v16, v8, v16 \n\t"
1026+
"vse.v v16, (%[TMP_C]) \n\t"
1027+
"M1x4_END: \n\t"
1028+
:[TMP_C]"+r"(tmpc),
1029+
[PA]"+r"(ptrba), [PB]"+r"(ptrbb)
1030+
:[ALPHA]"f"(alpha), [BK]"r"(bk)
1031+
:"cc", "t0", "t3","t1","t2",
1032+
"ft0", "ft11",
1033+
"v0", "v1", "v2", "v3","v4",
1034+
"v8", "v9", "v10", "v11",
1035+
"v16", "v17","v18", "v19"
1036+
);
1037+
1038+
C0[0] += tmp[0];
1039+
C1[0] += tmp[1];
1040+
C2[0] += tmp[2];
1041+
C3[0] += tmp[3];
1042+
1043+
/* don't need move c point
9881044
C0 += 1;
9891045
C1 += 1;
9901046
C2 += 1;
9911047
C3 += 1;
1048+
*/
9921049
}
9931050

9941051
k = bk<<2;

0 commit comments

Comments
 (0)