Skip to content

Commit edbf093

Browse files
authored
Update zarch SCAL kernels to handle INF and NAN arguments (#4829)
* handle INF and NAN in input (for S/D only if DUMMY2 argument is set)
1 parent 136a4ed commit edbf093

File tree

4 files changed

+160
-78
lines changed

4 files changed

+160
-78
lines changed

kernel/zarch/cscal.c

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -234,22 +234,38 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
234234
} else {
235235

236236
while (j < n1) {
237-
238-
temp0 = -da_i * x[i + 1];
239-
x[i + 1] = da_i * x[i];
237+
if (isnan(x[i]) || isinf(x[i]))
238+
temp0 = NAN;
239+
else
240+
temp0 = -da_i * x[i + 1];
241+
if (!isinf(x[i + 1]))
242+
x[i + 1] = da_i * x[i];
243+
else
244+
x[i + 1] = NAN;
240245
x[i] = temp0;
241-
temp1 = -da_i * x[i + 1 + inc_x];
242-
x[i + 1 + inc_x] = da_i * x[i + inc_x];
246+
if (isnan(x[i+inc_x]) || isinf(x[i+inc_x]))
247+
temp1 = NAN;
248+
else
249+
temp1 = -da_i * x[i + 1 + inc_x];
250+
if (!isinf(x[i + 1 + inc_x]))
251+
x[i + 1 + inc_x] = da_i * x[i + inc_x];
252+
else
253+
x[i + 1 + inc_x] = NAN;
243254
x[i + inc_x] = temp1;
244255
i += 2 * inc_x;
245256
j += 2;
246257

247258
}
248259

249260
while (j < n) {
250-
251-
temp0 = -da_i * x[i + 1];
252-
x[i + 1] = da_i * x[i];
261+
if (isnan(x[i]) || isinf(x[i]))
262+
temp0 = NAN;
263+
else
264+
temp0 = -da_i * x[i + 1];
265+
if (isinf(x[i + 1]))
266+
x[i + 1] = NAN;
267+
else
268+
x[i + 1] = da_i * x[i];
253269
x[i] = temp0;
254270
i += inc_x;
255271
j++;
@@ -332,26 +348,42 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
332348
j = n1;
333349
}
334350

335-
if (da_r == 0.0) {
351+
if (da_r == 0.0 || isnan(da_r)) {
336352

337353
if (da_i == 0.0) {
338-
354+
float res = 0.0;
355+
if (isnan(da_r)) res = da_r;
339356
while (j < n) {
340357

341-
x[i] = 0.0;
342-
x[i + 1] = 0.0;
358+
x[i] = res;
359+
x[i + 1] = res;
343360
i += 2;
344361
j++;
345362

346363
}
364+
} else if (isinf(da_r)) {
365+
while(j < n)
366+
{
367+
368+
x[i]= NAN;
369+
x[i+1] = da_r;
370+
i += 2 ;
371+
j++;
372+
373+
}
347374

348375
} else {
349376

350377
while (j < n) {
351378

352379
temp0 = -da_i * x[i + 1];
353-
x[i + 1] = da_i * x[i];
354-
x[i] = temp0;
380+
if (isinf(x[i])) temp0 = NAN;
381+
if (!isinf(x[i + 1]))
382+
x[i + 1] = da_i * x[i];
383+
else
384+
x[i + 1] = NAN;
385+
if (x[i] == x[i])
386+
x[i] = temp0;
355387
i += 2;
356388
j++;
357389

kernel/zarch/dscal.c

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -96,20 +96,28 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x,
9696
if (inc_x == 1) {
9797

9898
if (da == 0.0) {
99-
100-
BLASLONG n1 = n & -16;
101-
if (n1 > 0) {
102-
103-
dscal_kernel_16_zero(n1, x);
104-
j = n1;
99+
100+
if (dummy2 == 0) {
101+
BLASLONG n1 = n & -16;
102+
if (n1 > 0) {
103+
dscal_kernel_16_zero(n1, x);
104+
j = n1;
105+
}
106+
107+
while (j < n) {
108+
x[j] = 0.0;
109+
j++;
110+
}
111+
} else {
112+
while (j < n) {
113+
if (isfinite(x[j]))
114+
x[j] = 0.0;
115+
else
116+
x[j] = NAN;
117+
j++;
118+
}
105119
}
106-
107-
while (j < n) {
108-
109-
x[j] = 0.0;
110-
j++;
111-
}
112-
120+
113121
} else {
114122

115123
BLASLONG n1 = n & -16;
@@ -127,23 +135,23 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x,
127135
} else {
128136

129137
if (da == 0.0) {
130-
138+
if (dummy2 == 0) {
131139
BLASLONG n1 = n & -4;
132-
133140
while (j < n1) {
134-
135141
x[i] = 0.0;
136142
x[i + inc_x] = 0.0;
137143
x[i + 2 * inc_x] = 0.0;
138144
x[i + 3 * inc_x] = 0.0;
139145

140146
i += inc_x * 4;
141147
j += 4;
142-
143148
}
149+
}
144150
while (j < n) {
145-
146-
x[i] = 0.0;
151+
if (dummy2==0 || isfinite(x[i]))
152+
x[i] = 0.0;
153+
else
154+
x[i] = NAN;
147155
i += inc_x;
148156
j++;
149157
}

kernel/zarch/sscal.c

Lines changed: 55 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -95,21 +95,31 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x,
9595

9696
if (inc_x == 1) {
9797

98-
if (da == 0.0) {
99-
100-
BLASLONG n1 = n & -32;
101-
if (n1 > 0) {
102-
103-
sscal_kernel_32_zero(n1, x);
104-
j = n1;
105-
}
106-
107-
while (j < n) {
108-
109-
x[j] = 0.0;
110-
j++;
98+
if (da == 0.0 || !isfinite(da)) {
99+
if (dummy2 == 0) {
100+
BLASLONG n1 = n & -32;
101+
if (n1 > 0) {
102+
103+
sscal_kernel_32_zero(n1, x);
104+
j = n1;
105+
}
106+
107+
while (j < n) {
108+
109+
x[j] = 0.0;
110+
j++;
111+
}
112+
} else {
113+
float res = 0.0;
114+
if (!isfinite(da)) res = NAN;
115+
while (j < n) {
116+
if (isfinite(x[i]))
117+
x[j] = res;
118+
else
119+
x[j] = NAN;
120+
j++;
121+
}
111122
}
112-
113123
} else {
114124

115125
BLASLONG n1 = n & -32;
@@ -126,26 +136,37 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x,
126136

127137
} else {
128138

129-
if (da == 0.0) {
130-
131-
BLASLONG n1 = n & -2;
132-
133-
while (j < n1) {
134-
135-
x[i] = 0.0;
136-
x[i + inc_x] = 0.0;
137-
138-
i += inc_x * 2;
139-
j += 2;
140-
141-
}
142-
while (j < n) {
143-
144-
x[i] = 0.0;
145-
i += inc_x;
146-
j++;
147-
}
148-
139+
if (da == 0.0 || !isfinite(da)) {
140+
if (dummy2 == 0) {
141+
BLASLONG n1 = n & -2;
142+
143+
while (j < n1) {
144+
145+
x[i] = 0.0;
146+
x[i + inc_x] = 0.0;
147+
148+
i += inc_x * 2;
149+
j += 2;
150+
151+
}
152+
while (j < n) {
153+
154+
x[i] = 0.0;
155+
i += inc_x;
156+
j++;
157+
}
158+
} else {
159+
while (j < n) {
160+
float res = 0.0;
161+
if (!isfinite(da)) res = NAN;
162+
if (isfinite(x[i]))
163+
x[i] = res;
164+
else
165+
x[i] = NAN;
166+
i += inc_x;
167+
j++;
168+
}
169+
}
149170
} else {
150171
BLASLONG n1 = n & -2;
151172

kernel/zarch/zscal.c

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -237,13 +237,19 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
237237
temp0 = NAN;
238238
else
239239
temp0 = -da_i * x[i + 1];
240-
x[i + 1] = da_i * x[i];
240+
if (!isinf(x[i + 1]))
241+
x[i + 1] = da_i * x[i];
242+
else
243+
x[i + 1] = NAN;
241244
x[i] = temp0;
242245
if (isnan(x[i + inc_x]) || isinf(x[i + inc_x]))
243246
temp1 = NAN;
244247
else
245248
temp1 = -da_i * x[i + 1 + inc_x];
246-
x[i + 1 + inc_x] = da_i * x[i + inc_x];
249+
if (!isinf(x[i + 1 + inc_x]))
250+
x[i + 1 + inc_x] = da_i * x[i + inc_x];
251+
else
252+
x[i + 1 + inc_x] = NAN;
247253
x[i + inc_x] = temp1;
248254
i += 2 * inc_x;
249255
j += 2;
@@ -256,7 +262,10 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
256262
temp0 = NAN;
257263
else
258264
temp0 = -da_i * x[i + 1];
259-
x[i + 1] = da_i * x[i];
265+
if (!isinf(x[i +1]))
266+
x[i + 1] = da_i * x[i];
267+
else
268+
x[i + 1] = NAN;
260269
x[i] = temp0;
261270
i += inc_x;
262271
j++;
@@ -330,7 +339,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
330339
zscal_kernel_8_zero(n1, x);
331340
else
332341
zscal_kernel_8(n1, da_r, da_i, x);
333-
else if (da_i == 0)
342+
else if (da_i == 0 && da_r == da_r)
334343
zscal_kernel_8_zero_i(n1, alpha, x);
335344
else
336345
zscal_kernel_8(n1, da_r, da_i, x);
@@ -339,29 +348,41 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
339348
j = n1;
340349
}
341350

342-
if (da_r == 0.0) {
351+
if (da_r == 0.0 || isnan(da_r)) {
343352

344353
if (da_i == 0.0) {
345-
354+
double res= 0.0;
355+
if (isnan(da_r)) res = da_r;
346356
while (j < n) {
347357

348-
x[i] = 0.0;
349-
x[i + 1] = 0.0;
358+
x[i] = res;
359+
x[i + 1] = res;
350360
i += 2;
351361
j++;
352362

353363
}
354364

365+
} else if (isinf(da_r)) {
366+
while (j < n) {
367+
x[i] = NAN;
368+
x[i + 1] = da_r;
369+
i += 2;
370+
j++;
371+
}
355372
} else {
356373

357374
while (j < n) {
358375

359-
if (isnan(x[i]) || isinf(x[i]))
376+
if (isinf(x[i]))
360377
temp0 = NAN;
361378
else
362379
temp0 = -da_i * x[i + 1];
363-
x[i + 1] = da_i * x[i];
364-
x[i] = temp0;
380+
if (!isinf(x[i + 1]))
381+
x[i + 1] = da_i * x[i];
382+
else
383+
x[i + 1] = NAN;
384+
if (x[i]==x[i])
385+
x[i] = temp0;
365386
i += 2;
366387
j++;
367388

0 commit comments

Comments
 (0)