Skip to content

Commit 2b86d94

Browse files
authored
add trace (#651)
* add trace
1 parent 6dfc6db commit 2b86d94

File tree

8 files changed

+181
-1
lines changed

8 files changed

+181
-1
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,18 @@ INP_DLLEXPORT void dpnp_std_c(
705705
template <typename _DataType, typename _IndecesType>
706706
INP_DLLEXPORT void dpnp_take_c(void* array, void* indices, void* result, size_t size);
707707

708+
/**
709+
* @ingroup BACKEND_API
710+
* @brief math library implementation of trace function
711+
*
712+
* @param [in] array Input array with data.
713+
* @param [out] result Output array.
714+
* @param [in] shape Shape of input array.
715+
* @param [in] ndim Number of elements in array.shape.
716+
*/
717+
template <typename _DataType, typename _ResultType>
718+
INP_DLLEXPORT void dpnp_trace_c(const void* array, void* result, const size_t* shape, const size_t ndim);
719+
708720
/**
709721
* @ingroup BACKEND_API
710722
* @brief math library implementation of take function

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ enum class DPNPFuncName : size_t
199199
DPNP_FN_TAN, /**< Used in numpy.tan() implementation */
200200
DPNP_FN_TANH, /**< Used in numpy.tanh() implementation */
201201
DPNP_FN_TRANSPOSE, /**< Used in numpy.transpose() implementation */
202+
DPNP_FN_TRACE, /**< Used in numpy.trace() implementation */
202203
DPNP_FN_TRAPZ, /**< Used in numpy.trapz() implementation */
203204
DPNP_FN_TRI, /**< Used in numpy.tri() implementation */
204205
DPNP_FN_TRIL, /**< Used in numpy.tril() implementation */

dpnp/backend/kernels/dpnp_krnl_arraycreation.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,63 @@ void dpnp_vander_c(const void* array1_in, void* result1, const size_t size_in, c
211211
}
212212
}
213213

214+
template <typename _DataType, typename _ResultType>
215+
class dpnp_trace_c_kernel;
216+
217+
template <typename _DataType, typename _ResultType>
218+
void dpnp_trace_c(const void* array1_in, void* result1, const size_t* shape, const size_t ndim)
219+
{
220+
cl::sycl::event event;
221+
222+
if ((array1_in == nullptr) || (result1 == nullptr))
223+
{
224+
return;
225+
}
226+
227+
const _DataType* array_in = reinterpret_cast<const _DataType*>(array1_in);
228+
_ResultType* result = reinterpret_cast<_ResultType*>(result1);
229+
230+
if (shape == nullptr)
231+
{
232+
return;
233+
}
234+
235+
if (ndim == 0)
236+
{
237+
return;
238+
}
239+
240+
size_t size = 1;
241+
for (size_t i = 0; i < ndim - 1; ++i)
242+
{
243+
size *= shape[i];
244+
}
245+
246+
if (size == 0)
247+
{
248+
return;
249+
}
250+
251+
cl::sycl::range<1> gws(size);
252+
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
253+
size_t i = global_id[0];
254+
_DataType elem = 0;
255+
for (size_t j = 0; j < shape[ndim - 1]; ++j)
256+
{
257+
elem += array_in[i * shape[ndim - 1] + j];
258+
}
259+
result[i] = elem;
260+
};
261+
262+
auto kernel_func = [&](cl::sycl::handler& cgh) {
263+
cgh.parallel_for<class dpnp_trace_c_kernel<_DataType, _ResultType>>(gws, kernel_parallel_for_func);
264+
};
265+
266+
event = DPNP_QUEUE.submit(kernel_func);
267+
268+
event.wait();
269+
}
270+
214271
template <typename _DataType>
215272
class dpnp_tri_c_kernel;
216273

@@ -539,6 +596,23 @@ void func_map_init_arraycreation(func_map_t& fmap)
539596
fmap[DPNPFuncName::DPNP_FN_VANDER][eft_C128][eft_C128] = {
540597
eft_C128, (void*)dpnp_vander_c<std::complex<double>, std::complex<double>>};
541598

