Skip to content

Commit d246a3d

Browse files
ENH: update random.shuffle; on device (#676)
* ENH: update random.shuffle; on device
1 parent 2b86d94 commit d246a3d

File tree

1 file changed

+34
-21
lines changed

1 file changed

+34
-21
lines changed

dpnp/backend/kernels/dpnp_krnl_random.cpp

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -946,17 +946,20 @@ template <typename _DataType>
946946
void dpnp_rng_shuffle_c(
947947
void* result, const size_t itemsize, const size_t ndim, const size_t high_dim_size, const size_t size)
948948
{
949-
if (!(size) || !(high_dim_size > 1))
949+
if (!result)
950950
{
951951
return;
952952
}
953953

954-
char* result1 = reinterpret_cast<char*>(result);
954+
if (!size || !ndim || !(high_dim_size > 1))
955+
{
956+
return;
957+
}
955958

956-
double* Uvec = nullptr;
959+
char* result1 = reinterpret_cast<char*>(result);
957960

958961
size_t uvec_size = high_dim_size - 1;
959-
Uvec = reinterpret_cast<double*>(dpnp_memory_alloc_c(uvec_size * sizeof(double)));
962+
double* Uvec = reinterpret_cast<double*>(dpnp_memory_alloc_c(uvec_size * sizeof(double)));
960963
mkl_rng::uniform<double> uniform_distribution(0.0, 1.0);
961964
auto uniform_event = mkl_rng::generate(uniform_distribution, DPNP_RNG_ENGINE, uvec_size, Uvec);
962965
uniform_event.wait();
@@ -966,42 +969,52 @@ void dpnp_rng_shuffle_c(
966969
// Fast, statically typed path: shuffle the underlying buffer.
967970
// Only for non-empty, 1d objects of class ndarray (subclasses such
968971
// as MaskedArrays may not support this approach).
969-
// TODO
970-
// kernel
971-
char* buf = nullptr;
972-
buf = reinterpret_cast<char*>(dpnp_memory_alloc_c(itemsize * sizeof(char)));
972+
char* buf = reinterpret_cast<char*>(dpnp_memory_alloc_c(itemsize * sizeof(char)));
973973
for (size_t i = uvec_size; i > 0; i--)
974974
{
975975
size_t j = (size_t)(floor((i + 1) * Uvec[i - 1]));
976-
memcpy(buf, result1 + j * itemsize, itemsize);
977-
memcpy(result1 + j * itemsize, result1 + i * itemsize, itemsize);
978-
memcpy(result1 + i * itemsize, buf, itemsize);
976+
if (i != j)
977+
{
978+
auto memcpy1 =
979+
DPNP_QUEUE.submit([&](cl::sycl::handler& h) { h.memcpy(buf, result1 + j * itemsize, itemsize); });
980+
auto memcpy2 = DPNP_QUEUE.submit([&](cl::sycl::handler& h) {
981+
h.depends_on({memcpy1});
982+
h.memcpy(result1 + j * itemsize, result1 + i * itemsize, itemsize);
983+
});
984+
auto memcpy3 = DPNP_QUEUE.submit([&](cl::sycl::handler& h) {
985+
h.depends_on({memcpy2});
986+
h.memcpy(result1 + i * itemsize, buf, itemsize);
987+
});
988+
memcpy3.wait();
989+
}
979990
}
980-
981991
dpnp_memory_free_c(buf);
982992
}
983993
else
984994
{
985995
// Multidimensional ndarrays require a bounce buffer.
986-
// TODO
987-
// kernel
988-
char* buf = nullptr;
989996
size_t step_size = (size / high_dim_size) * itemsize; // size in bytes for x[i] element
990-
buf = reinterpret_cast<char*>(dpnp_memory_alloc_c(step_size * sizeof(char)));
997+
char* buf = reinterpret_cast<char*>(dpnp_memory_alloc_c(step_size * sizeof(char)));
991998
for (size_t i = uvec_size; i > 0; i--)
992999
{
9931000
size_t j = (size_t)(floor((i + 1) * Uvec[i - 1]));
9941001
if (j < i)
9951002
{
996-
memcpy(buf, result1 + j * step_size, step_size);
997-
memcpy(result1 + j * step_size, result1 + i * step_size, step_size);
998-
memcpy(result1 + i * step_size, buf, step_size);
1003+
auto memcpy1 =
1004+
DPNP_QUEUE.submit([&](cl::sycl::handler& h) { h.memcpy(buf, result1 + j * step_size, step_size); });
1005+
auto memcpy2 = DPNP_QUEUE.submit([&](cl::sycl::handler& h) {
1006+
h.depends_on({memcpy1});
1007+
h.memcpy(result1 + j * step_size, result1 + i * step_size, step_size);
1008+
});
1009+
auto memcpy3 = DPNP_QUEUE.submit([&](cl::sycl::handler& h) {
1010+
h.depends_on({memcpy2});
1011+
h.memcpy(result1 + i * step_size, buf, step_size);
1012+
});
1013+
memcpy3.wait();
9991014
}
10001015
}
1001-
10021016
dpnp_memory_free_c(buf);
10031017
}
1004-
10051018
dpnp_memory_free_c(Uvec);
10061019
}
10071020

0 commit comments

Comments
 (0)