Skip to content

Commit 4c99634

Browse files
authored
GEMM combine MKL and SYCL into one kernel. plus some restyle. (#110)
1 parent 29e191b commit 4c99634

File tree

5 files changed

+84
-115
lines changed

5 files changed

+84
-115
lines changed

dpnp/backend/backend_iface.hpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,21 @@ INP_DLLEXPORT char* dpnp_memory_alloc_c(size_t size_in_bytes);
9393
INP_DLLEXPORT void dpnp_memory_free_c(void* ptr);
9494
void dpnp_memory_memcpy_c(void* dst, const void* src, size_t size_in_bytes);
9595

96+
/**
97+
* @ingroup BACKEND_API
98+
* @brief Matrix multiplication.
99+
*
100+
* Matrix multiplication procedure. Works with 2-D matrices
101+
*
102+
* @param [in] array1 Input array.
103+
*
104+
* @param [in] array2 Input array.
105+
*
106+
* @param [out] result1 Output array.
107+
*
108+
* @param [in] size Number of elements in input arrays.
109+
*
110+
*/
96111
template <typename _DataType>
97112
INP_DLLEXPORT void
98113
custom_blas_gemm_c(void* array1, void* array2, void* result1, size_t size_m, size_t size_n, size_t size_k);
@@ -388,15 +403,6 @@ template <typename _DataType, typename _ResultType>
388403
INP_DLLEXPORT void custom_var_c(
389404
void* array, void* result, const size_t* shape, size_t ndim, const size_t* axis, size_t naxis, size_t ddof);
390405

391-
#if 0 // Example for OpenCL kernel
392-
template <typename _DataType>
393-
void custom_dgemm_c_opencl(void* array_1, void* array_2, void* result_1, size_t size);
394-
#endif
395-
396-
template <typename _DataType>
397-
INP_DLLEXPORT void
398-
dpnp_blas_gemm_c(void* array1, void* array2, void* result1, size_t size_m, size_t size_n, size_t size_k);
399-
400406
/**
401407
* @ingroup BACKEND_API
402408
* @brief Element wise function __name__

dpnp/backend/backend_iface_fptr.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,8 +416,8 @@ static func_map_t func_map_init()
416416

417417
fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_INT][eft_INT] = {eft_INT, (void*)custom_blas_gemm_c<int>};
418418
fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_LNG][eft_LNG] = {eft_LNG, (void*)custom_blas_gemm_c<long>};
419-
fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_blas_gemm_c<float>};
420-
fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_blas_gemm_c<double>};
419+
fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_FLT][eft_FLT] = {eft_FLT, (void*)custom_blas_gemm_c<float>};
420+
fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_DBL][eft_DBL] = {eft_DBL, (void*)custom_blas_gemm_c<double>};
421421

422422
fmap[DPNPFuncName::DPNP_FN_MAX][eft_INT][eft_INT] = {eft_INT, (void*)custom_max_c<int>};
423423
fmap[DPNPFuncName::DPNP_FN_MAX][eft_LNG][eft_LNG] = {eft_LNG, (void*)custom_max_c<long>};

dpnp/backend/custom_kernels.cpp

Lines changed: 64 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,15 @@
2626
#include <cmath>
2727
#include <iostream>
2828
#include <mkl_blas_sycl.hpp>
29+
#include <type_traits>
2930

3031
#include <backend_iface.hpp>
3132
#include "backend_pstl.hpp"
3233
#include "backend_utils.hpp"
3334
#include "queue_sycl.hpp"
3435

36+
namespace mkl_blas = oneapi::mkl::blas;
37+
3538
template <typename _KernelNameSpecialization>
3639
class custom_blas_gemm_c_kernel;
3740

@@ -43,43 +46,78 @@ void custom_blas_gemm_c(void* array1_in, void* array2_in, void* result1, size_t
4346
_DataType* array_2 = reinterpret_cast<_DataType*>(array2_in);
4447
_DataType* result = reinterpret_cast<_DataType*>(result1);
4548

46-
// input1: M x K
47-
// input2: K x N
48-
// result: M x N
49-
const size_t dim_m = size_m; // shape1.front(); // First dimensions of array1
50-
const size_t dim_n = size_n; // shape2.back(); // Last dimensions of array2
51-
const size_t dim_k = size_k; // shape1.back(); // First dimensions of array2
52-
53-
cl::sycl::range<2> gws(dim_m, dim_n); // dimensions are: "i" and "j"
54-
event = DPNP_QUEUE.submit([&](cl::sycl::handler& cgh) {
55-
cgh.parallel_for<class custom_blas_gemm_c_kernel<_DataType> >(
56-
gws,
57-
[=](cl::sycl::id<2> global_id)
49+
if (!size_m || !size_n || !size_k)
50+
{
51+
return;
52+
}
53+
54+
if constexpr (std::is_same<_DataType, double>::value || std::is_same<_DataType, float>::value)
55+
{
56+
// using std::max for these ldx variables is required by MKL
57+
const std::int64_t lda = std::max<size_t>(1UL, size_k); // First dimensions of array_1
58+
const std::int64_t ldb = std::max<size_t>(1UL, size_n); // First dimensions of array_2
59+
const std::int64_t ldc = std::max<size_t>(1UL, size_n); // Fast dimensions of result
60+
61+
event = mkl_blas::gemm(DPNP_QUEUE,
62+
oneapi::mkl::transpose::nontrans,
63+
oneapi::mkl::transpose::nontrans,
64+
size_n,
65+
size_m,
66+
size_k,
67+
_DataType(1),
68+
array_2,
69+
ldb,
70+
array_1,
71+
lda,
72+
_DataType(0),
73+
result,
74+
ldc);
75+
}
76+
else
77+
{
78+
// input1: M x K
79+
// input2: K x N
80+
// result: M x N
81+
const size_t dim_m = size_m; // shape1.front(); // First dimensions of array1
82+
const size_t dim_n = size_n; // shape2.back(); // Last dimensions of array2
83+
const size_t dim_k = size_k; // shape1.back(); // First dimensions of array2
84+
85+
cl::sycl::range<2> gws(dim_m, dim_n); // dimensions are: "i" and "j"
86+
87+
auto kernel_parallel_for_func = [=](cl::sycl::id<2> global_id) {
88+
size_t i = global_id[0]; //for (size_t i = 0; i < size; ++i)
5889
{
59-
size_t i = global_id[0]; //for (size_t i = 0; i < size; ++i)
90+
size_t j = global_id[1]; //for (size_t j = 0; j < size; ++j)
6091
{
61-
size_t j = global_id[1]; //for (size_t j = 0; j < size; ++j)
92+
_DataType acc = _DataType(0);
93+
for (size_t k = 0; k < dim_k; ++k)
6294
{
63-
_DataType acc = _DataType(0);
64-
for (size_t k = 0; k < dim_k; ++k)
65-
{
66-
const size_t index_1 = i * dim_k + k;
67-
const size_t index_2 = k * dim_n + j;
68-
acc += array_1[index_1] * array_2[index_2];
69-
}
70-
const size_t index_result = i * dim_n + j;
71-
result[index_result] = acc;
95+
const size_t index_1 = i * dim_k + k;
96+
const size_t index_2 = k * dim_n + j;
97+
acc += array_1[index_1] * array_2[index_2];
7298
}
99+
const size_t index_result = i * dim_n + j;
100+
result[index_result] = acc;
73101
}
74-
}); // parallel_for
75-
}); // queue.submit
102+
}
103+
};
104+
105+
auto kernel_func = [&](cl::sycl::handler& cgh) {
106+
cgh.parallel_for<class custom_blas_gemm_c_kernel<_DataType>>(gws, kernel_parallel_for_func);
107+
};
76108

109+
event = DPNP_QUEUE.submit(kernel_func);
110+
}
77111
event.wait();
78112
}
79113

114+
template void custom_blas_gemm_c<int>(
115+
void* array1_in, void* array2_in, void* result1, size_t size_m, size_t size_n, size_t size_k);
80116
template void custom_blas_gemm_c<long>(
81117
void* array1_in, void* array2_in, void* result1, size_t size_m, size_t size_n, size_t size_k);
82-
template void custom_blas_gemm_c<int>(
118+
template void custom_blas_gemm_c<float>(
119+
void* array1_in, void* array2_in, void* result1, size_t size_m, size_t size_n, size_t size_k);
120+
template void custom_blas_gemm_c<double>(
83121
void* array1_in, void* array2_in, void* result1, size_t size_m, size_t size_n, size_t size_k);
84122

85123
template <typename _KernelNameSpecialization>

dpnp/backend/mkl_wrap_blas3.cpp

Lines changed: 0 additions & 75 deletions
This file was deleted.

setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,9 @@
205205
if _mkl_root is None:
206206
raise EnvironmentError("Intel NumPy: Please install Intel OneAPI environment. MKLROOT is empty")
207207
_mkl_include = [os.path.join(_mkl_root, 'include')]
208-
_mkl_libs = ["mkl_rt", "mkl_sycl", "mkl_intel_ilp64", "mkl_sequential", "mkl_core", "sycl", "OpenCL", "pthread", "m", "dl"]
209-
_project_cmplr_macro += [("MKL_ILP64", "1")] # using 64bit integers in MKL interface (long)
208+
_mkl_libs = ["mkl_rt", "mkl_sycl", "mkl_intel_ilp64", "mkl_sequential",
209+
"mkl_core", "sycl", "OpenCL", "pthread", "m", "dl"]
210+
_project_cmplr_macro += [("MKL_ILP64", "1")] # using 64bit integers in MKL interface (long)
210211

211212
_mkl_libpath = [os.path.join(_mkl_root, 'lib', 'intel64')]
212213
if IS_LIN:
@@ -272,7 +273,6 @@
272273
"dpnp/backend/custom_kernels_statistics.cpp",
273274
"dpnp/backend/memory_sycl.cpp",
274275
"dpnp/backend/mkl_wrap_blas1.cpp",
275-
"dpnp/backend/mkl_wrap_blas3.cpp",
276276
"dpnp/backend/mkl_wrap_lapack.cpp",
277277
"dpnp/backend/mkl_wrap_rng.cpp",
278278
"dpnp/backend/queue_sycl.cpp"

0 commit comments

Comments
 (0)