26
26
#include < cmath>
27
27
#include < iostream>
28
28
#include < mkl_blas_sycl.hpp>
29
+ #include < type_traits>
29
30
30
31
#include < backend_iface.hpp>
31
32
#include " backend_pstl.hpp"
32
33
#include " backend_utils.hpp"
33
34
#include " queue_sycl.hpp"
34
35
36
+ namespace mkl_blas = oneapi::mkl::blas;
37
+
35
38
template <typename _KernelNameSpecialization>
36
39
class custom_blas_gemm_c_kernel ;
37
40
@@ -43,43 +46,78 @@ void custom_blas_gemm_c(void* array1_in, void* array2_in, void* result1, size_t
43
46
_DataType* array_2 = reinterpret_cast <_DataType*>(array2_in);
44
47
_DataType* result = reinterpret_cast <_DataType*>(result1);
45
48
46
- // input1: M x K
47
- // input2: K x N
48
- // result: M x N
49
- const size_t dim_m = size_m; // shape1.front(); // First dimensions of array1
50
- const size_t dim_n = size_n; // shape2.back(); // Last dimensions of array2
51
- const size_t dim_k = size_k; // shape1.back(); // First dimensions of array2
52
-
53
- cl::sycl::range<2 > gws (dim_m, dim_n); // dimensions are: "i" and "j"
54
- event = DPNP_QUEUE.submit ([&](cl::sycl::handler& cgh) {
55
- cgh.parallel_for <class custom_blas_gemm_c_kernel <_DataType> >(
56
- gws,
57
- [=](cl::sycl::id<2 > global_id)
49
+ if (!size_m || !size_n || !size_k)
50
+ {
51
+ return ;
52
+ }
53
+
54
+ if constexpr (std::is_same<_DataType, double >::value || std::is_same<_DataType, float >::value)
55
+ {
56
+ // using std::max for these ldx variables is required by MKL
57
+ const std::int64_t lda = std::max<size_t >(1UL , size_k); // First dimensions of array_1
58
+ const std::int64_t ldb = std::max<size_t >(1UL , size_n); // First dimensions of array_2
59
+ const std::int64_t ldc = std::max<size_t >(1UL , size_n); // Fast dimensions of result
60
+
61
+ event = mkl_blas::gemm (DPNP_QUEUE,
62
+ oneapi::mkl::transpose::nontrans,
63
+ oneapi::mkl::transpose::nontrans,
64
+ size_n,
65
+ size_m,
66
+ size_k,
67
+ _DataType (1 ),
68
+ array_2,
69
+ ldb,
70
+ array_1,
71
+ lda,
72
+ _DataType (0 ),
73
+ result,
74
+ ldc);
75
+ }
76
+ else
77
+ {
78
+ // input1: M x K
79
+ // input2: K x N
80
+ // result: M x N
81
+ const size_t dim_m = size_m; // shape1.front(); // First dimensions of array1
82
+ const size_t dim_n = size_n; // shape2.back(); // Last dimensions of array2
83
+ const size_t dim_k = size_k; // shape1.back(); // First dimensions of array2
84
+
85
+ cl::sycl::range<2 > gws (dim_m, dim_n); // dimensions are: "i" and "j"
86
+
87
+ auto kernel_parallel_for_func = [=](cl::sycl::id<2 > global_id) {
88
+ size_t i = global_id[0 ]; // for (size_t i = 0; i < size; ++i)
58
89
{
59
- size_t i = global_id[0 ]; // for (size_t i = 0; i < size; ++i )
90
+ size_t j = global_id[1 ]; // for (size_t j = 0; j < size; ++j )
60
91
{
61
- size_t j = global_id[1 ]; // for (size_t j = 0; j < size; ++j)
92
+ _DataType acc = _DataType (0 );
93
+ for (size_t k = 0 ; k < dim_k; ++k)
62
94
{
63
- _DataType acc = _DataType (0 );
64
- for (size_t k = 0 ; k < dim_k; ++k)
65
- {
66
- const size_t index_1 = i * dim_k + k;
67
- const size_t index_2 = k * dim_n + j;
68
- acc += array_1[index_1] * array_2[index_2];
69
- }
70
- const size_t index_result = i * dim_n + j;
71
- result[index_result] = acc;
95
+ const size_t index_1 = i * dim_k + k;
96
+ const size_t index_2 = k * dim_n + j;
97
+ acc += array_1[index_1] * array_2[index_2];
72
98
}
99
+ const size_t index_result = i * dim_n + j;
100
+ result[index_result] = acc;
73
101
}
74
- }); // parallel_for
75
- }); // queue.submit
102
+ }
103
+ };
104
+
105
+ auto kernel_func = [&](cl::sycl::handler& cgh) {
106
+ cgh.parallel_for <class custom_blas_gemm_c_kernel <_DataType>>(gws, kernel_parallel_for_func);
107
+ };
76
108
109
+ event = DPNP_QUEUE.submit (kernel_func);
110
+ }
77
111
event.wait ();
78
112
}
79
113
114
+ template void custom_blas_gemm_c<int >(
115
+ void * array1_in, void * array2_in, void * result1, size_t size_m, size_t size_n, size_t size_k);
80
116
template void custom_blas_gemm_c<long >(
81
117
void * array1_in, void * array2_in, void * result1, size_t size_m, size_t size_n, size_t size_k);
82
- template void custom_blas_gemm_c<int >(
118
+ template void custom_blas_gemm_c<float >(
119
+ void * array1_in, void * array2_in, void * result1, size_t size_m, size_t size_n, size_t size_k);
120
+ template void custom_blas_gemm_c<double >(
83
121
void * array1_in, void * array2_in, void * result1, size_t size_m, size_t size_n, size_t size_k);
84
122
85
123
template <typename _KernelNameSpecialization>
0 commit comments