Skip to content

Commit bb49887

Browse files
kshyattmaleadt
andauthored
Add support for caching workspace buffers. (#2279)
Co-authored-by: Tim Besard <tim.besard@gmail.com>
1 parent 2783c88 commit bb49887

File tree

6 files changed

+224
-176
lines changed

6 files changed

+224
-176
lines changed

lib/cusolver/dense.jl

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ for (bname, fname,elty) in ((:cusolverDnSpotrf_bufferSize, :cusolverDnSpotrf, :F
2727
function bufferSize()
2828
out = Ref{Cint}(0)
2929
$bname(dense_handle(), uplo, n, A, lda, out)
30-
out[]
30+
out[] * sizeof($elty)
3131
end
3232

3333
devinfo = CuArray{Cint}(undef, 1)
34-
with_workspace($elty, bufferSize) do buffer
34+
with_workspace(bufferSize) do buffer
3535
$fname(dense_handle(), uplo, n, A, lda, buffer, length(buffer), devinfo)
3636
end
3737

@@ -88,11 +88,11 @@ for (bname, fname,elty) in ((:cusolverDnSpotri_bufferSize, :cusolverDnSpotri, :F
8888
function bufferSize()
8989
out = Ref{Cint}(0)
9090
$bname(dense_handle(), uplo, n, A, lda, out)
91-
out[]
91+
out[] * sizeof($elty)
9292
end
9393

9494
devinfo = CuArray{Cint}(undef, 1)
95-
with_workspace($elty, bufferSize) do buffer
95+
with_workspace(bufferSize) do buffer
9696
$fname(dense_handle(), uplo, n, A, lda, buffer, length(buffer), devinfo)
9797
end
9898

@@ -118,11 +118,11 @@ for (bname, fname,elty) in ((:cusolverDnSgetrf_bufferSize, :cusolverDnSgetrf, :F
118118
function bufferSize()
119119
out = Ref{Cint}(0)
120120
$bname(dense_handle(), m, n, A, lda, out)
121-
return out[]
121+
return out[] * sizeof($elty)
122122
end
123123

124124
devinfo = CuArray{Cint}(undef, 1)
125-
with_workspace($elty, bufferSize) do buffer
125+
with_workspace(bufferSize) do buffer
126126
$fname(dense_handle(), m, n, A, lda, buffer, ipiv, devinfo)
127127
end
128128

@@ -154,11 +154,11 @@ for (bname, fname,elty) in ((:cusolverDnSgeqrf_bufferSize, :cusolverDnSgeqrf, :F
154154
function bufferSize()
155155
out = Ref{Cint}(0)
156156
$bname(dense_handle(), m, n, A, lda, out)
157-
return out[]
157+
return out[] * sizeof($elty)
158158
end
159159

160160
devinfo = CuArray{Cint}(undef, 1)
161-
with_workspace($elty, bufferSize) do buffer
161+
with_workspace(bufferSize) do buffer
162162
$fname(dense_handle(), m, n, A, lda, tau, buffer, length(buffer), devinfo)
163163
end
164164

@@ -193,11 +193,11 @@ for (bname, fname,elty) in ((:cusolverDnSsytrf_bufferSize, :cusolverDnSsytrf, :F
193193
function bufferSize()
194194
out = Ref{Cint}(0)
195195
$bname(dense_handle(), n, A, lda, out)
196-
return out[]
196+
return out[] * sizeof($elty)
197197
end
198198

199199
devinfo = CuArray{Cint}(undef, 1)
200-
with_workspace($elty, bufferSize) do buffer
200+
with_workspace(bufferSize) do buffer
201201
$fname(dense_handle(), uplo, n, A, lda, ipiv, buffer, length(buffer), devinfo)
202202
end
203203

@@ -292,11 +292,11 @@ for (bname, fname, elty) in ((:cusolverDnSormqr_bufferSize, :cusolverDnSormqr, :
292292
function bufferSize()
293293
out = Ref{Cint}(0)
294294
$bname(dense_handle(), side, trans, m, n, k, A, lda, tau, C, ldc, out)
295-
return out[]
295+
return out[] * sizeof($elty)
296296
end
297297

298298
devinfo = CuArray{Cint}(undef, 1)
299-
with_workspace($elty, bufferSize) do buffer
299+
with_workspace(bufferSize) do buffer
300300
$fname(dense_handle(), side, trans, m, n, k, A, lda, tau, C, ldc,
301301
buffer, length(buffer), devinfo)
302302
end
@@ -325,11 +325,11 @@ for (bname, fname, elty) in ((:cusolverDnSorgqr_bufferSize, :cusolverDnSorgqr, :
325325
function bufferSize()
326326
out = Ref{Cint}(0)
327327
$bname(dense_handle(), m, n, k, A, lda, tau, out)
328-
return out[]
328+
return out[] * sizeof($elty)
329329
end
330330

331331
devinfo = CuArray{Cint}(undef, 1)
332-
with_workspace($elty, bufferSize) do buffer
332+
with_workspace(bufferSize) do buffer
333333
$fname(dense_handle(), m, n, k, A, lda, tau, buffer, length(buffer), devinfo)
334334
end
335335

@@ -359,7 +359,7 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgebrd_bufferSize, :cusolverDnSg
359359
function bufferSize()
360360
out = Ref{Cint}(0)
361361
$bname(dense_handle(), m, n, out)
362-
return out[]
362+
return out[] * sizeof($elty)
363363
end
364364

365365
devinfo = CuArray{Cint}(undef, 1)
@@ -369,7 +369,7 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgebrd_bufferSize, :cusolverDnSg
369369
TAUQ = CuArray{$elty}(undef, k)
370370
TAUP = CuArray{$elty}(undef, k)
371371

372-
with_workspace($elty, bufferSize) do buffer
372+
with_workspace(bufferSize) do buffer
373373
$fname(dense_handle(), m, n, A, lda, D, E, TAUQ, TAUP, buffer, length(buffer), devinfo)
374374
end
375375

@@ -421,12 +421,12 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvd_bufferSize, :cusolverDnSg
421421
function bufferSize()
422422
out = Ref{Cint}(0)
423423
$bname(dense_handle(), m, n, out)
424-
return out[]
424+
return out[] * sizeof($elty)
425425
end
426426

427427
rwork = CuArray{$relty}(undef, min(m, n) - 1)
428428
devinfo = CuArray{Cint}(undef, 1)
429-
with_workspace($elty, bufferSize) do work
429+
with_workspace(bufferSize) do work
430430
$fname(dense_handle(), jobu, jobvt, m, n, A, lda, S, U, ldu, Vt, ldvt,
431431
work, length(work), rwork, devinfo)
432432
end
@@ -481,11 +481,11 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdj_bufferSize, :cusolverDnS
481481
out = Ref{Cint}(0)
482482
$bname(dense_handle(), jobz, econ, m, n, A, lda, S, U, ldu, V, ldv,
483483
out, params[])
484-
return out[]
484+
return out[] * sizeof($elty)
485485
end
486486

487487
devinfo = CuArray{Cint}(undef, 1)
488-
with_workspace($elty, bufferSize) do work
488+
with_workspace(bufferSize) do work
489489
$fname(dense_handle(), jobz, econ, m, n, A, lda, S, U, ldu, V, ldv,
490490
work, length(work), devinfo, params[])
491491
end
@@ -533,11 +533,11 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdjBatched_bufferSize, :cuso
533533
out = Ref{Cint}(0)
534534
$bname(dense_handle(), jobz, m, n, A, lda, S, U, ldu, V, ldv,
535535
out, params[], batchSize)
536-
return out[]
536+
return out[] * sizeof($elty)
537537
end
538538

539539
devinfo = CuArray{Cint}(undef, batchSize)
540-
with_workspace($elty, bufferSize) do work
540+
with_workspace(bufferSize) do work
541541
$fname(dense_handle(), jobz, m, n, A, lda, S, U, ldu, V, ldv,
542542
work, length(work), devinfo, params[], batchSize)
543543
end
@@ -589,14 +589,14 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdaStridedBatched_bufferSize
589589
$bname(dense_handle(), jobz, rank, m, n, A, lda, strideA,
590590
S, strideS, U, ldu, strideU, V, ldv, strideV,
591591
out, batchSize)
592-
return out[]
592+
return out[] * sizeof($elty)
593593
end
594594

595595
devinfo = CuArray{Cint}(undef, batchSize)
596596
# residual storage
597597
h_RnrmF = Array{Cdouble}(undef, batchSize)
598598

599-
with_workspace($elty, bufferSize) do work
599+
with_workspace(bufferSize) do work
600600
$fname(dense_handle(), jobz, rank, m, n, A, lda, strideA,
601601
S, strideS, U, ldu, strideU, V, ldv, strideV,
602602
work, length(work), devinfo, h_RnrmF, batchSize)
@@ -631,11 +631,11 @@ for (jname, bname, fname, elty, relty) in ((:syevd!, :cusolverDnSsyevd_bufferSiz
631631
function bufferSize()
632632
out = Ref{Cint}(0)
633633
$bname(dense_handle(), jobz, uplo, n, A, lda, W, out)
634-
return out[]
634+
return out[] * sizeof($elty)
635635
end
636636

637637
devinfo = CuArray{Cint}(undef, 1)
638-
with_workspace($elty, bufferSize) do buffer
638+
with_workspace(bufferSize) do buffer
639639
$fname(dense_handle(), jobz, uplo, n, A, lda, W,
640640
buffer, length(buffer), devinfo)
641641
end
@@ -676,11 +676,11 @@ for (jname, bname, fname, elty, relty) in ((:sygvd!, :cusolverDnSsygvd_bufferSiz
676676
function bufferSize()
677677
out = Ref{Cint}(0)
678678
$bname(dense_handle(), itype, jobz, uplo, n, A, lda, B, ldb, W, out)
679-
return out[]
679+
return out[] * sizeof($elty)
680680
end
681681

682682
devinfo = CuArray{Cint}(undef, 1)
683-
with_workspace($elty, bufferSize) do buffer
683+
with_workspace(bufferSize) do buffer
684684
$fname(dense_handle(), itype, jobz, uplo, n, A, lda, B, ldb, W,
685685
buffer, length(buffer), devinfo)
686686
end
@@ -728,11 +728,11 @@ for (jname, bname, fname, elty, relty) in ((:sygvj!, :cusolverDnSsygvj_bufferSiz
728728
out = Ref{Cint}(0)
729729
$bname(dense_handle(), itype, jobz, uplo, n, A, lda, B, ldb, W,
730730
out, params[])
731-
return out[]
731+
return out[] * sizeof($elty)
732732
end
733733

734734
devinfo = CuArray{Cint}(undef, 1)
735-
with_workspace($elty, bufferSize) do buffer
735+
with_workspace(bufferSize) do buffer
736736
$fname(dense_handle(), itype, jobz, uplo, n, A, lda, B, ldb, W,
737737
buffer, length(buffer), devinfo, params[])
738738
end
@@ -781,11 +781,11 @@ for (jname, bname, fname, elty, relty) in ((:syevjBatched!, :cusolverDnSsyevjBat
781781
function bufferSize()
782782
out = Ref{Cint}(0)
783783
$bname(dense_handle(), jobz, uplo, n, A, lda, W, out, params[], batchSize)
784-
return out[]
784+
return out[] * sizeof($elty)
785785
end
786786

787787
# Run the solver
788-
with_workspace($elty, bufferSize) do work
788+
with_workspace(bufferSize) do work
789789
$fname(dense_handle(), jobz, uplo, n, A, lda, W, work,
790790
length(work), devinfo, params[], batchSize)
791791
end

0 commit comments

Comments
 (0)