Skip to content

Commit cda7fcc

Browse files
authored
Return type desc 1 (#794)
* use descriptor as return type in aggregated pyx functions
1 parent 1d927cc commit cda7fcc

14 files changed

+253
-233
lines changed

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -245,12 +245,12 @@ cdef dpnp_DPNPFuncType_to_dtype(size_t type)
245245
"""
246246
Bitwise functions
247247
"""
248-
cpdef dparray dpnp_bitwise_and(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
249-
cpdef dparray dpnp_bitwise_or(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
250-
cpdef dparray dpnp_bitwise_xor(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
251-
cpdef dparray dpnp_invert(dpnp_descriptor x1)
252-
cpdef dparray dpnp_left_shift(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
253-
cpdef dparray dpnp_right_shift(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
248+
cpdef dpnp_descriptor dpnp_bitwise_and(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
249+
cpdef dpnp_descriptor dpnp_bitwise_or(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
250+
cpdef dpnp_descriptor dpnp_bitwise_xor(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
251+
cpdef dpnp_descriptor dpnp_invert(dpnp_descriptor x1)
252+
cpdef dpnp_descriptor dpnp_left_shift(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
253+
cpdef dpnp_descriptor dpnp_right_shift(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
254254

255255

256256
"""
@@ -287,17 +287,17 @@ cpdef dparray dpnp_init_val(shape, dtype, value)
287287
"""
288288
Mathematical functions
289289
"""
290-
cpdef dparray dpnp_add(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
291-
cpdef dparray dpnp_arctan2(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*, dparray out=*, object where=*)
292-
cpdef dparray dpnp_divide(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
293-
cpdef dparray dpnp_hypot(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
294-
cpdef dparray dpnp_maximum(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
295-
cpdef dparray dpnp_minimum(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
296-
cpdef dparray dpnp_multiply(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
297-
cpdef dparray dpnp_negative(dpnp_descriptor array1)
298-
cpdef dparray dpnp_power(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*, dparray out=*, object where=*)
299-
cpdef dparray dpnp_remainder(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
300-
cpdef dparray dpnp_subtract(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
290+
cpdef dpnp_descriptor dpnp_add(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
291+
cpdef dpnp_descriptor dpnp_arctan2(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*, dparray out=*, object where=*)
292+
cpdef dpnp_descriptor dpnp_divide(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
293+
cpdef dpnp_descriptor dpnp_hypot(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
294+
cpdef dpnp_descriptor dpnp_maximum(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
295+
cpdef dpnp_descriptor dpnp_minimum(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
296+
cpdef dpnp_descriptor dpnp_multiply(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
297+
cpdef dpnp_descriptor dpnp_negative(dpnp_descriptor array1)
298+
cpdef dpnp_descriptor dpnp_power(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*, dparray out=*, object where=*)
299+
cpdef dpnp_descriptor dpnp_remainder(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
300+
cpdef dpnp_descriptor dpnp_subtract(object x1_obj, object x2_obj, object dtype=*, dparray out=*, object where=*)
301301

302302

303303
"""
@@ -318,8 +318,8 @@ cpdef dparray dpnp_min(dparray a, axis)
318318
"""
319319
Sorting functions
320320
"""
321-
cpdef dparray dpnp_argsort(dpnp_descriptor array1)
322-
cpdef dparray dpnp_sort(dpnp_descriptor array1)
321+
cpdef dpnp_descriptor dpnp_argsort(dpnp_descriptor array1)
322+
cpdef dpnp_descriptor dpnp_sort(dpnp_descriptor array1)
323323

324324
"""
325325
Searching functions
@@ -330,28 +330,28 @@ cpdef dparray dpnp_argmin(dpnp_descriptor array1)
330330
"""
331331
Trigonometric functions
332332
"""
333-
cpdef dparray dpnp_arccos(dpnp_descriptor array1)
334-
cpdef dparray dpnp_arccosh(dpnp_descriptor array1)
335-
cpdef dparray dpnp_arcsin(dpnp_descriptor array1, dparray out)
336-
cpdef dparray dpnp_arcsinh(dpnp_descriptor array1)
337-
cpdef dparray dpnp_arctan(dpnp_descriptor array1, dparray out)
338-
cpdef dparray dpnp_arctanh(dpnp_descriptor array1)
339-
cpdef dparray dpnp_cbrt(dpnp_descriptor array1)
340-
cpdef dparray dpnp_cos(dpnp_descriptor array1, dparray out)
341-
cpdef dparray dpnp_cosh(dpnp_descriptor array1)
342-
cpdef dparray dpnp_degrees(dpnp_descriptor array1)
343-
cpdef dparray dpnp_exp(dpnp_descriptor array1, dparray out)
344-
cpdef dparray dpnp_exp2(dpnp_descriptor array1)
345-
cpdef dparray dpnp_expm1(dpnp_descriptor array1)
346-
cpdef dparray dpnp_log(dpnp_descriptor array1, dparray out)
347-
cpdef dparray dpnp_log10(dpnp_descriptor array1)
348-
cpdef dparray dpnp_log1p(dpnp_descriptor array1)
349-
cpdef dparray dpnp_log2(dpnp_descriptor array1)
350-
cpdef dparray dpnp_radians(dpnp_descriptor array1)
351-
cpdef dparray dpnp_recip(dpnp_descriptor array1)
352-
cpdef dparray dpnp_sin(dpnp_descriptor array1, dparray out)
353-
cpdef dparray dpnp_sinh(dpnp_descriptor array1)
354-
cpdef dparray dpnp_sqrt(dpnp_descriptor array1)
355-
cpdef dparray dpnp_square(dpnp_descriptor array1)
356-
cpdef dparray dpnp_tan(dpnp_descriptor array1, dparray out)
357-
cpdef dparray dpnp_tanh(dpnp_descriptor array1)
333+
cpdef dpnp_descriptor dpnp_arccos(dpnp_descriptor array1)
334+
cpdef dpnp_descriptor dpnp_arccosh(dpnp_descriptor array1)
335+
cpdef dpnp_descriptor dpnp_arcsin(dpnp_descriptor array1, dparray out)
336+
cpdef dpnp_descriptor dpnp_arcsinh(dpnp_descriptor array1)
337+
cpdef dpnp_descriptor dpnp_arctan(dpnp_descriptor array1, dparray out)
338+
cpdef dpnp_descriptor dpnp_arctanh(dpnp_descriptor array1)
339+
cpdef dpnp_descriptor dpnp_cbrt(dpnp_descriptor array1)
340+
cpdef dpnp_descriptor dpnp_cos(dpnp_descriptor array1, dparray out)
341+
cpdef dpnp_descriptor dpnp_cosh(dpnp_descriptor array1)
342+
cpdef dpnp_descriptor dpnp_degrees(dpnp_descriptor array1)
343+
cpdef dpnp_descriptor dpnp_exp(dpnp_descriptor array1, dparray out)
344+
cpdef dpnp_descriptor dpnp_exp2(dpnp_descriptor array1)
345+
cpdef dpnp_descriptor dpnp_expm1(dpnp_descriptor array1)
346+
cpdef dpnp_descriptor dpnp_log(dpnp_descriptor array1, dparray out)
347+
cpdef dpnp_descriptor dpnp_log10(dpnp_descriptor array1)
348+
cpdef dpnp_descriptor dpnp_log1p(dpnp_descriptor array1)
349+
cpdef dpnp_descriptor dpnp_log2(dpnp_descriptor array1)
350+
cpdef dpnp_descriptor dpnp_radians(dpnp_descriptor array1)
351+
cpdef dpnp_descriptor dpnp_recip(dpnp_descriptor array1)
352+
cpdef dpnp_descriptor dpnp_sin(dpnp_descriptor array1, dparray out)
353+
cpdef dpnp_descriptor dpnp_sinh(dpnp_descriptor array1)
354+
cpdef dpnp_descriptor dpnp_sqrt(dpnp_descriptor array1)
355+
cpdef dpnp_descriptor dpnp_square(dpnp_descriptor array1)
356+
cpdef dpnp_descriptor dpnp_tan(dpnp_descriptor array1, dparray out)
357+
cpdef dpnp_descriptor dpnp_tanh(dpnp_descriptor array1)

dpnp/dpnp_algo/dpnp_algo.pyx

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -236,16 +236,18 @@ cdef dpnp_DPNPFuncType_to_dtype(size_t type):
236236
utils.checker_throw_type_error("dpnp_DPNPFuncType_to_dtype", type)
237237

238238

239-
cdef dparray call_fptr_1out(DPNPFuncName fptr_name, dparray_shape_type result_shape, result_dtype):
239+
cdef utils.dpnp_descriptor call_fptr_1out(DPNPFuncName fptr_name,
240+
dparray_shape_type result_shape,
241+
result_dtype):
240242

241-
# Convert string type names (dparray.dtype) to C enum DPNPFuncType
243+
# Convert type to C enum DPNPFuncType
242244
cdef DPNPFuncType dtype_in = dpnp_dtype_to_DPNPFuncType(result_dtype)
243245

244246
# get the FPTR data structure
245247
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(fptr_name, dtype_in, dtype_in)
246248

247249
# Create result array with type given by FPTR data
248-
cdef dparray result = utils.create_output_array(result_shape, kernel_data.return_type, None)
250+
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape, kernel_data.return_type, None)
249251

250252
cdef fptr_1out_t func = <fptr_1out_t > kernel_data.ptr
251253
# Call FPTR function
@@ -254,27 +256,31 @@ cdef dparray call_fptr_1out(DPNPFuncName fptr_name, dparray_shape_type result_sh
254256
return result
255257

256258

257-
cdef dparray call_fptr_1in_1out(DPNPFuncName fptr_name, utils.dpnp_descriptor x1, dparray_shape_type result_shape, dparray out=None, func_name=None):
259+
cdef utils.dpnp_descriptor call_fptr_1in_1out(DPNPFuncName fptr_name,
260+
utils.dpnp_descriptor x1,
261+
dparray_shape_type result_shape,
262+
dparray out=None,
263+
func_name=None):
258264

259-
""" Convert string type names (dparray.dtype) to C enum DPNPFuncType """
265+
""" Convert type (x1.dtype) to C enum DPNPFuncType """
260266
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype)
261267

262268
""" get the FPTR data structure """
263269
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(fptr_name, param1_type, param1_type)
264270

265271
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
266272

267-
cdef dparray result
273+
cdef utils.dpnp_descriptor result
268274

269275
if out is None:
270276
""" Create result array with type given by FPTR data """
271-
result = utils.create_output_array(result_shape, kernel_data.return_type, None)
277+
result = utils.create_output_descriptor(result_shape, kernel_data.return_type, None)
272278
else:
273279
if out.dtype != result_type:
274280
utils.checker_throw_value_error(func_name, 'out.dtype', out.dtype, result_type)
275281
if out.shape != result_shape:
276282
utils.checker_throw_value_error(func_name, 'out.shape', out.shape, result_shape)
277-
result = out
283+
result = dpnp_descriptor(out)
278284

279285
cdef fptr_1in_1out_t func = <fptr_1in_1out_t > kernel_data.ptr
280286

@@ -283,9 +289,15 @@ cdef dparray call_fptr_1in_1out(DPNPFuncName fptr_name, utils.dpnp_descriptor x1
283289
return result
284290

285291

286-
cdef dparray call_fptr_2in_1out(DPNPFuncName fptr_name, utils.dpnp_descriptor x1_obj, utils.dpnp_descriptor x2_obj,
287-
object dtype=None, dparray out=None, object where=True, func_name=None):
288-
# Convert string type names (dparray.dtype) to C enum DPNPFuncType
292+
cdef utils.dpnp_descriptor call_fptr_2in_1out(DPNPFuncName fptr_name,
293+
utils.dpnp_descriptor x1_obj,
294+
utils.dpnp_descriptor x2_obj,
295+
object dtype=None,
296+
dparray out=None,
297+
object where=True,
298+
func_name=None):
299+
300+
# Convert type (x1_obj.dtype) to C enum DPNPFuncType
289301
cdef DPNPFuncType x1_c_type = dpnp_dtype_to_DPNPFuncType(x1_obj.dtype)
290302
cdef DPNPFuncType x2_c_type = dpnp_dtype_to_DPNPFuncType(x2_obj.dtype)
291303

@@ -298,17 +310,18 @@ cdef dparray call_fptr_2in_1out(DPNPFuncName fptr_name, utils.dpnp_descriptor x1
298310
cdef dparray_shape_type x1_shape = x1_obj.shape
299311
cdef dparray_shape_type x2_shape = x2_obj.shape
300312
cdef dparray_shape_type result_shape = utils.get_common_shape(x1_shape, x2_shape)
301-
cdef dparray result
313+
cdef utils.dpnp_descriptor result
302314

303315
if out is None:
304316
""" Create result array with type given by FPTR data """
305-
result = utils.create_output_array(result_shape, kernel_data.return_type, None)
317+
result = utils.create_output_descriptor(result_shape, kernel_data.return_type, None)
306318
else:
307319
if out.dtype != result_type:
308320
utils.checker_throw_value_error(func_name, 'out.dtype', out.dtype, result_type)
309321
if out.shape != result_shape:
310322
utils.checker_throw_value_error(func_name, 'out.shape', out.shape, result_shape)
311-
result = out
323+
324+
result = dpnp_descriptor(out)
312325

313326
""" Call FPTR function """
314327
cdef fptr_2in_1out_t func = <fptr_2in_1out_t > kernel_data.ptr

dpnp/dpnp_algo/dpnp_algo_arraycreation.pyx

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,13 @@ ctypedef void(*custom_indexing_1out_func_ptr_t)(void * , const size_t , const si
6262
ctypedef void(*fptr_dpnp_trace_t)(const void *, void * , const size_t * , const size_t)
6363

6464

65-
cpdef dparray dpnp_copy(utils.dpnp_descriptor x1, order, subok):
65+
cpdef utils.dpnp_descriptor dpnp_copy(utils.dpnp_descriptor x1, order, subok):
6666
return call_fptr_1in_1out(DPNP_FN_COPY, x1, x1.shape)
6767

6868

69-
cpdef dparray dpnp_diag(dparray v, int k):
69+
cpdef utils.dpnp_descriptor dpnp_diag(utils.dpnp_descriptor v, int k):
70+
cdef dparray_shape_type input_shape = v.shape
71+
7072
if v.ndim == 1:
7173
n = v.shape[0] + abs(k)
7274

@@ -78,7 +80,8 @@ cpdef dparray dpnp_diag(dparray v, int k):
7880

7981
shape_result = (n, )
8082

81-
cdef dparray result = dpnp.zeros(shape_result, dtype=v.dtype)
83+
result_obj = dpnp.zeros(shape_result, dtype=v.dtype) # TODO need to call dpnp_zero instead
84+
cdef utils.dpnp_descriptor result = dpnp_descriptor(result_obj)
8285

8386
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(v.dtype)
8487

@@ -87,8 +90,9 @@ cpdef dparray dpnp_diag(dparray v, int k):
8790
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
8891

8992
cdef custom_1in_1out_func_ptr_t func = <custom_1in_1out_func_ptr_t > kernel_data.ptr
93+
cdef dparray_shape_type result_shape = result.shape
9094

91-
func(v.get_data(), result.get_data(), k, < size_t * > v._dparray_shape.data(), < size_t * > result._dparray_shape.data(), v.ndim, result.ndim)
95+
func(v.get_data(), result.get_data(), k, < size_t * > input_shape.data(), < size_t * > result_shape.data(), v.ndim, result.ndim)
9296

9397
return result
9498

@@ -258,11 +262,11 @@ cpdef list dpnp_meshgrid(xi, copy, sparse, indexing):
258262
return result
259263

260264

261-
cpdef dparray dpnp_ones(result_shape, result_dtype):
265+
cpdef utils.dpnp_descriptor dpnp_ones(result_shape, result_dtype):
262266
return call_fptr_1out(DPNP_FN_ONES, utils._object_to_tuple(result_shape), result_dtype)
263267

264268

265-
cpdef dparray dpnp_ones_like(result_shape, result_dtype):
269+
cpdef utils.dpnp_descriptor dpnp_ones_like(result_shape, result_dtype):
266270
return call_fptr_1out(DPNP_FN_ONES_LIKE, utils._object_to_tuple(result_shape), result_dtype)
267271

268272

@@ -368,9 +372,9 @@ cpdef dparray dpnp_vander(dparray x1, int N, int increasing):
368372
return result
369373

370374

371-
cpdef dparray dpnp_zeros(result_shape, result_dtype):
375+
cpdef utils.dpnp_descriptor dpnp_zeros(result_shape, result_dtype):
372376
return call_fptr_1out(DPNP_FN_ZEROS, utils._object_to_tuple(result_shape), result_dtype)
373377

374378

375-
cpdef dparray dpnp_zeros_like(result_shape, result_dtype):
379+
cpdef utils.dpnp_descriptor dpnp_zeros_like(result_shape, result_dtype):
376380
return call_fptr_1out(DPNP_FN_ZEROS_LIKE, utils._object_to_tuple(result_shape), result_dtype)

dpnp/dpnp_algo/dpnp_algo_bitwise.pyx

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,24 +44,24 @@ __all__ += [
4444
]
4545

4646

47-
cpdef dparray dpnp_bitwise_and(object x1_obj, object x2_obj, object dtype=None, dparray out=None, object where=True):
47+
cpdef utils.dpnp_descriptor dpnp_bitwise_and(object x1_obj, object x2_obj, object dtype=None, dparray out=None, object where=True):
4848
return call_fptr_2in_1out(DPNP_FN_BITWISE_AND, x1_obj, x2_obj, dtype=dtype, out=out, where=where)
4949

5050

51-
cpdef dparray dpnp_bitwise_or(object x1_obj, object x2_obj, object dtype=None, dparray out=None, object where=True):
51+
cpdef utils.dpnp_descriptor dpnp_bitwise_or(object x1_obj, object x2_obj, object dtype=None, dparray out=None, object where=True):
5252
return call_fptr_2in_1out(DPNP_FN_BITWISE_OR, x1_obj, x2_obj, dtype=dtype, out=out, where=where)
5353

5454

55-
cpdef dparray dpnp_bitwise_xor(object x1_obj, object x2_obj, object dtype=None, dparray out=None, object where=True):
55+
cpdef utils.dpnp_descriptor dpnp_bitwise_xor(object x1_obj, object x2_obj, object dtype=None, dparray out=None, object where=True):
5656
return call_fptr_2in_1out(DPNP_FN_BITWISE_XOR, x1_obj, x2_obj, dtype=dtype, out=out, where=where)
5757

5858

59-
cpdef dparray dpnp_invert(dpnp_descriptor arr):
59+
cpdef utils.dpnp_descriptor dpnp_invert(dpnp_descriptor arr):
6060
return call_fptr_1in_1out(DPNP_FN_INVERT, arr, arr.shape)
6161

6262

63-
cpdef dparray dpnp_left_shift(object x1_obj, object x2_obj, object dtype=None, dparray out=None, object where=True):
63+
cpdef utils.dpnp_descriptor dpnp_left_shift(object x1_obj, object x2_obj, object dtype=None, dparray out=None, object where=True):
6464
return call_fptr_2in_1out(DPNP_FN_LEFT_SHIFT, x1_obj, x2_obj, dtype=dtype, out=out, where=where)
6565

66-
cpdef dparray dpnp_right_shift(object x1_obj, object x2_obj, object dtype=None, dparray out=None, object where=True):
66+
cpdef utils.dpnp_descriptor dpnp_right_shift(object x1_obj, object x2_obj, object dtype=None, dparray out=None, object where=True):
6767
return call_fptr_2in_1out(DPNP_FN_RIGHT_SHIFT, x1_obj, x2_obj, dtype=dtype, out=out, where=where)

dpnp/dpnp_algo/dpnp_algo_linearalgebra.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ cpdef dparray dpnp_dot(dpnp_descriptor in_array1, dpnp_descriptor in_array2):
6565
if dim1 == 0 or dim2 == 0:
6666
x1_desc = dpnp.get_dpnp_descriptor(in_array1)
6767
x2_desc = dpnp.get_dpnp_descriptor(in_array2)
68-
return dpnp_multiply(x1_desc, x2_desc)
68+
return dpnp_multiply(x1_desc, x2_desc).get_pyobj()
6969

7070
cdef size_t size1 = 0
7171
cdef size_t size2 = 0

0 commit comments

Comments
 (0)