Skip to content

Commit 180a698

Browse files
Pass queue from cython to backend for matmul (#1063)
* Pass queue from cython to backend for matmul * Add DPNP_FN_MATMUL_EXT to be compatible with DPPY * minor fix in DPNPFuncName enum * Move parameter ueue to the first place in dpnp_matmul_c * Change void* to DPCTLSyclQueueRef + add test to check queue of result * add test to check queue of result * Limit list of SYCL devices for testing Co-authored-by: Alexander-Makaryev <alexander.makaryev@gmail.com>
1 parent 8f23a02 commit 180a698

File tree

11 files changed

+242
-21
lines changed

11 files changed

+242
-21
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ typedef ssize_t shape_elem_type;
5858
#include "dpnp_iface_fft.hpp"
5959
#include "dpnp_iface_random.hpp"
6060

61+
#include <dpctl_sycl_interface.h>
62+
6163
/**
6264
* @defgroup BACKEND_API Backend C++ library interface API
6365
* @{
@@ -196,6 +198,47 @@ INP_DLLEXPORT void dpnp_full_c(void* array_in, void* result, const size_t size);
196198
template <typename _DataType>
197199
INP_DLLEXPORT void dpnp_full_like_c(void* array_in, void* result, size_t size);
198200

201+
/**
202+
* @ingroup BACKEND_API
203+
* @brief Matrix multiplication.
204+
*
205+
* Matrix multiplication procedure.
206+
*
207+
* @param [in] q_ref Reference to SYCL queue.
208+
* @param [out] result_out Output array.
209+
* @param [in] result_size Size of output array.
210+
* @param [in] result_ndim Number of output array dimensions.
211+
* @param [in] result_shape Shape of output array.
212+
* @param [in] result_strides Strides of output array.
213+
* @param [in] input1_in First input array.
214+
* @param [in] input1_size Size of first input array.
215+
* @param [in] input1_ndim Number of first input array dimensions.
216+
* @param [in] input1_shape Shape of first input array.
217+
* @param [in] input1_strides Strides of first input array.
218+
* @param [in] input2_in Second input array.
219+
* @param [in] input2_size Size of second input array.
220+
* @param [in] input2_ndim Number of second input array dimensions.
221+
* @param [in] input2_shape Shape of second input array.
222+
* @param [in] input2_strides Strides of second input array.
223+
*/
224+
template <typename _DataType>
225+
INP_DLLEXPORT void dpnp_matmul_c(DPCTLSyclQueueRef q_ref,
226+
void* result_out,
227+
const size_t result_size,
228+
const size_t result_ndim,
229+
const shape_elem_type* result_shape,
230+
const shape_elem_type* result_strides,
231+
const void* input1_in,
232+
const size_t input1_size,
233+
const size_t input1_ndim,
234+
const shape_elem_type* input1_shape,
235+
const shape_elem_type* input1_strides,
236+
const void* input2_in,
237+
const size_t input2_size,
238+
const size_t input2_ndim,
239+
const shape_elem_type* input2_shape,
240+
const shape_elem_type* input2_strides);
241+
199242
/**
200243
* @ingroup BACKEND_API
201244
* @brief Matrix multiplication.

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ enum class DPNPFuncName : size_t
131131
DPNP_FN_LOG2, /**< Used in numpy.log2() implementation */
132132
DPNP_FN_LOG1P, /**< Used in numpy.log1p() implementation */
133133
DPNP_FN_MATMUL, /**< Used in numpy.matmul() implementation */
134+
DPNP_FN_MATMUL_EXT, /**< Used in numpy.matmul() implementation, requires extra parameters */
134135
DPNP_FN_MATRIX_RANK, /**< Used in numpy.linalg.matrix_rank() implementation */
135136
DPNP_FN_MAX, /**< Used in numpy.max() implementation */
136137
DPNP_FN_MAXIMUM, /**< Used in numpy.maximum() implementation */

dpnp/backend/kernels/dpnp_krnl_common.cpp

Lines changed: 77 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,8 @@ template <typename _KernelNameSpecialization>
533533
class dpnp_matmul_c_kernel;
534534

535535
template <typename _DataType>
536-
void dpnp_matmul_c(void* result_out,
536+
void dpnp_matmul_c(DPCTLSyclQueueRef q_ref,
537+
void* result_out,
537538
const size_t result_size,
538539
const size_t result_ndim,
539540
const shape_elem_type* result_shape,
@@ -569,13 +570,12 @@ void dpnp_matmul_c(void* result_out,
569570
return;
570571
}
571572

573+
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
572574
sycl::event event;
573-
DPNPC_ptr_adapter<_DataType> input1_ptr(input1_in, size_m * size_k);
574-
DPNPC_ptr_adapter<_DataType> input2_ptr(input2_in, size_k * size_n);
575-
DPNPC_ptr_adapter<_DataType> result_ptr(result_out, size_m * size_n, false, true);
576-
_DataType* array_1 = input1_ptr.get_ptr();
577-
_DataType* array_2 = input2_ptr.get_ptr();
578-
_DataType* result = result_ptr.get_ptr();
575+
576+
_DataType* array_1 = reinterpret_cast<_DataType*>(const_cast<void*>(input1_in));
577+
_DataType* array_2 = reinterpret_cast<_DataType*>(const_cast<void*>(input2_in));
578+
_DataType* result = reinterpret_cast<_DataType*>(result_out);
579579

580580
if constexpr (std::is_same<_DataType, double>::value || std::is_same<_DataType, float>::value)
581581
{
@@ -584,7 +584,7 @@ void dpnp_matmul_c(void* result_out,
584584
const std::int64_t ldb = std::max<size_t>(1UL, size_n); // First dimensions of array_2
585585
const std::int64_t ldc = std::max<size_t>(1UL, size_n); // Fast dimensions of result
586586

587-
event = mkl_blas::gemm(DPNP_QUEUE,
587+
event = mkl_blas::gemm(q,
588588
oneapi::mkl::transpose::nontrans,
589589
oneapi::mkl::transpose::nontrans,
590590
size_n,
@@ -632,11 +632,70 @@ void dpnp_matmul_c(void* result_out,
632632
cgh.parallel_for<class dpnp_matmul_c_kernel<_DataType>>(gws, kernel_parallel_for_func);
633633
};
634634

635-
event = DPNP_QUEUE.submit(kernel_func);
635+
event = q.submit(kernel_func);
636636
}
637637
event.wait();
638638
}
639639

640+
template <typename _DataType>
641+
void dpnp_matmul_c(void* result_out,
642+
const size_t result_size,
643+
const size_t result_ndim,
644+
const shape_elem_type* result_shape,
645+
const shape_elem_type* result_strides,
646+
const void* input1_in,
647+
const size_t input1_size,
648+
const size_t input1_ndim,
649+
const shape_elem_type* input1_shape,
650+
const shape_elem_type* input1_strides,
651+
const void* input2_in,
652+
const size_t input2_size,
653+
const size_t input2_ndim,
654+
const shape_elem_type* input2_shape,
655+
const shape_elem_type* input2_strides)
656+
{
657+
DPCTLSyclQueueRef q_ref = reinterpret_cast<DPCTLSyclQueueRef>(&DPNP_QUEUE);
658+
dpnp_matmul_c<_DataType>(q_ref,
659+
result_out, result_size, result_ndim, result_shape, result_strides,
660+
input1_in, input1_size, input1_ndim, input1_shape, input1_strides,
661+
input2_in, input2_size, input2_ndim, input2_shape, input2_strides);
662+
}
663+
664+
template <typename _DataType>
665+
void (*dpnp_matmul_default_c)(void*,
666+
const size_t,
667+
const size_t,
668+
const shape_elem_type*,
669+
const shape_elem_type*,
670+
const void*,
671+
const size_t,
672+
const size_t,
673+
const shape_elem_type*,
674+
const shape_elem_type*,
675+
const void*,
676+
const size_t,
677+
const size_t,
678+
const shape_elem_type*,
679+
const shape_elem_type*) = dpnp_matmul_c<_DataType>;
680+
681+
template <typename _DataType>
682+
void (*dpnp_matmul_ext_c)(DPCTLSyclQueueRef,
683+
void*,
684+
const size_t,
685+
const size_t,
686+
const shape_elem_type*,
687+
const shape_elem_type*,
688+
const void*,
689+
const size_t,
690+
const size_t,
691+
const shape_elem_type*,
692+
const shape_elem_type*,
693+
const void*,
694+
const size_t,
695+
const size_t,
696+
const shape_elem_type*,
697+
const shape_elem_type*) = dpnp_matmul_c<_DataType>;
698+
640699
void func_map_init_linalg(func_map_t& fmap)
641700
{
642701
fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_BLN][eft_BLN] = {eft_BLN, (void*)dpnp_astype_c<bool, bool>};
@@ -702,10 +761,15 @@ void func_map_init_linalg(func_map_t& fmap)
702761
fmap[DPNPFuncName::DPNP_FN_INITVAL][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_initval_c<double>};
703762
fmap[DPNPFuncName::DPNP_FN_INITVAL][eft_C128][eft_C128] = {eft_C128, (void*)dpnp_initval_c<std::complex<double>>};
704763

705-
fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_matmul_c<int32_t>};
706-
fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_matmul_c<int64_t>};
707-
fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_matmul_c<float>};
708-
fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_matmul_c<double>};
764+
fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_matmul_default_c<int32_t>};
765+
fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_matmul_default_c<int64_t>};
766+
fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_matmul_default_c<float>};
767+
fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_matmul_default_c<double>};
768+
769+
fmap[DPNPFuncName::DPNP_FN_MATMUL_EXT][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_matmul_ext_c<int32_t>};
770+
fmap[DPNPFuncName::DPNP_FN_MATMUL_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_matmul_ext_c<int64_t>};
771+
fmap[DPNPFuncName::DPNP_FN_MATMUL_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_matmul_ext_c<float>};
772+
fmap[DPNPFuncName::DPNP_FN_MATMUL_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_matmul_ext_c<double>};
709773

