Skip to content

Commit 0d592e7

Browse files
authored
Tril/triu functions to desc (#807)
1 parent b0e9ba1 commit 0d592e7

File tree

2 files changed

+48
-46
lines changed

2 files changed

+48
-46
lines changed

dpnp/dpnp_algo/dpnp_algo_arraycreation.pyx

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -318,54 +318,55 @@ cpdef dparray dpnp_tri(N, M=None, k=0, dtype=numpy.float):
318318
return result
319319

320320

321-
cpdef dparray dpnp_tril(dparray m, int k):
321+
cpdef utils.dpnp_descriptor dpnp_tril(utils.dpnp_descriptor m, int k):
322+
cdef dparray_shape_type input_shape = m.shape
323+
cdef dparray_shape_type result_shape
324+
322325
if m.ndim == 1:
323326
result_shape = (m.shape[0], m.shape[0])
324327
else:
325328
result_shape = m.shape
326329

327-
result_ndim = len(result_shape)
328-
cdef dparray result = dparray(result_shape, dtype=m.dtype)
329-
330330
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(m.dtype)
331-
332331
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_TRIL, param1_type, param1_type)
333332

334-
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
333+
# ceate result array with type given by FPTR data
334+
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape, kernel_data.return_type, None)
335335

336336
cdef custom_1in_1out_func_ptr_t func = <custom_1in_1out_func_ptr_t > kernel_data.ptr
337-
338-
func(m.get_data(), result.get_data(), k, < size_t * > m._dparray_shape.data(), < size_t * > result._dparray_shape.data(), m.ndim, result.ndim)
337+
func(m.get_data(), result.get_data(), k, < size_t * > input_shape.data(), < size_t * > result_shape.data(), m.ndim, result.ndim)
339338

340339
return result
341340

342341

343-
cpdef dparray dpnp_triu(dparray m, int k):
342+
cpdef utils.dpnp_descriptor dpnp_triu(utils.dpnp_descriptor m, int k):
343+
cdef dparray_shape_type input_shape = m.shape
344+
cdef dparray_shape_type result_shape
345+
344346
if m.ndim == 1:
345-
res_shape = (m.shape[0], m.shape[0])
347+
result_shape = (m.shape[0], m.shape[0])
346348
else:
347-
res_shape = m.shape
348-
349-
cdef dparray result = dparray(shape=res_shape, dtype=m.dtype)
349+
result_shape = m.shape
350350

351351
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(m.dtype)
352-
353352
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_TRIU, param1_type, param1_type)
354353

355-
cdef custom_1in_1out_func_ptr_t func = <custom_1in_1out_func_ptr_t > kernel_data.ptr
354+
# ceate result array with type given by FPTR data
355+
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape, kernel_data.return_type, None)
356356

357-
func(m.get_data(), result.get_data(), k, < size_t * > m._dparray_shape.data(), < size_t * > result._dparray_shape.data(), m.ndim, result.ndim)
357+
cdef custom_1in_1out_func_ptr_t func = <custom_1in_1out_func_ptr_t > kernel_data.ptr
358+
func(m.get_data(), result.get_data(), k, < size_t * > input_shape.data(), < size_t * > result_shape.data(), m.ndim, result.ndim)
358359

359360
return result
360361

361362

362-
cpdef dparray dpnp_vander(dparray x1, int N, int increasing):
363+
cpdef utils.dpnp_descriptor dpnp_vander(utils.dpnp_descriptor x1, int N, int increasing):
363364
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype)
364-
365365
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_VANDER, param1_type, DPNP_FT_NONE)
366366

367-
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
368-
cdef dparray result = dparray((x1.size, N), dtype=result_type)
367+
# ceate result array with type given by FPTR data
368+
cdef dparray_shape_type result_shape = (x1.size, N)
369+
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape, kernel_data.return_type, None)
369370

370371
cdef ftpr_custom_vander_1in_1out_t func = <ftpr_custom_vander_1in_1out_t > kernel_data.ptr
371372
func(x1.get_data(), result.get_data(), x1.size, N, increasing)

dpnp/dpnp_iface_arraycreation.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,7 +1106,7 @@ def ones_like(x1, dtype=None, order='C', subok=False, shape=None):
11061106
return call_origin(numpy.ones_like, x1, dtype, order, subok, shape)
11071107

11081108

