Skip to content

Commit 5b20411

Browse files
authored
[CUSPARSE] Interface gtsv2 (#1795)
1 parent 563a06c commit 5b20411

File tree

2 files changed

+124
-10
lines changed

2 files changed

+124
-10
lines changed

lib/cusparse/preconditioners.jl

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,34 @@
11
# routines that implement different preconditioners
22

3-
export ic02!, ic02, ilu02!, ilu02
3+
export ic02!, ic02, ilu02!, ilu02, gtsv2!, gtsv2
44

55
"""
6-
ic02!(A::CuSparseMatrix, index::SparseChar)
6+
ic02!(A::CuSparseMatrix, index::SparseChar='O')
77
88
Incomplete Cholesky factorization with no pivoting.
99
Preserves the sparse layout of matrix `A`.
1010
"""
11-
ic02!(A::CuSparseMatrix, index::SparseChar)
11+
function ic02! end
12+
13+
"""
14+
ilu02!(A::CuSparseMatrix, index::SparseChar='O')
15+
16+
Incomplete LU factorization with no pivoting.
17+
Preserves the sparse layout of matrix `A`.
18+
"""
19+
function ilu02! end
20+
21+
"""
22+
gtsv2!(dl::CuVector, d::CuVector, du::CuVector, B::CuVecOrMat, index::SparseChar='O'; pivoting::Bool=true)
23+
24+
Solve the linear system `A * X = B` where `A` is a tridiagonal matrix defined
25+
by three vectors corresponding to its lower (`dl`), main (`d`), and upper (`du`) diagonals.
26+
With `pivoting`, the solution is more accurate but also more expensive.
27+
Note that the solution `X` overwrites the right-hand side `B`.
28+
"""
29+
function gtsv2! end
30+
31+
# csric02
1232
for (bname,aname,sname,elty) in ((:cusparseScsric02_bufferSize, :cusparseScsric02_analysis, :cusparseScsric02, :Float32),
1333
(:cusparseDcsric02_bufferSize, :cusparseDcsric02_analysis, :cusparseDcsric02, :Float64),
1434
(:cusparseCcsric02_bufferSize, :cusparseCcsric02_analysis, :cusparseCcsric02, :ComplexF32),
@@ -88,13 +108,7 @@ for (bname,aname,sname,elty) in ((:cusparseScsric02_bufferSize, :cusparseScsric0
88108
end
89109
end
90110

91-
"""
92-
ilu02!(A::CuSparseMatrix, index::SparseChar)
93-
94-
Incomplete LU factorization with no pivoting.
95-
Preserves the sparse layout of matrix `A`.
96-
"""
97-
ilu02!(A::CuSparseMatrix, index::SparseChar)
111+
# csrilu02
98112
for (bname,aname,sname,elty) in ((:cusparseScsrilu02_bufferSize, :cusparseScsrilu02_analysis, :cusparseScsrilu02, :Float32),
99113
(:cusparseDcsrilu02_bufferSize, :cusparseDcsrilu02_analysis, :cusparseDcsrilu02, :Float64),
100114
(:cusparseCcsrilu02_bufferSize, :cusparseCcsrilu02_analysis, :cusparseCcsrilu02, :ComplexF32),
@@ -280,3 +294,49 @@ for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
280294
end
281295
end
282296
end
297+
298+
# gtsv2
299+
for (bname_pivot,fname_pivot,bname_nopivot,fname_nopivot,elty) in ((:cusparseSgtsv2_bufferSizeExt, :cusparseSgtsv2, :cusparseSgtsv2_nopivot_bufferSizeExt, :cusparseSgtsv2_nopivot, :Float32),
300+
(:cusparseDgtsv2_bufferSizeExt, :cusparseDgtsv2, :cusparseDgtsv2_nopivot_bufferSizeExt, :cusparseDgtsv2_nopivot, :Float64),
301+
(:cusparseCgtsv2_bufferSizeExt, :cusparseCgtsv2, :cusparseCgtsv2_nopivot_bufferSizeExt, :cusparseCgtsv2_nopivot, :ComplexF32),
302+
(:cusparseZgtsv2_bufferSizeExt, :cusparseZgtsv2, :cusparseZgtsv2_nopivot_bufferSizeExt, :cusparseZgtsv2_nopivot, :ComplexF64))
303+
@eval begin
304+
function gtsv2!(dl::CuVector{$elty}, d::CuVector{$elty}, du::CuVector{$elty}, B::CuVecOrMat{$elty}, index::SparseChar='O'; pivoting::Bool=true)
305+
ml = length(dl)
306+
m = length(d)
307+
mu = length(du)
308+
mB = size(B,1)
309+
(m 2) && throw(DimensionMismatch("The size of the linear system must be at least 3."))
310+
!(ml == m == mu) && throw(DimensionMismatch("(dl, d, du) must have the same length, the size of the vectors is ($ml,$m,$mu)!"))
311+
(m != mB) && throw(DimensionMismatch("The tridiagonal matrix and the right-hand side B have inconsistent dimensions ($m != $mB)!"))
312+
n = size(B,2)
313+
ldb = max(1,stride(B,2))
314+
315+
function bufferSize()
316+
out = Ref{Csize_t}(1)
317+
if pivoting
318+
$bname_pivot(handle(), m, n, dl, d, du, B, ldb, out)
319+
else
320+
$bname_nopivot(handle(), m, n, dl, d, du, B, ldb, out)
321+
end
322+
return out[]
323+
end
324+
with_workspace(bufferSize) do buffer
325+
if pivoting
326+
$fname_pivot(handle(), m, n, dl, d, du, B, ldb, buffer)
327+
else
328+
$fname_nopivot(handle(), m, n, dl, d, du, B, ldb, buffer)
329+
end
330+
end
331+
B
332+
end
333+
end
334+
end
335+
336+
for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
337+
@eval begin
338+
function gtsv2(dl::CuVector{$elty}, d::CuVector{$elty}, du::CuVector{$elty}, B::CuVecOrMat{$elty}, index::SparseChar='O'; pivoting::Bool=true)
339+
gtsv2!(dl, d, du, copy(B), index; pivoting)
340+
end
341+
end
342+
end

test/cusparse.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,3 +972,57 @@ for SparseMatrixType in [CuSparseMatrixCSC, CuSparseMatrixCSR]
972972
end
973973
end
974974
end
975+
976+
@testset "gtsv2" begin
977+
dl1 = [0; 1; 3]
978+
d1 = [1; 1; 4]
979+
du1 = [1; 2; 0]
980+
B1 = [1 0 0; 0 1 0; 0 0 1]
981+
X1 = [1/3 2/3 -1/3; 2/3 -2/3 1/3; -1/2 1/2 0]
982+
983+
dl2 = [0; 1; 1; 1; 1; 1; 0]
984+
d2 = [6; 4; 4; 4; 4; 4; 6]
985+
du2 = [0; 1; 1; 1; 1; 1; 0]
986+
B2 = [0; 1; 2; -6; 2; 1; 0]
987+
X2 = [0; 0; 1; -2; 1; 0; 0]
988+
989+
dl3 = [0; 1; 1; 7; 6; 3; 8; 6; 5; 4]
990+
d3 = [2; 3; 3; 2; 2; 4; 1; 2; 4; 5]
991+
du3 = [1; 2; 1; 6; 1; 3; 5; 7; 3; 0]
992+
B3 = [1; 2; 6; 34; 10; 1; 4; 22; 25; 3]
993+
X3 = [1; -1; 2; 1; 3; -2; 0; 4; 2; -1]
994+
for pivoting (false, true)
995+
@testset "gtsv2 with pivoting=$pivoting -- $elty" for elty in [Float32,Float64,ComplexF32,ComplexF64]
996+
@testset "example 1" begin
997+
dl1_d = CuVector{elty}(dl1)
998+
d1_d = CuVector{elty}(d1)
999+
du1_d = CuVector{elty}(du1)
1000+
B1_d = CuArray{elty}(B1)
1001+
X1_d = gtsv2(dl1_d, d1_d, du1_d, B1_d; pivoting)
1002+
@test collect(X1_d) X1
1003+
gtsv2!(dl1_d, d1_d, du1_d, B1_d; pivoting)
1004+
@test collect(B1_d) X1
1005+
end
1006+
@testset "example 2" begin
1007+
dl2_d = CuVector{elty}(dl2)
1008+
d2_d = CuVector{elty}(d2)
1009+
du2_d = CuVector{elty}(du2)
1010+
B2_d = CuArray{elty}(B2)
1011+
X2_d = gtsv2(dl2_d, d2_d, du2_d, B2_d; pivoting)
1012+
@test collect(X2_d) X2
1013+
gtsv2!(dl2_d, d2_d, du2_d, B2_d; pivoting)
1014+
@test collect(B2_d) X2
1015+
end
1016+
@testset "example 3" begin
1017+
dl3_d = CuVector{elty}(dl3)
1018+
d3_d = CuVector{elty}(d3)
1019+
du3_d = CuVector{elty}(du3)
1020+
B3_d = CuArray{elty}(B3)
1021+
X3_d = gtsv2(dl3_d, d3_d, du3_d, B3_d; pivoting)
1022+
@test collect(X3_d) X3
1023+
gtsv2!(dl3_d, d3_d, du3_d, B3_d; pivoting)
1024+
@test collect(B3_d) X3
1025+
end
1026+
end
1027+
end
1028+
end

0 commit comments

Comments
 (0)