Skip to content

Commit 0d6029e

Browse files
Avoid copying array when strides are default (#1031)
Co-authored-by: Alexander-Makaryev <alexander.makaryev@gmail.com>
1 parent 761e571 commit 0d6029e

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

dpnp/backend/kernels/dpnp_krnl_common.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -567,9 +567,9 @@ void dpnp_matmul_c(void* result_out,
567567
}
568568

569569
cl::sycl::event event;
570-
DPNPC_ptr_adapter<_DataType> input1_ptr(input1_in, size_m * size_k, true);
571-
DPNPC_ptr_adapter<_DataType> input2_ptr(input2_in, size_k * size_n, true);
572-
DPNPC_ptr_adapter<_DataType> result_ptr(result_out, size_m * size_n, true, true);
570+
DPNPC_ptr_adapter<_DataType> input1_ptr(input1_in, size_m * size_k);
571+
DPNPC_ptr_adapter<_DataType> input2_ptr(input2_in, size_k * size_n);
572+
DPNPC_ptr_adapter<_DataType> result_ptr(result_out, size_m * size_n, false, true);
573573
_DataType* array_1 = input1_ptr.get_ptr();
574574
_DataType* array_2 = input2_ptr.get_ptr();
575575
_DataType* result = result_ptr.get_ptr();

dpnp/dpnp_iface.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,9 @@ def get_dpnp_descriptor(ext_obj, copy_when_strides=True):
211211
# then this behavior can be disabled with setting "copy_when_strides"
212212
if copy_when_strides and getattr(ext_obj, "strides", None) is not None:
213213
# TODO: replace this workaround when usm_ndarray will provide such functionality
214-
ext_obj = array(ext_obj)
214+
shape_offsets = tuple(numpy.prod(ext_obj.shape[i+1:], dtype=numpy.int64) for i in range(ext_obj.ndim))
215+
if ext_obj.strides != shape_offsets:
216+
ext_obj = array(ext_obj)
215217

216218
dpnp_desc = dpnp_descriptor(ext_obj)
217219
if dpnp_desc.is_valid:

0 commit comments

Comments
 (0)