@@ -32,7 +32,8 @@ for (bname, fname,elty) in ((:cusolverDnSpotrf_bufferSize, :cusolverDnSpotrf, :F
32
32
33
33
devinfo = CuArray {Cint} (undef, 1 )
34
34
with_workspace (bufferSize) do buffer
35
- $ fname (dense_handle (), uplo, n, A, lda, buffer, length (buffer), devinfo)
35
+ $ fname (dense_handle (), uplo, n, A, lda,
36
+ buffer, sizeof (buffer) ÷ sizeof ($ elty), devinfo)
36
37
end
37
38
38
39
info = @allowscalar devinfo[1 ]
@@ -93,7 +94,8 @@ for (bname, fname,elty) in ((:cusolverDnSpotri_bufferSize, :cusolverDnSpotri, :F
93
94
94
95
devinfo = CuArray {Cint} (undef, 1 )
95
96
with_workspace (bufferSize) do buffer
96
- $ fname (dense_handle (), uplo, n, A, lda, buffer, length (buffer), devinfo)
97
+ $ fname (dense_handle (), uplo, n, A, lda,
98
+ buffer, sizeof (buffer) ÷ sizeof ($ elty), devinfo)
97
99
end
98
100
99
101
info = @allowscalar devinfo[1 ]
@@ -159,7 +161,8 @@ for (bname, fname,elty) in ((:cusolverDnSgeqrf_bufferSize, :cusolverDnSgeqrf, :F
159
161
160
162
devinfo = CuArray {Cint} (undef, 1 )
161
163
with_workspace (bufferSize) do buffer
162
- $ fname (dense_handle (), m, n, A, lda, tau, buffer, length (buffer), devinfo)
164
+ $ fname (dense_handle (), m, n, A, lda, tau,
165
+ buffer, sizeof (buffer) ÷ sizeof ($ elty), devinfo)
163
166
end
164
167
165
168
info = @allowscalar devinfo[1 ]
@@ -198,7 +201,8 @@ for (bname, fname,elty) in ((:cusolverDnSsytrf_bufferSize, :cusolverDnSsytrf, :F
198
201
199
202
devinfo = CuArray {Cint} (undef, 1 )
200
203
with_workspace (bufferSize) do buffer
201
- $ fname (dense_handle (), uplo, n, A, lda, ipiv, buffer, length (buffer), devinfo)
204
+ $ fname (dense_handle (), uplo, n, A, lda, ipiv,
205
+ buffer, sizeof (buffer) ÷ sizeof ($ elty), devinfo)
202
206
end
203
207
204
208
info = @allowscalar devinfo[1 ]
@@ -299,7 +303,7 @@ for (bname, fname, elty) in ((:cusolverDnSormqr_bufferSize, :cusolverDnSormqr, :
299
303
devinfo = CuArray {Cint} (undef, 1 )
300
304
with_workspace (bufferSize) do buffer
301
305
$ fname (dense_handle (), side, trans, m, n, k, A, lda, tau, C, ldc,
302
- buffer, length (buffer), devinfo)
306
+ buffer, sizeof (buffer) ÷ sizeof ( $ elty ), devinfo)
303
307
end
304
308
305
309
info = @allowscalar devinfo[1 ]
@@ -331,7 +335,8 @@ for (bname, fname, elty) in ((:cusolverDnSorgqr_bufferSize, :cusolverDnSorgqr, :
331
335
332
336
devinfo = CuArray {Cint} (undef, 1 )
333
337
with_workspace (bufferSize) do buffer
334
- $ fname (dense_handle (), m, n, k, A, lda, tau, buffer, length (buffer), devinfo)
338
+ $ fname (dense_handle (), m, n, k, A, lda, tau,
339
+ buffer, sizeof (buffer) ÷ sizeof ($ elty), devinfo)
335
340
end
336
341
337
342
info = @allowscalar devinfo[1 ]
@@ -371,7 +376,8 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgebrd_bufferSize, :cusolverDnSg
371
376
TAUP = CuArray {$elty} (undef, k)
372
377
373
378
with_workspace (bufferSize) do buffer
374
- $ fname (dense_handle (), m, n, A, lda, D, E, TAUQ, TAUP, buffer, length (buffer), devinfo)
379
+ $ fname (dense_handle (), m, n, A, lda, D, E, TAUQ, TAUP,
380
+ buffer, sizeof (buffer) ÷ sizeof ($ elty), devinfo)
375
381
end
376
382
377
383
info = @allowscalar devinfo[1 ]
@@ -427,9 +433,9 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvd_bufferSize, :cusolverDnSg
427
433
428
434
rwork = CuArray {$relty} (undef, min (m, n) - 1 )
429
435
devinfo = CuArray {Cint} (undef, 1 )
430
- with_workspace (bufferSize) do work
436
+ with_workspace (bufferSize) do buffer
431
437
$ fname (dense_handle (), jobu, jobvt, m, n, A, lda, S, U, ldu, Vt, ldvt,
432
- work, length (work ), rwork, devinfo)
438
+ buffer, sizeof (buffer) ÷ sizeof ( $ elty ), rwork, devinfo)
433
439
end
434
440
unsafe_free! (rwork)
435
441
@@ -486,9 +492,9 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdj_bufferSize, :cusolverDnS
486
492
end
487
493
488
494
devinfo = CuArray {Cint} (undef, 1 )
489
- with_workspace (bufferSize) do work
495
+ with_workspace (bufferSize) do buffer
490
496
$ fname (dense_handle (), jobz, econ, m, n, A, lda, S, U, ldu, V, ldv,
491
- work, length (work ), devinfo, params[])
497
+ buffer, sizeof (buffer) ÷ sizeof ( $ elty ), devinfo, params[])
492
498
end
493
499
494
500
info = @allowscalar devinfo[1 ]
@@ -538,9 +544,9 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdjBatched_bufferSize, :cuso
538
544
end
539
545
540
546
devinfo = CuArray {Cint} (undef, batchSize)
541
- with_workspace (bufferSize) do work
547
+ with_workspace (bufferSize) do buffer
542
548
$ fname (dense_handle (), jobz, m, n, A, lda, S, U, ldu, V, ldv,
543
- work, length (work ), devinfo, params[], batchSize)
549
+ buffer, sizeof (buffer) ÷ sizeof ( $ elty ), devinfo, params[], batchSize)
544
550
end
545
551
546
552
info = @allowscalar collect (devinfo)
@@ -597,10 +603,10 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdaStridedBatched_bufferSize
597
603
# residual storage
598
604
h_RnrmF = Array {Cdouble} (undef, batchSize)
599
605
600
- with_workspace (bufferSize) do work
606
+ with_workspace (bufferSize) do buffer
601
607
$ fname (dense_handle (), jobz, rank, m, n, A, lda, strideA,
602
608
S, strideS, U, ldu, strideU, V, ldv, strideV,
603
- work, length (work ), devinfo, h_RnrmF, batchSize)
609
+ buffer, sizeof (buffer) ÷ sizeof ( $ elty ), devinfo, h_RnrmF, batchSize)
604
610
end
605
611
606
612
info = @allowscalar collect (devinfo)
@@ -638,7 +644,7 @@ for (jname, bname, fname, elty, relty) in ((:syevd!, :cusolverDnSsyevd_bufferSiz
638
644
devinfo = CuArray {Cint} (undef, 1 )
639
645
with_workspace (bufferSize) do buffer
640
646
$ fname (dense_handle (), jobz, uplo, n, A, lda, W,
641
- buffer, length (buffer), devinfo)
647
+ buffer, sizeof (buffer) ÷ sizeof ( $ elty ), devinfo)
642
648
end
643
649
644
650
info = @allowscalar devinfo[1 ]
@@ -683,7 +689,7 @@ for (jname, bname, fname, elty, relty) in ((:sygvd!, :cusolverDnSsygvd_bufferSiz
683
689
devinfo = CuArray {Cint} (undef, 1 )
684
690
with_workspace (bufferSize) do buffer
685
691
$ fname (dense_handle (), itype, jobz, uplo, n, A, lda, B, ldb, W,
686
- buffer, length (buffer), devinfo)
692
+ buffer, sizeof (buffer) ÷ sizeof ( $ elty ), devinfo)
687
693
end
688
694
689
695
info = @allowscalar devinfo[1 ]
@@ -735,7 +741,7 @@ for (jname, bname, fname, elty, relty) in ((:sygvj!, :cusolverDnSsygvj_bufferSiz
735
741
devinfo = CuArray {Cint} (undef, 1 )
736
742
with_workspace (bufferSize) do buffer
737
743
$ fname (dense_handle (), itype, jobz, uplo, n, A, lda, B, ldb, W,
738
- buffer, length (buffer), devinfo, params[])
744
+ buffer, sizeof (buffer) ÷ sizeof ( $ elty ), devinfo, params[])
739
745
end
740
746
741
747
info = @allowscalar devinfo[1 ]
@@ -786,9 +792,9 @@ for (jname, bname, fname, elty, relty) in ((:syevjBatched!, :cusolverDnSsyevjBat
786
792
end
787
793
788
794
# Run the solver
789
- with_workspace (bufferSize) do work
790
- $ fname (dense_handle (), jobz, uplo, n, A, lda, W, work ,
791
- length (work ), devinfo, params[], batchSize)
795
+ with_workspace (bufferSize) do buffer
796
+ $ fname (dense_handle (), jobz, uplo, n, A, lda, W, buffer ,
797
+ sizeof (buffer) ÷ sizeof ( $ elty ), devinfo, params[], batchSize)
792
798
end
793
799
794
800
# Copy the solver info and delete the device memory
@@ -894,47 +900,47 @@ end
894
900
# LAPACK
895
901
for elty in (:Float32 , :Float64 , :ComplexF32 , :ComplexF64 )
896
902
@eval begin
897
- LinearAlgebra . LAPACK. potrf! (uplo:: Char , A:: StridedCuMatrix{$elty} ) = CUSOLVER . potrf! (uplo, A)
898
- LinearAlgebra . LAPACK. potrs! (uplo:: Char , A:: StridedCuMatrix{$elty} , B:: StridedCuVecOrMat{$elty} ) = CUSOLVER . potrs! (uplo, A, B)
899
- LinearAlgebra . LAPACK. potri! (uplo:: Char , A:: StridedCuMatrix{$elty} ) = CUSOLVER . potri! (uplo, A)
900
- LinearAlgebra . LAPACK. getrf! (A:: StridedCuMatrix{$elty} ) = CUSOLVER . getrf! (A)
901
- LinearAlgebra . LAPACK. getrf! (A:: StridedCuMatrix{$elty} , ipiv:: CuVector{Cint} ) = CUSOLVER . getrf! (A, ipiv)
902
- LinearAlgebra . LAPACK. geqrf! (A:: StridedCuMatrix{$elty} ) = CUSOLVER . geqrf! (A)
903
- LinearAlgebra . LAPACK. geqrf! (A:: StridedCuMatrix{$elty} , tau:: CuVector{$elty} ) = CUSOLVER . geqrf! (A, tau)
904
- LinearAlgebra . LAPACK. sytrf! (uplo:: Char , A:: StridedCuMatrix{$elty} ) = sytrf! (uplo, A)
905
- LinearAlgebra . LAPACK. sytrf! (uplo:: Char , A:: StridedCuMatrix{$elty} , ipiv:: CuVector{Cint} ) = CUSOLVER . sytrf! (uplo, A, ipiv)
906
- LinearAlgebra . LAPACK. getrs! (trans:: Char , A:: StridedCuMatrix{$elty} , ipiv:: CuVector{Cint} , B:: StridedCuVecOrMat{$elty} ) = CUSOLVER . getrs! (trans, A, ipiv, B)
907
- LinearAlgebra . LAPACK. ormqr! (side:: Char , trans:: Char , A:: CuMatrix{$elty} , tau:: CuVector{$elty} , C:: CuVecOrMat{$elty} ) = CUSOLVER . ormqr! (side, trans, A, tau, C)
908
- LinearAlgebra . LAPACK. orgqr! (A:: StridedCuMatrix{$elty} , tau:: CuVector{$elty} ) = CUSOLVER . orgqr! (A, tau)
909
- LinearAlgebra . LAPACK. gebrd! (A:: StridedCuMatrix{$elty} ) = CUSOLVER . gebrd! (A)
910
- LinearAlgebra . LAPACK. gesvd! (jobu:: Char , jobvt:: Char , A:: StridedCuMatrix{$elty} ) = CUSOLVER . gesvd! (jobu, jobvt, A)
903
+ LAPACK. potrf! (uplo:: Char , A:: StridedCuMatrix{$elty} ) = potrf! (uplo, A)
904
+ LAPACK. potrs! (uplo:: Char , A:: StridedCuMatrix{$elty} , B:: StridedCuVecOrMat{$elty} ) = potrs! (uplo, A, B)
905
+ LAPACK. potri! (uplo:: Char , A:: StridedCuMatrix{$elty} ) = potri! (uplo, A)
906
+ LAPACK. getrf! (A:: StridedCuMatrix{$elty} ) = getrf! (A)
907
+ LAPACK. getrf! (A:: StridedCuMatrix{$elty} , ipiv:: CuVector{Cint} ) = getrf! (A, ipiv)
908
+ LAPACK. geqrf! (A:: StridedCuMatrix{$elty} ) = geqrf! (A)
909
+ LAPACK. geqrf! (A:: StridedCuMatrix{$elty} , tau:: CuVector{$elty} ) = geqrf! (A, tau)
910
+ LAPACK. sytrf! (uplo:: Char , A:: StridedCuMatrix{$elty} ) = sytrf! (uplo, A)
911
+ LAPACK. sytrf! (uplo:: Char , A:: StridedCuMatrix{$elty} , ipiv:: CuVector{Cint} ) = sytrf! (uplo, A, ipiv)
912
+ LAPACK. getrs! (trans:: Char , A:: StridedCuMatrix{$elty} , ipiv:: CuVector{Cint} , B:: StridedCuVecOrMat{$elty} ) = getrs! (trans, A, ipiv, B)
913
+ LAPACK. ormqr! (side:: Char , trans:: Char , A:: CuMatrix{$elty} , tau:: CuVector{$elty} , C:: CuVecOrMat{$elty} ) = ormqr! (side, trans, A, tau, C)
914
+ LAPACK. orgqr! (A:: StridedCuMatrix{$elty} , tau:: CuVector{$elty} ) = orgqr! (A, tau)
915
+ LAPACK. gebrd! (A:: StridedCuMatrix{$elty} ) = gebrd! (A)
916
+ LAPACK. gesvd! (jobu:: Char , jobvt:: Char , A:: StridedCuMatrix{$elty} ) = gesvd! (jobu, jobvt, A)
911
917
end
912
918
end
913
919
914
920
for elty in (:Float32 , :Float64 )
915
921
@eval begin
916
- LinearAlgebra . LAPACK. syev! (jobz:: Char , uplo:: Char , A:: StridedCuMatrix{$elty} ) = CUSOLVER . syevd! (jobz, uplo, A)
917
- LinearAlgebra . LAPACK. sygvd! (itype:: Int , jobz:: Char , uplo:: Char , A:: StridedCuMatrix{$elty} , B:: StridedCuMatrix{$elty} ) = CUSOLVER . sygvd! (itype, jobz, uplo, A, B)
922
+ LAPACK. syev! (jobz:: Char , uplo:: Char , A:: StridedCuMatrix{$elty} ) = syevd! (jobz, uplo, A)
923
+ LAPACK. sygvd! (itype:: Int , jobz:: Char , uplo:: Char , A:: StridedCuMatrix{$elty} , B:: StridedCuMatrix{$elty} ) = sygvd! (itype, jobz, uplo, A, B)
918
924
end
919
925
end
920
926
921
927
for elty in (:ComplexF32 , :ComplexF64 )
922
928
@eval begin
923
- LinearAlgebra . LAPACK. syev! (jobz:: Char , uplo:: Char , A:: StridedCuMatrix{$elty} ) = CUSOLVER . heevd! (jobz, uplo, A)
924
- LinearAlgebra . LAPACK. sygvd! (itype:: Int , jobz:: Char , uplo:: Char , A:: StridedCuMatrix{$elty} , B:: StridedCuMatrix{$elty} ) = CUSOLVER . hegvd! (itype, jobz, uplo, A, B)
929
+ LAPACK. syev! (jobz:: Char , uplo:: Char , A:: StridedCuMatrix{$elty} ) = heevd! (jobz, uplo, A)
930
+ LAPACK. sygvd! (itype:: Int , jobz:: Char , uplo:: Char , A:: StridedCuMatrix{$elty} , B:: StridedCuMatrix{$elty} ) = hegvd! (itype, jobz, uplo, A, B)
925
931
end
926
932
end
927
933
928
934
if VERSION >= v " 1.10"
929
935
for elty in (:Float32 , :Float64 )
930
936
@eval begin
931
- LinearAlgebra . LAPACK. syevd! (jobz:: Char , uplo:: Char , A:: StridedCuMatrix{$elty} ) = CUSOLVER . syevd! (jobz, uplo, A)
937
+ LAPACK. syevd! (jobz:: Char , uplo:: Char , A:: StridedCuMatrix{$elty} ) = syevd! (jobz, uplo, A)
932
938
end
933
939
end
934
940
935
941
for elty in (:ComplexF32 , :ComplexF64 )
936
942
@eval begin
937
- LinearAlgebra . LAPACK. syevd! (jobz:: Char , uplo:: Char , A:: StridedCuMatrix{$elty} ) = CUSOLVER . heevd! (jobz, uplo, A)
943
+ LAPACK. syevd! (jobz:: Char , uplo:: Char , A:: StridedCuMatrix{$elty} ) = heevd! (jobz, uplo, A)
938
944
end
939
945
end
940
946
end
0 commit comments