Skip to content

Commit 2893d0a

Browse files
authored
Merge pull request #5211 from guoyuanplct/develop
Optimizing the Implementation of GEMV on the RISC-V V Extension
2 parents ed1e470 + 1ff303f commit 2893d0a

File tree

1 file changed

+207
-97
lines changed

1 file changed

+207
-97
lines changed

kernel/riscv64/gemv_n_vector.c

Lines changed: 207 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,15 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2727

2828
#include "common.h"
2929
#if !defined(DOUBLE)
30-
#define VSETVL(n) RISCV_RVV(vsetvl_e32m4)(n)
31-
#define FLOAT_V_T vfloat32m4_t
32-
#define VLEV_FLOAT RISCV_RVV(vle32_v_f32m4)
33-
#define VLSEV_FLOAT RISCV_RVV(vlse32_v_f32m4)
34-
#define VSEV_FLOAT RISCV_RVV(vse32_v_f32m4)
35-
#define VSSEV_FLOAT RISCV_RVV(vsse32_v_f32m4)
36-
#define VFMACCVF_FLOAT RISCV_RVV(vfmacc_vf_f32m4)
30+
#define VSETVL(n) RISCV_RVV(vsetvl_e32m8)(n)
31+
#define FLOAT_V_T vfloat32m8_t
32+
#define VLEV_FLOAT RISCV_RVV(vle32_v_f32m8)
33+
#define VLSEV_FLOAT RISCV_RVV(vlse32_v_f32m8)
34+
#define VSEV_FLOAT RISCV_RVV(vse32_v_f32m8)
35+
#define VSSEV_FLOAT RISCV_RVV(vsse32_v_f32m8)
36+
#define VFMACCVF_FLOAT RISCV_RVV(vfmacc_vf_f32m8)
37+
#define VFMUL_VF_FLOAT RISCV_RVV(vfmul_vf_f32m8)
38+
#define VFILL_ZERO_FLOAT RISCV_RVV(vfsub_vv_f32m8)
3739
#else
3840
#define VSETVL(n) RISCV_RVV(vsetvl_e64m4)(n)
3941
#define FLOAT_V_T vfloat64m4_t
@@ -42,103 +44,211 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
4244
#define VSEV_FLOAT RISCV_RVV(vse64_v_f64m4)
4345
#define VSSEV_FLOAT RISCV_RVV(vsse64_v_f64m4)
4446
#define VFMACCVF_FLOAT RISCV_RVV(vfmacc_vf_f64m4)
47+
#define VFMUL_VF_FLOAT RISCV_RVV(vfmul_vf_f64m4)
48+
#define VFILL_ZERO_FLOAT RISCV_RVV(vfsub_vv_f64m4)
4549
#endif
4650

