Skip to content

Commit b1aa2e1

Browse files
authored
Merge pull request #4802 from markdryan/markdryan/rvv_axpby_incy0
Fix axpby_rvv kernels for cases where inc_y = 0
2 parents a3c10c6 + 67bf4b6 commit b1aa2e1

File tree

2 files changed

+78
-43
lines changed

2 files changed

+78
-43
lines changed

kernel/riscv64/axpby_rvv.c

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ int CNAME(BLASLONG n, FLOAT alpha, FLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *
114114
vy = VFMULVF_FLOAT(vy, beta, vl);
115115
VSEV_FLOAT (y, vy, vl);
116116
}
117+
} else if (inc_y == 0) {
118+
FLOAT vf = y[0];
119+
for (; n > 0; n--)
120+
vf *= beta;
121+
y[0] = vf;
117122
} else {
118123
BLASLONG stride_y = inc_y * sizeof(FLOAT);
119124
for (size_t vl; n > 0; n -= vl, y += vl*inc_y) {
@@ -134,6 +139,13 @@ int CNAME(BLASLONG n, FLOAT alpha, FLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *
134139
vy = VFMACCVF_FLOAT(vy, alpha, vx, vl);
135140
VSEV_FLOAT (y, vy, vl);
136141
}
142+
} else if (inc_y == 0) {
143+
FLOAT vf = y[0];
144+
for (; n > 0; n--) {
145+
vf = (vf * beta) + (x[0] * alpha);
146+
x += inc_x;
147+
}
148+
y[0] = vf;
137149
} else if (1 == inc_x) {
138150
BLASLONG stride_y = inc_y * sizeof(FLOAT);
139151
for (size_t vl; n > 0; n -= vl, x += vl, y += vl*inc_y) {

kernel/riscv64/zaxpby_rvv.c

Lines changed: 66 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,10 @@ int CNAME(BLASLONG n, FLOAT alpha_r, FLOAT alpha_i, FLOAT *x, BLASLONG inc_x, FL
7979

8080
BLASLONG stride_x = inc_x2 * sizeof(FLOAT);
8181
BLASLONG stride_y = inc_y2 * sizeof(FLOAT);
82+
BLASLONG ix;
8283
FLOAT_V_T vx0, vx1, vy0, vy1;
8384
FLOAT_VX2_T vxx2, vyx2;
85+
FLOAT temp;
8486

8587
if ( beta_r == 0.0 && beta_i == 0.0)
8688
{
@@ -125,53 +127,74 @@ int CNAME(BLASLONG n, FLOAT alpha_r, FLOAT alpha_i, FLOAT *x, BLASLONG inc_x, FL
125127

126128
if ( alpha_r == 0.0 && alpha_i == 0.0 )
127129
{
128-
for (size_t vl; n > 0; n -= vl, y += vl*inc_y2)
129-
{
130-
vl = VSETVL(n);
131-
132-
vyx2 = VLSSEG_FLOAT(y, stride_y, vl);
133-
vy0 = VGET_VX2(vyx2, 0);
134-
vy1 = VGET_VX2(vyx2, 1);
135-
136-
v0 = VFMULVF_FLOAT(vy1, beta_i, vl);
137-
v0 = VFMSACVF_FLOAT(v0, beta_r, vy0, vl);
138-
139-
v1 = VFMULVF_FLOAT(vy1, beta_r, vl);
140-
v1 = VFMACCVF_FLOAT(v1, beta_i, vy0, vl);
141-
142-
v_x2 = VSET_VX2(v_x2, 0, v0);
143-
v_x2 = VSET_VX2(v_x2, 1, v1);
144-
VSSSEG_FLOAT(y, stride_y, v_x2, vl);
130+
if ( inc_y == 0 ) {
131+
for (; n > 0; n--)
132+
{
133+
temp = (beta_r * y[0] - beta_i * y[1]);
134+
y[1] = (beta_r * y[1] + beta_i * y[0]);
135+
y[0] = temp;
136+
}
137+
} else {
138+
for (size_t vl; n > 0; n -= vl, y += vl*inc_y2)
139+
{
140+
vl = VSETVL(n);
141+
142+
vyx2 = VLSSEG_FLOAT(y, stride_y, vl);
143+
vy0 = VGET_VX2(vyx2, 0);
144+
vy1 = VGET_VX2(vyx2, 1);
145+
146+
v0 = VFMULVF_FLOAT(vy1, beta_i, vl);
147+
v0 = VFMSACVF_FLOAT(v0, beta_r, vy0, vl);
148+
149+
v1 = VFMULVF_FLOAT(vy1, beta_r, vl);
150+
v1 = VFMACCVF_FLOAT(v1, beta_i, vy0, vl);
151+
152+
v_x2 = VSET_VX2(v_x2, 0, v0);
153+
v_x2 = VSET_VX2(v_x2, 1, v1);
154+
VSSSEG_FLOAT(y, stride_y, v_x2, vl);
155+
}
145156
}
146157
}
147158
else
148159
{
149-
for (size_t vl; n > 0; n -= vl, x += vl*inc_x2, y += vl*inc_y2)
150-
{
151-
vl = VSETVL(n);
152-
153-
vxx2 = VLSSEG_FLOAT(x, stride_x, vl);
154-
vyx2 = VLSSEG_FLOAT(y, stride_y, vl);
155-
156-
vx0 = VGET_VX2(vxx2, 0);
157-
vx1 = VGET_VX2(vxx2, 1);
158-
vy0 = VGET_VX2(vyx2, 0);
159-
vy1 = VGET_VX2(vyx2, 1);
160-
161-
v0 = VFMULVF_FLOAT(vx0, alpha_r, vl);
162-
v0 = VFNMSACVF_FLOAT(v0, alpha_i, vx1, vl);
163-
v0 = VFMACCVF_FLOAT(v0, beta_r, vy0, vl);
164-
v0 = VFNMSACVF_FLOAT(v0, beta_i, vy1, vl);
165-
166-
v1 = VFMULVF_FLOAT(vx1, alpha_r, vl);
167-
v1 = VFMACCVF_FLOAT(v1, alpha_i, vx0, vl);
168-
v1 = VFMACCVF_FLOAT(v1, beta_r, vy1, vl);
169-
v1 = VFMACCVF_FLOAT(v1, beta_i, vy0, vl);
170-
171-
v_x2 = VSET_VX2(v_x2, 0, v0);
172-
v_x2 = VSET_VX2(v_x2, 1, v1);
173-
174-
VSSSEG_FLOAT(y, stride_y, v_x2, vl);
160+
if ( inc_y == 0 ) {
161+
ix = 0;
162+
for (; n > 0; n--) {
163+
temp = (alpha_r * x[ix] - alpha_i * x[ix+1] ) +
164+
(beta_r * y[0] - beta_i * y[1]);
165+
y[1] = (alpha_r * x[ix+1] + alpha_i * x[ix]) +
166+
(beta_r * y[1] + beta_i * y[0]);
167+
y[0] = temp;
168+
ix += inc_x2;
169+
}
170+
} else {
171+
for (size_t vl; n > 0; n -= vl, x += vl*inc_x2, y += vl*inc_y2)
172+
{
173+
vl = VSETVL(n);
174+
175+
vxx2 = VLSSEG_FLOAT(x, stride_x, vl);
176+
vyx2 = VLSSEG_FLOAT(y, stride_y, vl);
177+
178+
vx0 = VGET_VX2(vxx2, 0);
179+
vx1 = VGET_VX2(vxx2, 1);
180+
vy0 = VGET_VX2(vyx2, 0);
181+
vy1 = VGET_VX2(vyx2, 1);
182+
183+
v0 = VFMULVF_FLOAT(vx0, alpha_r, vl);
184+
v0 = VFNMSACVF_FLOAT(v0, alpha_i, vx1, vl);
185+
v0 = VFMACCVF_FLOAT(v0, beta_r, vy0, vl);
186+
v0 = VFNMSACVF_FLOAT(v0, beta_i, vy1, vl);
187+
188+
v1 = VFMULVF_FLOAT(vx1, alpha_r, vl);
189+
v1 = VFMACCVF_FLOAT(v1, alpha_i, vx0, vl);
190+
v1 = VFMACCVF_FLOAT(v1, beta_r, vy1, vl);
191+
v1 = VFMACCVF_FLOAT(v1, beta_i, vy0, vl);
192+
193+
v_x2 = VSET_VX2(v_x2, 0, v0);
194+
v_x2 = VSET_VX2(v_x2, 1, v1);
195+
196+
VSSSEG_FLOAT(y, stride_y, v_x2, vl);
197+
}
175198
}
176199
}
177200
}

0 commit comments

Comments
 (0)