@@ -199,12 +199,18 @@ template <typename T> bool test() {
199
199
200
200
sycl::buffer<std::complex<T>> data (testcases, sycl::range{N});
201
201
sycl::buffer<std::complex<T>> results (sycl::range{N});
202
+ sycl::buffer<std::complex<T>> exp_conj (sycl::range{N});
203
+ sycl::buffer<std::complex<T>> conj_exp (sycl::range{N});
202
204
203
205
q.submit ([&](sycl::handler &cgh) {
204
206
sycl::accessor acc_data (data, cgh, sycl::read_only);
205
- sycl::accessor acc (results, cgh, sycl::write_only);
207
+ sycl::accessor acc_results (results, cgh, sycl::write_only);
208
+ sycl::accessor acc_exp_conj (exp_conj, cgh, sycl::write_only);
209
+ sycl::accessor acc_conj_exp (conj_exp, cgh, sycl::write_only);
206
210
cgh.parallel_for (sycl::range{N}, [=](sycl::item<1 > it) {
207
- acc[it] = std::exp (acc_data[it]);
211
+ acc_results[it] = std::exp (acc_data[it]);
212
+ acc_exp_conj[it] = std::exp (std::conj (acc_data[it]));
213
+ acc_conj_exp[it] = std::conj (std::exp (acc_data[it]));
208
214
});
209
215
}).wait_and_throw ();
210
216
@@ -219,9 +225,18 @@ template <typename T> bool test() {
219
225
220
226
// Based on https://en.cppreference.com/w/cpp/numeric/complex/exp
221
227
// z below refers to the argument passed to std::exp(complex<T>)
222
- sycl::host_accessor acc (results);
228
+ sycl::host_accessor acc_results (results);
229
+ sycl::host_accessor acc_exp_conj (exp_conj);
230
+ sycl::host_accessor acc_conj_exp (conj_exp);
223
231
for (unsigned i = 0 ; i < N; ++i) {
224
- std::complex<T> r = acc[i];
232
+ // std::exp(std::conj(z)) == std::conj(std::exp(z))
233
+ // NAN is not equal to NAN in floating-point arithmetic, therefore compare
234
+ // only results without NAN
235
+ if (!std::isnan (acc_exp_conj[i].real ()) &&
236
+ !std::isnan (acc_exp_conj[i].imag ()))
237
+ CHECK (acc_exp_conj[i] == acc_conj_exp[i], passed, i);
238
+
239
+ std::complex<T> r = acc_results[i];
225
240
// If z is (+/-0, +0), the result is (1, +0)
226
241
if (testcases[i].real () == 0 && testcases[i].imag () == 0 &&
227
242
!std::signbit (testcases[i].imag ())) {
@@ -247,6 +262,33 @@ template <typename T> bool test() {
247
262
CHECK (r.imag () == 0 , passed, i);
248
263
CHECK (std::signbit (testcases[i].imag ()) == std::signbit (r.imag ()),
249
264
passed, i);
265
+ // If z is (-inf, y) (for any finite y), the result is +0cis(y) where
266
+ // cis(y) is cos(y) + isin(y)
267
+ } else if (std::isinf (testcases[i].real ()) &&
268
+ std::signbit (testcases[i].real ()) &&
269
+ std::isfinite (testcases[i].imag ())) {
270
+ CHECK (r.real () == 0 , passed, i)
271
+ CHECK (std::signbit (r.real ()) ==
272
+ std::signbit (std::cos (testcases[i].imag ())),
273
+ passed, i)
274
+ CHECK (r.imag () == 0 , passed, i)
275
+ CHECK (std::signbit (r.imag ()) ==
276
+ std::signbit (std::sin (testcases[i].imag ())),
277
+ passed, i)
278
+ // If z is (+inf, y) (for any finite nonzero y), the result is +∞cis(y)
279
+ // where cis(y) is cos(y) + isin(y)
280
+ } else if (std::isinf (testcases[i].real ()) &&
281
+ !std::signbit (testcases[i].real ()) &&
282
+ std::isfinite (testcases[i].imag ()) &&
283
+ testcases[i].imag () != 0 ) {
284
+ CHECK (std::isinf (r.real ()), passed, i)
285
+ CHECK (std::signbit (r.real ()) ==
286
+ std::signbit (std::cos (testcases[i].imag ())),
287
+ passed, i)
288
+ CHECK (std::isinf (r.imag ()), passed, i)
289
+ CHECK (std::signbit (r.imag ()) ==
290
+ std::signbit (std::sin (testcases[i].imag ())),
291
+ passed, i)
250
292
// If z is (-inf, +inf), the result is (+/-0, +/-0) (signs are
251
293
// unspecified)
252
294
} else if (std::isinf (testcases[i].real ()) && testcases[i].real () < 0 &&
0 commit comments