Skip to content

Commit f31bea0

Browse files
authored
Merge pull request #4419 from martin-frbg/issue4413
[WIP] Add fixes and utests for ZSCAL with NaN or Inf arguments
2 parents 3599f2d + 20413ee commit f31bea0

File tree

15 files changed

+172
-42
lines changed

15 files changed

+172
-42
lines changed

kernel/arm64/zscal.S

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ zscal_begin:
223223
fcmp DA_I, #0.0
224224
beq .Lzscal_kernel_RI_zero
225225

226-
b .Lzscal_kernel_R_zero
226+
// b .Lzscal_kernel_R_zero
227227

228228
.Lzscal_kernel_R_non_zero:
229229

kernel/mips/KERNEL.P5600

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,10 @@ endif
103103
ifdef HAVE_MSA
104104
SSCALKERNEL = ../mips/sscal_msa.c
105105
DSCALKERNEL = ../mips/dscal_msa.c
106-
CSCALKERNEL = ../mips/cscal_msa.c
107-
ZSCALKERNEL = ../mips/zscal_msa.c
106+
#CSCALKERNEL = ../mips/cscal_msa.c
107+
#ZSCALKERNEL = ../mips/zscal_msa.c
108+
CSCALKERNEL = ../mips/zscal.c
109+
ZSCALKERNEL = ../mips/zscal.c
108110
else
109111
SSCALKERNEL = ../mips/scal.c
110112
DSCALKERNEL = ../mips/scal.c

kernel/mips/zscal.c

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r,FLOAT da_i, F
4747
else
4848
{
4949
temp = - da_i * x[ip+1] ;
50+
if (isnan(x[ip]) || isinf(x[ip])) temp = NAN;
5051
x[ip+1] = da_i * x[ip] ;
5152
}
5253
}
@@ -63,8 +64,11 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r,FLOAT da_i, F
6364
x[ip+1] = da_r * x[ip+1] + da_i * x[ip] ;
6465
}
6566
}
66-
x[ip] = temp;
67-
67+
if ( da_r != da_r )
68+
x[ip] = da_r;
69+
else
70+
x[ip] = temp;
71+
6872
ip += inc_x2;
6973
}
7074

kernel/riscv64/zscal.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r,FLOAT da_i, F
6060
else
6161
{
6262
temp = - da_i * x[ip+1] ;
63+
if (isnan(x[ip]) || isinf(x[ip])) temp = NAN;
6364
x[ip+1] = da_i * x[ip] ;
6465
}
6566
}

