Skip to content

Commit f08814d

Browse files
authored
ELEMWISE 2arg_3types add MKL kernels (#131)
* ELEMWISE 2arg_3types add MKL kernels
1 parent b26d624 commit f08814d

File tree

3 files changed

+38
-29
lines changed

3 files changed

+38
-29
lines changed

dpnp/backend/backend_iface.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ INP_DLLEXPORT void custom_var_c(
410410

411411
#include <custom_1arg_1type_tbl.hpp>
412412

413-
#define MACRO_CUSTOM_2ARG_3TYPES_OP(__name__, __operation__) \
413+
#define MACRO_CUSTOM_2ARG_3TYPES_OP(__name__, __operation1__, __operation2__) \
414414
template <typename _DataType_input1, typename _DataType_input2, typename _DataType_output> \
415415
INP_DLLEXPORT void custom_elemwise_##__name__##_c(void* array1, void* array2, void* result1, size_t size);
416416

dpnp/backend/custom_2arg_3type_tbl.hpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,23 @@
3131
* Parameters:
3232
* - public name of the function and kernel name
3333
* - operation used to calculate the result
34+
* - mkl operation used to calculate the result
3435
*
3536
*/
3637

3738
#ifndef MACRO_CUSTOM_2ARG_3TYPES_OP
3839
#error "MACRO_CUSTOM_2ARG_3TYPES_OP is not defined"
3940
#endif
4041

41-
MACRO_CUSTOM_2ARG_3TYPES_OP(add, input_elem1 + input_elem2)
42-
MACRO_CUSTOM_2ARG_3TYPES_OP(arctan2, cl::sycl::atan2((double)input_elem1, (double)input_elem2))
43-
MACRO_CUSTOM_2ARG_3TYPES_OP(divide, input_elem1 / input_elem2)
44-
MACRO_CUSTOM_2ARG_3TYPES_OP(fmod, cl::sycl::fmod((double)input_elem1, (double)input_elem2))
45-
MACRO_CUSTOM_2ARG_3TYPES_OP(hypot, cl::sycl::hypot((double)input_elem1, (double)input_elem2))
46-
MACRO_CUSTOM_2ARG_3TYPES_OP(maximum, cl::sycl::max(input_elem1, input_elem2))
47-
MACRO_CUSTOM_2ARG_3TYPES_OP(minimum, cl::sycl::min(input_elem1, input_elem2))
48-
MACRO_CUSTOM_2ARG_3TYPES_OP(multiply, input_elem1* input_elem2)
49-
MACRO_CUSTOM_2ARG_3TYPES_OP(power, cl::sycl::powr((double)input_elem1, (double)input_elem2))
50-
MACRO_CUSTOM_2ARG_3TYPES_OP(subtract, input_elem1 - input_elem2)
42+
MACRO_CUSTOM_2ARG_3TYPES_OP(add, input_elem1 + input_elem2, oneapi::mkl::vm::add)
43+
MACRO_CUSTOM_2ARG_3TYPES_OP(arctan2, cl::sycl::atan2((double)input_elem1, (double)input_elem2), oneapi::mkl::vm::atan2)
44+
MACRO_CUSTOM_2ARG_3TYPES_OP(divide, input_elem1 / input_elem2, oneapi::mkl::vm::div)
45+
MACRO_CUSTOM_2ARG_3TYPES_OP(fmod, cl::sycl::fmod((double)input_elem1, (double)input_elem2), oneapi::mkl::vm::fmod)
46+
MACRO_CUSTOM_2ARG_3TYPES_OP(hypot, cl::sycl::hypot((double)input_elem1, (double)input_elem2), oneapi::mkl::vm::hypot)
47+
MACRO_CUSTOM_2ARG_3TYPES_OP(maximum, cl::sycl::max(input_elem1, input_elem2), oneapi::mkl::vm::fmax)
48+
MACRO_CUSTOM_2ARG_3TYPES_OP(minimum, cl::sycl::min(input_elem1, input_elem2), oneapi::mkl::vm::fmin)
49+
MACRO_CUSTOM_2ARG_3TYPES_OP(multiply, input_elem1* input_elem2, oneapi::mkl::vm::mul)
50+
MACRO_CUSTOM_2ARG_3TYPES_OP(power, cl::sycl::powr((double)input_elem1, (double)input_elem2), oneapi::mkl::vm::pow)
51+
MACRO_CUSTOM_2ARG_3TYPES_OP(subtract, input_elem1 - input_elem2, oneapi::mkl::vm::sub)
5152

5253
#undef MACRO_CUSTOM_2ARG_3TYPES_OP

dpnp/backend/custom_kernels_elemwise.cpp

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@
105105
#include <custom_1arg_1type_tbl.hpp>
106106

107107
/* ========================================================================== */
108-
#define MACRO_CUSTOM_2ARG_3TYPES_OP(__name__, __operation__) \
108+
#define MACRO_CUSTOM_2ARG_3TYPES_OP(__name__, __operation1__, __operation2__) \
109109
template <typename _KernelNameSpecialization1, \
110110
typename _KernelNameSpecialization2, \
111111
typename _KernelNameSpecialization3> \
@@ -119,23 +119,31 @@
119119
_DataType_input2* array2 = reinterpret_cast<_DataType_input2*>(array2_in); \
120120
_DataType_output* result = reinterpret_cast<_DataType_output*>(result1); \
121121
\
122-
cl::sycl::range<1> gws(size); \
123-
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) { \
124-
size_t i = global_id[0]; /*for (size_t i = 0; i < size; ++i)*/ \
125-
{ \
126-
_DataType_output input_elem1 = array1[i]; \
127-
_DataType_output input_elem2 = array2[i]; \
128-
result[i] = __operation__; \
129-
} \
130-
}; \
131-
\
132-
auto kernel_func = [&](cl::sycl::handler& cgh) { \
133-
cgh.parallel_for< \
134-
class custom_elemwise_##__name__##_c_kernel<_DataType_input1, _DataType_input2, _DataType_output>>( \
135-
gws, kernel_parallel_for_func); \
136-
}; \
137-
\
138-
event = DPNP_QUEUE.submit(kernel_func); \
122+
if constexpr ((std::is_same<_DataType_input1, double>::value || std::is_same<_DataType_input1, float>::value) \
123+
&& std::is_same<_DataType_input2, _DataType_input1>::value) \
124+
{ \
125+
event = __operation2__(DPNP_QUEUE, size, array1, array2, result); \
126+
} \
127+
else \
128+
{ \
129+
cl::sycl::range<1> gws(size); \
130+
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) { \
131+
size_t i = global_id[0]; /*for (size_t i = 0; i < size; ++i)*/ \
132+
{ \
133+
_DataType_output input_elem1 = array1[i]; \
134+
_DataType_output input_elem2 = array2[i]; \
135+
result[i] = __operation1__; \
136+
} \
137+
}; \
138+
\
139+
auto kernel_func = [&](cl::sycl::handler& cgh) { \
140+
cgh.parallel_for<class custom_elemwise_##__name__##_c_kernel<_DataType_input1, _DataType_input2, \
141+
_DataType_output>>( \
142+
gws, kernel_parallel_for_func); \
143+
}; \
144+
\
145+
event = DPNP_QUEUE.submit(kernel_func); \
146+
} \
139147
\
140148
event.wait(); \
141149
} \

0 commit comments

Comments
 (0)