Skip to content

Commit ceab3dd

Browse files
authored
Pass SYCL dependant event to matmul as parameter (#1117)
* Pass SYCL dependant event to matmul as parameter
1 parent d267349 commit ceab3dd

File tree

4 files changed

+138
-78
lines changed

4 files changed

+138
-78
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -220,24 +220,26 @@ INP_DLLEXPORT void dpnp_full_like_c(void* array_in, void* result, size_t size);
220220
* @param [in] input2_ndim Number of second input array dimensions.
221221
* @param [in] input2_shape Shape of second input array.
222222
* @param [in] input2_strides Strides of second input array.
223+
* @param [in] dep_event_vec_ref Reference to vector of SYCL events.
223224
*/
224225
template <typename _DataType>
225-
INP_DLLEXPORT void dpnp_matmul_c(DPCTLSyclQueueRef q_ref,
226-
void* result_out,
227-
const size_t result_size,
228-
const size_t result_ndim,
229-
const shape_elem_type* result_shape,
230-
const shape_elem_type* result_strides,
231-
const void* input1_in,
232-
const size_t input1_size,
233-
const size_t input1_ndim,
234-
const shape_elem_type* input1_shape,
235-
const shape_elem_type* input1_strides,
236-
const void* input2_in,
237-
const size_t input2_size,
238-
const size_t input2_ndim,
239-
const shape_elem_type* input2_shape,
240-
const shape_elem_type* input2_strides);
226+
INP_DLLEXPORT DPCTLSyclEventRef dpnp_matmul_c(DPCTLSyclQueueRef q_ref,
227+
void* result_out,
228+
const size_t result_size,
229+
const size_t result_ndim,
230+
const shape_elem_type* result_shape,
231+
const shape_elem_type* result_strides,
232+
const void* input1_in,
233+
const size_t input1_size,
234+
const size_t input1_ndim,
235+
const shape_elem_type* input1_shape,
236+
const shape_elem_type* input1_strides,
237+
const void* input2_in,
238+
const size_t input2_size,
239+
const size_t input2_ndim,
240+
const shape_elem_type* input2_shape,
241+
const shape_elem_type* input2_strides,
242+
const DPCTLEventVectorRef dep_event_vec_ref);
241243

242244
/**
243245
* @ingroup BACKEND_API

dpnp/backend/kernels/dpnp_krnl_common.cpp

Lines changed: 65 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -533,22 +533,23 @@ template <typename _KernelNameSpecialization>
533533
class dpnp_matmul_c_kernel;
534534

535535
template <typename _DataType>
536-
void dpnp_matmul_c(DPCTLSyclQueueRef q_ref,
537-
void* result_out,
538-
const size_t result_size,
539-
const size_t result_ndim,
540-
const shape_elem_type* result_shape,
541-
const shape_elem_type* result_strides,
542-
const void* input1_in,
543-
const size_t input1_size,
544-
const size_t input1_ndim,
545-
const shape_elem_type* input1_shape,
546-
const shape_elem_type* input1_strides,
547-
const void* input2_in,
548-
const size_t input2_size,
549-
const size_t input2_ndim,
550-
const shape_elem_type* input2_shape,
551-
const shape_elem_type* input2_strides)
536+
DPCTLSyclEventRef dpnp_matmul_c(DPCTLSyclQueueRef q_ref,
537+
void* result_out,
538+
const size_t result_size,
539+
const size_t result_ndim,
540+
const shape_elem_type* result_shape,
541+
const shape_elem_type* result_strides,
542+
const void* input1_in,
543+
const size_t input1_size,
544+
const size_t input1_ndim,
545+
const shape_elem_type* input1_shape,
546+
const shape_elem_type* input1_strides,
547+
const void* input2_in,
548+
const size_t input2_size,
549+
const size_t input2_ndim,
550+
const shape_elem_type* input2_shape,
551+
const shape_elem_type* input2_strides,
552+
const DPCTLEventVectorRef dep_event_vec_ref)
552553
{
553554
(void)result_size;
554555
(void)result_ndim;
@@ -561,16 +562,19 @@ void dpnp_matmul_c(DPCTLSyclQueueRef q_ref,
561562
(void)input2_ndim;
562563
(void)input2_strides;
563564

565+
DPCTLSyclEventRef event_ref = nullptr;
566+
564567
size_t size_m = input1_shape[0];
565568
size_t size_n = input2_shape[1];
566569
size_t size_k = input1_shape[1];
567570

568571
if (!size_m || !size_n || !size_k)
569572
{
570-
return;
573+
return event_ref;
571574
}
572575

573576
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
577+
std::vector<sycl::event> dep_events = cast_event_vector(dep_event_vec_ref);
574578
sycl::event event;
575579

576580
_DataType* array_1 = reinterpret_cast<_DataType*>(const_cast<void*>(input1_in));
@@ -597,7 +601,8 @@ void dpnp_matmul_c(DPCTLSyclQueueRef q_ref,
597601
lda,
598602
_DataType(0),
599603
result,
600-
ldc);
604+
ldc,
605+
dep_events);
601606
}
602607
else
603608
{
@@ -629,12 +634,16 @@ void dpnp_matmul_c(DPCTLSyclQueueRef q_ref,
629634
};
630635

631636
auto kernel_func = [&](sycl::handler& cgh) {
637+
cgh.depends_on(dep_events);
632638
cgh.parallel_for<class dpnp_matmul_c_kernel<_DataType>>(gws, kernel_parallel_for_func);
633639
};
634640

635641
event = q.submit(kernel_func);
636642
}
637-
event.wait();
643+
644+
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);
645+
646+
return DPCTLEvent_Copy(event_ref);
638647
}
639648

640649
template <typename _DataType>
@@ -655,10 +664,26 @@ void dpnp_matmul_c(void* result_out,
655664
const shape_elem_type* input2_strides)
656665
{
657666
DPCTLSyclQueueRef q_ref = reinterpret_cast<DPCTLSyclQueueRef>(&DPNP_QUEUE);
658-
dpnp_matmul_c<_DataType>(q_ref,
659-
result_out, result_size, result_ndim, result_shape, result_strides,
660-
input1_in, input1_size, input1_ndim, input1_shape, input1_strides,
661-
input2_in, input2_size, input2_ndim, input2_shape, input2_strides);
667+
DPCTLEventVectorRef dep_event_vec_ref = nullptr;
668+
DPCTLSyclEventRef event_ref = dpnp_matmul_c<_DataType>(q_ref,
669+
result_out,
670+
result_size,
671+
result_ndim,
672+
result_shape,
673+
result_strides,
674+
input1_in,
675+
input1_size,
676+
input1_ndim,
677+
input1_shape,
678+
input1_strides,
679+
input2_in,
680+
input2_size,
681+
input2_ndim,
682+
input2_shape,
683+
input2_strides,
684+
dep_event_vec_ref);
685+
sycl::event event = *(reinterpret_cast<sycl::event*>(event_ref));
686+
event.wait_and_throw();
662687
}
663688

664689
template <typename _DataType>
@@ -679,22 +704,23 @@ void (*dpnp_matmul_default_c)(void*,
679704
const shape_elem_type*) = dpnp_matmul_c<_DataType>;
680705

681706
template <typename _DataType>
682-
void (*dpnp_matmul_ext_c)(DPCTLSyclQueueRef,
683-
void*,
684-
const size_t,
685-
const size_t,
686-
const shape_elem_type*,
687-
const shape_elem_type*,
688-
const void*,
689-
const size_t,
690-
const size_t,
691-
const shape_elem_type*,
692-
const shape_elem_type*,
693-
const void*,
694-
const size_t,
695-
const size_t,
696-
const shape_elem_type*,
697-
const shape_elem_type*) = dpnp_matmul_c<_DataType>;
707+
DPCTLSyclEventRef (*dpnp_matmul_ext_c)(DPCTLSyclQueueRef,
708+
void*,
709+
const size_t,
710+
const size_t,
711+
const shape_elem_type*,
712+
const shape_elem_type*,
713+
const void*,
714+
const size_t,
715+
const size_t,
716+
const shape_elem_type*,
717+
const shape_elem_type*,
718+
const void*,
719+
const size_t,
720+
const size_t,
721+
const shape_elem_type*,
722+
const shape_elem_type*,
723+
const DPCTLEventVectorRef) = dpnp_matmul_c<_DataType>;
698724

699725
void func_map_init_linalg(func_map_t& fmap)
700726
{

dpnp/backend/src/dpnp_utils.hpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
#include <iostream>
3333
#include <iterator>
3434

35+
#include <CL/sycl.hpp>
36+
3537
#include <dpnp_iface_fptr.hpp>
3638

3739
#define LIBSYCL_VERSION_GREATER(major, minor, patch) \
@@ -228,6 +230,32 @@ static inline bool
228230
return std::equal(std::begin(input1_vec), std::end(input1_vec), std::begin(input2_vec));
229231
}
230232

233+
/**
234+
* @ingroup BACKEND_UTILS
235+
* @brief Cast vector of DPCtl events to vector of SYCL enents.
236+
*
237+
* @param [in] events_ref Reference to vector of DPCtl events.
238+
*
239+
* @return Vector of SYCL events.
240+
*/
241+
namespace
242+
{
243+
std::vector<sycl::event> cast_event_vector(const DPCTLEventVectorRef event_vec_ref)
244+
{
245+
const size_t event_vec_size = DPCTLEventVector_Size(event_vec_ref);
246+
247+
std::vector<sycl::event> event_vec;
248+
event_vec.reserve(event_vec_size);
249+
for (size_t i = 0; i < event_vec_size; ++i)
250+
{
251+
DPCTLSyclEventRef event_ref = DPCTLEventVector_GetAt(event_vec_ref, i);
252+
sycl::event event = *(reinterpret_cast<sycl::event*>(event_ref));
253+
event_vec.push_back(event);
254+
}
255+
return event_vec;
256+
}
257+
}
258+
231259
/**
232260
* @ingroup BACKEND_UTILS
233261
* @brief Get common shape based on input shapes.

dpnp/dpnp_algo/dpnp_algo_linearalgebra.pyx

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,14 @@ ctypedef void(*fptr_2in_1out_dot_t)(void * , const size_t, const size_t,
5252
const shape_elem_type *, const shape_elem_type * ,
5353
void * , const size_t, const size_t,
5454
const shape_elem_type *, const shape_elem_type * )
55-
ctypedef void(*fptr_2in_1out_matmul_t)(c_dpctl.DPCTLSyclQueueRef,
56-
void * , const size_t, const size_t,
57-
const shape_elem_type *, const shape_elem_type * ,
58-
void * , const size_t, const size_t,
59-
const shape_elem_type *, const shape_elem_type * ,
60-
void * , const size_t, const size_t,
61-
const shape_elem_type *, const shape_elem_type * )
55+
ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_2in_1out_matmul_t)(c_dpctl.DPCTLSyclQueueRef,
56+
void * , const size_t, const size_t,
57+
const shape_elem_type *, const shape_elem_type * ,
58+
void * , const size_t, const size_t,
59+
const shape_elem_type *, const shape_elem_type * ,
60+
void * , const size_t, const size_t,
61+
const shape_elem_type *, const shape_elem_type * ,
62+
const c_dpctl.DPCTLEventVectorRef)
6263

6364
cpdef utils.dpnp_descriptor dpnp_dot(utils.dpnp_descriptor in_array1, utils.dpnp_descriptor in_array2):
6465

@@ -296,22 +297,25 @@ cpdef utils.dpnp_descriptor dpnp_matmul(utils.dpnp_descriptor in_array1, utils.d
296297

297298
cdef fptr_2in_1out_matmul_t func = <fptr_2in_1out_matmul_t > kernel_data.ptr
298299
# call FPTR function
299-
func(q_ref,
300-
result.get_data(),
301-
result.size,
302-
result.ndim,
303-
NULL, # result_shape
304-
NULL, # result_strides
305-
in_array1.get_data(),
306-
in_array1.size,
307-
in_array1.ndim,
308-
shape1.data(),
309-
NULL, # in_array1_strides
310-
in_array2.get_data(),
311-
in_array2.size,
312-
in_array2.ndim,
313-
shape2.data(),
314-
NULL) # in_array2_strides
300+
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
301+
result.get_data(),
302+
result.size,
303+
result.ndim,
304+
NULL, # result_shape
305+
NULL, # result_strides
306+
in_array1.get_data(),
307+
in_array1.size,
308+
in_array1.ndim,
309+
shape1.data(),
310+
NULL, # in_array1_strides
311+
in_array2.get_data(),
312+
in_array2.size,
313+
in_array2.ndim,
314+
shape2.data(),
315+
NULL, # in_array2_strides
316+
NULL) # dep_event_vec_ref
317+
with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
318+
c_dpctl.DPCTLEvent_Delete(event_ref)
315319

316320
return result
317321

0 commit comments

Comments
 (0)