@@ -2,12 +2,13 @@ using oneAPI_Support_Headers_jll
2
2
3
3
include (" generate_helpers.jl" )
4
4
5
- blas = [joinpath (oneAPI_Support_Headers_jll. artifact_dir, " include" , " oneapi" , " mkl" , " blas" , " buffer_decls.hpp" )]
6
- lapack = [joinpath (oneAPI_Support_Headers_jll. artifact_dir, " include" , " oneapi" , " mkl" , " lapack" , " lapack.hpp" ),
7
- joinpath (oneAPI_Support_Headers_jll. artifact_dir, " include" , " oneapi" , " mkl" , " lapack" , " scratchpad.hpp" )]
8
- sparse = [joinpath (oneAPI_Support_Headers_jll. artifact_dir, " include" , " oneapi" , " mkl" , " spblas" , " sparse_structures.hpp" ),
9
- joinpath (oneAPI_Support_Headers_jll. artifact_dir, " include" , " oneapi" , " mkl" , " spblas" , " sparse_auxiliary.hpp" ),
10
- joinpath (oneAPI_Support_Headers_jll. artifact_dir, " include" , " oneapi" , " mkl" , " spblas" , " sparse_operations.hpp" )]
5
+ include_dir = joinpath (oneAPI_Support_Headers_jll. artifact_dir, " include" )
6
+ blas = [joinpath (include_dir, " oneapi" , " mkl" , " blas" , " buffer_decls.hpp" )]
7
+ lapack = [joinpath (include_dir, " oneapi" , " mkl" , " lapack" , " lapack.hpp" ),
8
+ joinpath (include_dir, " oneapi" , " mkl" , " lapack" , " scratchpad.hpp" )]
9
+ sparse = [joinpath (include_dir, " oneapi" , " mkl" , " spblas" , " sparse_structures.hpp" ),
10
+ joinpath (include_dir, " oneapi" , " mkl" , " spblas" , " sparse_auxiliary.hpp" ),
11
+ joinpath (include_dir, " oneapi" , " mkl" , " spblas" , " sparse_operations.hpp" )]
11
12
12
13
dict_version = Dict {Int, Char} (1 => ' S' , 2 => ' D' , 3 => ' C' , 4 => ' Z' )
13
14
@@ -23,7 +24,8 @@ version_types_header = Dict{Char, String}('S' => "float",
23
24
24
25
comments = [" namespace" , " #" , " }" , " /*" , " *" , " //" , " [[" , " ONEMKL_DECLARE_" , " ONEMKL_INLINE_DECLARE" ]
25
26
26
- void_output = [" init_matrix_handle" , " init_matmat_descr" , " release_matmat_descr" , " set_matmat_data" , " get_matmat_data" ]
27
+ void_output = [" init_matrix_handle" , " init_matmat_descr" , " release_matmat_descr" , " set_matmat_data" ,
28
+ " get_matmat_data" , " init_omatadd_descr" , " init_omatconvert_desc" ]
27
29
28
30
function generate_headers (library:: String , filename:: Vector{String} , output:: String ; pattern:: String = " " )
29
31
routines = Dict {String,Int} ()
@@ -189,6 +191,8 @@ function generate_headers(library::String, filename::Vector{String}, output::Str
189
191
header = replace (header, " ,)" => " )" )
190
192
header = replace (header, " void" => " void" )
191
193
header = replace (header, " sycl::event" => " sycl::event" )
194
+ header = replace (header, " * const* " => " **" )
195
+ header = replace (header, " int64_t**" => " int64_t **" )
192
196
193
197
ind1 = findfirst (' ' , header)
194
198
ind2 = findfirst (' (' , header)
@@ -245,6 +249,7 @@ function generate_headers(library::String, filename::Vector{String}, output::Str
245
249
(version == ' C' ) && (header = replace (header, " std::complex " => " float _Complex " ))
246
250
(version == ' Z' ) && (header = replace (header, " std::complex " => " double _Complex " ))
247
251
end
252
+ header = replace (header, " omatconvert (" => " omatconvert(" )
248
253
header = replace (header, " transpose " => " onemklTranspose " )
249
254
header = replace (header, " uplo " => " onemklUplo " )
250
255
header = replace (header, " diag " => " onemklDiag " )
@@ -255,6 +260,8 @@ function generate_headers(library::String, filename::Vector{String}, output::Str
255
260
header = replace (header, " sparse::matrix_view_descr " => " onemklMatrixView " )
256
261
header = replace (header, " matrix_view_descr " => " onemklMatrixView " )
257
262
header = replace (header, " sparse::matmat_request " => " onemklMatmatRequest " )
263
+ header = replace (header, " omatconvert_alg " => " onemklOmatconvertAlg " )
264
+ header = replace (header, " omatadd_alg " => " onemklOmataddAlg " )
258
265
header = replace (header, name_routine => " sparse_" * name_routine)
259
266
end
260
267
push! (signatures, (header, name_routine, version, type_routine, template))
@@ -381,6 +388,10 @@ function generate_cpp(library::String, filename::Vector{String}, output::String;
381
388
parameters = replace (parameters, " matrix_handle_t " => " (oneapi::mkl::sparse::matrix_handle_t) " )
382
389
parameters = replace (parameters, " matmat_descr_t *" => " (oneapi::mkl::sparse::matmat_descr_t*) " )
383
390
parameters = replace (parameters, " matmat_descr_t " => " (oneapi::mkl::sparse::matmat_descr_t) " )
391
+ parameters = replace (parameters, " omatadd_descr_t *" => " (oneapi::mkl::sparse::omatadd_descr_t*) " )
392
+ parameters = replace (parameters, " omatadd_descr_t " => " (oneapi::mkl::sparse::omatadd_descr_t) " )
393
+ parameters = replace (parameters, " omatconvert_descr_t *" => " (oneapi::mkl::sparse::omatconvert_descr_t*) " )
394
+ parameters = replace (parameters, " omatconvert_descr_t " => " (oneapi::mkl::sparse::omatconvert_descr_t) " )
384
395
parameters = replace (parameters, " short **" => " reinterpret_cast<sycl::half **>" )
385
396
parameters = replace (parameters, " float _Complex **" => " reinterpret_cast<std::complex<float> **>" )
386
397
parameters = replace (parameters, " double _Complex **" => " reinterpret_cast<std::complex<double> **>" )
@@ -407,7 +418,8 @@ function generate_cpp(library::String, filename::Vector{String}, output::String;
407
418
408
419
for type in (" onemklTranspose" , " onemklSide" , " onemklUplo" , " onemklDiag" , " onemklGenerate" ,
409
420
" onemklLayout" , " onemklJob" , " onemklJobsvd" , " onemklCompz" , " onemklRangev" ,
410
- " onemklIndex" , " onemklProperty" , " onemklMatrixView" , " onemklMatmatRequest" )
421
+ " onemklIndex" , " onemklProperty" , " onemklMatrixView" , " onemklMatmatRequest" ,
422
+ " onemklOmatconvertAlg" , " onemklOmataddAlg" )
411
423
parameters = replace (parameters, Regex (" $type ([A-Za-z0-9_]+)," ) => SubstitutionString (" convert(\\ 1)," ))
412
424
parameters = replace (parameters, Regex (" , $type ([A-Za-z0-9_]+)" ) => SubstitutionString (" , convert(\\ 1)" ))
413
425
end
0 commit comments