Skip to content

Commit fadcd8d

Browse files
committed
[oneMKL] Interface sparse_optimize_* routines
1 parent 0ea7978 commit fadcd8d

File tree

3 files changed

+22
-1
lines changed

3 files changed

+22
-1
lines changed

deps/src/onemkl.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4146,7 +4146,6 @@ extern "C" int onemklZsparse_update_diagonal_values(syclQueue_t device_queue, ma
41464146

41474147
extern "C" int onemklXsparse_optimize_gemv(syclQueue_t device_queue, onemklTranspose opA, matrix_handle_t A) {
41484148
auto status = oneapi::mkl::sparse::optimize_gemv(device_queue->val, convert(opA), (oneapi::mkl::sparse::matrix_handle_t) A, {});
4149-
__FORCE_MKL_FLUSH__(status);
41504149
return 0;
41514150
}
41524151

lib/mkl/wrappers_sparse.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ for (fname, elty) in ((:onemklSsparse_gemv, :Float32),
5757
end
5858
end
5959

60+
function sparse_optimize_gemv!(trans::Char, A::oneSparseMatrixCSR)
61+
queue = global_queue(context(A.nzVal), device(A.nzVal))
62+
onemklXsparse_optimize_gemv(sycl_queue(queue), trans, A.handle)
63+
return A
64+
end
65+
6066
for (fname, elty) in ((:onemklSsparse_gemm, :Float32),
6167
(:onemklDsparse_gemm, :Float64),
6268
(:onemklCsparse_gemm, :ComplexF32),
@@ -124,6 +130,12 @@ for (fname, elty) in ((:onemklSsparse_trmv, :Float32),
124130
end
125131
end
126132

133+
function sparse_optimize_trmv!(uplo::Char, trans::Char, diag::Char, A::oneSparseMatrixCSR)
134+
queue = global_queue(context(A.nzVal), device(A.nzVal))
135+
onemklXsparse_optimize_trmv(sycl_queue(queue), uplo, trans, diag, A.handle)
136+
return A
137+
end
138+
127139
for (fname, elty) in ((:onemklSsparse_trsv, :Float32),
128140
(:onemklDsparse_trsv, :Float64),
129141
(:onemklCsparse_trsv, :ComplexF32),
@@ -142,3 +154,9 @@ for (fname, elty) in ((:onemklSsparse_trsv, :Float32),
142154
end
143155
end
144156
end
157+
158+
function sparse_optimize_trsv!(uplo::Char, trans::Char, diag::Char, A::oneSparseMatrixCSR)
159+
queue = global_queue(context(A.nzVal), device(A.nzVal))
160+
onemklXsparse_optimize_trsv(sycl_queue(queue), uplo, trans, diag, A.handle)
161+
return A
162+
end

test/onemkl.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,6 +1100,7 @@ end
11001100

11011101
alpha = rand(T)
11021102
beta = rand(T)
1103+
oneMKL.sparse_optimize_gemv!(transa, dA)
11031104
oneMKL.sparse_gemv!(transa, alpha, dA, dx, beta, dy)
11041105
# @test alpha * opa(A) * x + beta * y ≈ collect(dy)
11051106
end
@@ -1160,6 +1161,8 @@ end
11601161

11611162
alpha = rand(T)
11621163
beta = rand(T)
1164+
1165+
oneMKL.sparse_optimize_trmv!(uplo, transa, diag, dA)
11631166
oneMKL.sparse_trmv!(uplo, transa, diag, alpha, dA, dx, beta, dy)
11641167
@test alpha * wrapper(opa(A)) * x + beta * y collect(dy)
11651168
end
@@ -1182,6 +1185,7 @@ end
11821185
dx = oneVector{T}(x)
11831186
dy = oneVector{T}(y)
11841187

1188+
oneMKL.sparse_optimize_trsv!(uplo, transa, diag, dA)
11851189
oneMKL.sparse_trsv!(uplo, transa, diag, dA, dx, dy)
11861190
y = wrapper(opa(A)) \ x
11871191
@test y collect(dy)

0 commit comments

Comments
 (0)