Skip to content

Commit d769cce

Browse files
Fix issue where passing dpnp_array to dpctl API as parameter (#1108)
* Fix issue where passing dpnp_array to dpctl API as parameter Co-authored-by: Alexander-Makaryev <alexander.makaryev@gmail.com>
1 parent ff987dd commit d769cce

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

dpnp/dpnp_utils/dpnp_algo_utils.pyx

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def copy_from_origin(dst, src):
102102
if config.__DPNP_OUTPUT_DPCTL__ and hasattr(dst, "__sycl_usm_array_interface__"):
103103
if src.size:
104104
# dst.usm_data.copy_from_host(src.reshape(-1).view("|u1"))
105-
dpctl.tensor._copy_utils._copy_from_numpy_into(dst, src)
105+
dpctl.tensor._copy_utils._copy_from_numpy_into(unwrap_array(dst), src)
106106
else:
107107
for i in range(dst.size):
108108
dst.flat[i] = src.item(i)
@@ -177,6 +177,14 @@ def call_origin(function, *args, **kwargs):
177177
return result
178178

179179

180+
def unwrap_array(x1):
181+
"""Get array from input object."""
182+
if isinstance(x1, dpnp.dpnp_array.dpnp_array):
183+
return x1.get_array()
184+
185+
return x1
186+
187+
180188
cpdef checker_throw_axis_error(function_name, param_name, param, expected):
181189
err_msg = f"{ERROR_PREFIX} in function {function_name}()"
182190
err_msg += f" axes '{param_name}' expected `{expected}`, but '{param}' provided"

0 commit comments

Comments
 (0)