710774
return;
711775
}

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
106106
DPNP_FN_LOG1P
107107
DPNP_FN_LOG2
108108
DPNP_FN_MATMUL
109+
DPNP_FN_MATMUL_EXT
109110
DPNP_FN_MATRIX_RANK
110111
DPNP_FN_MAX
111112
DPNP_FN_MAXIMUM

dpnp/dpnp_algo/dpnp_algo.pyx

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,14 @@ import dpnp.config as config
3939
import dpnp.dpnp_utils as utils_py
4040
from dpnp.dpnp_array import dpnp_array
4141

42-
import numpy
42+
cimport dpctl as c_dpctl
4343
import dpctl
4444

4545
cimport cpython
4646
cimport dpnp.dpnp_utils as utils
47+
4748
cimport numpy
49+
import numpy
4850

4951

5052
__all__ = [

dpnp/dpnp_algo/dpnp_algo_linearalgebra.pyx

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ ctypedef void(*fptr_2in_1out_dot_t)(void * , const size_t, const size_t,
5252
const shape_elem_type *, const shape_elem_type * ,
5353
void * , const size_t, const size_t,
5454
const shape_elem_type *, const shape_elem_type * )
55+
ctypedef void(*fptr_2in_1out_matmul_t)(c_dpctl.DPCTLSyclQueueRef,
56+
void * , const size_t, const size_t,
57+
const shape_elem_type *, const shape_elem_type * ,
58+
void * , const size_t, const size_t,
59+
const shape_elem_type *, const shape_elem_type * ,
60+
void * , const size_t, const size_t,
61+
const shape_elem_type *, const shape_elem_type * )
5562

5663
cpdef utils.dpnp_descriptor dpnp_dot(utils.dpnp_descriptor in_array1, utils.dpnp_descriptor in_array2):
5764

@@ -271,7 +278,7 @@ cpdef utils.dpnp_descriptor dpnp_matmul(utils.dpnp_descriptor in_array1, utils.d
271278
cdef DPNPFuncType param2_type = dpnp_dtype_to_DPNPFuncType(in_array2.dtype)
272279

273280
# get the FPTR data structure
274-
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_MATMUL, param1_type, param2_type)
281+
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_MATMUL_EXT, param1_type, param2_type)
275282

