Skip to content

Commit eb71cc5

Browse files
MAINT: refactor random module (#664)
1 parent c90b74c commit eb71cc5

File tree

2 files changed

+152
-140
lines changed

2 files changed

+152
-140
lines changed

dpnp/backend/kernels/dpnp_krnl_random.cpp

Lines changed: 108 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ void dpnp_rng_srand_c(size_t seed)
7272
}
7373

7474
template <typename _DataType>
75-
INP_DLLEXPORT void dpnp_rng_beta_c(void* result, const _DataType a, const _DataType b, const size_t size)
75+
void dpnp_rng_beta_c(void* result, const _DataType a, const _DataType b, const size_t size)
7676
{
7777
if (!size)
7878
{
@@ -116,7 +116,6 @@ void dpnp_rng_binomial_c(void* result, const int ntrial, const double p, const s
116116
{
117117
return;
118118
}
119-
_DataType* result1 = reinterpret_cast<_DataType*>(result);
120119

121120
if (ntrial == 0 || p == 0)
122121
{
@@ -131,6 +130,7 @@ void dpnp_rng_binomial_c(void* result, const int ntrial, const double p, const s
131130
}
132131
else
133132
{
133+
_DataType* result1 = reinterpret_cast<_DataType*>(result);
134134
if (dpnp_queue_is_cpu_c())
135135
{
136136
mkl_rng::binomial<_DataType> distribution(ntrial, p);
@@ -253,29 +253,34 @@ void dpnp_rng_f_c(void* result, const _DataType df_num, const _DataType df_den,
253253
template <typename _DataType>
254254
void dpnp_rng_gamma_c(void* result, const _DataType shape, const _DataType scale, const size_t size)
255255
{
256-
if (!size)
256+
if (!size || result == nullptr)
257257
{
258258
return;
259259
}
260260

261-
// set displacement a
262-
const _DataType a = (_DataType(0.0));
263-
264-
_DataType* result1 = reinterpret_cast<_DataType*>(result);
265-
266-
if (dpnp_queue_is_cpu_c())
261+
if (shape == 0.0 || scale == 0.0)
267262
{
268-
mkl_rng::gamma<_DataType> distribution(shape, a, scale);
269-
// perform generation
270-
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
271-
event_out.wait();
263+
dpnp_zeros_c<_DataType>(result, size);
272264
}
273265
else
274266
{
275-
int errcode = vdRngGamma(VSL_RNG_METHOD_GAMMA_GNORM, get_rng_stream(), size, result1, shape, a, scale);
276-
if (errcode != VSL_STATUS_OK)
267+
_DataType* result1 = reinterpret_cast<_DataType*>(result);
268+
const _DataType a = (_DataType(0.0));
269+
270+
if (dpnp_queue_is_cpu_c())
271+
{
272+
mkl_rng::gamma<_DataType> distribution(shape, a, scale);
273+
// perform generation
274+
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
275+
event_out.wait();
276+
}
277+
else
277278
{
278-
throw std::runtime_error("DPNP RNG Error: dpnp_rng_gamma_c() failed.");
279+
int errcode = vdRngGamma(VSL_RNG_METHOD_GAMMA_GNORM, get_rng_stream(), size, result1, shape, a, scale);
280+
if (errcode != VSL_STATUS_OK)
281+
{
282+
throw std::runtime_error("DPNP RNG Error: dpnp_rng_gamma_c() failed.");
283+
}
279284
}
280285
}
281286
}
@@ -298,16 +303,23 @@ void dpnp_rng_gaussian_c(void* result, const _DataType mean, const _DataType std
298303
template <typename _DataType>
299304
void dpnp_rng_geometric_c(void* result, const float p, const size_t size)
300305
{
301-
if (!size)
306+
if (!size || !result)
302307
{
303308
return;
304309
}
305-
_DataType* result1 = reinterpret_cast<_DataType*>(result);
306310

307-
mkl_rng::geometric<_DataType> distribution(p);
308-
// perform generation
309-
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
310-
event_out.wait();
311+
if (p == 1.0)
312+
{
313+
dpnp_ones_c<_DataType>(result, size);
314+
}
315+
else
316+
{
317+
_DataType* result1 = reinterpret_cast<_DataType*>(result);
318+
mkl_rng::geometric<_DataType> distribution(p);
319+
// perform generation
320+
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
321+
event_out.wait();
322+
}
311323
}
312324

313325
template <typename _KernelNameSpecialization>
@@ -316,82 +328,114 @@ class dpnp_blas_scal_c_kernel;
316328
template <typename _DataType>
317329
void dpnp_rng_gumbel_c(void* result, const double loc, const double scale, const size_t size)
318330
{
319-
cl::sycl::event event;
320-
if (!size)
331+
if (!size || !result)
321332
{
322333
return;
323334
}
324335

325-
const _DataType alpha = (_DataType(-1.0));
326-
std::int64_t incx = 1;
327-
_DataType* result1 = reinterpret_cast<_DataType*>(result);
328-
double negloc = loc * (double(-1.0));
329-
330-
mkl_rng::gumbel<_DataType> distribution(negloc, scale);
331-
event = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
332-
event.wait();
333-
334-
// OK for CPU and segfault for GPU device
335-
// event = mkl_blas::scal(DPNP_QUEUE, size, alpha, result1, incx);
336-
if (dpnp_queue_is_cpu_c())
336+
if (scale == 0.0)
337337
{
338-
event = mkl_blas::scal(DPNP_QUEUE, size, alpha, result1, incx);
338+
_DataType* fill_value = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(sizeof(_DataType)));
339+
fill_value[0] = static_cast<_DataType>(loc);
340+
dpnp_initval_c<_DataType>(result, fill_value, size);
341+
dpnp_memory_free_c(fill_value);
339342
}
340343
else
341344
{
342-
// for (size_t i = 0; i < size; i++) result1[i] *= alpha;
343-
cl::sycl::range<1> gws(size);
344-
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
345-
size_t i = global_id[0];
346-
result1[i] *= alpha;
347-
};
348-
auto kernel_func = [&](cl::sycl::handler& cgh) {
349-
cgh.parallel_for<class dpnp_blas_scal_c_kernel<_DataType>>(gws, kernel_parallel_for_func);
350-
};
351-
event = DPNP_QUEUE.submit(kernel_func);
345+
const _DataType alpha = (_DataType(-1.0));
346+
std::int64_t incx = 1;
347+
_DataType* result1 = reinterpret_cast<_DataType*>(result);
348+
double negloc = loc * (double(-1.0));
349+
350+
mkl_rng::gumbel<_DataType> distribution(negloc, scale);
351+
auto event_distribution = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
352+
353+
// OK for CPU and segfault for GPU device
354+
// event = mkl_blas::scal(DPNP_QUEUE, size, alpha, result1, incx);
355+
cl::sycl::event prod_event;
356+
if (dpnp_queue_is_cpu_c())
357+
{
358+
prod_event = mkl_blas::scal(DPNP_QUEUE, size, alpha, result1, incx, {event_distribution});
359+
}
360+
else
361+
{
362+
// for (size_t i = 0; i < size; i++) result1[i] *= alpha;
363+
cl::sycl::range<1> gws(size);
364+
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
365+
size_t i = global_id[0];
366+
result1[i] *= alpha;
367+
};
368+
auto kernel_func = [&](cl::sycl::handler& cgh) {
369+
cgh.depends_on({event_distribution});
370+
cgh.parallel_for<class dpnp_blas_scal_c_kernel<_DataType>>(gws, kernel_parallel_for_func);
371+
};
372+
prod_event = DPNP_QUEUE.submit(kernel_func);
373+
}
374+
prod_event.wait();
352375
}
353-
event.wait();
354376
}
355377

356378
template <typename _DataType>
357379
void dpnp_rng_hypergeometric_c(void* result, const int l, const int s, const int m, const size_t size)
358380
{
359-
if (!size)
381+
if (!size || !result)
360382
{
361383
return;
362384
}
363-
_DataType* result1 = reinterpret_cast<_DataType*>(result);
364385

365-
if (dpnp_queue_is_cpu_c())
386+
if (m == 0)
366387
{
367-
mkl_rng::hypergeometric<_DataType> distribution(l, s, m);
368-
// perform generation
369-
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
370-
event_out.wait();
388+
dpnp_zeros_c<_DataType>(result, size);
389+
}
390+
else if (l == m)
391+
{
392+
_DataType* fill_value = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(sizeof(_DataType)));
393+
fill_value[0] = static_cast<_DataType>(s);
394+
dpnp_initval_c<_DataType>(result, fill_value, size);
395+
dpnp_memory_free_c(fill_value);
371396
}
372397
else
373398
{
374-
int errcode = viRngHypergeometric(VSL_RNG_METHOD_HYPERGEOMETRIC_H2PE, get_rng_stream(), size, result1, l, s, m);
375-
if (errcode != VSL_STATUS_OK)
399+
_DataType* result1 = reinterpret_cast<_DataType*>(result);
400+
if (dpnp_queue_is_cpu_c())
376401
{
377-
throw std::runtime_error("DPNP RNG Error: dpnp_rng_hypergeometric_c() failed.");
402+
mkl_rng::hypergeometric<_DataType> distribution(l, s, m);
403+
// perform generation
404+
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
405+
event_out.wait();
406+
}
407+
else
408+
{
409+
int errcode =
410+
viRngHypergeometric(VSL_RNG_METHOD_HYPERGEOMETRIC_H2PE, get_rng_stream(), size, result1, l, s, m);
411+
if (errcode != VSL_STATUS_OK)
412+
{
413+
throw std::runtime_error("DPNP RNG Error: dpnp_rng_hypergeometric_c() failed.");
414+
}
378415
}
379416
}
380417
}
381418

382419
template <typename _DataType>
383420
void dpnp_rng_laplace_c(void* result, const double loc, const double scale, const size_t size)
384421
{
385-
if (!size)
422+
if (!size || !result)
386423
{
387424
return;
388425
}
389-
_DataType* result1 = reinterpret_cast<_DataType*>(result);
390426

391-
mkl_rng::laplace<_DataType> distribution(loc, scale);
392-
// perform generation
393-
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
394-
event_out.wait();
427+
if (scale == 0.0)
428+
{
429+
dpnp_zeros_c<_DataType>(result, size);
430+
}
431+
else
432+
{
433+
_DataType* result1 = reinterpret_cast<_DataType*>(result);
434+
mkl_rng::laplace<_DataType> distribution(loc, scale);
435+
// perform generation
436+
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
437+
event_out.wait();
438+
}
395439
}
396440

397441
template <typename _KernelNameSpecialization>

0 commit comments

Comments
 (0)