Skip to content

Commit 113e6be

Browse files
authored
add kernel for cross function (#535)
* start to implement
1 parent 915adf7 commit 113e6be

File tree

5 files changed

+52
-23
lines changed

5 files changed

+52
-23
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,19 @@ INP_DLLEXPORT void dpnp_elemwise_absolute_c(void* array1_in, void* result1, size
156156
template <typename _DataType_input1, typename _DataType_input2, typename _DataType_output>
157157
INP_DLLEXPORT void dpnp_dot_c(void* array1, void* array2, void* result1, size_t size);
158158

159+
/**
160+
* @ingroup BACKEND_API
161+
* @brief Custom implementation of cross function
162+
*
163+
* @param [in] array1_in First input array.
164+
* @param [in] array2_in Second input array.
165+
* @param [out] result1 Output array.
166+
* @param [in] size Number of elements in input arrays.
167+
*
168+
*/
169+
template <typename _DataType_input1, typename _DataType_input2, typename _DataType_output>
170+
INP_DLLEXPORT void dpnp_cross_c(void* array1_in, void* array2_in, void* result1, size_t size);
171+
159172
/**
160173
* @ingroup BACKEND_API
161174
* @brief Sum of array elements

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ enum class DPNPFuncName : size_t
8383
DPNP_FN_COS, /**< Used in numpy.cos() implementation */
8484
DPNP_FN_COSH, /**< Used in numpy.cosh() implementation */
8585
DPNP_FN_COV, /**< Used in numpy.cov() implementation */
86+
DPNP_FN_CROSS, /**< Used in numpy.cross() implementation */
8687
DPNP_FN_DEGREES, /**< Used in numpy.degrees() implementation */
8788
DPNP_FN_DET, /**< Used in numpy.linalg.det() implementation */
8889
DPNP_FN_DIVIDE, /**< Used in numpy.divide() implementation */

dpnp/backend/kernels/dpnp_krnl_mathematical.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,25 @@ template void dpnp_elemwise_absolute_c<float>(void* array1_in, void* result1, si
8383
template void dpnp_elemwise_absolute_c<long>(void* array1_in, void* result1, size_t size);
8484
template void dpnp_elemwise_absolute_c<int>(void* array1_in, void* result1, size_t size);
8585

86+
template <typename _KernelNameSpecialization1, typename _KernelNameSpecialization2, typename _KernelNameSpecialization3>
87+
class dpnp_cross_c_kernel;
88+
89+
template <typename _DataType_input1, typename _DataType_input2, typename _DataType_output>
90+
void dpnp_cross_c(void* array1_in, void* array2_in, void* result1, size_t size)
91+
{
92+
_DataType_input1* array1 = reinterpret_cast<_DataType_input1*>(array1_in);
93+
_DataType_input2* array2 = reinterpret_cast<_DataType_input2*>(array2_in);
94+
_DataType_output* result = reinterpret_cast<_DataType_output*>(result1);
95+
96+
result[0] = array1[1] * array2[2] - array1[2] * array2[1];
97+
98+
result[1] = array1[2] * array2[0] - array1[0] * array2[2];
99+
100+
result[2] = array1[0] * array2[1] - array1[1] * array2[0];
101+
102+
return;
103+
}
104+
86105
template <typename _KernelNameSpecialization1, typename _KernelNameSpecialization2, typename _KernelNameSpecialization3>
87106
class dpnp_floor_divide_c_kernel;
88107

@@ -214,6 +233,23 @@ void func_map_init_mathematical(func_map_t& fmap)
214233
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_elemwise_absolute_c<float>};
215234
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_elemwise_absolute_c<double>};
216235

236+
fmap[DPNPFuncName::DPNP_FN_CROSS][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_cross_c<int, int, int>};
237+
fmap[DPNPFuncName::DPNP_FN_CROSS][eft_INT][eft_LNG] = {eft_LNG, (void*)dpnp_cross_c<int, long, long>};
238+
fmap[DPNPFuncName::DPNP_FN_CROSS][eft_INT][eft_FLT] = {eft_DBL, (void*)dpnp_cross_c<int, float, double>};
239+
fmap[DPNPFuncName::DPNP_FN_CROSS][eft_INT][eft_DBL] = {eft_DBL, (void*)dpnp_cross_c<int, double, double>};
240+
fmap[DPNPFuncName::DPNP_FN_CROSS][eft_LNG][eft_INT] = {eft_LNG, (void*)dpnp_cross_c<long, int, long>};
241+
fmap[DPNPFuncName::DPNP_FN_CROSS][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_cross_c<long, long, long>};
242+
fmap[DPNPFuncName::DPNP_FN_CROSS][eft_LNG][eft_FLT] = {eft_DBL, (void*)dpnp_cross_c<long, float, double>};
243+
fmap[DPNPFuncName::DPNP_FN_CROSS][eft_LNG][eft_DBL] = {eft_DBL, (void*)dpnp_cross_c<long, double, double>};
244+
fmap[DPNPFuncName::DPNP_FN_CROSS][eft_FLT][eft_INT] = {eft_DBL, (void*)dpnp_cross_c<float, int, double>};
245+
fmap[DPNPFuncName::DPNP_FN_CROSS][eft_FLT][eft_LNG] = {eft_DBL, (void*)dpnp_cross_c<float, long, double>};
246+
fmap[DPNPFuncName::DPNP_FN_CROSS][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_cross_c<float, float, float>};
247+
fmap[DPNPFuncName::DPNP_FN_CROSS][eft_FLT][eft_DBL] = {eft_DBL, (void*)dpnp_cross_c<float, double, double>};
248+
fmap[DPNPFuncName::DPNP_FN_CROSS][eft_DBL][eft_INT] = {eft_DBL, (void*)dpnp_cross_c<double, int, double>};
249+
fmap[DPNPFuncName::DPNP_FN_CROSS][eft_DBL][eft_LNG] = {eft_DBL, (void*)dpnp_cross_c<double, long, double>};
250+
fmap[DPNPFuncName::DPNP_FN_CROSS][eft_DBL][eft_FLT] = {eft_DBL, (void*)dpnp_cross_c<double, float, double>};
251+
fmap[DPNPFuncName::DPNP_FN_CROSS][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_cross_c<double, double, double>};
252+
217253
fmap[DPNPFuncName::DPNP_FN_FLOOR_DIVIDE][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_floor_divide_c<int, int, int>};
218254
fmap[DPNPFuncName::DPNP_FN_FLOOR_DIVIDE][eft_INT][eft_LNG] = {eft_LNG, (void*)dpnp_floor_divide_c<int, long, long>};
219255
fmap[DPNPFuncName::DPNP_FN_FLOOR_DIVIDE][eft_INT][eft_FLT] = {eft_DBL,

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
5656
DPNP_FN_COS
5757
DPNP_FN_COSH
5858
DPNP_FN_COV
59+
DPNP_FN_CROSS
5960
DPNP_FN_DEGREES
6061
DPNP_FN_DET
6162
DPNP_FN_DIVIDE

dpnp/dpnp_algo/dpnp_algo_mathematical.pyx

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -136,29 +136,7 @@ cpdef dparray dpnp_copysign(dparray x1, dparray x2):
136136

137137

138138
cpdef dparray dpnp_cross(dparray x1, dparray x2):
139-
140-
types_map = {
141-
(dpnp.int32, dpnp.int32): dpnp.int32,
142-
(dpnp.int32, dpnp.int64): dpnp.int64,
143-
(dpnp.int64, dpnp.int32): dpnp.int64,
144-
(dpnp.int64, dpnp.int64): dpnp.int64,
145-
(dpnp.float32, dpnp.float32): dpnp.float32,
146-
}
147-
148-
res_type = types_map.get((x1.dtype.type, x2.dtype.type), dpnp.float64)
149-
150-
cdef dparray result = dparray(3, dtype=res_type)
151-
152-
cur_res = x1[1] * x2[2] - x1[2] * x2[1]
153-
result._setitem_scalar(0, cur_res)
154-
155-
cur_res = x1[2] * x2[0] - x1[0] * x2[2]
156-
result._setitem_scalar(1, cur_res)
157-
158-
cur_res = x1[0] * x2[1] - x1[1] * x2[0]
159-
result._setitem_scalar(2, cur_res)
160-
161-
return result
139+
return call_fptr_2in_1out(DPNP_FN_CROSS, x1, x2, x1.shape)
162140

163141

164142
cpdef dparray dpnp_cumprod(dparray x1, bint usenan=False):

0 commit comments

Comments
 (0)