Skip to content

Commit 5424ba0

Browse files
ENH: update random.binomial (#650)
* ENH: update random.binomial
1 parent 63625f6 commit 5424ba0

File tree

2 files changed

+36
-25
lines changed

2 files changed

+36
-25
lines changed

dpnp/backend/kernels/dpnp_krnl_random.cpp

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,27 +108,45 @@ INP_DLLEXPORT void dpnp_rng_beta_c(void* result, const _DataType a, const _DataT
108108
template <typename _DataType>
109109
void dpnp_rng_binomial_c(void* result, const int ntrial, const double p, const size_t size)
110110
{
111+
if (result == nullptr)
112+
{
113+
return;
114+
}
111115
if (!size)
112116
{
113117
return;
114118
}
115119
_DataType* result1 = reinterpret_cast<_DataType*>(result);
116120

117-
if (dpnp_queue_is_cpu_c())
121+
if (ntrial == 0 || p == 0)
118122
{
119-
mkl_rng::binomial<_DataType> distribution(ntrial, p);
120-
// perform generation
121-
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
122-
event_out.wait();
123+
dpnp_zeros_c<_DataType>(result, size);
124+
}
125+
else if (p == 1)
126+
{
127+
_DataType* fill_value = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(sizeof(_DataType)));
128+
fill_value[0] = static_cast<_DataType>(ntrial);
129+
dpnp_initval_c<_DataType>(result, fill_value, size);
130+
dpnp_memory_free_c(fill_value);
123131
}
124132
else
125133
{
126-
int errcode = viRngBinomial(VSL_RNG_METHOD_BINOMIAL_BTPE, get_rng_stream(), size, result1, ntrial, p);
127-
if (errcode != VSL_STATUS_OK)
134+
if (dpnp_queue_is_cpu_c())
128135
{
129-
throw std::runtime_error("DPNP RNG Error: dpnp_rng_binomial_c() failed.");
136+
mkl_rng::binomial<_DataType> distribution(ntrial, p);
137+
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
138+
event_out.wait();
139+
}
140+
else
141+
{
142+
int errcode = viRngBinomial(VSL_RNG_METHOD_BINOMIAL_BTPE, get_rng_stream(), size, result1, ntrial, p);
143+
if (errcode != VSL_STATUS_OK)
144+
{
145+
throw std::runtime_error("DPNP RNG Error: dpnp_rng_binomial_c() failed.");
146+
}
130147
}
131148
}
149+
return;
132150
}
133151

134152
template <typename _DataType>

dpnp/random/dpnp_algo_random.pyx

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -177,26 +177,19 @@ cpdef dparray dpnp_rng_binomial(int ntrial, double p, size):
177177
cdef DPNPFuncData kernel_data
178178
cdef fptr_dpnp_rng_binomial_c_1out_t func
179179

180-
if ntrial == 0 or p == 0.0:
181-
result = dparray(size, dtype=dtype)
182-
result.fill(0.0)
183-
elif p == 1.0:
184-
result = dparray(size, dtype=dtype)
185-
result.fill(ntrial)
186-
else:
187-
# convert string type names (dparray.dtype) to C enum DPNPFuncType
188-
param1_type = dpnp_dtype_to_DPNPFuncType(dtype)
180+
# convert string type names (dparray.dtype) to C enum DPNPFuncType
181+
param1_type = dpnp_dtype_to_DPNPFuncType(dtype)
189182

190-
# get the FPTR data structure
191-
kernel_data = get_dpnp_function_ptr(DPNP_FN_RNG_BINOMIAL, param1_type, param1_type)
183+
# get the FPTR data structure
184+
kernel_data = get_dpnp_function_ptr(DPNP_FN_RNG_BINOMIAL, param1_type, param1_type)
192185

193-
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
194-
# ceate result array with type given by FPTR data
195-
result = dparray(size, dtype=result_type)
186+
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
187+
# ceate result array with type given by FPTR data
188+
result = dparray(size, dtype=result_type)
196189

197-
func = <fptr_dpnp_rng_binomial_c_1out_t > kernel_data.ptr
198-
# call FPTR function
199-
func(result.get_data(), ntrial, p, result.size)
190+
func = <fptr_dpnp_rng_binomial_c_1out_t > kernel_data.ptr
191+
# call FPTR function
192+
func(result.get_data(), ntrial, p, result.size)
200193

201194
return result
202195

0 commit comments

Comments
 (0)