4751
int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLONG lda, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT *buffer)
4852
{
49-
BLASLONG i = 0, j = 0, k = 0;
50-
BLASLONG ix = 0, iy = 0;
51-
52-
if(n < 0) return(0);
53-
FLOAT *a_ptr = a;
54-
FLOAT temp = 0.0;
55-
FLOAT_V_T va0, va1, vy0, vy1;
56-
unsigned int gvl = 0;
57-
if(inc_y == 1){
58-
gvl = VSETVL(m);
59-
if(gvl <= m/2){
60-
for(k=0,j=0; k<m/(2*gvl); k++){
61-
a_ptr = a;
62-
ix = 0;
63-
vy0 = VLEV_FLOAT(&y[j], gvl);
64-
vy1 = VLEV_FLOAT(&y[j+gvl], gvl);
65-
for(i = 0; i < n; i++){
66-
temp = alpha * x[ix];
67-
va0 = VLEV_FLOAT(&a_ptr[j], gvl);
68-
vy0 = VFMACCVF_FLOAT(vy0, temp, va0, gvl);
69-
70-
va1 = VLEV_FLOAT(&a_ptr[j+gvl], gvl);
71-
vy1 = VFMACCVF_FLOAT(vy1, temp, va1, gvl);
72-
a_ptr += lda;
73-
ix += inc_x;
74-
}
75-
VSEV_FLOAT(&y[j], vy0, gvl);
76-
VSEV_FLOAT(&y[j+gvl], vy1, gvl);
77-
j += gvl * 2;
78-
}
79-
}
80-
//tail
81-
for(;j < m;){
82-
gvl = VSETVL(m-j);
83-
a_ptr = a;
84-
ix = 0;
85-
vy0 = VLEV_FLOAT(&y[j], gvl);
86-
for(i = 0; i < n; i++){
87-
temp = alpha * x[ix];
88-
va0 = VLEV_FLOAT(&a_ptr[j], gvl);
89-
vy0 = VFMACCVF_FLOAT(vy0, temp, va0, gvl);
90-
91-
a_ptr += lda;
92-
ix += inc_x;
93-
}
94-
VSEV_FLOAT(&y[j], vy0, gvl);
95-
j += gvl;
53+
BLASLONG i = 0, j = 0, k = 0;
54+
BLASLONG ix = 0, iy = 0;
55+
56+
if(n < 0) return(0);
57+
FLOAT *a_ptr = a;
58+
FLOAT temp[4];
59+
FLOAT_V_T va0, va1, vy0, vy1,vy0_temp, vy1_temp , temp_v ,va0_0 , va0_1 , va1_0 ,va1_1 ,va2_0 ,va2_1 ,va3_0 ,va3_1 ;
60+
unsigned int gvl = 0;
61+
if(inc_y == 1 && inc_x == 1){
62+
gvl = VSETVL(m);
63+
if(gvl <= m/2){
64+
for(k=0,j=0; k<m/(2*gvl); k++){
65+
a_ptr = a;
66+
ix = 0;
67+
vy0_temp = VLEV_FLOAT(&y[j], gvl);
68+
vy1_temp = VLEV_FLOAT(&y[j+gvl], gvl);
69+
vy0 = VFILL_ZERO_FLOAT(vy0 , vy0 , gvl);
70+
vy1 = VFILL_ZERO_FLOAT(vy1 , vy1 , gvl);
71+
int i;
72+
73+
int remainder = n % 4;
74+
for(i = 0; i < remainder; i++){
75+
temp[0] = x[ix];
76+
va0 = VLEV_FLOAT(&a_ptr[j], gvl);
77+
vy0 = VFMACCVF_FLOAT(vy0, temp[0], va0, gvl);
78+
79+
va1 = VLEV_FLOAT(&a_ptr[j+gvl], gvl);
80+
vy1 = VFMACCVF_FLOAT(vy1, temp[0], va1, gvl);
81+
a_ptr += lda;
82+
ix ++;
9683
}
97-
}else{
98-
BLASLONG stride_y = inc_y * sizeof(FLOAT);
99-
gvl = VSETVL(m);
100-
if(gvl <= m/2){
101-
BLASLONG inc_yv = inc_y * gvl;
102-
for(k=0,j=0; k<m/(2*gvl); k++){
103-
a_ptr = a;
104-
ix = 0;
105-
vy0 = VLSEV_FLOAT(&y[iy], stride_y, gvl);
106-
vy1 = VLSEV_FLOAT(&y[iy+inc_yv], stride_y, gvl);
107-
for(i = 0; i < n; i++){
108-
temp = alpha * x[ix];
109-
va0 = VLEV_FLOAT(&a_ptr[j], gvl);
110-
vy0 = VFMACCVF_FLOAT(vy0, temp, va0, gvl);
111-
112-
va1 = VLEV_FLOAT(&a_ptr[j+gvl], gvl);
113-
vy1 = VFMACCVF_FLOAT(vy1, temp, va1, gvl);
114-
a_ptr += lda;
115-
ix += inc_x;
116-
}
117-
VSSEV_FLOAT(&y[iy], stride_y, vy0, gvl);
118-
VSSEV_FLOAT(&y[iy+inc_yv], stride_y, vy1, gvl);
119-
j += gvl * 2;
120-
iy += inc_yv * 2;
121-
}
84+
85+
for(i = remainder; i < n; i += 4){
86+
va0_0 = VLEV_FLOAT(&(a_ptr)[j], gvl);
87+
va0_1 = VLEV_FLOAT(&(a_ptr)[j+gvl], gvl);
88+
va1_0 = VLEV_FLOAT(&(a_ptr+lda * 1)[j], gvl);
89+
va1_1 = VLEV_FLOAT(&(a_ptr+lda * 1)[j+gvl], gvl);
90+
va2_0 = VLEV_FLOAT(&(a_ptr+lda * 2)[j], gvl);
91+
va2_1 = VLEV_FLOAT(&(a_ptr+lda * 2)[j+gvl], gvl);
92+
va3_0 = VLEV_FLOAT(&(a_ptr+lda * 3)[j], gvl);
93+
va3_1 = VLEV_FLOAT(&(a_ptr+lda * 3)[j+gvl], gvl);
94+
95+
vy0 = VFMACCVF_FLOAT(vy0, x[ix], va0_0, gvl);
96+
vy1 = VFMACCVF_FLOAT(vy1, x[ix], va0_1, gvl);
97+
98+
vy0 = VFMACCVF_FLOAT(vy0, x[ix+1], va1_0, gvl);
99+
vy1 = VFMACCVF_FLOAT(vy1, x[ix+1], va1_1, gvl);
100+
101+
vy0 = VFMACCVF_FLOAT(vy0, x[ix+2], va2_0, gvl);
102+
vy1 = VFMACCVF_FLOAT(vy1, x[ix+2], va2_1, gvl);
103+
104+
vy0 = VFMACCVF_FLOAT(vy0, x[ix+3], va3_0, gvl);
105+
vy1 = VFMACCVF_FLOAT(vy1, x[ix+3], va3_1, gvl);
106+
a_ptr += 4 * lda;
107+
ix +=4;
122108
}
123-
//tail
124-
for(;j < m;){
125-
gvl = VSETVL(m-j);
126-
a_ptr = a;
127-
ix = 0;
128-
vy0 = VLSEV_FLOAT(&y[j*inc_y], stride_y, gvl);
129-
for(i = 0; i < n; i++){
130-
temp = alpha * x[ix];
131-
va0 = VLEV_FLOAT(&a_ptr[j], gvl);
132-
vy0 = VFMACCVF_FLOAT(vy0, temp, va0, gvl);
133-
134-
a_ptr += lda;
135-
ix += inc_x;
136-
}
137-
VSSEV_FLOAT(&y[j*inc_y], stride_y, vy0, gvl);
138-
j += gvl;
109+
vy0 = VFMACCVF_FLOAT(vy0_temp, alpha, vy0, gvl);
110+
vy1 = VFMACCVF_FLOAT(vy1_temp, alpha, vy1, gvl);
111+
VSEV_FLOAT(&y[j], vy0, gvl);
112+
VSEV_FLOAT(&y[j+gvl], vy1, gvl);
113+
j += gvl * 2;
114+
}
115+
}
116+
//tail
117+
if(gvl <= m - j ){
118+
a_ptr = a;
119+
ix = 0;
120+
vy0_temp = VLEV_FLOAT(&y[j], gvl);
121+
vy0 = VFILL_ZERO_FLOAT(vy0 , vy0 , gvl);
122+
int i;
123+
124+
int remainder = n % 4;
125+
for(i = 0; i < remainder; i++){
126+
temp[0] = x[ix];
127+
va0 = VLEV_FLOAT(&a_ptr[j], gvl);
128+
vy0 = VFMACCVF_FLOAT(vy0, temp[0], va0, gvl);
129+
a_ptr += lda;
130+
ix ++;
131+
}
132+
133+
for(i = remainder; i < n; i += 4){
134+
va0_0 = VLEV_FLOAT(&(a_ptr)[j], gvl);
135+
va1_0 = VLEV_FLOAT(&(a_ptr+lda * 1)[j], gvl);
136+
va2_0 = VLEV_FLOAT(&(a_ptr+lda * 2)[j], gvl);
137+
va3_0 = VLEV_FLOAT(&(a_ptr+lda * 3)[j], gvl);
138+
vy0 = VFMACCVF_FLOAT(vy0, x[ix], va0_0, gvl);
139+
vy0 = VFMACCVF_FLOAT(vy0, x[ix+1], va1_0, gvl);
140+
vy0 = VFMACCVF_FLOAT(vy0, x[ix+2], va2_0, gvl);
141+
vy0 = VFMACCVF_FLOAT(vy0, x[ix+3], va3_0, gvl);
142+
a_ptr += 4 * lda;
143+
ix +=4;
144+
}
145+
vy0 = VFMACCVF_FLOAT(vy0_temp, alpha, vy0, gvl);
146+
147+
VSEV_FLOAT(&y[j], vy0, gvl);
148+
149+
j += gvl ;
150+
}
151+
152+
153+
for(;j < m;){
154+
gvl = VSETVL(m-j);
155+
a_ptr = a;
156+
ix = 0;
157+
vy0 = VLEV_FLOAT(&y[j], gvl);
158+
for(i = 0; i < n; i++){
159+
temp[0] = alpha * x[ix];
160+
va0 = VLEV_FLOAT(&a_ptr[j], gvl);
161+
vy0 = VFMACCVF_FLOAT(vy0, temp[0], va0, gvl);
162+
163+
a_ptr += lda;
164+
ix += inc_x;
165+
}
166+
VSEV_FLOAT(&y[j], vy0, gvl);
167+
j += gvl;
168+
}
169+
}else if (inc_y == 1 && inc_x !=1) {
170+
gvl = VSETVL(m);
171+
if(gvl <= m/2){
172+
for(k=0,j=0; k<m/(2*gvl); k++){
173+
a_ptr = a;
174+
ix = 0;
175+
vy0 = VLEV_FLOAT(&y[j], gvl);
176+
vy1 = VLEV_FLOAT(&y[j+gvl], gvl);
177+
for(i = 0; i < n; i++){
178+
temp[0] = alpha * x[ix];
179+
va0 = VLEV_FLOAT(&a_ptr[j], gvl);
180+
vy0 = VFMACCVF_FLOAT(vy0, temp[0], va0, gvl);
181+
182+
va1 = VLEV_FLOAT(&a_ptr[j+gvl], gvl);
183+
vy1 = VFMACCVF_FLOAT(vy1, temp[0], va1, gvl);
184+
a_ptr += lda;
185+
ix += inc_x;
139186
}
187+
VSEV_FLOAT(&y[j], vy0, gvl);
188+
VSEV_FLOAT(&y[j+gvl], vy1, gvl);
189+
j += gvl * 2;
190+
}
140191
}
141-
return(0);
142-
}
192+
//tail
193+
for(;j < m;){
194+
gvl = VSETVL(m-j);
195+
a_ptr = a;
196+
ix = 0;
197+
vy0 = VLEV_FLOAT(&y[j], gvl);
198+
for(i = 0; i < n; i++){
199+
temp[0] = alpha * x[ix];
200+
va0 = VLEV_FLOAT(&a_ptr[j], gvl);
201+
vy0 = VFMACCVF_FLOAT(vy0, temp[0], va0, gvl);
143202

203+
a_ptr += lda;
204+
ix += inc_x;
205+
}
206+
VSEV_FLOAT(&y[j], vy0, gvl);
207+
j += gvl;
208+
}
209+
}else{
210+
BLASLONG stride_y = inc_y * sizeof(FLOAT);
211+
gvl = VSETVL(m);
212+
if(gvl <= m/2){
213+
BLASLONG inc_yv = inc_y * gvl;
214+
for(k=0,j=0; k<m/(2*gvl); k++){
215+
a_ptr = a;
216+
ix = 0;
217+
vy0 = VLSEV_FLOAT(&y[iy], stride_y, gvl);
218+
vy1 = VLSEV_FLOAT(&y[iy+inc_yv], stride_y, gvl);
219+
for(i = 0; i < n; i++){
220+
temp[0] = alpha * x[ix];
221+
va0 = VLEV_FLOAT(&a_ptr[j], gvl);
222+
vy0 = VFMACCVF_FLOAT(vy0, temp[0], va0, gvl);
144223

224+
va1 = VLEV_FLOAT(&a_ptr[j+gvl], gvl);
225+
vy1 = VFMACCVF_FLOAT(vy1, temp[0], va1, gvl);
226+
a_ptr += lda;
227+
ix += inc_x;
228+
}
229+
VSSEV_FLOAT(&y[iy], stride_y, vy0, gvl);
230+
VSSEV_FLOAT(&y[iy+inc_yv], stride_y, vy1, gvl);
231+
j += gvl * 2;
232+
iy += inc_yv * 2;
233+
}
234+
}
235+
//tail
236+
for(;j < m;){
237+
gvl = VSETVL(m-j);
238+
a_ptr = a;
239+
ix = 0;
240+
vy0 = VLSEV_FLOAT(&y[j*inc_y], stride_y, gvl);
241+
for(i = 0; i < n; i++){
242+
temp[0] = alpha * x[ix];
243+
va0 = VLEV_FLOAT(&a_ptr[j], gvl);
244+
vy0 = VFMACCVF_FLOAT(vy0, temp[0], va0, gvl);
245+
246+
a_ptr += lda;
247+
ix += inc_x;
248+
}
249+
VSSEV_FLOAT(&y[j*inc_y], stride_y, vy0, gvl);
250+
j += gvl;
251+
}
252+
}
253+
return(0);
254+
}

0 commit comments

Comments
 (0)