Skip to content

Commit a6b21aa

Browse files
authored
call_origin() has no dparray's (#879)
* call_origin() has no dparray's
1 parent b8c1e3e commit a6b21aa

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

dpnp/dpnp_utils/dpnp_algo_utils.pyx

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -150,20 +150,20 @@ def call_origin(function, *args, **kwargs):
150150
if (kwargs_dtype is not None):
151151
result_dtype = kwargs_dtype
152152

153-
result = dparray(result_origin.shape, dtype=result_dtype)
153+
result = create_output_container(result_origin.shape, result_dtype)
154154
else:
155155
result = kwargs_out
156156

157157
for i in range(result.size):
158158
result.flat[i] = result_origin.item(i)
159159

160160
elif isinstance(result, tuple):
161-
# convert tuple(ndarray) to tuple(dparray)
161+
# convert tuple(fallback_array) to tuple(result_array)
162162
result_list = []
163163
for res_origin in result:
164164
res = res_origin
165165
if isinstance(res_origin, numpy.ndarray):
166-
res = dparray(res_origin.shape, dtype=res_origin.dtype)
166+
res = create_output_container(res_origin.shape, res_origin.dtype)
167167
for i in range(res.size):
168168
res.flat[i] = res_origin.item(i)
169169
result_list.append(res)
@@ -384,6 +384,20 @@ cdef DPNPFuncType get_output_c_type(DPNPFuncName funcID,
384384
checker_throw_value_error("get_output_c_type", "dtype and out", requested_dtype, requested_out)
385385

386386

387+
def create_output_container(shape, type):
388+
if config.__DPNP_OUTPUT_NUMPY__:
389+
""" Create NumPy ndarray """
390+
# TODO need to use "buffer=" parameter to use SYCL aware memory
391+
result = numpy.ndarray(shape, dtype=type)
392+
elif config.__DPNP_DPCTL_AVAILABLE__:
393+
""" Create DPCTL array """
394+
result = dpctl.usm_ndarray(shape, dtype=numpy.dtype(type).name)
395+
else:
396+
""" Create DPNP array """
397+
result = dparray(shape, dtype=type)
398+
399+
return result
400+
387401
cdef dpnp_descriptor create_output_descriptor(shape_type_c output_shape,
388402
DPNPFuncType c_type,
389403
dpnp_descriptor requested_out):
@@ -392,18 +406,8 @@ cdef dpnp_descriptor create_output_descriptor(shape_type_c output_shape,
392406
if requested_out is None:
393407
result = None
394408
result_dtype = dpnp_DPNPFuncType_to_dtype( < size_t > c_type)
395-
if config.__DPNP_OUTPUT_NUMPY__:
396-
""" Create NumPy ndarray """
397-
# TODO need to use "buffer=" parameter to use SYCL aware memory
398-
result = numpy.ndarray(output_shape, dtype=result_dtype)
399-
elif config.__DPNP_DPCTL_AVAILABLE__:
400-
""" Create DPCTL array """
401-
result = dpctl.usm_ndarray(output_shape, dtype=numpy.dtype(result_dtype).name)
402-
else:
403-
""" Create DPNP array """
404-
result = dparray(output_shape, dtype=result_dtype)
405-
406-
result_desc = dpnp_descriptor(result)
409+
result_obj = create_output_container(output_shape, result_dtype)
410+
result_desc = dpnp_descriptor(result_obj)
407411
else:
408412
""" Based on 'out' parameter """
409413
if (output_shape != requested_out.shape):

0 commit comments

Comments
 (0)