Skip to content

Commit f3b3f8b

Browse files
authored
[CUSOLVER] Support symmetric factorization without pivoting (#2640)
1 parent dd23441 commit f3b3f8b

File tree

3 files changed

+74
-15
lines changed

3 files changed

+74
-15
lines changed

lib/cusolver/dense.jl

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,34 @@ for (bname, fname,elty) in ((:cusolverDnSsytrf_bufferSize, :cusolverDnSsytrf, :F
206206
A, ipiv, info
207207
end
208208

209-
function sytrf!(uplo::Char, A::StridedCuMatrix{$elty})
209+
function sytrf!(uplo::Char, A::StridedCuMatrix{$elty}; pivoting::Bool=true)
210210
n = checksquare(A)
211-
ipiv = CuArray{Cint}(undef, n)
212-
sytrf!(uplo, A, ipiv)
211+
if pivoting
212+
ipiv = CuArray{Cint}(undef, n)
213+
return sytrf!(uplo, A, ipiv)
214+
else
215+
CUSOLVER.version() < v"11.7.2" && throw(ErrorException("This operation is not supported by the current CUDA version."))
216+
chkuplo(uplo)
217+
n = checksquare(A)
218+
lda = max(1, stride(A, 2))
219+
dh = dense_handle()
220+
221+
function bufferSize()
222+
out = Ref{Cint}(0)
223+
$bname(dh, n, A, lda, out)
224+
return out[] * sizeof($elty)
225+
end
226+
227+
with_workspace(dh.workspace_gpu, bufferSize) do buffer
228+
$fname(dh, uplo, n, A, lda, CU_NULL,
229+
buffer, sizeof(buffer) ÷ sizeof($elty), dh.info)
230+
end
231+
232+
info = @allowscalar dh.info[1]
233+
chkargsok(info |> BlasInt)
234+
235+
return A, info
236+
end
213237
end
214238
end
215239
end

lib/cusolver/dense_generic.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,33 @@ function sytrs!(uplo::Char, A::StridedCuMatrix{T}, p::CuVector{Int64}, B::Stride
149149
B
150150
end
151151

152+
function sytrs!(uplo::Char, A::StridedCuMatrix{T}, B::StridedCuMatrix{T}) where {T <: BlasFloat}
153+
chkuplo(uplo)
154+
n = checksquare(A)
155+
nrhs = size(B, 2)
156+
lda = max(1, stride(A, 2))
157+
ldb = max(1, stride(B, 2))
158+
dh = dense_handle()
159+
160+
function bufferSize()
161+
out_cpu = Ref{Csize_t}(0)
162+
out_gpu = Ref{Csize_t}(0)
163+
cusolverDnXsytrs_bufferSize(dh, uplo, n, nrhs, T, A,
164+
lda, CU_NULL, T, B, ldb, out_gpu, out_cpu)
165+
out_gpu[], out_cpu[]
166+
end
167+
with_workspaces(dh.workspace_gpu, dh.workspace_cpu,
168+
bufferSize()...) do buffer_gpu, buffer_cpu
169+
cusolverDnXsytrs(dh, uplo, n, nrhs, T, A, lda, CU_NULL,
170+
T, B, ldb, buffer_gpu, sizeof(buffer_gpu),
171+
buffer_cpu, sizeof(buffer_cpu), dh.info)
172+
end
173+
174+
flag = @allowscalar dh.info[1]
175+
chkargsok(flag |> BlasInt)
176+
B
177+
end
178+
152179
# Xtrtri
153180
function trtri!(uplo::Char, diag::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
154181
chkuplo(uplo)

test/libraries/cusolver/dense_generic.jl

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,18 +92,26 @@ p = 5
9292
end
9393

9494
@testset "sytrs!" begin
95-
for uplo in ('L', 'U')
96-
A = rand(elty,n,n)
97-
B = rand(elty,n,p)
98-
A = A + transpose(A)
99-
d_A = CuMatrix(A)
100-
d_B = CuMatrix(B)
101-
d_A, d_ipiv, _ = CUSOLVER.sytrf!(uplo, d_A)
102-
d_ipiv = CuVector{Int64}(d_ipiv)
103-
A, ipiv, _ = LAPACK.sytrf!(uplo, A)
104-
CUSOLVER.sytrs!(uplo, d_A, d_ipiv, d_B)
105-
LAPACK.sytrs!(uplo, A, ipiv, B)
106-
@test B collect(d_B)
95+
@testset "uplo = $uplo" for uplo in ('L', 'U')
96+
@testset "pivoting = $pivoting" for pivoting in (false, true)
97+
A = rand(elty,n,n)
98+
B = rand(elty,n,p)
99+
A = A + transpose(A)
100+
d_A = CuMatrix(A)
101+
d_B = CuMatrix(B)
102+
!pivoting && (CUSOLVER.version() < v"11.7.2") && continue
103+
if pivoting
104+
d_A, d_ipiv, _ = CUSOLVER.sytrf!(uplo, d_A; pivoting)
105+
d_ipiv = CuVector{Int64}(d_ipiv)
106+
CUSOLVER.sytrs!(uplo, d_A, d_ipiv, d_B)
107+
else
108+
d_A, _ = CUSOLVER.sytrf!(uplo, d_A; pivoting)
109+
CUSOLVER.sytrs!(uplo, d_A, d_B)
110+
end
111+
A, ipiv, _ = LAPACK.sytrf!(uplo, A)
112+
LAPACK.sytrs!(uplo, A, ipiv, B)
113+
@test B collect(d_B)
114+
end
107115
end
108116
end
109117

0 commit comments

Comments
 (0)