Skip to content

Commit 3fe0040

Browse files
authored
trace() to desc (#809)
1 parent 5e876eb commit 3fe0040

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

dpnp/dpnp_algo/dpnp_algo_arraycreation.pyx

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,25 +271,28 @@ cpdef utils.dpnp_descriptor dpnp_ones_like(result_shape, result_dtype):
271271
return call_fptr_1out(DPNP_FN_ONES_LIKE, utils._object_to_tuple(result_shape), result_dtype)
272272

273273

274-
cpdef dparray dpnp_trace(arr, offset=0, axis1=0, axis2=1, dtype=None, out=None):
274+
cpdef utils.dpnp_descriptor dpnp_trace(utils.dpnp_descriptor arr, offset=0, axis1=0, axis2=1, dtype=None, out=None):
275275
if dtype is None:
276276
dtype_ = arr.dtype
277277
else:
278278
dtype_ = dtype
279279

280280
cdef dparray diagonal_arr = dpnp.diagonal(arr, offset, axis1, axis2)
281+
cdef size_t diagonal_ndim = diagonal_arr.ndim
282+
cdef dparray_shape_type diagonal_shape = diagonal_arr.shape
281283

282284
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(arr.dtype)
283285
cdef DPNPFuncType param2_type = dpnp_dtype_to_DPNPFuncType(dtype_)
284286

285287
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_TRACE, param1_type, param2_type)
286288

287-
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
288-
cdef dparray result = dparray(diagonal_arr.shape[:-1], dtype=result_type)
289+
# ceate result array with type given by FPTR data
290+
cdef dparray_shape_type result_shape = diagonal_shape[:-1]
291+
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape, kernel_data.return_type, None)
289292

290293
cdef fptr_dpnp_trace_t func = <fptr_dpnp_trace_t > kernel_data.ptr
291294

292-
func(diagonal_arr.get_data(), result.get_data(), < size_t * > diagonal_arr._dparray_shape.data(), diagonal_arr.ndim)
295+
func(diagonal_arr.get_data(), result.get_data(), < size_t * > diagonal_shape.data(), diagonal_ndim)
293296

294297
return result
295298

dpnp/dpnp_iface_arraycreation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1131,7 +1131,7 @@ def trace(x1, offset=0, axis1=0, axis2=1, dtype=None, out=None):
11311131
elif out is not None:
11321132
pass
11331133
else:
1134-
return dpnp_trace(x1, offset, axis1, axis2, dtype, out) #.get_pyobj()
1134+
return dpnp_trace(x1_desc, offset, axis1, axis2, dtype, out).get_pyobj()
11351135

11361136
return call_origin(numpy.trace, x1, offset, axis1, axis2, dtype, out)
11371137

0 commit comments

Comments
 (0)