Skip to content

Commit 63625f6

Browse files
authored
Support array-scalar operations for 10 more funcs (#648)
* Support array-scalar operations for 10 more funcs
1 parent a452d3c commit 63625f6

18 files changed

+913
-854
lines changed

dpnp/backend/include/dpnp_gen_2arg_3type_tbl.hpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,18 @@
5656

5757
#endif
5858

59-
MACRO_2ARG_3TYPES_OP(dpnp_add_c, input_elem1 + input_elem2, oneapi::mkl::vm::add)
60-
MACRO_2ARG_3TYPES_OP(dpnp_arctan2_c, cl::sycl::atan2((double)input_elem1, (double)input_elem2), oneapi::mkl::vm::atan2)
59+
MACRO_2ARG_3TYPES_OP(dpnp_add_c, input1_elem + input2_elem, oneapi::mkl::vm::add)
60+
MACRO_2ARG_3TYPES_OP(dpnp_arctan2_c, cl::sycl::atan2((double)input1_elem, (double)input2_elem), oneapi::mkl::vm::atan2)
6161
MACRO_2ARG_3TYPES_OP(dpnp_copysign_c,
62-
cl::sycl::copysign((double)input_elem1, (double)input_elem2),
62+
cl::sycl::copysign((double)input1_elem, (double)input2_elem),
6363
oneapi::mkl::vm::copysign)
64-
MACRO_2ARG_3TYPES_OP(dpnp_divide_c, input_elem1 / input_elem2, oneapi::mkl::vm::div)
65-
MACRO_2ARG_3TYPES_OP(dpnp_fmod_c, cl::sycl::fmod((double)input_elem1, (double)input_elem2), oneapi::mkl::vm::fmod)
66-
MACRO_2ARG_3TYPES_OP(dpnp_hypot_c, cl::sycl::hypot((double)input_elem1, (double)input_elem2), oneapi::mkl::vm::hypot)
67-
MACRO_2ARG_3TYPES_OP(dpnp_maximum_c, cl::sycl::max(input_elem1, input_elem2), oneapi::mkl::vm::fmax)
68-
MACRO_2ARG_3TYPES_OP(dpnp_minimum_c, cl::sycl::min(input_elem1, input_elem2), oneapi::mkl::vm::fmin)
69-
MACRO_2ARG_3TYPES_OP(dpnp_power_c, cl::sycl::pow((double)input_elem1, (double)input_elem2), oneapi::mkl::vm::pow)
70-
MACRO_2ARG_3TYPES_OP(dpnp_subtract_c, input_elem1 - input_elem2, oneapi::mkl::vm::sub)
64+
MACRO_2ARG_3TYPES_OP(dpnp_divide_c, input1_elem / input2_elem, oneapi::mkl::vm::div)
65+
MACRO_2ARG_3TYPES_OP(dpnp_fmod_c, cl::sycl::fmod((double)input1_elem, (double)input2_elem), oneapi::mkl::vm::fmod)
66+
MACRO_2ARG_3TYPES_OP(dpnp_hypot_c, cl::sycl::hypot((double)input1_elem, (double)input2_elem), oneapi::mkl::vm::hypot)
67+
MACRO_2ARG_3TYPES_OP(dpnp_maximum_c, cl::sycl::max(input1_elem, input2_elem), oneapi::mkl::vm::fmax)
68+
MACRO_2ARG_3TYPES_OP(dpnp_minimum_c, cl::sycl::min(input1_elem, input2_elem), oneapi::mkl::vm::fmin)
69+
MACRO_2ARG_3TYPES_OP(dpnp_multiply_c, input1_elem * input2_elem, oneapi::mkl::vm::mul)
70+
MACRO_2ARG_3TYPES_OP(dpnp_power_c, cl::sycl::pow((double)input1_elem, (double)input2_elem), oneapi::mkl::vm::pow)
71+
MACRO_2ARG_3TYPES_OP(dpnp_subtract_c, input1_elem - input2_elem, oneapi::mkl::vm::sub)
7172

7273
#undef MACRO_2ARG_3TYPES_OP

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -703,8 +703,16 @@ INP_DLLEXPORT void dpnp_invert_c(void* array1_in, void* result, size_t size);
703703

704704
#define MACRO_2ARG_1TYPE_OP(__name__, __operation__) \
705705
template <typename _DataType> \
706-
INP_DLLEXPORT void __name__( \
707-
void* result1, const void* array1, const size_t size1, const void* array2, const size_t size2);
706+
INP_DLLEXPORT void __name__(void* result_out, \
707+
const void* input1_in, \
708+
const size_t input1_size, \
709+
const size_t* input1_shape, \
710+
const size_t input1_shape_ndim, \
711+
const void* input2_in, \
712+
const size_t input2_size, \
713+
const size_t* input2_shape, \
714+
const size_t input2_shape_ndim, \
715+
const size_t* where);
708716

709717
#include <dpnp_gen_2arg_1type_tbl.hpp>
710718

@@ -721,8 +729,17 @@ INP_DLLEXPORT void dpnp_invert_c(void* array1_in, void* result, size_t size);
721729
#include <dpnp_gen_1arg_2type_tbl.hpp>
722730

723731
#define MACRO_2ARG_3TYPES_OP(__name__, __operation1__, __operation2__) \
724-
template <typename _DataType_input1, typename _DataType_input2, typename _DataType_output> \
725-
INP_DLLEXPORT void __name__(void* array1, void* array2, void* result1, size_t size);
732+
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2> \
733+
INP_DLLEXPORT void __name__(void* result_out, \
734+
const void* input1_in, \
735+
const size_t input1_size, \
736+
const size_t* input1_shape, \
737+
const size_t input1_shape_ndim, \
738+
const void* input2_in, \
739+
const size_t input2_size, \
740+
const size_t* input2_shape, \
741+
const size_t input2_shape_ndim, \
742+
const size_t* where);
726743

727744
#include <dpnp_gen_2arg_3type_tbl.hpp>
728745

@@ -762,33 +779,6 @@ INP_DLLEXPORT void dpnp_floor_divide_c(void* array1_in, void* array2_in, void* r
762779
template <typename _DataType_input, typename _DataType_output>
763780
INP_DLLEXPORT void dpnp_modf_c(void* array1_in, void* result1_out, void* result2_out, size_t size);
764781

765-
/**
766-
* @ingroup BACKEND_API
767-
* @brief multiply function.
768-
*
769-
* @param [out] result_out Output array.
770-
* @param [in] input1_in Input 1 either array or scalar.
771-
* @param [in] input1_size Number of elements in input 1.
772-
* @param [in] input1_shape Shape of input 1.
773-
* @param [in] input1_shape_ndim Size of shape 1.
774-
* @param [in] input2_in Input 2 either array or scalar.
775-
* @param [in] input2_size Number of elements in input 2.
776-
* @param [in] input2_shape Shape of input 2.
777-
* @param [in] input2_shape_ndim Size of shape 2.
778-
* @param [in] where Mask array.
779-
*/
780-
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
781-
INP_DLLEXPORT void dpnp_multiply_c(void* result_out,
782-
const void* input1_in,
783-
const size_t input1_size,
784-
const size_t* input1_shape,
785-
const size_t input1_shape_ndim,
786-
const void* input2_in,
787-
const size_t input2_size,
788-
const size_t* input2_shape,
789-
const size_t input2_shape_ndim,
790-
const size_t* where);
791-
792782
/**
793783
* @ingroup BACKEND_API
794784
* @brief Implementation of ones function

dpnp/backend/kernels/dpnp_krnl_bitwise.cpp

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,27 +70,41 @@ static void func_map_init_bitwise_1arg_1type(func_map_t& fmap)
7070
class __name__##_kernel; \
7171
\
7272
template <typename _DataType> \
73-
void __name__(void* result1, const void* array1_in, const size_t size1, const void* array2_in, const size_t size2) \
73+
void __name__(void* result_out, \
74+
const void* input1_in, \
75+
const size_t input1_size, \
76+
const size_t* input1_shape, \
77+
const size_t input1_shape_ndim, \
78+
const void* input2_in, \
79+
const size_t input2_size, \
80+
const size_t* input2_shape, \
81+
const size_t input2_shape_ndim, \
82+
const size_t* where) \
7483
{ \
75-
if (!size1 || !size2) \
84+
/* avoid warning unused variable*/ \
85+
(void)input1_shape; \
86+
(void)input1_shape_ndim; \
87+
(void)input2_shape; \
88+
(void)input2_shape_ndim; \
89+
(void)where; \
90+
\
91+
if (!input1_size || !input2_size) \
7692
{ \
7793
return; \
7894
} \
7995
\
8096
cl::sycl::event event; \
81-
const _DataType* array1 = reinterpret_cast<const _DataType*>(array1_in); \
82-
const _DataType* array2 = reinterpret_cast<const _DataType*>(array2_in); \
83-
_DataType* result = reinterpret_cast<_DataType*>(result1); \
97+
const _DataType* input1 = reinterpret_cast<const _DataType*>(input1_in); \
98+
const _DataType* input2 = reinterpret_cast<const _DataType*>(input2_in); \
99+
_DataType* result = reinterpret_cast<_DataType*>(result_out); \
84100
\
85-
const size_t gws_size = std::max(size1, size2); \
101+
const size_t gws_size = std::max(input1_size, input2_size); \
86102
cl::sycl::range<1> gws(gws_size); \
87103
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) { \
88104
size_t i = global_id[0]; /*for (size_t i = 0; i < size; ++i)*/ \
89-
{ \
90-
const _DataType input_elem1 = (size1 == 1) ? array1[0] : array1[i]; \
91-
const _DataType input_elem2 = (size2 == 1) ? array2[0] : array2[i]; \
92-
result[i] = __operation__; \
93-
} \
105+
const _DataType input_elem1 = (input1_size == 1) ? input1[0] : input1[i]; \
106+
const _DataType input_elem2 = (input2_size == 1) ? input2[0] : input2[i]; \
107+
result[i] = __operation__; \
94108
}; \
95109
\
96110
auto kernel_func = [&](cl::sycl::handler& cgh) { \

0 commit comments

Comments
 (0)