1109-
def trace(arr, offset=0, axis1=0, axis2=1, dtype=None, out=None):
1109+
def trace(x1, offset=0, axis1=0, axis2=1, dtype=None, out=None):
11101110
"""
11111111
Return the sum along diagonals of the array.
11121112
@@ -1117,23 +1117,23 @@ def trace(arr, offset=0, axis1=0, axis2=1, dtype=None, out=None):
11171117
Input array is supported as :obj:`dpnp.ndarray`.
11181118
Parameters ``axis1``, ``axis2``, ``out`` and ``dtype`` are supported only with default values.
11191119
"""
1120-
if not use_origin_backend():
1121-
if not isinstance(arr, dparray):
1122-
pass
1123-
elif arr.size == 0:
1120+
1121+
x1_desc = dpnp.get_dpnp_descriptor(x1)
1122+
if x1_desc:
1123+
if x1_desc.size == 0:
11241124
pass
1125-
elif arr.ndim < 2:
1125+
elif x1_desc.ndim < 2:
11261126
pass
11271127
elif axis1 != 0:
11281128
pass
11291129
elif axis2 != 1:
11301130
pass
1131-
elif out is not None and (not isinstance(out, dparray) or (isinstance(out, dparray) and out.shape != arr.shape)):
1131+
elif out is not None:
11321132
pass
11331133
else:
1134-
return dpnp_trace(arr, offset, axis1, axis2, dtype, out)
1134+
return dpnp_trace(x1, offset, axis1, axis2, dtype, out) #.get_pyobj()
11351135

1136-
return call_origin(numpy.trace, arr, offset, axis1, axis2, dtype, out)
1136+
return call_origin(numpy.trace, x1, offset, axis1, axis2, dtype, out)
11371137

11381138

11391139
def tri(N, M=None, k=0, dtype=numpy.float, **kwargs):
@@ -1176,7 +1176,7 @@ def tri(N, M=None, k=0, dtype=numpy.float, **kwargs):
11761176
return call_origin(numpy.tri, N, M, k, dtype, **kwargs)
11771177

11781178

1179-
def tril(m, k=0):
1179+
def tril(x1, k=0):
11801180
"""
11811181
Lower triangle of an array.
11821182
@@ -1195,16 +1195,17 @@ def tril(m, k=0):
11951195
11961196
"""
11971197

1198-
if not use_origin_backend(m):
1199-
if not isinstance(m, dparray):
1198+
x1_desc = dpnp.get_dpnp_descriptor(x1)
1199+
if x1_desc:
1200+
if not isinstance(k, int):
12001201
pass
12011202
else:
1202-
return dpnp_tril(m, k)
1203+
return dpnp_tril(x1_desc, k).get_pyobj()
12031204

1204-
return call_origin(numpy.tril, m, k)
1205+
return call_origin(numpy.tril, x1, k)
12051206

12061207

1207-
def triu(m, k=0):
1208+
def triu(x1, k=0):
12081209
"""
12091210
Upper triangle of an array.
12101211
@@ -1224,15 +1225,14 @@ def triu(m, k=0):
12241225
12251226
"""
12261227

1227-
if not use_origin_backend(m):
1228-
if not isinstance(m, dparray):
1229-
pass
1230-
elif not isinstance(k, int):
1228+
x1_desc = dpnp.get_dpnp_descriptor(x1)
1229+
if x1_desc:
1230+
if not isinstance(k, int):
12311231
pass
12321232
else:
1233-
return dpnp_triu(m, k)
1233+
return dpnp_triu(x1_desc, k).get_pyobj()
12341234

1235-
return call_origin(numpy.triu, m, k)
1235+
return call_origin(numpy.triu, x1, k)
12361236

12371237

12381238
def vander(x1, N=None, increasing=False):
@@ -1263,15 +1263,16 @@ def vander(x1, N=None, increasing=False):
12631263
[ 1, 3, 9, 27],
12641264
[ 1, 5, 25, 125]])
12651265
"""
1266-
if (not use_origin_backend(x1)):
1267-
if not isinstance(x1, dparray):
1268-
pass
1269-
elif x1.ndim != 1:
1266+
1267+
x1_desc = dpnp.get_dpnp_descriptor(x1)
1268+
if x1_desc:
1269+
if x1.ndim != 1:
12701270
pass
12711271
else:
12721272
if N is None:
12731273
N = x1.size
1274-
return dpnp_vander(x1, N, increasing)
1274+
1275+
return dpnp_vander(x1_desc, N, increasing).get_pyobj()
12751276

12761277
return call_origin(numpy.vander, x1, N=N, increasing=increasing)
12771278

0 commit comments

Comments
 (0)