kernel/riscv64/zscal_vector.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r,FLOAT da_i, F
8080
j += gvl;
8181
ix += inc_x * 2 * gvl;
8282
}
83+
#if 0
8384
}else if(da_r == 0.0){
8485
gvl = VSETVL(n);
8586
BLASLONG stride_x = inc_x * 2 * sizeof(FLOAT);
@@ -97,6 +98,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r,FLOAT da_i, F
9798
j += gvl;
9899
ix += inc_xv;
99100
}
101+
#endif
100102
if(j < n){
101103
gvl = VSETVL(n-j);
102104
v0 = VLSEV_FLOAT(&x[ix], stride_x, gvl);

kernel/x86/zscal.S

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@
9898
fcomip %st(1), %st
9999
ffreep %st(0)
100100
jne .L30
101-
101+
jp .L30
102102
EMMS
103103

104104
pxor %mm0, %mm0

kernel/x86/zscal_sse.S

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
xorps %xmm7, %xmm7
8888
comiss %xmm0, %xmm7
8989
jne .L100 # Alpha_r != ZERO
90+
jp .L100 # Alpha_r NaN
9091

9192
comiss %xmm1, %xmm7
9293
jne .L100 # Alpha_i != ZERO

kernel/x86/zscal_sse2.S

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
xorps %xmm7, %xmm7
9999
comisd %xmm0, %xmm7
100100
jne .L100
101+
jp .L100
101102

102103
comisd %xmm1, %xmm7
103104
jne .L100

kernel/x86_64/zscal.c

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3939
#endif
4040

4141
#include "common.h"
42-
42+
#include <float.h>
4343

4444
#if defined (SKYLAKEX) || defined (COOPERLAKE) || defined (SAPPHIRERAPIDS)
4545
#include "zscal_microk_skylakex-2.c"
@@ -222,12 +222,10 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
222222

223223
if ( da_r == 0.0 )
224224
{
225-
226225
BLASLONG n1 = n & -2;
227226

228227
if ( da_i == 0.0 )
229228
{
230-
231229
while(j < n1)
232230
{
233231

@@ -253,7 +251,6 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
253251
}
254252
else
255253
{
256-
257254
while(j < n1)
258255
{
259256

@@ -356,49 +353,59 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
356353

357354
alpha[0] = da_r;
358355
alpha[1] = da_i;
359-
356+
360357
if ( da_r == 0.0 )
361358
if ( da_i == 0 )
362359
zscal_kernel_8_zero(n1 , alpha , x);
363360
else
364-
zscal_kernel_8_zero_r(n1 , alpha , x);
361+
// zscal_kernel_8_zero_r(n1 , alpha , x);
362+
zscal_kernel_8(n1 , alpha , x);
365363
else
366-
if ( da_i == 0 )
364+
if ( da_i == 0 && da_r == da_r)
367365
zscal_kernel_8_zero_i(n1 , alpha , x);
368366
else
369367
zscal_kernel_8(n1 , alpha , x);
370-
368+
}
371369
i = n1 << 1;
372370
j = n1;
373-
}
374-
375-
376-
if ( da_r == 0.0 )
371+
372+
if ( da_r == 0.0 || da_r != da_r )
377373
{
378-
379374
if ( da_i == 0.0 )
380375
{
381-
376+
FLOAT res=0.0;
377+
if (da_r != da_r) res= da_r;
382378
while(j < n)
383379
{
384-
385-
x[i]=0.0;
386-
x[i+1]=0.0;
380+
x[i]=res;
381+
x[i+1]=res;
387382
i += 2 ;
388383
j++;
389384

390385
}
391386

392387
}
393-
else
388+
else if (da_r < -FLT_MAX || da_r > FLT_MAX) {
389+
while(j < n)
390+
{
391+
x[i]= NAN;
392+
x[i+1] = da_r;
393+
i += 2 ;
394+
j++;
395+
396+
}
397+
398+
} else
394399
{
395400

396401
while(j < n)
397402
{
398-
399403
temp0 = -da_i * x[i+1];
404+
if (x[i] < -FLT_MAX || x[i] > FLT_MAX)
405+
temp0 = NAN;
400406
x[i+1] = da_i * x[i];
401-
x[i] = temp0;
407+
if ( x[i] == x[i]) //preserve NaN
408+
x[i] = temp0;
402409
i += 2 ;
403410
j++;
404411

@@ -409,28 +416,24 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
409416
}
410417
else
411418
{
412-
413-
if ( da_i == 0.0 )
419+
if (da_i == 0.0)
414420
{
415-
416-
while(j < n)
417-
{
421+
while(j < n)
422+
{
418423

419424
temp0 = da_r * x[i];
420425
x[i+1] = da_r * x[i+1];
421426
x[i] = temp0;
422427
i += 2 ;
423428
j++;
424429

425-
}
426-
430+
}
427431
}
428432
else
429433
{
430434

431435
while(j < n)
432436
{
433-
434437
temp0 = da_r * x[i] - da_i * x[i+1];
435438
x[i+1] = da_r * x[i+1] + da_i * x[i];
436439
x[i] = temp0;
@@ -445,5 +448,3 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
445448

446449
return(0);
447450
}
448-
449-

kernel/x86_64/zscal_sse2.S

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
pxor %xmm15, %xmm15
8383
comisd %xmm0, %xmm15
8484
jne .L100
85+
jp .L100
8586

8687
comisd %xmm1, %xmm15
8788
jne .L100

0 commit comments

Comments
 (0)