@@ -382,7 +382,7 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
382
382
{
383
383
BLASLONG i ,j ,k ;
384
384
FLOAT * C0 ,* C1 ,* C2 ,* C3 ;
385
- FLOAT * ptrba ,* ptrbb ;
385
+ FLOAT * ptrba ,* ptrbb , * tmpc ;
386
386
387
387
FLOAT loadb0 ,loadb1 ,loadb2 ,loadb3 ;
388
388
FLOAT load0 ,load1 ,load2 ,load3 ,load4 ,load5 ,load6 ,load7 ;
@@ -392,6 +392,7 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
392
392
FLOAT res8 ,res9 ,res10 ,res11 ;
393
393
FLOAT res12 ,res13 ,res14 ,res15 ;
394
394
395
+
395
396
for (j = 0 ; j < bn /4 ; j += 1 ){
396
397
C0 = C ;
397
398
C1 = C0 + ldc ;
@@ -942,53 +943,109 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
942
943
}
943
944
if (bm & 1 ){
944
945
ptrbb = bb ;
945
-
946
- res0 = 0 ;
947
-
948
- res4 = 0 ;
949
-
950
- res8 = 0 ;
951
-
952
- res12 = 0 ;
953
-
954
- for (k = 0 ; k < bk ; k += 1 ){
955
- loadb0 = ptrbb [0 ];
956
- loadb1 = ptrbb [1 ];
957
-
958
- load0 = ptrba [0 ];
959
-
960
- res0 = res0 + load0 * loadb0 ;
961
-
962
- res4 = res4 + load0 * loadb1 ;
963
-
964
- loadb2 = ptrbb [2 ];
965
- loadb3 = ptrbb [3 ];
966
-
967
- res8 = res8 + load0 * loadb2 ;
968
-
969
- res12 = res12 + load0 * loadb3 ;
970
-
971
- ptrba += 1 ;
972
- ptrbb += 4 ;
973
- }
974
-
975
- res0 = res0 * alpha ;
976
-
977
- res4 = res4 * alpha ;
946
+ //t0 for k
947
+ //ft0-ft3,ft4-ft7,v8-v15 for B, t1-t3 for PB1-3
948
+ //v0-v3,v4-v7 for A, t4-t6 for PA1-3
949
+ //v16-v31 for temp C
978
950
979
- res8 = res8 * alpha ;
951
+ FLOAT tmp [4 ];
952
+ tmpc = tmp ;
953
+ //t1-t3 for PB
954
+ //v0-v4 for A, v8-v11 for B
955
+ //v16-v19 for C
956
+ asm volatile (
957
+ "vsetvli zero, zero, e32,m1 \n\t"
958
+ "fmv.w.x ft11, zero \n\t"
959
+
960
+ "vfmv.v.f v16, ft11 \n\t"
961
+ "vfmv.v.f v17, ft11 \n\t"
962
+ "vfmv.v.f v18, ft11 \n\t"
963
+ "vfmv.v.f v19, ft11 \n\t"
964
+ //unloop 4
980
965
981
- res12 = res12 * alpha ;
966
+ "srli t0, %[BK], 2 \n\t"
967
+ "blez t0, M1x4_TAIL \n\t"
982
968
983
- C0 [0 ] += res0 ;
984
- C1 [0 ] += res4 ;
985
- C2 [0 ] += res8 ;
986
- C3 [0 ] += res12 ;
969
+ "addi t1, %[PB], 4*4 \n\t"
970
+ "addi t2, %[PB], 8*4 \n\t"
971
+ "addi t3, %[PB], 12*4 \n\t"
972
+
973
+ ".align 4 \n\t"
974
+ "M1x4_MAINLOOP: \n\t"
987
975
976
+ "vle.v v4, (%[PA]) \n\t"
977
+ "addi %[PA], %[PA], 4*4 \n\t"
978
+ "vrgather.vi v0, v4, 0 \n\t"
979
+
980
+ "vle.v v8, (%[PB]) \n\t"
981
+ "addi %[PB], %[PB], 16*4 \n\t"
982
+ "vrgather.vi v1, v4, 1 \n\t"
983
+
984
+ "vle.v v9, (t1) \n\t"
985
+ "addi t1, t1, 16*4 \n\t"
986
+ "vrgather.vi v2, v4, 2 \n\t"
987
+
988
+ "vle.v v10, (t2) \n\t"
989
+ "addi t2, t2, 16*4 \n\t"
990
+ "vrgather.vi v3, v4, 3 \n\t"
991
+
992
+ "vle.v v11, (t3) \n\t"
993
+ "addi t3, t3, 16*4 \n\t"
994
+
995
+ "vfmacc.vv v16, v8, v0 \n\t"
996
+ "vfmacc.vv v17, v9, v1 \n\t"
997
+ "vfmacc.vv v18, v10, v2 \n\t"
998
+ "vfmacc.vv v19, v11, v3 \n\t"
999
+
1000
+ "addi t0, t0, -1 \n\t"
1001
+ "bgtz t0, M1x4_MAINLOOP \n\t"
1002
+
1003
+ "M1x4_TAIL: \n\t"
1004
+ "andi t0, %[BK], 3 \n\t"
1005
+ "blez t0, M1x4_SAVERESULT \n\t"
1006
+
1007
+ "M1x4_TAILLOOP: \n\t"
1008
+ "flw ft0, (%[PA]) \n\t"
1009
+ "addi %[PA], %[PA], 1*4 \n\t"
1010
+ "vle.v v8, (%[PB]) \n\t"
1011
+ "addi %[PB], %[PB], 4*4 \n\t"
1012
+ "vfmv.v.f v0, ft0 \n\t"
1013
+ "vfmacc.vv v16, v8, v0 \n\t"
1014
+
1015
+ "addi t0, t0, -1 \n\t"
1016
+ "bgtz t0, M1x4_TAILLOOP \n\t"
1017
+
1018
+ "M1x4_SAVERESULT: \n\t"
1019
+ //merge v16-v19
1020
+ "vfadd.vv v16, v16, v17 \n\t"
1021
+ "vfadd.vv v18, v18, v19 \n\t"
1022
+ "vfadd.vv v16, v16, v18 \n\t"
1023
+
1024
+ "vfmv.v.f v8, %[ALPHA] \n\t"
1025
+ "vfmul.vv v16, v8, v16 \n\t"
1026
+ "vse.v v16, (%[TMP_C]) \n\t"
1027
+ "M1x4_END: \n\t"
1028
+ :[TMP_C ]"+r" (tmpc ),
1029
+ [PA ]"+r" (ptrba ), [PB ]"+r" (ptrbb )
1030
+ :[ALPHA ]"f" (alpha ), [BK ]"r" (bk )
1031
+ :"cc" , "t0" , "t3" ,"t1" ,"t2" ,
1032
+ "ft0" , "ft11" ,
1033
+ "v0" , "v1" , "v2" , "v3" ,"v4" ,
1034
+ "v8" , "v9" , "v10" , "v11" ,
1035
+ "v16" , "v17" ,"v18" , "v19"
1036
+ );
1037
+
1038
+ C0 [0 ] += tmp [0 ];
1039
+ C1 [0 ] += tmp [1 ];
1040
+ C2 [0 ] += tmp [2 ];
1041
+ C3 [0 ] += tmp [3 ];
1042
+
1043
+ /* don't need move c point
988
1044
C0 += 1;
989
1045
C1 += 1;
990
1046
C2 += 1;
991
1047
C3 += 1;
1048
+ */
992
1049
}
993
1050
994
1051
k = bk <<2 ;
0 commit comments