Skip to content

Commit 270b45c

Browse files
shssfu75572
andauthored
use sycl_adapter in krnl_random (#915)
* use sycl_adapter in krnl_random * more cases Co-authored-by: u75572 <u75572@s001-n157.aidevcloud>
1 parent 6325553 commit 270b45c

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

dpnp/backend/kernels/dpnp_krnl_random.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
#include "dpnp_fptr.hpp"
3535
#include "dpnp_utils.hpp"
36+
#include "dpnpc_memory_adapter.hpp"
3637
#include "queue_sycl.hpp"
3738

3839
namespace mkl_blas = oneapi::mkl::blas;
@@ -209,7 +210,8 @@ void dpnp_rng_f_c(void* result, const _DataType df_num, const _DataType df_den,
209210
_DataType scale = 2.0 / df_num;
210211
_DataType* den = nullptr;
211212

212-
_DataType* result1 = reinterpret_cast<_DataType*>(result);
213+
DPNPC_ptr_adapter<_DataType> result1_ptr(result, size, true, true);
214+
_DataType* result1 = result1_ptr.get_ptr();
213215

214216
if (dpnp_queue_is_cpu_c())
215217
{
@@ -604,7 +606,8 @@ void dpnp_rng_noncentral_chisquare_c(void* result, const _DataType df, const _Da
604606
{
605607
return;
606608
}
607-
_DataType* result1 = reinterpret_cast<_DataType*>(result);
609+
DPNPC_ptr_adapter<_DataType> result1_ptr(result, size, true, true);
610+
_DataType* result1 = result1_ptr.get_ptr();
608611

609612
const _DataType d_zero = _DataType(0.0);
610613
const _DataType d_one = _DataType(1.0);
@@ -939,7 +942,8 @@ void dpnp_rng_rayleigh_c(void* result, const _DataType scale, const size_t size)
939942
const _DataType a = 0.0;
940943
const _DataType beta = 2.0;
941944

942-
_DataType* result1 = reinterpret_cast<_DataType*>(result);
945+
DPNPC_ptr_adapter<_DataType> result1_ptr(result, size, true, true);
946+
_DataType* result1 = result1_ptr.get_ptr();
943947

944948
mkl_rng::exponential<_DataType> distribution(a, beta);
945949

@@ -970,7 +974,8 @@ void dpnp_rng_shuffle_c(
970974
return;
971975
}
972976

973-
char* result1 = reinterpret_cast<char*>(result);
977+
DPNPC_ptr_adapter<char> result1_ptr(result, size, true, true);
978+
char* result1 = result1_ptr.get_ptr();
974979

975980
size_t uvec_size = high_dim_size - 1;
976981
double* Uvec = reinterpret_cast<double*>(dpnp_memory_alloc_c(uvec_size * sizeof(double)));
@@ -1248,7 +1253,9 @@ void dpnp_rng_vonmises_large_kappa_c(void* result, const _DataType mu, const _Da
12481253
{
12491254
return;
12501255
}
1251-
_DataType* result1 = reinterpret_cast<_DataType*>(result);
1256+
1257+
DPNPC_ptr_adapter<_DataType> result1_ptr(result, size, true, true);
1258+
_DataType* result1 = result1_ptr.get_ptr();
12521259

12531260
_DataType r_over_two_kappa, recip_two_kappa;
12541261
_DataType s_minus_one, hpt, r_over_two_kappa_minus_one, rho_minus_one;
@@ -1343,7 +1350,9 @@ void dpnp_rng_vonmises_small_kappa_c(void* result, const _DataType mu, const _Da
13431350
{
13441351
return;
13451352
}
1346-
_DataType* result1 = reinterpret_cast<_DataType*>(result);
1353+
1354+
DPNPC_ptr_adapter<_DataType> result1_ptr(result, size, true, true);
1355+
_DataType* result1 = result1_ptr.get_ptr();
13471356

13481357
_DataType rho_over_kappa, rho, r, s_kappa;
13491358
_DataType* Uvec = nullptr;
@@ -1538,7 +1547,9 @@ void dpnp_rng_zipf_c(void* result, const _DataType a, const size_t size)
15381547
long X;
15391548
const _DataType d_zero = 0.0;
15401549
const _DataType d_one = 1.0;
1541-
_DataType* result1 = reinterpret_cast<_DataType*>(result);
1550+
1551+
DPNPC_ptr_adapter<_DataType> result1_ptr(result, size, true, true);
1552+
_DataType* result1 = result1_ptr.get_ptr();
15421553

15431554
am1 = a - d_one;
15441555
b = pow(2.0, am1);

0 commit comments

Comments
 (0)