276283
# ceate result array with type given by FPTR data
277284
result_sycl_device, result_usm_type, result_sycl_queue = utils.get_common_usm_allocation(in_array1, in_array2)
@@ -284,9 +291,13 @@ cpdef utils.dpnp_descriptor dpnp_matmul(utils.dpnp_descriptor in_array1, utils.d
284291
if result.size == 0:
285292
return result
286293

287-
cdef fptr_2in_1out_dot_t func = <fptr_2in_1out_dot_t > kernel_data.ptr
294+
cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
295+
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()
296+
297+
cdef fptr_2in_1out_matmul_t func = <fptr_2in_1out_matmul_t > kernel_data.ptr
288298
# call FPTR function
289-
func(result.get_data(),
299+
func(q_ref,
300+
result.get_data(),
290301
result.size,
291302
result.ndim,
292303
NULL, # result_shape

dpnp/dpnp_iface.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def convert_single_elem_array_to_scalar(obj, keepdims=False):
186186
return obj
187187

188188

189-
def get_dpnp_descriptor(ext_obj, copy_when_strides=True):
189+
def get_dpnp_descriptor(ext_obj, copy_when_strides=True, copy_when_nondefault_queue=True):
190190
"""
191191
Return True:
192192
never
@@ -221,6 +221,18 @@ def get_dpnp_descriptor(ext_obj, copy_when_strides=True):
221221
if ext_obj.strides != shape_offsets or ext_obj_offset != 0:
222222
ext_obj = array(ext_obj)
223223

224+
# while dpnp functions are based on DPNP_QUEUE
225+
# we need to create a copy on device associated with DPNP_QUEUE
226+
# if function get implementation for different queue
227+
# then this behavior can be disabled with setting "copy_when_nondefault_queue"
228+
arr_obj = unwrap_array(ext_obj)
229+
queue = getattr(arr_obj, "sycl_queue", None)
230+
if queue is not None and copy_when_nondefault_queue:
231+
default_queue = dpctl.SyclQueue()
232+
queue_is_default = dpctl.utils.get_execution_queue([queue, default_queue]) is not None
233+
if not queue_is_default:
234+
ext_obj = array(arr_obj, sycl_queue=default_queue)
235+
224236
dpnp_desc = dpnp_descriptor(ext_obj)
225237
if dpnp_desc.is_valid:
226238
return dpnp_desc

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,8 @@ def matmul(x1, x2, out=None, **kwargs):
241241
242242
"""
243243

244-
x1_desc = dpnp.get_dpnp_descriptor(x1)
245-
x2_desc = dpnp.get_dpnp_descriptor(x2)
244+
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
245+
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_nondefault_queue=False)
246246
if x1_desc and x2_desc and not kwargs:
247247
if x1_desc.ndim != 2 or x2_desc.ndim != 2:
248248
pass

dpnp/dpnp_utils/dpnp_algo_utils.pyx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ __all__ = [
6363
"_get_linear_index",
6464
"normalize_axis",
6565
"_object_to_tuple",
66+
"unwrap_array",
6667
"use_origin_backend"
6768
]
6869

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import importlib.machinery as imm # Python 3 is required
4141
import sys
4242
import os
43+
import dpctl
4344
import numpy
4445

4546
from setuptools import setup, Extension
@@ -132,7 +133,7 @@
132133
The project modules description
133134
"""
134135
kwargs_common = {
135-
"include_dirs": [numpy.get_include()] + _project_backend_dir,
136+
"include_dirs": [numpy.get_include(), dpctl.get_include()] + _project_backend_dir,
136137
"extra_compile_args": _sdl_cflags,
137138
"extra_link_args": _project_extra_link_args,
138139
"define_macros": [("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")],

0 commit comments

Comments
 (0)