Skip to content

Commit 5bdd984

Browse files
authored
matrix_power, matmul with out param to desc (#812)
* matrix_power, matmul with out param to desc
1 parent 92a0a66 commit 5bdd984

File tree

8 files changed

+59
-50
lines changed

8 files changed

+59
-50
lines changed

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ cpdef dpnp_descriptor dpnp_not_equal(dpnp_descriptor input1, dpnp_descriptor inp
273273
Linear algebra
274274
"""
275275
cpdef dparray dpnp_dot(dpnp_descriptor in_array1, dpnp_descriptor in_array2)
276-
cpdef dparray dpnp_matmul(dpnp_descriptor in_array1, dpnp_descriptor in_array2, dparray out=*)
276+
cpdef dpnp_descriptor dpnp_matmul(dpnp_descriptor in_array1, dpnp_descriptor in_array2, dpnp_descriptor out=*)
277277

278278

279279
"""

dpnp/dpnp_algo/dpnp_algo_linearalgebra.pyx

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ cpdef dparray dpnp_kron(dpnp_descriptor in_array1, dpnp_descriptor in_array2):
197197
return result
198198

199199

200-
cpdef dparray dpnp_matmul(dpnp_descriptor in_array1, dpnp_descriptor in_array2, dparray out=None):
200+
cpdef utils.dpnp_descriptor dpnp_matmul(utils.dpnp_descriptor in_array1, utils.dpnp_descriptor in_array2, utils.dpnp_descriptor out=None):
201201

202202
cdef dparray_shape_type shape_result
203203

@@ -258,19 +258,8 @@ cpdef dparray dpnp_matmul(dpnp_descriptor in_array1, dpnp_descriptor in_array2,
258258
# get the FPTR data structure
259259
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_MATMUL, param1_type, param2_type)
260260

261-
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
262-
263-
cdef dparray result
264-
265-
if out is not None:
266-
if out.dtype != result_type:
267-
utils.checker_throw_value_error('matmul', 'out.dtype', out.dtype, result_type)
268-
if out.shape != shape_result:
269-
utils.checker_throw_value_error('matmul', 'out.shape', out.shape, shape_result)
270-
result = out
271-
else:
272-
result = dparray(shape_result, dtype=result_type)
273-
261+
# ceate result array with type given by FPTR data
262+
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(shape_result, kernel_data.return_type, out)
274263
if result.size == 0:
275264
return result
276265

dpnp/dpnp_algo/dpnp_algo_mathematical.pyx

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,13 @@ cpdef utils.dpnp_descriptor dpnp_power(utils.dpnp_descriptor x1_obj, utils.dpnp_
341341
return call_fptr_2in_1out(DPNP_FN_POWER, x1_obj, x2_obj, dtype=dtype, out=out, where=where, func_name="power")
342342

343343

344-
cpdef dparray dpnp_prod(utils.dpnp_descriptor input, object axis=None, object dtype=None, dparray out=None, cpp_bool keepdims=False, object initial=None, object where=True):
344+
cpdef utils.dpnp_descriptor dpnp_prod(utils.dpnp_descriptor input,
345+
object axis=None,
346+
object dtype=None,
347+
utils.dpnp_descriptor out=None,
348+
cpp_bool keepdims=False,
349+
object initial=None,
350+
object where=True):
345351
"""
346352
input:float64 : outout:float64 : name:prod
347353
input:float32 : outout:float32 : name:prod
@@ -364,7 +370,7 @@ cpdef dparray dpnp_prod(utils.dpnp_descriptor input, object axis=None, object dt
364370
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_PROD, input_c_type, result_c_type)
365371

366372
""" Create result array """
367-
cdef dparray result = utils.create_output_array(result_shape, result_c_type, out)
373+
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape, result_c_type, out)
368374
cdef dpnp_reduction_c_t func = <dpnp_reduction_c_t > kernel_data.ptr
369375

370376
""" Call FPTR interface function """
@@ -385,7 +391,13 @@ cpdef utils.dpnp_descriptor dpnp_subtract(object x1_obj, object x2_obj, object d
385391
return call_fptr_2in_1out(DPNP_FN_SUBTRACT, x1_obj, x2_obj, dtype=dtype, out=out, where=where)
386392

387393

388-
cpdef utils.dpnp_descriptor dpnp_sum(utils.dpnp_descriptor input, object axis=None, object dtype=None, dparray out=None, cpp_bool keepdims=False, object initial=None, object where=True):
394+
cpdef utils.dpnp_descriptor dpnp_sum(utils.dpnp_descriptor input,
395+
object axis=None,
396+
object dtype=None,
397+
utils.dpnp_descriptor out=None,
398+
cpp_bool keepdims=False,
399+
object initial=None,
400+
object where=True):
389401

390402
cdef dparray_shape_type input_shape = input.shape
391403
cdef DPNPFuncType input_c_type = dpnp_dtype_to_DPNPFuncType(input.dtype)

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,7 @@ def matmul(x1, x2, out=None, **kwargs):
245245

246246
x1_desc = dpnp.get_dpnp_descriptor(x1)
247247
x2_desc = dpnp.get_dpnp_descriptor(x2)
248-
out_desc = dpnp.get_dpnp_descriptor(x2)
249-
if x1_desc and x2_desc and out_desc and not kwargs:
248+
if x1_desc and x2_desc and not kwargs:
250249
if x1_desc.size != x2_desc.size:
251250
pass
252251
elif not x1_desc.ndim:
@@ -276,7 +275,8 @@ def matmul(x1, x2, out=None, **kwargs):
276275
if (dparray1_size > cost_size) and (dparray2_size > cost_size):
277276
return dpnp_matmul(x1_desc, x2_desc, out)
278277
else:
279-
return dpnp_matmul(x1_desc, x2_desc, out)
278+
out_desc = dpnp.get_dpnp_descriptor(out) if out is not None else None
279+
return dpnp_matmul(x1_desc, x2_desc, out_desc).get_pyobj()
280280

281281
return call_origin(numpy.matmul, x1, x2, out=out, **kwargs)
282282

dpnp/dpnp_iface_mathematical.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,12 +1372,11 @@ def prod(x1, axis=None, dtype=None, out=None, keepdims=False, initial=None, wher
13721372

13731373
x1_desc = dpnp.get_dpnp_descriptor(x1)
13741374
if x1_desc:
1375-
if out is not None and not isinstance(out, dparray):
1376-
pass
1377-
elif where is not True:
1375+
if where is not True:
13781376
pass
13791377
else:
1380-
result_obj = dpnp_prod(x1_desc, axis, dtype, out, keepdims, initial, where)
1378+
out_desc = dpnp.get_dpnp_descriptor(out) if out is not None else None
1379+
result_obj = dpnp_prod(x1_desc, axis, dtype, out_desc, keepdims, initial, where).get_pyobj()
13811380
result = dpnp.convert_single_elem_array_to_scalar(result_obj, keepdims)
13821381

13831382
return result
@@ -1499,7 +1498,6 @@ def subtract(x1, x2, dtype=None, out=None, where=True, **kwargs):
14991498
x2_is_scalar = dpnp.isscalar(x2)
15001499
x1_desc = dpnp.get_dpnp_descriptor(x1)
15011500
x2_desc = dpnp.get_dpnp_descriptor(x2)
1502-
15031501
if x1_desc and x2_desc and not kwargs:
15041502
if not x1_desc and not x1_is_scalar:
15051503
pass
@@ -1522,7 +1520,8 @@ def subtract(x1, x2, dtype=None, out=None, where=True, **kwargs):
15221520
elif not where:
15231521
pass
15241522
else:
1525-
return dpnp_subtract(x1_desc, x2_desc, dtype=dtype, out=out, where=where).get_pyobj()
1523+
out_desc = dpnp.get_dpnp_descriptor(out) if out is not None else None
1524+
return dpnp_subtract(x1_desc, x2_desc, dtype=dtype, out=out_desc, where=where).get_pyobj()
15261525

15271526
return call_origin(numpy.subtract, x1, x2, dtype=dtype, out=out, where=where, **kwargs)
15281527

@@ -1551,12 +1550,11 @@ def sum(x1, axis=None, dtype=None, out=None, keepdims=False, initial=None, where
15511550

15521551
x1_desc = dpnp.get_dpnp_descriptor(x1)
15531552
if x1_desc:
1554-
if out is not None and not isinstance(out, dparray):
1555-
pass
1556-
elif where is not True:
1553+
if where is not True:
15571554
pass
15581555
else:
1559-
result_obj = dpnp_sum(x1_desc, axis, dtype, out, keepdims, initial, where).get_pyobj()
1556+
out_desc = dpnp.get_dpnp_descriptor(out) if out is not None else None
1557+
result_obj = dpnp_sum(x1_desc, axis, dtype, out_desc, keepdims, initial, where).get_pyobj()
15601558
result = dpnp.convert_single_elem_array_to_scalar(result_obj, keepdims)
15611559

15621560
return result

dpnp/dpnp_utils/dpnp_algo_utils.pxd

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,16 @@ cdef DPNPFuncType get_output_c_type(DPNPFuncName funcID,
156156
Calculate output array type by 'out' and 'dtype' cast parameters
157157
"""
158158

159-
cdef dparray create_output_array(dparray_shape_type output_shape, DPNPFuncType c_type, object requested_out)
159+
cdef dparray create_output_array(dparray_shape_type output_shape,
160+
DPNPFuncType c_type,
161+
object requested_out)
160162
"""
161163
Create output array based on shape, type and 'out' parameters
162164
"""
163165

164-
cdef dpnp_descriptor create_output_descriptor(dparray_shape_type output_shape, DPNPFuncType c_type, object requested_out)
166+
cdef dpnp_descriptor create_output_descriptor(dparray_shape_type output_shape,
167+
DPNPFuncType c_type,
168+
dpnp_descriptor requested_out)
165169
"""
166170
Same as "create_output_array" but output is "dpnp_descriptor"
167171
"""

dpnp/dpnp_utils/dpnp_algo_utils.pyx

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,10 @@ cdef DPNPFuncType get_output_c_type(DPNPFuncName funcID,
336336

337337

338338
cdef dparray create_output_array(dparray_shape_type output_shape, DPNPFuncType c_type, object requested_out):
339+
"""
340+
TODO This function needs to be deleted. Replace with create_output_descriptor()
341+
"""
342+
339343
cdef dparray result
340344

341345
if requested_out is None:
@@ -349,10 +353,24 @@ cdef dparray create_output_array(dparray_shape_type output_shape, DPNPFuncType c
349353

350354
return result
351355

352-
cdef dpnp_descriptor create_output_descriptor(dparray_shape_type output_shape, DPNPFuncType c_type, object requested_out):
353-
result = create_output_array(output_shape, c_type, requested_out)
356+
cdef dpnp_descriptor create_output_descriptor(dparray_shape_type output_shape,
357+
DPNPFuncType c_type,
358+
dpnp_descriptor requested_out):
359+
cdef dpnp_descriptor result_desc
360+
361+
if requested_out is None:
362+
""" Create DPNP array """
363+
result = dparray(output_shape, dtype=dpnp_DPNPFuncType_to_dtype( < size_t > c_type))
364+
result_desc = dpnp_descriptor(result)
365+
else:
366+
""" Based on 'out' parameter """
367+
if (output_shape != requested_out.shape):
368+
checker_throw_value_error("create_output_array", "out.shape", requested_out.shape, output_shape)
354369

355-
cdef dpnp_descriptor result_desc = dpnp_descriptor(result)
370+
if isinstance(requested_out, dpnp_descriptor):
371+
result_desc = requested_out
372+
else:
373+
result_desc = dpnp_descriptor(requested_out)
356374

357375
return result_desc
358376

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -239,26 +239,14 @@ def matrix_power(input, count):
239239
240240
"""
241241

242-
is_input_dparray = isinstance(input, dparray)
243-
244-
if not use_origin_backend(input) and is_input_dparray and count > 0:
242+
if not use_origin_backend() and count > 0:
245243
result = input
246244
for id in range(count - 1):
247245
result = dpnp.matmul(result, input)
248246

249247
return result
250248

251-
input1 = dpnp.asnumpy(input) if is_input_dparray else input
252-
253-
# TODO need to put dparray memory into NumPy call
254-
result_numpy = numpy.linalg.matrix_power(input1, count)
255-
result = result_numpy
256-
if isinstance(result, numpy.ndarray):
257-
result = dparray(result_numpy.shape, dtype=result_numpy.dtype)
258-
for i in range(result.size):
259-
result._setitem_scalar(i, result_numpy.item(i))
260-
261-
return result
249+
return call_origin(numpy.linalg.matrix_power, input, count)
262250

263251

264252
def matrix_rank(input, tol=None, hermitian=False):

0 commit comments

Comments
 (0)