599+
fmap[DPNPFuncName::DPNP_FN_TRACE][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_trace_c<int, int>};
600+
fmap[DPNPFuncName::DPNP_FN_TRACE][eft_LNG][eft_INT] = {eft_INT, (void*)dpnp_trace_c<long, int>};
601+
fmap[DPNPFuncName::DPNP_FN_TRACE][eft_FLT][eft_INT] = {eft_INT, (void*)dpnp_trace_c<float, int>};
602+
fmap[DPNPFuncName::DPNP_FN_TRACE][eft_DBL][eft_INT] = {eft_INT, (void*)dpnp_trace_c<double, int>};
603+
fmap[DPNPFuncName::DPNP_FN_TRACE][eft_INT][eft_LNG] = {eft_LNG, (void*)dpnp_trace_c<int, long>};
604+
fmap[DPNPFuncName::DPNP_FN_TRACE][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_trace_c<long, long>};
605+
fmap[DPNPFuncName::DPNP_FN_TRACE][eft_FLT][eft_LNG] = {eft_LNG, (void*)dpnp_trace_c<float, long>};
606+
fmap[DPNPFuncName::DPNP_FN_TRACE][eft_DBL][eft_LNG] = {eft_LNG, (void*)dpnp_trace_c<double, long>};
607+
fmap[DPNPFuncName::DPNP_FN_TRACE][eft_INT][eft_FLT] = {eft_FLT, (void*)dpnp_trace_c<int, float>};
608+
fmap[DPNPFuncName::DPNP_FN_TRACE][eft_LNG][eft_FLT] = {eft_FLT, (void*)dpnp_trace_c<long, float>};
609+
fmap[DPNPFuncName::DPNP_FN_TRACE][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_trace_c<float, float>};
610+
fmap[DPNPFuncName::DPNP_FN_TRACE][eft_DBL][eft_FLT] = {eft_FLT, (void*)dpnp_trace_c<double, float>};
611+
fmap[DPNPFuncName::DPNP_FN_TRACE][eft_INT][eft_DBL] = {eft_DBL, (void*)dpnp_trace_c<int, double>};
612+
fmap[DPNPFuncName::DPNP_FN_TRACE][eft_LNG][eft_DBL] = {eft_DBL, (void*)dpnp_trace_c<long, double>};
613+
fmap[DPNPFuncName::DPNP_FN_TRACE][eft_FLT][eft_DBL] = {eft_DBL, (void*)dpnp_trace_c<float, double>};
614+
fmap[DPNPFuncName::DPNP_FN_TRACE][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_trace_c<double, double>};
615+
542616
fmap[DPNPFuncName::DPNP_FN_TRI][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_tri_c<int>};
543617
fmap[DPNPFuncName::DPNP_FN_TRI][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_tri_c<long>};
544618
fmap[DPNPFuncName::DPNP_FN_TRI][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_tri_c<float>};

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
171171
DPNP_FN_TAKE
172172
DPNP_FN_TAN
173173
DPNP_FN_TANH
174+
DPNP_FN_TRACE
174175
DPNP_FN_TRANSPOSE
175176
DPNP_FN_TRAPZ
176177
DPNP_FN_TRI

