Skip to content

Commit ba17758

Browse files
committed
fix axpy implementations where y has a stride of 0
1 parent 5266998 commit ba17758

File tree

2 files changed

+47
-3
lines changed

2 files changed

+47
-3
lines changed

kernel/riscv64/axpy_rvv.c

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,29 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3030
#if !defined(DOUBLE)
3131
#define VSETVL(n) __riscv_vsetvl_e32m8(n)
3232
#define FLOAT_V_T vfloat32m8_t
33+
#define FLOAT_V_M1_T vfloat32m1_t
3334
#define VLEV_FLOAT __riscv_vle32_v_f32m8
3435
#define VLSEV_FLOAT __riscv_vlse32_v_f32m8
3536
#define VSEV_FLOAT __riscv_vse32_v_f32m8
37+
#define VSEV_FLOAT_M1 __riscv_vse32_v_f32m1
3638
#define VSSEV_FLOAT __riscv_vsse32_v_f32m8
3739
#define VFMACCVF_FLOAT __riscv_vfmacc_vf_f32m8
40+
#define VFMVVF_FLOAT __riscv_vfmv_v_f_f32m8
41+
#define VFREDSUMVS_FLOAT __riscv_vfredusum_vs_f32m8_f32m1
42+
#define VFMVVF_FLOAT_M1 __riscv_vfmv_v_f_f32m1
3843
#else
3944
#define VSETVL(n) __riscv_vsetvl_e64m8(n)
4045
#define FLOAT_V_T vfloat64m8_t
46+
#define FLOAT_V_M1_T vfloat64m1_t
4147
#define VLEV_FLOAT __riscv_vle64_v_f64m8
4248
#define VLSEV_FLOAT __riscv_vlse64_v_f64m8
4349
#define VSEV_FLOAT __riscv_vse64_v_f64m8
50+
#define VSEV_FLOAT_M1 __riscv_vse64_v_f64m1
4451
#define VSSEV_FLOAT __riscv_vsse64_v_f64m8
4552
#define VFMACCVF_FLOAT __riscv_vfmacc_vf_f64m8
53+
#define VFMVVF_FLOAT __riscv_vfmv_v_f_f64m8
54+
#define VFREDSUMVS_FLOAT __riscv_vfredusum_vs_f64m8_f64m1
55+
#define VFMVVF_FLOAT_M1 __riscv_vfmv_v_f_f64m1
4656
#endif
4757

4858
int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT *dummy, BLASLONG dummy2)
@@ -76,7 +86,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLAS
7686
VSEV_FLOAT(y, vy, vl);
7787
}
7888

79-
} else if (1 == inc_x) {
89+
} else if (1 == inc_x && 0 != inc_y) {
8090

8191
BLASLONG stride_y = inc_y * sizeof(FLOAT);
8292

@@ -89,8 +99,20 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLAS
8999
VSSEV_FLOAT(y, stride_y, vy, vl);
90100
}
91101

92-
} else {
102+
} else if( 0 == inc_y ) {
103+
BLASLONG stride_x = inc_x * sizeof(FLOAT);
104+
size_t in_vl = VSETVL(n);
105+
vy = VFMVVF_FLOAT( y[0], in_vl );
93106

107+
for (size_t vl; n > 0; n -= vl, x += vl*inc_x) {
108+
vl = VSETVL(n);
109+
vx = VLSEV_FLOAT(x, stride_x, vl);
110+
vy = VFMACCVF_FLOAT(vy, da, vx, vl);
111+
}
112+
FLOAT_V_M1_T vres = VFMVVF_FLOAT_M1( 0.0f, 1 );
113+
vres = VFREDSUMVS_FLOAT( vy, vres, in_vl );
114+
VSEV_FLOAT_M1(y, vres, 1);
115+
} else {
94116
BLASLONG stride_x = inc_x * sizeof(FLOAT);
95117
BLASLONG stride_y = inc_y * sizeof(FLOAT);
96118

kernel/riscv64/axpy_vector.c

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,20 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
5151

5252
#define VSETVL JOIN(RISCV_RVV(vsetvl), _e, ELEN, LMUL, _)
5353
#define FLOAT_V_T JOIN(vfloat, ELEN, LMUL, _t, _)
54+
#define FLOAT_V_M1_T JOIN(vfloat, ELEN, m1, _t, _)
5455
#define VLEV_FLOAT JOIN(RISCV_RVV(vle), ELEN, _v_f, ELEN, LMUL)
5556
#define VLSEV_FLOAT JOIN(RISCV_RVV(vlse), ELEN, _v_f, ELEN, LMUL)
5657
#define VSEV_FLOAT JOIN(RISCV_RVV(vse), ELEN, _v_f, ELEN, LMUL)
5758
#define VSSEV_FLOAT JOIN(RISCV_RVV(vsse), ELEN, _v_f, ELEN, LMUL)
5859
#define VFMACCVF_FLOAT JOIN(RISCV_RVV(vfmacc), _vf_f, ELEN, LMUL, _)
60+
#define VFMVVF_FLOAT JOIN(RISCV_RVV(vfmv), _v_f_f, ELEN, LMUL, _)
61+
#define VFMVVF_FLOAT_M1 JOIN(RISCV_RVV(vfmv), _v_f_f, ELEN, m1, _)
62+
63+
#ifdef RISCV_0p10_INTRINSICS
64+
#define VFREDSUMVS_FLOAT(va, vb, gvl) JOIN(RISCV_RVV(vfredusum_vs_f), ELEN, LMUL, _f, JOIN2( ELEN, m1))(v_res, va, vb, gvl)
65+
#else
66+
#define VFREDSUMVS_FLOAT JOIN(RISCV_RVV(vfredusum_vs_f), ELEN, LMUL, _f, JOIN2( ELEN, m1))
67+
#endif
5968

6069
int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT *dummy, BLASLONG dummy2)
6170
{
@@ -123,7 +132,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLAS
123132
VSEV_FLOAT(&y[j], vy0, gvl);
124133
j += gvl;
125134
}
126-
}else if(inc_x == 1){
135+
} else if (1 == inc_x && 0 != inc_y) {
127136
stride_y = inc_y * sizeof(FLOAT);
128137
gvl = VSETVL(n);
129138
if(gvl <= n/2){
@@ -151,6 +160,19 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLAS
151160
VSSEV_FLOAT(&y[j*inc_y], stride_y, vy0, gvl);
152161
j += gvl;
153162
}
163+
} else if( 0 == inc_y ) {
164+
BLASLONG stride_x = inc_x * sizeof(FLOAT);
165+
size_t in_vl = VSETVL(n);
166+
vy0 = VFMVVF_FLOAT( y[0], in_vl );
167+
168+
for (size_t vl; n > 0; n -= vl, x += vl*inc_x) {
169+
vl = VSETVL(n);
170+
vx0 = VLSEV_FLOAT(x, stride_x, vl);
171+
vy0 = VFMACCVF_FLOAT(vy0, da, vx0, vl);
172+
}
173+
FLOAT_V_M1_T v_res = VFMVVF_FLOAT_M1( 0.0f, 1 );
174+
v_res = VFREDSUMVS_FLOAT( vy0, v_res, in_vl );
175+
y[0] = EXTRACT_FLOAT(v_res);
154176
}else{
155177
stride_x = inc_x * sizeof(FLOAT);
156178
stride_y = inc_y * sizeof(FLOAT);

0 commit comments

Comments
 (0)