Skip to content

Commit f3156e3

Browse files
authored
[CUSOLVER] Improve the dispatch for LAPACK routines (#1806)
1 parent d53a63e commit f3156e3

File tree

2 files changed

+93
-54
lines changed

2 files changed

+93
-54
lines changed

lib/cusolver/dense.jl

Lines changed: 91 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
using LinearAlgebra
33
using LinearAlgebra: BlasInt, checksquare
4-
using LinearAlgebra.LAPACK: chkargsok, chklapackerror
4+
using LinearAlgebra.LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdiag
55

66
using ..CUBLAS: unsafe_batch
77

@@ -18,10 +18,11 @@ for (bname, fname,elty) in ((:cusolverDnSpotrf_bufferSize, :cusolverDnSpotrf, :F
1818
(:cusolverDnCpotrf_bufferSize, :cusolverDnCpotrf, :ComplexF32),
1919
(:cusolverDnZpotrf_bufferSize, :cusolverDnZpotrf, :ComplexF64))
2020
@eval begin
21-
function LinearAlgebra.LAPACK.potrf!(uplo::Char,
21+
function potrf!(uplo::Char,
2222
A::StridedCuMatrix{$elty})
23-
n = checksquare(A)
24-
lda = max(1, stride(A, 2))
23+
VERSION >= v"1.8" && LinearAlgebra.BLAS.chkuplo(uplo)
24+
n = checksquare(A)
25+
lda = max(1, stride(A, 2))
2526

2627
function bufferSize()
2728
out = Ref{Cint}(0)
@@ -49,9 +50,10 @@ for (fname,elty) in ((:cusolverDnSpotrs, :Float32),
4950
(:cusolverDnCpotrs, :ComplexF32),
5051
(:cusolverDnZpotrs, :ComplexF64))
5152
@eval begin
52-
function LinearAlgebra.LAPACK.potrs!(uplo::Char,
53+
function potrs!(uplo::Char,
5354
A::StridedCuMatrix{$elty},
5455
B::StridedCuVecOrMat{$elty})
56+
VERSION >= v"1.8" && LinearAlgebra.BLAS.chkuplo(uplo)
5557
n = checksquare(A)
5658
if size(B, 1) != n
5759
throw(DimensionMismatch("first dimension of B, $(size(B,1)), must match second dimension of A, $n"))
@@ -72,24 +74,16 @@ for (fname,elty) in ((:cusolverDnSpotrs, :Float32),
7274
end
7375
end
7476

75-
## potri
76-
"""
77-
potri!(uplo::Char, A::CuMatrix)
78-
79-
!!! note
80-
81-
`potri!` requires CUDA >= 10.1
82-
"""
83-
LinearAlgebra.LAPACK.potri!(uplo::Char, A::CuMatrix)
8477
for (bname, fname,elty) in ((:cusolverDnSpotri_bufferSize, :cusolverDnSpotri, :Float32),
85-
(:cusolverDnDpotri_bufferSize, :cusolverDnDpotri, :Float64),
86-
(:cusolverDnCpotri_bufferSize, :cusolverDnCpotri, :ComplexF32),
87-
(:cusolverDnZpotri_bufferSize, :cusolverDnZpotri, :ComplexF64))
88-
@eval begin
89-
function LinearAlgebra.LAPACK.potri!(uplo::Char,
90-
A::StridedCuMatrix{$elty})
91-
n = checksquare(A)
92-
lda = max(1, stride(A, 2))
78+
(:cusolverDnDpotri_bufferSize, :cusolverDnDpotri, :Float64),
79+
(:cusolverDnCpotri_bufferSize, :cusolverDnCpotri, :ComplexF32),
80+
(:cusolverDnZpotri_bufferSize, :cusolverDnZpotri, :ComplexF64))
81+
@eval begin
82+
function potri!(uplo::Char,
83+
A::StridedCuMatrix{$elty})
84+
VERSION >= v"1.8" && LinearAlgebra.BLAS.chkuplo(uplo)
85+
n = checksquare(A)
86+
lda = max(1, stride(A, 2))
9387

9488
function bufferSize()
9589
out = Ref{Cint}(0)
@@ -118,26 +112,26 @@ for (bname, fname,elty) in ((:cusolverDnSgetrf_bufferSize, :cusolverDnSgetrf, :F
118112
(:cusolverDnZgetrf_bufferSize, :cusolverDnZgetrf, :ComplexF64))
119113
@eval begin
120114
function getrf!(A::StridedCuMatrix{$elty})
121-
m,n = size(A)
122-
lda = max(1, stride(A, 2))
115+
m,n = size(A)
116+
lda = max(1, stride(A, 2))
123117

124118
function bufferSize()
125119
out = Ref{Cint}(0)
126120
$bname(dense_handle(), m, n, A, lda, out)
127121
return out[]
128122
end
129123

130-
devipiv = CuArray{Cint}(undef, min(m,n))
124+
ipiv = CuArray{Cint}(undef, min(m,n))
131125
devinfo = CuArray{Cint}(undef, 1)
132126
with_workspace($elty, bufferSize) do buffer
133-
$fname(dense_handle(), m, n, A, lda, buffer, devipiv, devinfo)
127+
$fname(dense_handle(), m, n, A, lda, buffer, ipiv, devinfo)
134128
end
135129

136130
info = @allowscalar devinfo[1]
137131
unsafe_free!(devinfo)
138132
chkargsok(BlasInt(info))
139133

140-
A, devipiv, info
134+
A, ipiv, info
141135
end
142136
end
143137
end
@@ -149,8 +143,8 @@ for (bname, fname,elty) in ((:cusolverDnSgeqrf_bufferSize, :cusolverDnSgeqrf, :F
149143
(:cusolverDnZgeqrf_bufferSize, :cusolverDnZgeqrf, :ComplexF64))
150144
@eval begin
151145
function geqrf!(A::StridedCuMatrix{$elty})
152-
m, n = size(A)
153-
lda = max(1, stride(A, 2))
146+
m, n = size(A)
147+
lda = max(1, stride(A, 2))
154148

155149
function bufferSize()
156150
out = Ref{Cint}(0)
@@ -181,7 +175,8 @@ for (bname, fname,elty) in ((:cusolverDnSsytrf_bufferSize, :cusolverDnSsytrf, :F
181175
@eval begin
182176
function sytrf!(uplo::Char,
183177
A::StridedCuMatrix{$elty})
184-
n = checksquare(A)
178+
VERSION >= v"1.8" && LinearAlgebra.BLAS.chkuplo(uplo)
179+
n = checksquare(A)
185180
lda = max(1, stride(A, 2))
186181

187182
function bufferSize()
@@ -190,17 +185,17 @@ for (bname, fname,elty) in ((:cusolverDnSsytrf_bufferSize, :cusolverDnSsytrf, :F
190185
return out[]
191186
end
192187

193-
devipiv = CuArray{Cint}(undef, n)
188+
ipiv = CuArray{Cint}(undef, n)
194189
devinfo = CuArray{Cint}(undef, 1)
195190
with_workspace($elty, bufferSize) do buffer
196-
$fname(dense_handle(), uplo, n, A, lda, devipiv, buffer, length(buffer), devinfo)
191+
$fname(dense_handle(), uplo, n, A, lda, ipiv, buffer, length(buffer), devinfo)
197192
end
198193

199194
info = @allowscalar devinfo[1]
200195
unsafe_free!(devinfo)
201196
chkargsok(BlasInt(info))
202197

203-
A, devipiv, info
198+
A, ipiv, info
204199
end
205200
end
206201
end
@@ -215,6 +210,11 @@ for (fname,elty) in ((:cusolverDnSgetrs, :Float32),
215210
A::StridedCuMatrix{$elty},
216211
ipiv::CuVector{Cint},
217212
B::StridedCuVecOrMat{$elty})
213+
214+
# Support transa = 'C' for real matrices
215+
trans = $elty <: Real && trans == 'C' ? 'T' : trans
216+
217+
chktrans(trans)
218218
n = size(A, 1)
219219
if size(A, 2) != n
220220
throw(DimensionMismatch("LU factored matrix A must be square!"))
@@ -232,6 +232,7 @@ for (fname,elty) in ((:cusolverDnSgetrs, :Float32),
232232
info = @allowscalar devinfo[1]
233233
unsafe_free!(devinfo)
234234
chkargsok(BlasInt(info))
235+
235236
B
236237
end
237238
end
@@ -248,6 +249,12 @@ for (bname, fname, elty) in ((:cusolverDnSormqr_bufferSize, :cusolverDnSormqr, :
248249
A::CuMatrix{$elty},
249250
tau::CuVector{$elty},
250251
C::CuVecOrMat{$elty})
252+
253+
# Support transa = 'C' for real matrices
254+
trans = $elty <: Real && trans == 'C' ? 'T' : trans
255+
256+
chkside(side)
257+
chktrans(trans)
251258
m,n = ndims(C) == 2 ? size(C) : (size(C, 1), 1)
252259
mA = size(A, 1)
253260
k = length(tau)
@@ -330,8 +337,8 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgebrd_bufferSize, :cusolverDnSg
330337
(:cusolverDnZgebrd_bufferSize, :cusolverDnZgebrd, :ComplexF64, :Float64))
331338
@eval begin
332339
function gebrd!(A::StridedCuMatrix{$elty})
333-
m, n = size(A)
334-
lda = max(1, stride(A, 2))
340+
m, n = size(A)
341+
lda = max(1, stride(A, 2))
335342

336343
function bufferSize()
337344
out = Ref{Cint}(0)
@@ -367,7 +374,7 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvd_bufferSize, :cusolverDnSg
367374
function gesvd!(jobu::Char,
368375
jobvt::Char,
369376
A::StridedCuMatrix{$elty})
370-
m, n = size(A)
377+
m, n = size(A)
371378
if m < n
372379
throw(ArgumentError("CUSOLVER's gesvd currently requires m >= n"))
373380
end
@@ -484,6 +491,7 @@ for (jname, bname, fname, elty, relty) in ((:syevd!, :cusolverDnSsyevd_bufferSiz
484491
function $jname(jobz::Char,
485492
uplo::Char,
486493
A::StridedCuMatrix{$elty})
494+
VERSION >= v"1.8" && LinearAlgebra.BLAS.chkuplo(uplo)
487495
n = checksquare(A)
488496
lda = max(1, stride(A, 2))
489497
W = CuArray{$relty}(undef, n)
@@ -523,6 +531,7 @@ for (jname, bname, fname, elty, relty) in ((:sygvd!, :cusolverDnSsygvd_bufferSiz
523531
uplo::Char,
524532
A::StridedCuMatrix{$elty},
525533
B::StridedCuMatrix{$elty})
534+
VERSION >= v"1.8" && LinearAlgebra.BLAS.chkuplo(uplo)
526535
nA, nB = checksquare(A, B)
527536
if nB != nA
528537
throw(DimensionMismatch("Dimensions of A ($nA, $nA) and B ($nB, $nB) must match!"))
@@ -569,6 +578,7 @@ for (jname, bname, fname, elty, relty) in ((:sygvj!, :cusolverDnSsygvj_bufferSiz
569578
B::StridedCuMatrix{$elty};
570579
tol::$relty=eps($relty),
571580
max_sweeps::Int=100)
581+
VERSION >= v"1.8" && LinearAlgebra.BLAS.chkuplo(uplo)
572582
nA, nB = checksquare(A, B)
573583
if nB != nA
574584
throw(DimensionMismatch("Dimensions of A ($nA, $nA) and B ($nB, $nB) must match!"))
@@ -613,8 +623,7 @@ end
613623
for (jname, bname, fname, elty, relty) in ((:syevjBatched!, :cusolverDnSsyevjBatched_bufferSize, :cusolverDnSsyevjBatched, :Float32, :Float32),
614624
(:syevjBatched!, :cusolverDnDsyevjBatched_bufferSize, :cusolverDnDsyevjBatched, :Float64, :Float64),
615625
(:heevjBatched!, :cusolverDnCheevjBatched_bufferSize, :cusolverDnCheevjBatched, :ComplexF32, :Float32),
616-
(:heevjBatched!, :cusolverDnZheevjBatched_bufferSize, :cusolverDnZheevjBatched, :ComplexF64, :Float64)
617-
)
626+
(:heevjBatched!, :cusolverDnZheevjBatched_bufferSize, :cusolverDnZheevjBatched, :ComplexF64, :Float64))
618627
@eval begin
619628
function $jname(jobz::Char,
620629
uplo::Char,
@@ -623,6 +632,7 @@ for (jname, bname, fname, elty, relty) in ((:syevjBatched!, :cusolverDnSsyevjBat
623632
max_sweeps::Int=100)
624633

625634
# Set up information for the solver arguments
635+
VERSION >= v"1.8" && LinearAlgebra.BLAS.chkuplo(uplo)
626636
n = checksquare(A)
627637
lda = max(1, stride(A, 2))
628638
batchSize = size(A,3)
@@ -667,19 +677,19 @@ for (jname, bname, fname, elty, relty) in ((:syevjBatched!, :cusolverDnSsyevjBat
667677
end
668678
end
669679

670-
for (jname, fname, elty) in ((:potrsBatched!, :cusolverDnSpotrsBatched, :Float32),
671-
(:potrsBatched!, :cusolverDnDpotrsBatched, :Float64),
672-
(:potrsBatched!, :cusolverDnCpotrsBatched, :ComplexF32),
673-
(:potrsBatched!, :cusolverDnZpotrsBatched, :ComplexF64)
674-
)
680+
for (fname, elty) in ((:cusolverDnSpotrsBatched, :Float32),
681+
(:cusolverDnDpotrsBatched, :Float64),
682+
(:cusolverDnCpotrsBatched, :ComplexF32),
683+
(:cusolverDnZpotrsBatched, :ComplexF64))
675684
@eval begin
676-
function $jname(uplo::Char,
677-
A::Vector{<:StridedCuMatrix{$elty}},
678-
B::Vector{<:StridedCuVecOrMat{$elty}})
685+
function potrsBatched!(uplo::Char,
686+
A::Vector{<:StridedCuMatrix{$elty}},
687+
B::Vector{<:StridedCuVecOrMat{$elty}})
679688
if length(A) != length(B)
680689
throw(DimensionMismatch(""))
681690
end
682691
# Set up information for the solver arguments
692+
VERSION >= v"1.8" && LinearAlgebra.BLAS.chkuplo(uplo)
683693
n = checksquare(A[1])
684694
if size(B[1], 1) != n
685695
throw(DimensionMismatch("first dimension of B[i], $(size(B[1],1)), must match second dimension of A, $n"))
@@ -710,15 +720,15 @@ for (jname, fname, elty) in ((:potrsBatched!, :cusolverDnSpotrsBatched, :Float32
710720
end
711721
end
712722

713-
for (jname, fname, elty) in ((:potrfBatched!, :cusolverDnSpotrfBatched, :Float32),
714-
(:potrfBatched!, :cusolverDnDpotrfBatched, :Float64),
715-
(:potrfBatched!, :cusolverDnCpotrfBatched, :ComplexF32),
716-
(:potrfBatched!, :cusolverDnZpotrfBatched, :ComplexF64)
717-
)
723+
for (fname, elty) in ((:cusolverDnSpotrfBatched, :Float32),
724+
(:cusolverDnDpotrfBatched, :Float64),
725+
(:cusolverDnCpotrfBatched, :ComplexF32),
726+
(:cusolverDnZpotrfBatched, :ComplexF64))
718727
@eval begin
719-
function $jname(uplo::Char, A::Vector{<:StridedCuMatrix{$elty}})
728+
function potrfBatched!(uplo::Char, A::Vector{<:StridedCuMatrix{$elty}})
720729

721730
# Set up information for the solver arguments
731+
VERSION >= v"1.8" && LinearAlgebra.BLAS.chkuplo(uplo)
722732
n = checksquare(A[1])
723733
lda = max(1, stride(A[1], 2))
724734
batchSize = length(A)
@@ -745,3 +755,32 @@ for (jname, fname, elty) in ((:potrfBatched!, :cusolverDnSpotrfBatched, :Float32
745755
end
746756
end
747757
end
758+
759+
# LAPACK
760+
for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
761+
@eval begin
762+
LinearAlgebra.LAPACK.potrf!(uplo::Char, A::StridedCuMatrix{$elty}) = CUSOLVER.potrf!(uplo, A)
763+
LinearAlgebra.LAPACK.potrs!(uplo::Char, A::StridedCuMatrix{$elty}, B::StridedCuVecOrMat{$elty}) = CUSOLVER.potrs!(uplo, A, B)
764+
LinearAlgebra.LAPACK.potri!(uplo::Char, A::StridedCuMatrix{$elty}) = CUSOLVER.potri!(uplo, A)
765+
LinearAlgebra.LAPACK.getrf!(A::StridedCuMatrix{$elty}) = CUSOLVER.getrf!(A)
766+
LinearAlgebra.LAPACK.geqrf!(A::StridedCuMatrix{$elty}) = CUSOLVER.geqrf!(A)
767+
LinearAlgebra.LAPACK.sytrf!(uplo::Char, A::StridedCuMatrix{$elty}) = sytrf!(uplo, A)
768+
LinearAlgebra.LAPACK.getrs!(trans::Char, A::StridedCuMatrix{$elty}, ipiv::CuVector{Cint}, B::StridedCuVecOrMat{$elty}) = CUSOLVER.getrs!(trans, A, ipiv, B)
769+
LinearAlgebra.LAPACK.ormqr!(side::Char, trans::Char, A::CuMatrix{$elty}, tau::CuVector{$elty}, C::CuVecOrMat{$elty}) = CUSOLVER.ormqr!(side, trans, A, tau, C)
770+
LinearAlgebra.LAPACK.orgqr!(A::StridedCuMatrix{$elty}, tau::CuVector{$elty}) = CUSOLVER.orgqr!(A, tau)
771+
LinearAlgebra.LAPACK.gebrd!(A::StridedCuMatrix{$elty}) = CUSOLVER.gebrd!(A)
772+
LinearAlgebra.LAPACK.gesvd!(jobu::Char, jobvt::Char, A::StridedCuMatrix{$elty}) = CUSOLVER.gesvd!(jobu, jobvt, A)
773+
end
774+
end
775+
776+
for elty in (:Float32, :Float64)
777+
@eval begin
778+
LinearAlgebra.LAPACK.sygvd!(itype::Int, jobz::Char, uplo::Char, A::StridedCuMatrix{$elty}, B::StridedCuMatrix{$elty}) = CUSOLVER.sygvd!(itype, jobz, uplo, A, B)
779+
end
780+
end
781+
782+
for elty in (:ComplexF32, :ComplexF64)
783+
@eval begin
784+
LinearAlgebra.LAPACK.sygvd!(itype::Int, jobz::Char, uplo::Char, A::StridedCuMatrix{$elty}, B::StridedCuMatrix{$elty}) = CUSOLVER.hegvd!(itype, jobz, uplo, A, B)
785+
end
786+
end

lib/cusolver/linalg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
using LinearAlgebra
44
using LinearAlgebra: BlasComplex, BlasFloat, BlasReal
5-
using ..CUBLAS: CublasFloat, trsm!
5+
using ..CUBLAS: CublasFloat
66

77
function copy_cublasfloat(As...)
88
eltypes = eltype.(As)
@@ -53,7 +53,7 @@ function Base.:\(_A::CuMatOrAdj, _B::CuOrAdj)
5353
else
5454
# QR decomposition
5555
F, tau = CUSOLVER.geqrf!(A) # A = QR
56-
CUSOLVER.ormqr!('L', T <: Real ? 'T' : 'C', F, tau, B)
56+
CUSOLVER.ormqr!('L', 'C', F, tau, B)
5757
if B isa CuVector{T}
5858
X = B[1:m]
5959
CUBLAS.trsv!('U', 'N', 'N', view(F,1:m,1:m), X)

0 commit comments

Comments
 (0)