Skip to content

Commit 2d76330

Browse files
authored
use sycl_adapter in krnl_mathematical (#895)
1 parent 575ba79 commit 2d76330

File tree

1 file changed

+25
-13
lines changed

1 file changed

+25
-13
lines changed

dpnp/backend/kernels/dpnp_krnl_mathematical.cpp

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ void dpnp_around_c(const void* input_in, void* result_out, const size_t input_si
4949
}
5050

5151
cl::sycl::event event;
52-
_DataType* input = reinterpret_cast<_DataType*>(const_cast<void*>(input_in));
52+
DPNPC_ptr_adapter<_DataType> input1_ptr(input_in, input_size);
53+
_DataType* input = input1_ptr.get_ptr();
5354
_DataType* result = reinterpret_cast<_DataType*>(result_out);
5455

5556
if constexpr (std::is_same<_DataType, double>::value || std::is_same<_DataType, float>::value)
@@ -175,8 +176,10 @@ void dpnp_cumprod_c(void* array1_in, void* result1, size_t size)
175176
return;
176177
}
177178

178-
_DataType_input* array1 = reinterpret_cast<_DataType_input*>(array1_in);
179-
_DataType_output* result = reinterpret_cast<_DataType_output*>(result1);
179+
DPNPC_ptr_adapter<_DataType_input> input1_ptr(array1_in, size, true);
180+
DPNPC_ptr_adapter<_DataType_output> result_ptr(result1, size, true, true);
181+
_DataType_input* array1 = input1_ptr.get_ptr();
182+
_DataType_output* result = result_ptr.get_ptr();
180183

181184
_DataType_output cur_res = 1;
182185

@@ -200,8 +203,10 @@ void dpnp_cumsum_c(void* array1_in, void* result1, size_t size)
200203
return;
201204
}
202205

203-
_DataType_input* array1 = reinterpret_cast<_DataType_input*>(array1_in);
204-
_DataType_output* result = reinterpret_cast<_DataType_output*>(result1);
206+
DPNPC_ptr_adapter<_DataType_input> input1_ptr(array1_in, size, true);
207+
DPNPC_ptr_adapter<_DataType_output> result_ptr(result1, size, true, true);
208+
_DataType_input* array1 = input1_ptr.get_ptr();
209+
_DataType_output* result = result_ptr.get_ptr();
205210

206211
_DataType_output cur_res = 0;
207212

@@ -236,8 +241,10 @@ void dpnp_floor_divide_c(void* result_out,
236241
return;
237242
}
238243

239-
_DataType_input1* input1_data = reinterpret_cast<_DataType_input1*>(const_cast<void*>(input1_in));
240-
_DataType_input2* input2_data = reinterpret_cast<_DataType_input2*>(const_cast<void*>(input2_in));
244+
DPNPC_ptr_adapter<_DataType_input1> input1_ptr(input1_in, input1_size);
245+
DPNPC_ptr_adapter<_DataType_input2> input2_ptr(input2_in, input2_size);
246+
_DataType_input1* input1_data = input1_ptr.get_ptr();
247+
_DataType_input2* input2_data = input2_ptr.get_ptr();
241248
_DataType_output* result = reinterpret_cast<_DataType_output*>(result_out);
242249

243250
std::vector<size_t> result_shape =
@@ -307,7 +314,8 @@ template <typename _DataType_input, typename _DataType_output>
307314
void dpnp_modf_c(void* array1_in, void* result1_out, void* result2_out, size_t size)
308315
{
309316
cl::sycl::event event;
310-
_DataType_input* array1 = reinterpret_cast<_DataType_input*>(array1_in);
317+
DPNPC_ptr_adapter<_DataType_input> input1_ptr(array1_in, size);
318+
_DataType_input* array1 = input1_ptr.get_ptr();
311319
_DataType_output* result1 = reinterpret_cast<_DataType_output*>(result1_out);
312320
_DataType_output* result2 = reinterpret_cast<_DataType_output*>(result2_out);
313321

@@ -359,8 +367,10 @@ void dpnp_remainder_c(void* result_out,
359367
return;
360368
}
361369

362-
_DataType_input1* input1_data = reinterpret_cast<_DataType_input1*>(const_cast<void*>(input1_in));
363-
_DataType_input2* input2_data = reinterpret_cast<_DataType_input2*>(const_cast<void*>(input2_in));
370+
DPNPC_ptr_adapter<_DataType_input1> input1_ptr(input1_in, input1_size);
371+
DPNPC_ptr_adapter<_DataType_input2> input2_ptr(input2_in, input2_size);
372+
_DataType_input1* input1_data = input1_ptr.get_ptr();
373+
_DataType_input2* input2_data = input2_ptr.get_ptr();
364374
_DataType_output* result = reinterpret_cast<_DataType_output*>(result_out);
365375

366376
std::vector<size_t> result_shape = get_result_shape(input1_shape, input1_shape_ndim,
@@ -441,13 +451,15 @@ void dpnp_trapz_c(
441451
}
442452

443453
cl::sycl::event event;
444-
_DataType_input1* array1 = reinterpret_cast<_DataType_input1*>(const_cast<void*>(array1_in));
445-
_DataType_input2* array2 = reinterpret_cast<_DataType_input2*>(const_cast<void*>(array2_in));
454+
DPNPC_ptr_adapter<_DataType_input1> input1_ptr(array1_in, array1_size);
455+
DPNPC_ptr_adapter<_DataType_input2> input2_ptr(array2_in, array2_size);
456+
_DataType_input1* array1 = input1_ptr.get_ptr();
457+
_DataType_input2* array2 = input2_ptr.get_ptr();
446458
_DataType_output* result = reinterpret_cast<_DataType_output*>(result1);
447459

448460
if (array1_size < 2)
449461
{
450-
result[0] = 0;
462+
result[0] = 0; // TODO make it on SYCL QUEUE via memcpy
451463
return;
452464
}
453465

0 commit comments

Comments
 (0)