dpnp/dpnp_algo/dpnp_algo_arraycreation.pyx

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ __all__ += [
5252
"dpnp_meshgrid",
5353
"dpnp_ones",
5454
"dpnp_ones_like",
55+
"dpnp_trace",
5556
"dpnp_tri",
5657
"dpnp_tril",
5758
"dpnp_triu",
@@ -64,6 +65,7 @@ __all__ += [
6465
ctypedef void(*custom_1in_1out_func_ptr_t)(void *, void * , const int , size_t * , size_t * , const size_t, const size_t)
6566
ctypedef void(*ftpr_custom_vander_1in_1out_t)(void *, void *, size_t, size_t, int)
6667
ctypedef void(*custom_indexing_1out_func_ptr_t)(void * , const size_t , const size_t , const int)
68+
ctypedef void(*fptr_dpnp_trace_t)(const void *, void * , const size_t * , const size_t)
6769

6870

6971
cpdef dparray dpnp_copy(dparray x1, order, subok):
@@ -270,6 +272,29 @@ cpdef dparray dpnp_ones_like(result_shape, result_dtype):
270272
return call_fptr_1out(DPNP_FN_ONES_LIKE, result_shape, result_dtype)
271273

272274

275+
cpdef dparray dpnp_trace(arr, offset=0, axis1=0, axis2=1, dtype=None, out=None):
276+
if dtype is None:
277+
dtype_ = arr.dtype
278+
else:
279+
dtype_ = dtype
280+
281+
cdef dparray diagonal_arr = dpnp.diagonal(arr, offset, axis1, axis2)
282+
283+
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(arr.dtype)
284+
cdef DPNPFuncType param2_type = dpnp_dtype_to_DPNPFuncType(dtype_)
285+
286+
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_TRACE, param1_type, param2_type)
287+
288+
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
289+
cdef dparray result = dparray(diagonal_arr.shape[:-1], dtype=result_type)
290+
291+
cdef fptr_dpnp_trace_t func = <fptr_dpnp_trace_t > kernel_data.ptr
292+
293+
func(diagonal_arr.get_data(), result.get_data(), < size_t * > diagonal_arr._dparray_shape.data(), diagonal_arr.ndim)
294+
295+
return result
296+
297+
273298
cpdef dparray dpnp_tri(N, M=None, k=0, dtype=numpy.float):
274299
if M is None:
275300
M = N

dpnp/dpnp_iface_arraycreation.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
"ogrid",
7676
"ones",
7777
"ones_like",
78+
"trace",
7879
"tri",
7980
"tril",
8081
"triu",
@@ -1072,6 +1073,36 @@ def ones_like(x1, dtype=None, order='C', subok=False, shape=None):
10721073
return numpy.ones_like(x1, dtype, order, subok, shape)
10731074

10741075

1076+
def trace(arr, offset=0, axis1=0, axis2=1, dtype=None, out=None):
1077+
"""
1078+
Return the sum along diagonals of the array.
1079+
1080+
For full documentation refer to :obj:`numpy.trace`.
1081+
1082+
Limitations
1083+
-----------
1084+
Input array is supported as :obj:`dpnp.ndarray`.
1085+
Parameters ``axis1``, ``axis2``, ``out`` and ``dtype`` are supported only with default values.
1086+
"""
1087+
if not use_origin_backend():
1088+
if not isinstance(arr, dparray):
1089+
pass
1090+
elif arr.size == 0:
1091+
pass
1092+
elif arr.ndim < 2:
1093+
pass
1094+
elif axis1 != 0:
1095+
pass
1096+
elif axis2 != 1:
1097+
pass
1098+
elif out is not None and (not isinstance(out, dparray) or (isinstance(out, dparray) and out.shape != arr.shape)):
1099+
pass
1100+
else:
1101+
return dpnp_trace(arr, offset, axis1, axis2, dtype, out)
1102+
1103+
return call_origin(numpy.trace, arr, offset, axis1, axis2, dtype, out)
1104+
1105+
10751106
def tri(N, M=None, k=0, dtype=numpy.float, **kwargs):
10761107
"""
10771108
An array with ones at and below the given diagonal and zeros elsewhere.

tests/test_arraycreation.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,43 @@ def test_loadtxt(type):
151151
numpy.testing.assert_array_equal(dpnp_res, np_res)
152152

153153

154+
@pytest.mark.parametrize("dtype",
155+
[numpy.float64, numpy.float32, numpy.int64, numpy.int32],
156+
ids=['float64', 'float32', 'int64', 'int32'])
157+
@pytest.mark.parametrize("type",
158+
[numpy.float64, numpy.float32, numpy.int64, numpy.int32],
159+
ids=['float64', 'float32', 'int64', 'int32'])
160+
@pytest.mark.parametrize("offset",
161+
[0, 1],
162+
ids=['0', '1'])
163+
@pytest.mark.parametrize("array",
164+
[[[0, 0], [0, 0]],
165+
[[1, 2], [1, 2]],
166+
[[1, 2], [3, 4]],
167+
[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
168+
[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]],
169+
[[[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]],
170+
[[[[1, 2], [3, 4]], [[1, 2], [2, 1]]], [[[1, 3], [3, 1]], [[0, 1], [1, 3]]]],
171+
[[[[1, 2, 3], [3, 4, 5]], [[1, 2, 3], [2, 1, 0]]], [
172+
[[1, 3, 5], [3, 1, 0]], [[0, 1, 2], [1, 3, 4]]]],
173+
[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], [[[13, 14, 15], [16, 17, 18]], [[19, 20, 21], [22, 23, 24]]]]],
174+
ids=['[[0, 0], [0, 0]]',
175+
'[[1, 2], [1, 2]]',
176+
'[[1, 2], [3, 4]]',
177+
'[[0, 1, 2], [3, 4, 5], [6, 7, 8]]',
178+
'[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]',
179+
'[[[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]]',
180+
'[[[[1, 2], [3, 4]], [[1, 2], [2, 1]]], [[[1, 3], [3, 1]], [[0, 1], [1, 3]]]]',
181+
'[[[[1, 2, 3], [3, 4, 5]], [[1, 2, 3], [2, 1, 0]]], [[[1, 3, 5], [3, 1, 0]], [[0, 1, 2], [1, 3, 4]]]]',
182+
'[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], [[[13, 14, 15], [16, 17, 18]], [[19, 20, 21], [22, 23, 24]]]]'])
183+
def test_trace(array, offset, type, dtype):
184+
a = numpy.array(array, type)
185+
ia = dpnp.array(array, type)
186+
expected = numpy.trace(a, offset=offset, dtype=dtype)
187+
result = dpnp.trace(ia, offset=offset, dtype=dtype)
188+
numpy.testing.assert_array_equal(expected, result)
189+
190+
154191
@pytest.mark.parametrize("N",
155192
[0, 1, 2, 3, 4],
156193
ids=['0', '1', '2', '3', '4'])

tests_external/skipped_tests_numpy.tbl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1145,7 +1145,6 @@ tests/test_numeric.py::TestNonarrayArgs::test_reshape
11451145
tests/test_numeric.py::TestNonarrayArgs::test_searchsorted
11461146
tests/test_numeric.py::TestNonarrayArgs::test_size
11471147
tests/test_numeric.py::TestNonarrayArgs::test_std
1148-
tests/test_numeric.py::TestNonarrayArgs::test_trace
11491148
tests/test_numeric.py::TestNonarrayArgs::test_var
11501149
tests/test_numeric.py::TestNonzero::test_array_method
11511150
tests/test_numeric.py::TestNonzero::test_count_nonzero_axis

0 commit comments

Comments
 (0)