Skip to content

Commit c37509c

Browse files
author
tingbo.liao
committed
Optimize the nrm2_rvv function to further improve performance.
Signed-off-by: tingbo.liao <tingbo.liao@starfivetech.com>
1 parent a107547 commit c37509c

File tree

1 file changed

+204
-166
lines changed

1 file changed

+204
-166
lines changed

kernel/riscv64/nrm2_rvv.c

Lines changed: 204 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -27,185 +27,223 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2727

2828
#include "common.h"
2929

30-
#if defined(DOUBLE)
31-
#define VSETVL __riscv_vsetvl_e64m4
32-
#define FLOAT_V_T vfloat64m4_t
33-
#define FLOAT_V_T_M1 vfloat64m1_t
34-
#define VLEV_FLOAT __riscv_vle64_v_f64m4
35-
#define VLSEV_FLOAT __riscv_vlse64_v_f64m4
36-
#define VFMVVF_FLOAT __riscv_vfmv_v_f_f64m4
37-
#define VFMVSF_FLOAT __riscv_vfmv_s_f_f64m4
38-
#define VFMVVF_FLOAT_M1 __riscv_vfmv_v_f_f64m1
39-
#define MASK_T vbool16_t
40-
#define VFABS __riscv_vfabs_v_f64m4
41-
#define VMFNE __riscv_vmfne_vf_f64m4_b16
42-
#define VMFGT __riscv_vmfgt_vv_f64m4_b16
43-
#define VMFEQ __riscv_vmfeq_vf_f64m4_b16
44-
#define VCPOP __riscv_vcpop_m_b16
45-
#define VFREDMAX __riscv_vfredmax_vs_f64m4_f64m1
46-
#define VFREDMIN __riscv_vfredmin_vs_f64m4_f64m1
47-
#define VFIRST __riscv_vfirst_m_b16
48-
#define VRGATHER __riscv_vrgather_vx_f64m4
49-
#define VFDIV __riscv_vfdiv_vv_f64m4
50-
#define VFDIV_M __riscv_vfdiv_vv_f64m4_mu
51-
#define VFMUL __riscv_vfmul_vv_f64m4
52-
#define VFMUL_M __riscv_vfmul_vv_f64m4_mu
53-
#define VFMACC __riscv_vfmacc_vv_f64m4
54-
#define VFMACC_M __riscv_vfmacc_vv_f64m4_mu
55-
#define VMSBF __riscv_vmsbf_m_b16
56-
#define VMSOF __riscv_vmsof_m_b16
57-
#define VMAND __riscv_vmand_mm_b16
58-
#define VMANDN __riscv_vmand_mm_b16
59-
#define VFREDSUM __riscv_vfredusum_vs_f64m4_f64m1
60-
#define VMERGE __riscv_vmerge_vvm_f64m4
61-
#define VSEV_FLOAT __riscv_vse64_v_f64m4
62-
#define EXTRACT_FLOAT0_V(v) __riscv_vfmv_f_s_f64m4_f64(v)
63-
#define ABS fabs
64-
#else
65-
#define VSETVL __riscv_vsetvl_e32m4
30+
#if !defined(DOUBLE)
31+
#define VSETVL(n) __riscv_vsetvl_e32m4(n)
32+
#define VSETVL_MAX __riscv_vsetvlmax_e32m4()
6633
#define FLOAT_V_T vfloat32m4_t
6734
#define FLOAT_V_T_M1 vfloat32m1_t
35+
#define MASK_T vbool8_t
6836
#define VLEV_FLOAT __riscv_vle32_v_f32m4
6937
#define VLSEV_FLOAT __riscv_vlse32_v_f32m4
38+
#define VFREDSUM_FLOAT __riscv_vfredusum_vs_f32m4_f32m1_tu
39+
#define VFMACCVV_FLOAT_TU __riscv_vfmacc_vv_f32m4_tu
7040
#define VFMVVF_FLOAT __riscv_vfmv_v_f_f32m4
71-
#define VFMVSF_FLOAT __riscv_vfmv_s_f_f32m4
7241
#define VFMVVF_FLOAT_M1 __riscv_vfmv_v_f_f32m1
73-
#define MASK_T vbool8_t
74-
#define VFABS __riscv_vfabs_v_f32m4
75-
#define VMFNE __riscv_vmfne_vf_f32m4_b8
76-
#define VMFGT __riscv_vmfgt_vv_f32m4_b8
77-
#define VMFEQ __riscv_vmfeq_vf_f32m4_b8
78-
#define VCPOP __riscv_vcpop_m_b8
79-
#define VFREDMAX __riscv_vfredmax_vs_f32m4_f32m1
80-
#define VFREDMIN __riscv_vfredmin_vs_f32m4_f32m1
81-
#define VFIRST __riscv_vfirst_m_b8
82-
#define VRGATHER __riscv_vrgather_vx_f32m4
83-
#define VFDIV __riscv_vfdiv_vv_f32m4
84-
#define VFDIV_M __riscv_vfdiv_vv_f32m4_mu
85-
#define VFMUL __riscv_vfmul_vv_f32m4
86-
#define VFMUL_M __riscv_vfmul_vv_f32m4_mu
87-
#define VFMACC __riscv_vfmacc_vv_f32m4
88-
#define VFMACC_M __riscv_vfmacc_vv_f32m4_mu
89-
#define VMSBF __riscv_vmsbf_m_b8
90-
#define VMSOF __riscv_vmsof_m_b8
91-
#define VMAND __riscv_vmand_mm_b8
92-
#define VMANDN __riscv_vmand_mm_b8
93-
#define VFREDSUM __riscv_vfredusum_vs_f32m4_f32m1
94-
#define VMERGE __riscv_vmerge_vvm_f32m4
95-
#define VSEV_FLOAT __riscv_vse32_v_f32m4
96-
#define EXTRACT_FLOAT0_V(v) __riscv_vfmv_f_s_f32m4_f32(v)
42+
#define VMFIRSTM __riscv_vfirst_m_b8
43+
#define VFREDMAXVS_FLOAT_TU __riscv_vfredmax_vs_f32m4_f32m1_tu
44+
#define VFMVFS_FLOAT __riscv_vfmv_f_s_f32m1_f32
45+
#define VMFGTVF_FLOAT __riscv_vmfgt_vf_f32m4_b8
46+
#define VFDIVVF_FLOAT __riscv_vfdiv_vf_f32m4
47+
#define VFABSV_FLOAT __riscv_vfabs_v_f32m4
9748
#define ABS fabsf
49+
#else
50+
#define VSETVL(n) __riscv_vsetvl_e64m4(n)
51+
#define VSETVL_MAX __riscv_vsetvlmax_e64m4()
52+
#define FLOAT_V_T vfloat64m4_t
53+
#define FLOAT_V_T_M1 vfloat64m1_t
54+
#define MASK_T vbool16_t
55+
#define VLEV_FLOAT __riscv_vle64_v_f64m4
56+
#define VLSEV_FLOAT __riscv_vlse64_v_f64m4
57+
#define VFREDSUM_FLOAT __riscv_vfredusum_vs_f64m4_f64m1_tu
58+
#define VFMACCVV_FLOAT_TU __riscv_vfmacc_vv_f64m4_tu
59+
#define VFMVVF_FLOAT __riscv_vfmv_v_f_f64m4
60+
#define VFMVVF_FLOAT_M1 __riscv_vfmv_v_f_f64m1
61+
#define VMFIRSTM __riscv_vfirst_m_b16
62+
#define VFREDMAXVS_FLOAT_TU __riscv_vfredmax_vs_f64m4_f64m1_tu
63+
#define VFMVFS_FLOAT __riscv_vfmv_f_s_f64m1_f64
64+
#define VMFGTVF_FLOAT __riscv_vmfgt_vf_f64m4_b16
65+
#define VFDIVVF_FLOAT __riscv_vfdiv_vf_f64m4
66+
#define VFABSV_FLOAT __riscv_vfabs_v_f64m4
67+
#define ABS fabs
9868
#endif
9969

