Skip to content

Commit 19fcc96

Browse files
authored
Remove new_version from functions (#673)
* add new version to three funcs
1 parent f0d2192 commit 19fcc96

11 files changed

+351
-187
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 118 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -218,26 +218,59 @@ INP_DLLEXPORT void dpnp_elemwise_absolute_c(void* array1_in, void* result1, size
218218
* @ingroup BACKEND_API
219219
* @brief Custom implementation of dot function
220220
*
221-
* @param [in] array1 Input array.
222-
* @param [in] array2 Input array.
223-
* @param [out] result1 Output array.
224-
* @param [in] size Number of elements in input arrays.
225-
*/
226-
template <typename _DataType_input1, typename _DataType_input2, typename _DataType_output>
227-
INP_DLLEXPORT void dpnp_dot_c(void* array1, void* array2, void* result1, size_t size);
221+
* @param [out] result_out Output array.
222+
* @param [in] input1_in First input array.
223+
* @param [in] input1_size Size of first input array.
224+
* @param [in] input1_shape Shape of first input array.
225+
* @param [in] input1_shape_ndim Number of first array dimensions.
226+
* @param [in] input2_in Second input array.
227+
* @param [in] input2_size Shape of second input array.
228+
* @param [in] input2_shape Shape of first input array.
229+
* @param [in] input2_shape_ndim Number of second array dimensions.
230+
* @param [in] where Mask array.
231+
* @param [out] result1 Output array.
232+
* @param [in] size Number of elements in input arrays.
233+
*/
234+
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
235+
INP_DLLEXPORT void dpnp_dot_c(void* result_out,
236+
const void* input1_in,
237+
const size_t input1_size,
238+
const size_t* input1_shape,
239+
const size_t input1_shape_ndim,
240+
const void* input2_in,
241+
const size_t input2_size,
242+
const size_t* input2_shape,
243+
const size_t input2_shape_ndim,
244+
const size_t* where);
228245

229246
/**
230247
* @ingroup BACKEND_API
231248
* @brief Custom implementation of cross function
232249
*
233-
* @param [in] array1_in First input array.
234-
* @param [in] array2_in Second input array.
235-
* @param [out] result1 Output array.
236-
* @param [in] size Number of elements in input arrays.
237-
*
238-
*/
239-
template <typename _DataType_input1, typename _DataType_input2, typename _DataType_output>
240-
INP_DLLEXPORT void dpnp_cross_c(void* array1_in, void* array2_in, void* result1, size_t size);
250+
* @param [out] result_out Output array.
251+
* @param [in] input1_in First input array.
252+
* @param [in] input1_size Size of first input array.
253+
* @param [in] input1_shape Shape of first input array.
254+
* @param [in] input1_shape_ndim Number of first array dimensions.
255+
* @param [in] input2_in Second input array.
256+
* @param [in] input2_size Shape of second input array.
257+
* @param [in] input2_shape Shape of first input array.
258+
* @param [in] input2_shape_ndim Number of second array dimensions.
259+
* @param [in] where Mask array.
260+
* @param [out] result1 Output array.
261+
* @param [in] size Number of elements in input arrays.
262+
*/
263+
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
264+
INP_DLLEXPORT void dpnp_cross_c(void* result_out,
265+
const void* input1_in,
266+
const size_t input1_size,
267+
const size_t* input1_shape,
268+
const size_t input1_shape_ndim,
269+
const void* input2_in,
270+
const size_t input2_size,
271+
const size_t* input2_shape,
272+
const size_t input2_shape_ndim,
273+
const size_t* where);
241274

242275
/**
243276
* @ingroup BACKEND_API
@@ -436,13 +469,30 @@ INP_DLLEXPORT void dpnp_cholesky_c(void* array1_in, void* result1, const size_t
436469
* @ingroup BACKEND_API
437470
* @brief correlate function
438471
*
439-
* @param [in] array1_in Input array 1.
440-
* @param [in] array2_in Input array 2.
441-
* @param [out] result Output array.
442-
* @param [in] size Number of elements in input arrays.
443-
*/
444-
template <typename _DataType_input1, typename _DataType_input2, typename _DataType_output>
445-
INP_DLLEXPORT void dpnp_correlate_c(void* array1_in, void* array2_in, void* result, size_t size);
472+
* @param [out] result_out Output array.
473+
* @param [in] input1_in First input array.
474+
* @param [in] input1_size Size of first input array.
475+
* @param [in] input1_shape Shape of first input array.
476+
* @param [in] input1_shape_ndim Number of first array dimensions.
477+
* @param [in] input2_in Second input array.
478+
* @param [in] input2_size Shape of second input array.
479+
* @param [in] input2_shape Shape of first input array.
480+
* @param [in] input2_shape_ndim Number of second array dimensions.
481+
* @param [in] where Mask array.
482+
* @param [out] result1 Output array.
483+
* @param [in] size Number of elements in input arrays.
484+
*/
485+
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
486+
INP_DLLEXPORT void dpnp_correlate_c(void* result_out,
487+
const void* input1_in,
488+
const size_t input1_size,
489+
const size_t* input1_shape,
490+
const size_t input1_shape_ndim,
491+
const void* input2_in,
492+
const size_t input2_size,
493+
const size_t* input2_shape,
494+
const size_t input2_shape_ndim,
495+
const size_t* where);
446496

447497
/**
448498
* @ingroup BACKEND_API
@@ -782,13 +832,30 @@ INP_DLLEXPORT void dpnp_fill_diagonal_c(void* array1_in, void* val, size_t* shap
782832
* @ingroup BACKEND_API
783833
* @brief floor_divide function.
784834
*
785-
* @param [in] array1_in Input array 1.
786-
* @param [in] array2_in Input array 2.
787-
* @param [out] result1 Output array.
788-
* @param [in] size Number of elements in input arrays.
835+
* @param [out] result_out Output array.
836+
* @param [in] input1_in First input array.
837+
* @param [in] input1_size Size of first input array.
838+
* @param [in] input1_shape Shape of first input array.
839+
* @param [in] input1_shape_ndim Number of first array dimensions.
840+
* @param [in] input2_in Second input array.
841+
* @param [in] input2_size Shape of second input array.
842+
* @param [in] input2_shape Shape of first input array.
843+
* @param [in] input2_shape_ndim Number of second array dimensions.
844+
* @param [in] where Mask array.
845+
* @param [out] result1 Output array.
846+
* @param [in] size Number of elements in input arrays.
789847
*/
790848
template <typename _DataType_input1, typename _DataType_input2, typename _DataType_output>
791-
INP_DLLEXPORT void dpnp_floor_divide_c(void* array1_in, void* array2_in, void* result1, size_t size);
849+
INP_DLLEXPORT void dpnp_floor_divide_c(void* result_out,
850+
const void* input1_in,
851+
const size_t input1_size,
852+
const size_t* input1_shape,
853+
const size_t input1_shape_ndim,
854+
const void* input2_in,
855+
const size_t input2_size,
856+
const size_t* input2_shape,
857+
const size_t input2_shape_ndim,
858+
const size_t* where);
792859

793860
/**
794861
* @ingroup BACKEND_API
@@ -826,13 +893,30 @@ INP_DLLEXPORT void dpnp_ones_like_c(void* result, size_t size);
826893
* @ingroup BACKEND_API
827894
* @brief remainder function.
828895
*
829-
* @param [in] array1_in Input array 1.
830-
* @param [in] array2_in Input array 2.
831-
* @param [out] result1 Output array.
832-
* @param [in] size Number of elements in input arrays.
833-
*/
834-
template <typename _DataType_input1, typename _DataType_input2, typename _DataType_output>
835-
INP_DLLEXPORT void dpnp_remainder_c(void* array1_in, void* array2_in, void* result1, size_t size);
896+
* @param [out] result_out Output array.
897+
* @param [in] input1_in First input array.
898+
* @param [in] input1_size Size of first input array.
899+
* @param [in] input1_shape Shape of first input array.
900+
* @param [in] input1_shape_ndim Number of first array dimensions.
901+
* @param [in] input2_in Second input array.
902+
* @param [in] input2_size Shape of second input array.
903+
* @param [in] input2_shape Shape of first input array.
904+
* @param [in] input2_shape_ndim Number of second array dimensions.
905+
* @param [in] where Mask array.
906+
* @param [out] result1 Output array.
907+
* @param [in] size Number of elements in input arrays.
908+
*/
909+
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
910+
INP_DLLEXPORT void dpnp_remainder_c(void* result_out,
911+
const void* input1_in,
912+
const size_t input1_size,
913+
const size_t* input1_shape,
914+
const size_t input1_shape_ndim,
915+
const void* input2_in,
916+
const size_t input2_size,
917+
const size_t* input2_shape,
918+
const size_t input2_shape_ndim,
919+
const size_t* where);
836920

837921
/**
838922
* @ingroup BACKEND_API

dpnp/backend/kernels/dpnp_krnl_common.cpp

Lines changed: 46 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,32 @@ void dpnp_astype_c(const void* array1_in, void* result1, const size_t size)
7474
template <typename _KernelNameSpecialization1, typename _KernelNameSpecialization2, typename _KernelNameSpecialization3>
7575
class dpnp_dot_c_kernel;
7676

77-
template <typename _DataType_input1, typename _DataType_input2, typename _DataType_output>
78-
void dpnp_dot_c(void* array1_in, void* array2_in, void* result1, size_t size)
77+
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
78+
void dpnp_dot_c(void* result_out,
79+
const void* input1_in,
80+
const size_t input1_size,
81+
const size_t* input1_shape,
82+
const size_t input1_shape_ndim,
83+
const void* input2_in,
84+
const size_t input2_size,
85+
const size_t* input2_shape,
86+
const size_t input2_shape_ndim,
87+
const size_t* where)
7988
{
89+
90+
(void)input1_shape;
91+
(void)input1_shape_ndim;
92+
(void)input2_size;
93+
(void)input2_shape;
94+
(void)input2_shape_ndim;
95+
(void)where;
96+
8097
cl::sycl::event event;
81-
_DataType_input1* array_1 = reinterpret_cast<_DataType_input1*>(array1_in);
82-
_DataType_input2* array_2 = reinterpret_cast<_DataType_input2*>(array2_in);
83-
_DataType_output* result = reinterpret_cast<_DataType_output*>(result1);
98+
_DataType_input1* input1 = reinterpret_cast<_DataType_input1*>(const_cast<void*>(input1_in));
99+
_DataType_input2* input2 = reinterpret_cast<_DataType_input2*>(const_cast<void*>(input2_in));
100+
_DataType_output* result = reinterpret_cast<_DataType_output*>(result_out);
84101

85-
if (!size)
102+
if (!input1_size)
86103
{
87104
return;
88105
}
@@ -92,29 +109,29 @@ void dpnp_dot_c(void* array1_in, void* array2_in, void* result1, size_t size)
92109
std::is_same<_DataType_output, _DataType_input1>::value)
93110
{
94111
event = mkl_blas::dot(DPNP_QUEUE,
95-
size,
96-
array_1,
97-
1, // array_1 stride
98-
array_2,
99-
1, // array_2 stride
112+
input1_size,
113+
input1,
114+
1, // input1 stride
115+
input2,
116+
1, // input2 stride
100117
result);
101118
event.wait();
102119
}
103120
else
104121
{
105122
_DataType_output* local_mem =
106-
reinterpret_cast<_DataType_output*>(dpnp_memory_alloc_c(size * sizeof(_DataType_output)));
123+
reinterpret_cast<_DataType_output*>(dpnp_memory_alloc_c(input1_size * sizeof(_DataType_output)));
107124

108125
// what about reduction??
109-
cl::sycl::range<1> gws(size);
126+
cl::sycl::range<1> gws(input1_size);
110127

111128
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
112129
const size_t index = global_id[0];
113-
local_mem[index] = array_1[index] * array_2[index];
130+
local_mem[index] = input1[index] * input2[index];
114131
};
115132

116133
auto kernel_func = [&](cl::sycl::handler& cgh) {
117-
cgh.parallel_for<class dpnp_dot_c_kernel<_DataType_input1, _DataType_input2, _DataType_output>>(
134+
cgh.parallel_for<class dpnp_dot_c_kernel<_DataType_output, _DataType_input1, _DataType_input2>>(
118135
gws, kernel_parallel_for_func);
119136
};
120137

@@ -123,11 +140,11 @@ void dpnp_dot_c(void* array1_in, void* array2_in, void* result1, size_t size)
123140
event.wait();
124141

125142
auto policy = oneapi::dpl::execution::make_device_policy<
126-
class dpnp_dot_c_kernel<_DataType_input1, _DataType_input2, _DataType_output>>(DPNP_QUEUE);
143+
class dpnp_dot_c_kernel<_DataType_output, _DataType_input1, _DataType_input2>>(DPNP_QUEUE);
127144

128145
_DataType_output accumulator = 0;
129146
accumulator =
130-
std::reduce(policy, local_mem, local_mem + size, _DataType_output(0), std::plus<_DataType_output>());
147+
std::reduce(policy, local_mem, local_mem + input1_size, _DataType_output(0), std::plus<_DataType_output>());
131148
policy.queue().wait();
132149

133150
result[0] = accumulator;
@@ -389,20 +406,20 @@ void func_map_init_linalg(func_map_t& fmap)
389406
eft_C128, (void*)dpnp_astype_c<std::complex<double>, std::complex<double>>};
390407

391408
fmap[DPNPFuncName::DPNP_FN_DOT][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_dot_c<int, int, int>};
392-
fmap[DPNPFuncName::DPNP_FN_DOT][eft_INT][eft_LNG] = {eft_LNG, (void*)dpnp_dot_c<int, long, long>};
393-
fmap[DPNPFuncName::DPNP_FN_DOT][eft_INT][eft_FLT] = {eft_DBL, (void*)dpnp_dot_c<int, float, double>};
394-
fmap[DPNPFuncName::DPNP_FN_DOT][eft_INT][eft_DBL] = {eft_DBL, (void*)dpnp_dot_c<int, double, double>};
395-
fmap[DPNPFuncName::DPNP_FN_DOT][eft_LNG][eft_INT] = {eft_LNG, (void*)dpnp_dot_c<long, int, long>};
409+
fmap[DPNPFuncName::DPNP_FN_DOT][eft_INT][eft_LNG] = {eft_LNG, (void*)dpnp_dot_c<long, int, long>};
410+
fmap[DPNPFuncName::DPNP_FN_DOT][eft_INT][eft_FLT] = {eft_DBL, (void*)dpnp_dot_c<double, int, float>};
411+
fmap[DPNPFuncName::DPNP_FN_DOT][eft_INT][eft_DBL] = {eft_DBL, (void*)dpnp_dot_c<double, int, double>};
412+
fmap[DPNPFuncName::DPNP_FN_DOT][eft_LNG][eft_INT] = {eft_LNG, (void*)dpnp_dot_c<long, long, int>};
396413
fmap[DPNPFuncName::DPNP_FN_DOT][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_dot_c<long, long, long>};
397-
fmap[DPNPFuncName::DPNP_FN_DOT][eft_LNG][eft_FLT] = {eft_DBL, (void*)dpnp_dot_c<long, float, double>};
398-
fmap[DPNPFuncName::DPNP_FN_DOT][eft_LNG][eft_DBL] = {eft_DBL, (void*)dpnp_dot_c<long, double, double>};
399-
fmap[DPNPFuncName::DPNP_FN_DOT][eft_FLT][eft_INT] = {eft_DBL, (void*)dpnp_dot_c<float, int, double>};
400-
fmap[DPNPFuncName::DPNP_FN_DOT][eft_FLT][eft_LNG] = {eft_DBL, (void*)dpnp_dot_c<float, long, double>};
414+
fmap[DPNPFuncName::DPNP_FN_DOT][eft_LNG][eft_FLT] = {eft_DBL, (void*)dpnp_dot_c<double, long, float>};
415+
fmap[DPNPFuncName::DPNP_FN_DOT][eft_LNG][eft_DBL] = {eft_DBL, (void*)dpnp_dot_c<double, long, double>};
416+
fmap[DPNPFuncName::DPNP_FN_DOT][eft_FLT][eft_INT] = {eft_DBL, (void*)dpnp_dot_c<double, float, int>};
417+
fmap[DPNPFuncName::DPNP_FN_DOT][eft_FLT][eft_LNG] = {eft_DBL, (void*)dpnp_dot_c<double, float, long>};
401418
fmap[DPNPFuncName::DPNP_FN_DOT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_dot_c<float, float, float>};
402-
fmap[DPNPFuncName::DPNP_FN_DOT][eft_FLT][eft_DBL] = {eft_DBL, (void*)dpnp_dot_c<float, double, double>};
403-
fmap[DPNPFuncName::DPNP_FN_DOT][eft_DBL][eft_INT] = {eft_DBL, (void*)dpnp_dot_c<double, int, double>};
404-
fmap[DPNPFuncName::DPNP_FN_DOT][eft_DBL][eft_LNG] = {eft_DBL, (void*)dpnp_dot_c<double, long, double>};
405-
fmap[DPNPFuncName::DPNP_FN_DOT][eft_DBL][eft_FLT] = {eft_DBL, (void*)dpnp_dot_c<double, float, double>};
419+
fmap[DPNPFuncName::DPNP_FN_DOT][eft_FLT][eft_DBL] = {eft_DBL, (void*)dpnp_dot_c<double, float, double>};
420+
fmap[DPNPFuncName::DPNP_FN_DOT][eft_DBL][eft_INT] = {eft_DBL, (void*)dpnp_dot_c<double, double, int>};
421+
fmap[DPNPFuncName::DPNP_FN_DOT][eft_DBL][eft_LNG] = {eft_DBL, (void*)dpnp_dot_c<double, double, long>};
422+
fmap[DPNPFuncName::DPNP_FN_DOT][eft_DBL][eft_FLT] = {eft_DBL, (void*)dpnp_dot_c<double, double, float>};
406423
fmap[DPNPFuncName::DPNP_FN_DOT][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_dot_c<double, double, double>};
407424

408425
fmap[DPNPFuncName::DPNP_FN_EIG][eft_INT][eft_INT] = {eft_DBL, (void*)dpnp_eig_c<int, double>};

0 commit comments

Comments
 (0)