Skip to content

Commit fb949fb

Browse files
ENH: update random.logistic backend; kernels (#666)
1 parent 9496e32 commit fb949fb

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

dpnp/backend/kernels/dpnp_krnl_random.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -394,30 +394,38 @@ void dpnp_rng_laplace_c(void* result, const double loc, const double scale, cons
394394
event_out.wait();
395395
}
396396

397+
template <typename _KernelNameSpecialization>
398+
class dpnp_rng_logistic_c_kernel;
399+
397400
/* Logistic(loc, scale) ~ loc + scale * log(u/(1.0 - u)) */
398401
template <typename _DataType>
399402
void dpnp_rng_logistic_c(void* result, const double loc, const double scale, const size_t size)
400403
{
401-
if (!size)
404+
if (!size || !result)
402405
{
403406
return;
404407
}
405-
cl::sycl::vector_class<cl::sycl::event> no_deps;
406408

407409
const _DataType d_zero = _DataType(0.0);
408410
const _DataType d_one = _DataType(1.0);
409411

410412
_DataType* result1 = reinterpret_cast<_DataType*>(result);
411413

412414
mkl_rng::uniform<_DataType> distribution(d_zero, d_one);
413-
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
414-
event_out.wait();
415-
416-
for (size_t i = 0; i < size; i++)
417-
result1[i] = log(result1[i] / (1.0 - result1[i]));
415+
auto event_distribution = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
418416

419-
for (size_t i = 0; i < size; i++)
417+
cl::sycl::range<1> gws(size);
418+
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
419+
size_t i = global_id[0];
420+
result1[i] = cl::sycl::log(result1[i] / (1.0 - result1[i]));
420421
result1[i] = loc + scale * result1[i];
422+
};
423+
auto kernel_func = [&](cl::sycl::handler& cgh) {
424+
cgh.depends_on({event_distribution});
425+
cgh.parallel_for<class dpnp_rng_logistic_c_kernel<_DataType>>(gws, kernel_parallel_for_func);
426+
};
427+
auto event = DPNP_QUEUE.submit(kernel_func);
428+
event.wait();
421429
}
422430

423431
template <typename _DataType>

0 commit comments

Comments
 (0)