1
1
2
2
using LinearAlgebra
3
3
using LinearAlgebra: BlasInt, checksquare
4
- using LinearAlgebra. LAPACK: chkargsok, chklapackerror
4
+ using LinearAlgebra. LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdiag
5
5
6
6
using .. CUBLAS: unsafe_batch
7
7
@@ -18,10 +18,11 @@ for (bname, fname,elty) in ((:cusolverDnSpotrf_bufferSize, :cusolverDnSpotrf, :F
18
18
(:cusolverDnCpotrf_bufferSize , :cusolverDnCpotrf , :ComplexF32 ),
19
19
(:cusolverDnZpotrf_bufferSize , :cusolverDnZpotrf , :ComplexF64 ))
20
20
@eval begin
21
- function LinearAlgebra . LAPACK . potrf! (uplo:: Char ,
21
+ function potrf! (uplo:: Char ,
22
22
A:: StridedCuMatrix{$elty} )
23
- n = checksquare (A)
24
- lda = max (1 , stride (A, 2 ))
23
+ VERSION >= v " 1.8" && LinearAlgebra. BLAS. chkuplo (uplo)
24
+ n = checksquare (A)
25
+ lda = max (1 , stride (A, 2 ))
25
26
26
27
function bufferSize ()
27
28
out = Ref {Cint} (0 )
@@ -49,9 +50,10 @@ for (fname,elty) in ((:cusolverDnSpotrs, :Float32),
49
50
(:cusolverDnCpotrs , :ComplexF32 ),
50
51
(:cusolverDnZpotrs , :ComplexF64 ))
51
52
@eval begin
52
- function LinearAlgebra . LAPACK . potrs! (uplo:: Char ,
53
+ function potrs! (uplo:: Char ,
53
54
A:: StridedCuMatrix{$elty} ,
54
55
B:: StridedCuVecOrMat{$elty} )
56
+ VERSION >= v " 1.8" && LinearAlgebra. BLAS. chkuplo (uplo)
55
57
n = checksquare (A)
56
58
if size (B, 1 ) != n
57
59
throw (DimensionMismatch (" first dimension of B, $(size (B,1 )) , must match second dimension of A, $n " ))
@@ -72,24 +74,16 @@ for (fname,elty) in ((:cusolverDnSpotrs, :Float32),
72
74
end
73
75
end
74
76
75
- # # potri
76
- """
77
- potri!(uplo::Char, A::CuMatrix)
78
-
79
- !!! note
80
-
81
- `potri!` requires CUDA >= 10.1
82
- """
83
- LinearAlgebra. LAPACK. potri! (uplo:: Char , A:: CuMatrix )
84
77
for (bname, fname,elty) in ((:cusolverDnSpotri_bufferSize , :cusolverDnSpotri , :Float32 ),
85
- (:cusolverDnDpotri_bufferSize , :cusolverDnDpotri , :Float64 ),
86
- (:cusolverDnCpotri_bufferSize , :cusolverDnCpotri , :ComplexF32 ),
87
- (:cusolverDnZpotri_bufferSize , :cusolverDnZpotri , :ComplexF64 ))
88
- @eval begin
89
- function LinearAlgebra. LAPACK. potri! (uplo:: Char ,
90
- A:: StridedCuMatrix{$elty} )
91
- n = checksquare (A)
92
- lda = max (1 , stride (A, 2 ))
78
+ (:cusolverDnDpotri_bufferSize , :cusolverDnDpotri , :Float64 ),
79
+ (:cusolverDnCpotri_bufferSize , :cusolverDnCpotri , :ComplexF32 ),
80
+ (:cusolverDnZpotri_bufferSize , :cusolverDnZpotri , :ComplexF64 ))
81
+ @eval begin
82
+ function potri! (uplo:: Char ,
83
+ A:: StridedCuMatrix{$elty} )
84
+ VERSION >= v " 1.8" && LinearAlgebra. BLAS. chkuplo (uplo)
85
+ n = checksquare (A)
86
+ lda = max (1 , stride (A, 2 ))
93
87
94
88
function bufferSize ()
95
89
out = Ref {Cint} (0 )
@@ -118,26 +112,26 @@ for (bname, fname,elty) in ((:cusolverDnSgetrf_bufferSize, :cusolverDnSgetrf, :F
118
112
(:cusolverDnZgetrf_bufferSize , :cusolverDnZgetrf , :ComplexF64 ))
119
113
@eval begin
120
114
function getrf! (A:: StridedCuMatrix{$elty} )
121
- m,n = size (A)
122
- lda = max (1 , stride (A, 2 ))
115
+ m,n = size (A)
116
+ lda = max (1 , stride (A, 2 ))
123
117
124
118
function bufferSize ()
125
119
out = Ref {Cint} (0 )
126
120
$ bname (dense_handle (), m, n, A, lda, out)
127
121
return out[]
128
122
end
129
123
130
- devipiv = CuArray {Cint} (undef, min (m,n))
124
+ ipiv = CuArray {Cint} (undef, min (m,n))
131
125
devinfo = CuArray {Cint} (undef, 1 )
132
126
with_workspace ($ elty, bufferSize) do buffer
133
- $ fname (dense_handle (), m, n, A, lda, buffer, devipiv , devinfo)
127
+ $ fname (dense_handle (), m, n, A, lda, buffer, ipiv , devinfo)
134
128
end
135
129
136
130
info = @allowscalar devinfo[1 ]
137
131
unsafe_free! (devinfo)
138
132
chkargsok (BlasInt (info))
139
133
140
- A, devipiv , info
134
+ A, ipiv , info
141
135
end
142
136
end
143
137
end
@@ -149,8 +143,8 @@ for (bname, fname,elty) in ((:cusolverDnSgeqrf_bufferSize, :cusolverDnSgeqrf, :F
149
143
(:cusolverDnZgeqrf_bufferSize , :cusolverDnZgeqrf , :ComplexF64 ))
150
144
@eval begin
151
145
function geqrf! (A:: StridedCuMatrix{$elty} )
152
- m, n = size (A)
153
- lda = max (1 , stride (A, 2 ))
146
+ m, n = size (A)
147
+ lda = max (1 , stride (A, 2 ))
154
148
155
149
function bufferSize ()
156
150
out = Ref {Cint} (0 )
@@ -181,7 +175,8 @@ for (bname, fname,elty) in ((:cusolverDnSsytrf_bufferSize, :cusolverDnSsytrf, :F
181
175
@eval begin
182
176
function sytrf! (uplo:: Char ,
183
177
A:: StridedCuMatrix{$elty} )
184
- n = checksquare (A)
178
+ VERSION >= v " 1.8" && LinearAlgebra. BLAS. chkuplo (uplo)
179
+ n = checksquare (A)
185
180
lda = max (1 , stride (A, 2 ))
186
181
187
182
function bufferSize ()
@@ -190,17 +185,17 @@ for (bname, fname,elty) in ((:cusolverDnSsytrf_bufferSize, :cusolverDnSsytrf, :F
190
185
return out[]
191
186
end
192
187
193
- devipiv = CuArray {Cint} (undef, n)
188
+ ipiv = CuArray {Cint} (undef, n)
194
189
devinfo = CuArray {Cint} (undef, 1 )
195
190
with_workspace ($ elty, bufferSize) do buffer
196
- $ fname (dense_handle (), uplo, n, A, lda, devipiv , buffer, length (buffer), devinfo)
191
+ $ fname (dense_handle (), uplo, n, A, lda, ipiv , buffer, length (buffer), devinfo)
197
192
end
198
193
199
194
info = @allowscalar devinfo[1 ]
200
195
unsafe_free! (devinfo)
201
196
chkargsok (BlasInt (info))
202
197
203
- A, devipiv , info
198
+ A, ipiv , info
204
199
end
205
200
end
206
201
end
@@ -215,6 +210,11 @@ for (fname,elty) in ((:cusolverDnSgetrs, :Float32),
215
210
A:: StridedCuMatrix{$elty} ,
216
211
ipiv:: CuVector{Cint} ,
217
212
B:: StridedCuVecOrMat{$elty} )
213
+
214
+ # Support transa = 'C' for real matrices
215
+ trans = $ elty <: Real && trans == ' C' ? ' T' : trans
216
+
217
+ chktrans (trans)
218
218
n = size (A, 1 )
219
219
if size (A, 2 ) != n
220
220
throw (DimensionMismatch (" LU factored matrix A must be square!" ))
@@ -232,6 +232,7 @@ for (fname,elty) in ((:cusolverDnSgetrs, :Float32),
232
232
info = @allowscalar devinfo[1 ]
233
233
unsafe_free! (devinfo)
234
234
chkargsok (BlasInt (info))
235
+
235
236
B
236
237
end
237
238
end
@@ -248,6 +249,12 @@ for (bname, fname, elty) in ((:cusolverDnSormqr_bufferSize, :cusolverDnSormqr, :
248
249
A:: CuMatrix{$elty} ,
249
250
tau:: CuVector{$elty} ,
250
251
C:: CuVecOrMat{$elty} )
252
+
253
+ # Support transa = 'C' for real matrices
254
+ trans = $ elty <: Real && trans == ' C' ? ' T' : trans
255
+
256
+ chkside (side)
257
+ chktrans (trans)
251
258
m,n = ndims (C) == 2 ? size (C) : (size (C, 1 ), 1 )
252
259
mA = size (A, 1 )
253
260
k = length (tau)
@@ -330,8 +337,8 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgebrd_bufferSize, :cusolverDnSg
330
337
(:cusolverDnZgebrd_bufferSize , :cusolverDnZgebrd , :ComplexF64 , :Float64 ))
331
338
@eval begin
332
339
function gebrd! (A:: StridedCuMatrix{$elty} )
333
- m, n = size (A)
334
- lda = max (1 , stride (A, 2 ))
340
+ m, n = size (A)
341
+ lda = max (1 , stride (A, 2 ))
335
342
336
343
function bufferSize ()
337
344
out = Ref {Cint} (0 )
@@ -367,7 +374,7 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvd_bufferSize, :cusolverDnSg
367
374
function gesvd! (jobu:: Char ,
368
375
jobvt:: Char ,
369
376
A:: StridedCuMatrix{$elty} )
370
- m, n = size (A)
377
+ m, n = size (A)
371
378
if m < n
372
379
throw (ArgumentError (" CUSOLVER's gesvd currently requires m >= n" ))
373
380
end
@@ -484,6 +491,7 @@ for (jname, bname, fname, elty, relty) in ((:syevd!, :cusolverDnSsyevd_bufferSiz
484
491
function $jname (jobz:: Char ,
485
492
uplo:: Char ,
486
493
A:: StridedCuMatrix{$elty} )
494
+ VERSION >= v " 1.8" && LinearAlgebra. BLAS. chkuplo (uplo)
487
495
n = checksquare (A)
488
496
lda = max (1 , stride (A, 2 ))
489
497
W = CuArray {$relty} (undef, n)
@@ -523,6 +531,7 @@ for (jname, bname, fname, elty, relty) in ((:sygvd!, :cusolverDnSsygvd_bufferSiz
523
531
uplo:: Char ,
524
532
A:: StridedCuMatrix{$elty} ,
525
533
B:: StridedCuMatrix{$elty} )
534
+ VERSION >= v " 1.8" && LinearAlgebra. BLAS. chkuplo (uplo)
526
535
nA, nB = checksquare (A, B)
527
536
if nB != nA
528
537
throw (DimensionMismatch (" Dimensions of A ($nA , $nA ) and B ($nB , $nB ) must match!" ))
@@ -569,6 +578,7 @@ for (jname, bname, fname, elty, relty) in ((:sygvj!, :cusolverDnSsygvj_bufferSiz
569
578
B:: StridedCuMatrix{$elty} ;
570
579
tol:: $relty = eps ($ relty),
571
580
max_sweeps:: Int = 100 )
581
+ VERSION >= v " 1.8" && LinearAlgebra. BLAS. chkuplo (uplo)
572
582
nA, nB = checksquare (A, B)
573
583
if nB != nA
574
584
throw (DimensionMismatch (" Dimensions of A ($nA , $nA ) and B ($nB , $nB ) must match!" ))
613
623
for (jname, bname, fname, elty, relty) in ((:syevjBatched! , :cusolverDnSsyevjBatched_bufferSize , :cusolverDnSsyevjBatched , :Float32 , :Float32 ),
614
624
(:syevjBatched! , :cusolverDnDsyevjBatched_bufferSize , :cusolverDnDsyevjBatched , :Float64 , :Float64 ),
615
625
(:heevjBatched! , :cusolverDnCheevjBatched_bufferSize , :cusolverDnCheevjBatched , :ComplexF32 , :Float32 ),
616
- (:heevjBatched! , :cusolverDnZheevjBatched_bufferSize , :cusolverDnZheevjBatched , :ComplexF64 , :Float64 )
617
- )
626
+ (:heevjBatched! , :cusolverDnZheevjBatched_bufferSize , :cusolverDnZheevjBatched , :ComplexF64 , :Float64 ))
618
627
@eval begin
619
628
function $jname (jobz:: Char ,
620
629
uplo:: Char ,
@@ -623,6 +632,7 @@ for (jname, bname, fname, elty, relty) in ((:syevjBatched!, :cusolverDnSsyevjBat
623
632
max_sweeps:: Int = 100 )
624
633
625
634
# Set up information for the solver arguments
635
+ VERSION >= v " 1.8" && LinearAlgebra. BLAS. chkuplo (uplo)
626
636
n = checksquare (A)
627
637
lda = max (1 , stride (A, 2 ))
628
638
batchSize = size (A,3 )
@@ -667,19 +677,19 @@ for (jname, bname, fname, elty, relty) in ((:syevjBatched!, :cusolverDnSsyevjBat
667
677
end
668
678
end
669
679
670
- for (jname, fname, elty) in ((:potrsBatched! , :cusolverDnSpotrsBatched , :Float32 ),
671
- (:potrsBatched! , :cusolverDnDpotrsBatched , :Float64 ),
672
- (:potrsBatched! , :cusolverDnCpotrsBatched , :ComplexF32 ),
673
- (:potrsBatched! , :cusolverDnZpotrsBatched , :ComplexF64 )
674
- )
680
+ for (fname, elty) in ((:cusolverDnSpotrsBatched , :Float32 ),
681
+ (:cusolverDnDpotrsBatched , :Float64 ),
682
+ (:cusolverDnCpotrsBatched , :ComplexF32 ),
683
+ (:cusolverDnZpotrsBatched , :ComplexF64 ))
675
684
@eval begin
676
- function $jname (uplo:: Char ,
677
- A:: Vector{<:StridedCuMatrix{$elty}} ,
678
- B:: Vector{<:StridedCuVecOrMat{$elty}} )
685
+ function potrsBatched! (uplo:: Char ,
686
+ A:: Vector{<:StridedCuMatrix{$elty}} ,
687
+ B:: Vector{<:StridedCuVecOrMat{$elty}} )
679
688
if length (A) != length (B)
680
689
throw (DimensionMismatch (" " ))
681
690
end
682
691
# Set up information for the solver arguments
692
+ VERSION >= v " 1.8" && LinearAlgebra. BLAS. chkuplo (uplo)
683
693
n = checksquare (A[1 ])
684
694
if size (B[1 ], 1 ) != n
685
695
throw (DimensionMismatch (" first dimension of B[i], $(size (B[1 ],1 )) , must match second dimension of A, $n " ))
@@ -710,15 +720,15 @@ for (jname, fname, elty) in ((:potrsBatched!, :cusolverDnSpotrsBatched, :Float32
710
720
end
711
721
end
712
722
713
- for (jname, fname, elty) in ((:potrfBatched! , :cusolverDnSpotrfBatched , :Float32 ),
714
- (:potrfBatched! , :cusolverDnDpotrfBatched , :Float64 ),
715
- (:potrfBatched! , :cusolverDnCpotrfBatched , :ComplexF32 ),
716
- (:potrfBatched! , :cusolverDnZpotrfBatched , :ComplexF64 )
717
- )
723
+ for (fname, elty) in ((:cusolverDnSpotrfBatched , :Float32 ),
724
+ (:cusolverDnDpotrfBatched , :Float64 ),
725
+ (:cusolverDnCpotrfBatched , :ComplexF32 ),
726
+ (:cusolverDnZpotrfBatched , :ComplexF64 ))
718
727
@eval begin
719
- function $jname (uplo:: Char , A:: Vector{<:StridedCuMatrix{$elty}} )
728
+ function potrfBatched! (uplo:: Char , A:: Vector{<:StridedCuMatrix{$elty}} )
720
729
721
730
# Set up information for the solver arguments
731
+ VERSION >= v " 1.8" && LinearAlgebra. BLAS. chkuplo (uplo)
722
732
n = checksquare (A[1 ])
723
733
lda = max (1 , stride (A[1 ], 2 ))
724
734
batchSize = length (A)
@@ -745,3 +755,32 @@ for (jname, fname, elty) in ((:potrfBatched!, :cusolverDnSpotrfBatched, :Float32
745
755
end
746
756
end
747
757
end
758
+
759
+ # LAPACK
760
+ for elty in (:Float32 , :Float64 , :ComplexF32 , :ComplexF64 )
761
+ @eval begin
762
+ LinearAlgebra. LAPACK. potrf! (uplo:: Char , A:: StridedCuMatrix{$elty} ) = CUSOLVER. potrf! (uplo, A)
763
+ LinearAlgebra. LAPACK. potrs! (uplo:: Char , A:: StridedCuMatrix{$elty} , B:: StridedCuVecOrMat{$elty} ) = CUSOLVER. potrs! (uplo, A, B)
764
+ LinearAlgebra. LAPACK. potri! (uplo:: Char , A:: StridedCuMatrix{$elty} ) = CUSOLVER. potri! (uplo, A)
765
+ LinearAlgebra. LAPACK. getrf! (A:: StridedCuMatrix{$elty} ) = CUSOLVER. getrf! (A)
766
+ LinearAlgebra. LAPACK. geqrf! (A:: StridedCuMatrix{$elty} ) = CUSOLVER. geqrf! (A)
767
+ LinearAlgebra. LAPACK. sytrf! (uplo:: Char , A:: StridedCuMatrix{$elty} ) = sytrf! (uplo, A)
768
+ LinearAlgebra. LAPACK. getrs! (trans:: Char , A:: StridedCuMatrix{$elty} , ipiv:: CuVector{Cint} , B:: StridedCuVecOrMat{$elty} ) = CUSOLVER. getrs! (trans, A, ipiv, B)
769
+ LinearAlgebra. LAPACK. ormqr! (side:: Char , trans:: Char , A:: CuMatrix{$elty} , tau:: CuVector{$elty} , C:: CuVecOrMat{$elty} ) = CUSOLVER. ormqr! (side, trans, A, tau, C)
770
+ LinearAlgebra. LAPACK. orgqr! (A:: StridedCuMatrix{$elty} , tau:: CuVector{$elty} ) = CUSOLVER. orgqr! (A, tau)
771
+ LinearAlgebra. LAPACK. gebrd! (A:: StridedCuMatrix{$elty} ) = CUSOLVER. gebrd! (A)
772
+ LinearAlgebra. LAPACK. gesvd! (jobu:: Char , jobvt:: Char , A:: StridedCuMatrix{$elty} ) = CUSOLVER. gesvd! (jobu, jobvt, A)
773
+ end
774
+ end
775
+
776
+ for elty in (:Float32 , :Float64 )
777
+ @eval begin
778
+ LinearAlgebra. LAPACK. sygvd! (itype:: Int , jobz:: Char , uplo:: Char , A:: StridedCuMatrix{$elty} , B:: StridedCuMatrix{$elty} ) = CUSOLVER. sygvd! (itype, jobz, uplo, A, B)
779
+ end
780
+ end
781
+
782
+ for elty in (:ComplexF32 , :ComplexF64 )
783
+ @eval begin
784
+ LinearAlgebra. LAPACK. sygvd! (itype:: Int , jobz:: Char , uplo:: Char , A:: StridedCuMatrix{$elty} , B:: StridedCuMatrix{$elty} ) = CUSOLVER. hegvd! (itype, jobz, uplo, A, B)
785
+ end
786
+ end
0 commit comments