10070
FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x)
10171
{
102-
BLASLONG i=0;
103-
104-
if (n <= 0 || inc_x == 0) return(0.0);
105-
if(n == 1) return (ABS(x[0]));
106-
107-
unsigned int gvl = 0;
108-
109-
MASK_T nonzero_mask;
110-
MASK_T scale_mask;
111-
112-
gvl = VSETVL(n);
113-
FLOAT_V_T v0;
114-
FLOAT_V_T v_ssq = VFMVVF_FLOAT(0, gvl);
115-
FLOAT_V_T v_scale = VFMVVF_FLOAT(0, gvl);
116-
117-
FLOAT scale = 0;
118-
FLOAT ssq = 0;
119-
unsigned int stride_x = inc_x * sizeof(FLOAT);
120-
int idx = 0;
121-
122-
if( n >= gvl && inc_x > 0 ) // don't pay overheads if we're not doing useful work
123-
{
124-
for(i=0; i<n/gvl; i++){
125-
v0 = VLSEV_FLOAT( &x[idx], stride_x, gvl );
126-
nonzero_mask = VMFNE( v0, 0, gvl );
127-
v0 = VFABS( v0, gvl );
128-
scale_mask = VMFGT( v0, v_scale, gvl );
129-
130-
// assume scale changes are relatively infrequent
131-
132-
// unclear if the vcpop+branch is actually a win
133-
// since the operations being skipped are predicated anyway
134-
// need profiling to confirm
135-
if( VCPOP(scale_mask, gvl) )
136-
{
137-
v_scale = VFDIV_M( scale_mask, v_scale, v_scale, v0, gvl );
138-
v_scale = VFMUL_M( scale_mask, v_scale, v_scale, v_scale, gvl );
139-
v_ssq = VFMUL_M( scale_mask, v_ssq, v_ssq, v_scale, gvl );
140-
v_scale = VMERGE( v_scale, v0, scale_mask, gvl );
141-
}
142-
v0 = VFDIV_M( nonzero_mask, v0, v0, v_scale, gvl );
143-
v_ssq = VFMACC_M( nonzero_mask, v_ssq, v0, v0, gvl );
144-
idx += inc_x * gvl;
145-
}
146-
147-
// we have gvl elements which we accumulated independently, with independent scales
148-
// we need to combine these
149-
// naive sort so we process small values first to avoid losing information
150-
// could use vector sort extensions where available, but we're dealing with gvl elts at most
151-
152-
FLOAT * out_ssq = alloca(gvl*sizeof(FLOAT));
153-
FLOAT * out_scale = alloca(gvl*sizeof(FLOAT));
154-
VSEV_FLOAT( out_ssq, v_ssq, gvl );
155-
VSEV_FLOAT( out_scale, v_scale, gvl );
156-
for( int a = 0; a < (gvl-1); ++a )
157-
{
158-
int smallest = a;
159-
for( size_t b = a+1; b < gvl; ++b )
160-
if( out_scale[b] < out_scale[smallest] )
161-
smallest = b;
162-
if( smallest != a )
163-
{
164-
FLOAT tmp1 = out_ssq[a];
165-
FLOAT tmp2 = out_scale[a];
166-
out_ssq[a] = out_ssq[smallest];
167-
out_scale[a] = out_scale[smallest];
168-
out_ssq[smallest] = tmp1;
169-
out_scale[smallest] = tmp2;
170-
}
171-
}
172-
173-
int a = 0;
174-
while( a<gvl && out_scale[a] == 0 )
175-
++a;
176-
177-
if( a < gvl )
178-
{
179-
ssq = out_ssq[a];
180-
scale = out_scale[a];
181-
++a;
182-
for( ; a < gvl; ++a )
183-
{
184-
ssq = ssq * ( scale / out_scale[a] ) * ( scale / out_scale[a] ) + out_ssq[a];
185-
scale = out_scale[a];
186-
}
187-
}
188-
}
189-
190-
//finish any tail using scalar ops
191-
i*=gvl*inc_x;
192-
n*=inc_x;
72+
if (n <= 0 || inc_x == 0) return(0.0);
73+
if ( n == 1 ) return( ABS(x[0]) );
74+
75+
BLASLONG i = 0, j = 0;
76+
FLOAT scale = 0.0, ssq = 0.0;
77+
78+
if( inc_x > 0 ){
79+
FLOAT_V_T vr, v0, v_zero;
80+
unsigned int gvl = 0;
81+
FLOAT_V_T_M1 v_res, v_z0;
82+
gvl = VSETVL_MAX;
83+
v_res = VFMVVF_FLOAT_M1(0, gvl);
84+
v_z0 = VFMVVF_FLOAT_M1(0, gvl);
85+
MASK_T mask;
86+
BLASLONG index = 0;
87+
88+
if (inc_x == 1) {
89+
gvl = VSETVL(n);
90+
vr = VFMVVF_FLOAT(0, gvl);
91+
v_zero = VFMVVF_FLOAT(0, gvl);
92+
for (i = 0, j = 0; i < n / gvl; i++) {
93+
v0 = VLEV_FLOAT(&x[j], gvl);
94+
// fabs(vector)
95+
v0 = VFABSV_FLOAT(v0, gvl);
96+
// if scale change
97+
mask = VMFGTVF_FLOAT(v0, scale, gvl);
98+
index = VMFIRSTM(mask, gvl);
99+
if (index == -1) { // no elements greater than scale
100+
if (scale != 0.0) {
101+
v0 = VFDIVVF_FLOAT(v0, scale, gvl);
102+
vr = VFMACCVV_FLOAT_TU(vr, v0, v0, gvl);
103+
}
104+
}
105+
else { // found greater element
106+
// ssq in vector vr: vr[0]
107+
v_res = VFREDSUM_FLOAT(v_res, vr, v_z0, gvl);
108+
// total ssq before current vector
109+
ssq += VFMVFS_FLOAT(v_res);
110+
// find max
111+
v_res = VFREDMAXVS_FLOAT_TU(v_res, v0, v_z0, gvl);
112+
// update ssq before max_index
113+
ssq = ssq * (scale / VFMVFS_FLOAT(v_res)) * (scale / VFMVFS_FLOAT(v_res));
114+
// update scale
115+
scale = VFMVFS_FLOAT(v_res);
116+
// ssq in vector vr
117+
v0 = VFDIVVF_FLOAT(v0, scale, gvl);
118+
vr = VFMACCVV_FLOAT_TU(v_zero, v0, v0, gvl);
119+
}
120+
j += gvl;
121+
}
122+
// ssq in vector vr: vr[0]
123+
v_res = VFREDSUM_FLOAT(v_res, vr, v_z0, gvl);
124+
// total ssq now
125+
ssq += VFMVFS_FLOAT(v_res);
126+
127+
// tail processing
128+
if(j < n){
129+
gvl = VSETVL(n-j);
130+
v0 = VLEV_FLOAT(&x[j], gvl);
131+
// fabs(vector)
132+
v0 = VFABSV_FLOAT(v0, gvl);
133+
// if scale change
134+
mask = VMFGTVF_FLOAT(v0, scale, gvl);
135+
index = VMFIRSTM(mask, gvl);
136+
if (index == -1) { // no elements greater than scale
137+
if(scale != 0.0)
138+
v0 = VFDIVVF_FLOAT(v0, scale, gvl);
139+
} else { // found greater element
140+
// find max
141+
v_res = VFREDMAXVS_FLOAT_TU(v_res, v0, v_z0, gvl);
142+
// update ssq before max_index
143+
ssq = ssq * (scale / VFMVFS_FLOAT(v_res))*(scale / VFMVFS_FLOAT(v_res));
144+
// update scale
145+
scale = VFMVFS_FLOAT(v_res);
146+
v0 = VFDIVVF_FLOAT(v0, scale, gvl);
147+
}
148+
vr = VFMACCVV_FLOAT_TU(v_zero, v0, v0, gvl);
149+
// ssq in vector vr: vr[0]
150+
v_res = VFREDSUM_FLOAT(v_res, vr, v_z0, gvl);
151+
// total ssq now
152+
ssq += VFMVFS_FLOAT(v_res);
153+
}
154+
}
155+
else {
156+
gvl = VSETVL(n);
157+
vr = VFMVVF_FLOAT(0, gvl);
158+
v_zero = VFMVVF_FLOAT(0, gvl);
159+
unsigned int stride_x = inc_x * sizeof(FLOAT);
160+
int idx = 0, inc_v = inc_x * gvl;
161+
for (i = 0, j = 0; i < n / gvl; i++) {
162+
v0 = VLSEV_FLOAT(&x[idx], stride_x, gvl);
163+
// fabs(vector)
164+
v0 = VFABSV_FLOAT(v0, gvl);
165+
// if scale change
166+
mask = VMFGTVF_FLOAT(v0, scale, gvl);
167+
index = VMFIRSTM(mask, gvl);
168+
if (index == -1) {// no elements greater than scale
169+
if(scale != 0.0){
170+
v0 = VFDIVVF_FLOAT(v0, scale, gvl);
171+
vr = VFMACCVV_FLOAT_TU(vr, v0, v0, gvl);
172+
}
173+
}
174+
else { // found greater element
175+
// ssq in vector vr: vr[0]
176+
v_res = VFREDSUM_FLOAT(v_res, vr, v_z0, gvl);
177+
// total ssq before current vector
178+
ssq += VFMVFS_FLOAT(v_res);
179+
// find max
180+
v_res = VFREDMAXVS_FLOAT_TU(v_res, v0, v_z0, gvl);
181+
// update ssq before max_index
182+
ssq = ssq * (scale / VFMVFS_FLOAT(v_res))*(scale / VFMVFS_FLOAT(v_res));
183+
// update scale
184+
scale = VFMVFS_FLOAT(v_res);
185+
// ssq in vector vr
186+
v0 = VFDIVVF_FLOAT(v0, scale, gvl);
187+
vr = VFMACCVV_FLOAT_TU(v_zero, v0, v0, gvl);
188+
}
189+
j += gvl;
190+
idx += inc_v;
191+
}
192+
// ssq in vector vr: vr[0]
193+
v_res = VFREDSUM_FLOAT(v_res, vr, v_z0, gvl);
194+
// total ssq now
195+
ssq += VFMVFS_FLOAT(v_res);
196+
197+
// tail processing
198+
if (j < n) {
199+
gvl = VSETVL(n-j);
200+
v0 = VLSEV_FLOAT(&x[idx], stride_x, gvl);
201+
// fabs(vector)
202+
v0 = VFABSV_FLOAT(v0, gvl);
203+
// if scale change
204+
mask = VMFGTVF_FLOAT(v0, scale, gvl);
205+
index = VMFIRSTM(mask, gvl);
206+
if(index == -1) { // no elements greater than scale
207+
if(scale != 0.0) {
208+
v0 = VFDIVVF_FLOAT(v0, scale, gvl);
209+
vr = VFMACCVV_FLOAT_TU(v_zero, v0, v0, gvl);
210+
}
211+
}
212+
else { // found greater element
213+
// find max
214+
v_res = VFREDMAXVS_FLOAT_TU(v_res, v0, v_z0, gvl);
215+
// update ssq before max_index
216+
ssq = ssq * (scale / VFMVFS_FLOAT(v_res))*(scale / VFMVFS_FLOAT(v_res));
217+
// update scale
218+
scale = VFMVFS_FLOAT(v_res);
219+
v0 = VFDIVVF_FLOAT(v0, scale, gvl);
220+
vr = VFMACCVV_FLOAT_TU(v_zero, v0, v0, gvl);
221+
}
222+
// ssq in vector vr: vr[0]
223+
v_res = VFREDSUM_FLOAT(v_res, vr, v_z0, gvl);
224+
// total ssq now
225+
ssq += VFMVFS_FLOAT(v_res);
226+
}
227+
}
228+
}
229+
else{
230+
// using scalar ops when inc_x < 0
231+
n *= inc_x;
193232
while(abs(i) < abs(n)){
194-
if ( x[i] != 0.0 ){
195-
FLOAT absxi = ABS( x[i] );
196-
if ( scale < absxi ){
197-
ssq = 1 + ssq * ( scale / absxi ) * ( scale / absxi );
198-
scale = absxi ;
199-
}
200-
else{
201-
ssq += ( absxi/scale ) * ( absxi/scale );
202-
}
203-
204-
}
205-
206-
i += inc_x;
233+
if ( x[i] != 0.0 ){
234+
FLOAT absxi = ABS( x[i] );
235+
if ( scale < absxi ){
236+
ssq = 1 + ssq * ( scale / absxi ) * ( scale / absxi );
237+
scale = absxi ;
238+
}
239+
else{
240+
ssq += ( absxi/scale ) * ( absxi/scale );
241+
}
242+
243+
}
244+
i += inc_x;
207245
}
208-
246+
}
209247
return(scale * sqrt(ssq));
210248
}
211249

0 commit comments

Comments
 (0)