Skip to content

Commit 24651ba

Browse files
authored
Correct workspace handling (#2437)
1 parent 8a6a1e9 commit 24651ba

File tree

2 files changed

+54
-48
lines changed

2 files changed

+54
-48
lines changed

lib/cusolver/dense.jl

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ for (bname, fname,elty) in ((:cusolverDnSpotrf_bufferSize, :cusolverDnSpotrf, :F
3232

3333
devinfo = CuArray{Cint}(undef, 1)
3434
with_workspace(bufferSize) do buffer
35-
$fname(dense_handle(), uplo, n, A, lda, buffer, length(buffer), devinfo)
35+
$fname(dense_handle(), uplo, n, A, lda,
36+
buffer, sizeof(buffer) ÷ sizeof($elty), devinfo)
3637
end
3738

3839
info = @allowscalar devinfo[1]
@@ -93,7 +94,8 @@ for (bname, fname,elty) in ((:cusolverDnSpotri_bufferSize, :cusolverDnSpotri, :F
9394

9495
devinfo = CuArray{Cint}(undef, 1)
9596
with_workspace(bufferSize) do buffer
96-
$fname(dense_handle(), uplo, n, A, lda, buffer, length(buffer), devinfo)
97+
$fname(dense_handle(), uplo, n, A, lda,
98+
buffer, sizeof(buffer) ÷ sizeof($elty), devinfo)
9799
end
98100

99101
info = @allowscalar devinfo[1]
@@ -159,7 +161,8 @@ for (bname, fname,elty) in ((:cusolverDnSgeqrf_bufferSize, :cusolverDnSgeqrf, :F
159161

160162
devinfo = CuArray{Cint}(undef, 1)
161163
with_workspace(bufferSize) do buffer
162-
$fname(dense_handle(), m, n, A, lda, tau, buffer, length(buffer), devinfo)
164+
$fname(dense_handle(), m, n, A, lda, tau,
165+
buffer, sizeof(buffer) ÷ sizeof($elty), devinfo)
163166
end
164167

165168
info = @allowscalar devinfo[1]
@@ -198,7 +201,8 @@ for (bname, fname,elty) in ((:cusolverDnSsytrf_bufferSize, :cusolverDnSsytrf, :F
198201

199202
devinfo = CuArray{Cint}(undef, 1)
200203
with_workspace(bufferSize) do buffer
201-
$fname(dense_handle(), uplo, n, A, lda, ipiv, buffer, length(buffer), devinfo)
204+
$fname(dense_handle(), uplo, n, A, lda, ipiv,
205+
buffer, sizeof(buffer) ÷ sizeof($elty), devinfo)
202206
end
203207

204208
info = @allowscalar devinfo[1]
@@ -299,7 +303,7 @@ for (bname, fname, elty) in ((:cusolverDnSormqr_bufferSize, :cusolverDnSormqr, :
299303
devinfo = CuArray{Cint}(undef, 1)
300304
with_workspace(bufferSize) do buffer
301305
$fname(dense_handle(), side, trans, m, n, k, A, lda, tau, C, ldc,
302-
buffer, length(buffer), devinfo)
306+
buffer, sizeof(buffer) ÷ sizeof($elty), devinfo)
303307
end
304308

305309
info = @allowscalar devinfo[1]
@@ -331,7 +335,8 @@ for (bname, fname, elty) in ((:cusolverDnSorgqr_bufferSize, :cusolverDnSorgqr, :
331335

332336
devinfo = CuArray{Cint}(undef, 1)
333337
with_workspace(bufferSize) do buffer
334-
$fname(dense_handle(), m, n, k, A, lda, tau, buffer, length(buffer), devinfo)
338+
$fname(dense_handle(), m, n, k, A, lda, tau,
339+
buffer, sizeof(buffer) ÷ sizeof($elty), devinfo)
335340
end
336341

337342
info = @allowscalar devinfo[1]
@@ -371,7 +376,8 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgebrd_bufferSize, :cusolverDnSg
371376
TAUP = CuArray{$elty}(undef, k)
372377

373378
with_workspace(bufferSize) do buffer
374-
$fname(dense_handle(), m, n, A, lda, D, E, TAUQ, TAUP, buffer, length(buffer), devinfo)
379+
$fname(dense_handle(), m, n, A, lda, D, E, TAUQ, TAUP,
380+
buffer, sizeof(buffer) ÷ sizeof($elty), devinfo)
375381
end
376382

377383
info = @allowscalar devinfo[1]
@@ -427,9 +433,9 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvd_bufferSize, :cusolverDnSg
427433

428434
rwork = CuArray{$relty}(undef, min(m, n) - 1)
429435
devinfo = CuArray{Cint}(undef, 1)
430-
with_workspace(bufferSize) do work
436+
with_workspace(bufferSize) do buffer
431437
$fname(dense_handle(), jobu, jobvt, m, n, A, lda, S, U, ldu, Vt, ldvt,
432-
work, length(work), rwork, devinfo)
438+
buffer, sizeof(buffer) ÷ sizeof($elty), rwork, devinfo)
433439
end
434440
unsafe_free!(rwork)
435441

@@ -486,9 +492,9 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdj_bufferSize, :cusolverDnS
486492
end
487493

488494
devinfo = CuArray{Cint}(undef, 1)
489-
with_workspace(bufferSize) do work
495+
with_workspace(bufferSize) do buffer
490496
$fname(dense_handle(), jobz, econ, m, n, A, lda, S, U, ldu, V, ldv,
491-
work, length(work), devinfo, params[])
497+
buffer, sizeof(buffer) ÷ sizeof($elty), devinfo, params[])
492498
end
493499

494500
info = @allowscalar devinfo[1]
@@ -538,9 +544,9 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdjBatched_bufferSize, :cuso
538544
end
539545

540546
devinfo = CuArray{Cint}(undef, batchSize)
541-
with_workspace(bufferSize) do work
547+
with_workspace(bufferSize) do buffer
542548
$fname(dense_handle(), jobz, m, n, A, lda, S, U, ldu, V, ldv,
543-
work, length(work), devinfo, params[], batchSize)
549+
buffer, sizeof(buffer) ÷ sizeof($elty), devinfo, params[], batchSize)
544550
end
545551

546552
info = @allowscalar collect(devinfo)
@@ -597,10 +603,10 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdaStridedBatched_bufferSize
597603
# residual storage
598604
h_RnrmF = Array{Cdouble}(undef, batchSize)
599605

600-
with_workspace(bufferSize) do work
606+
with_workspace(bufferSize) do buffer
601607
$fname(dense_handle(), jobz, rank, m, n, A, lda, strideA,
602608
S, strideS, U, ldu, strideU, V, ldv, strideV,
603-
work, length(work), devinfo, h_RnrmF, batchSize)
609+
buffer, sizeof(buffer) ÷ sizeof($elty), devinfo, h_RnrmF, batchSize)
604610
end
605611

606612
info = @allowscalar collect(devinfo)
@@ -638,7 +644,7 @@ for (jname, bname, fname, elty, relty) in ((:syevd!, :cusolverDnSsyevd_bufferSiz
638644
devinfo = CuArray{Cint}(undef, 1)
639645
with_workspace(bufferSize) do buffer
640646
$fname(dense_handle(), jobz, uplo, n, A, lda, W,
641-
buffer, length(buffer), devinfo)
647+
buffer, sizeof(buffer) ÷ sizeof($elty), devinfo)
642648
end
643649

644650
info = @allowscalar devinfo[1]
@@ -683,7 +689,7 @@ for (jname, bname, fname, elty, relty) in ((:sygvd!, :cusolverDnSsygvd_bufferSiz
683689
devinfo = CuArray{Cint}(undef, 1)
684690
with_workspace(bufferSize) do buffer
685691
$fname(dense_handle(), itype, jobz, uplo, n, A, lda, B, ldb, W,
686-
buffer, length(buffer), devinfo)
692+
buffer, sizeof(buffer) ÷ sizeof($elty), devinfo)
687693
end
688694

689695
info = @allowscalar devinfo[1]
@@ -735,7 +741,7 @@ for (jname, bname, fname, elty, relty) in ((:sygvj!, :cusolverDnSsygvj_bufferSiz
735741
devinfo = CuArray{Cint}(undef, 1)
736742
with_workspace(bufferSize) do buffer
737743
$fname(dense_handle(), itype, jobz, uplo, n, A, lda, B, ldb, W,
738-
buffer, length(buffer), devinfo, params[])
744+
buffer, sizeof(buffer) ÷ sizeof($elty), devinfo, params[])
739745
end
740746

741747
info = @allowscalar devinfo[1]
@@ -786,9 +792,9 @@ for (jname, bname, fname, elty, relty) in ((:syevjBatched!, :cusolverDnSsyevjBat
786792
end
787793

788794
# Run the solver
789-
with_workspace(bufferSize) do work
790-
$fname(dense_handle(), jobz, uplo, n, A, lda, W, work,
791-
length(work), devinfo, params[], batchSize)
795+
with_workspace(bufferSize) do buffer
796+
$fname(dense_handle(), jobz, uplo, n, A, lda, W, buffer,
797+
sizeof(buffer) ÷ sizeof($elty), devinfo, params[], batchSize)
792798
end
793799

794800
# Copy the solver info and delete the device memory
@@ -894,47 +900,47 @@ end
894900
# LAPACK
895901
for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
896902
@eval begin
897-
LinearAlgebra.LAPACK.potrf!(uplo::Char, A::StridedCuMatrix{$elty}) = CUSOLVER.potrf!(uplo, A)
898-
LinearAlgebra.LAPACK.potrs!(uplo::Char, A::StridedCuMatrix{$elty}, B::StridedCuVecOrMat{$elty}) = CUSOLVER.potrs!(uplo, A, B)
899-
LinearAlgebra.LAPACK.potri!(uplo::Char, A::StridedCuMatrix{$elty}) = CUSOLVER.potri!(uplo, A)
900-
LinearAlgebra.LAPACK.getrf!(A::StridedCuMatrix{$elty}) = CUSOLVER.getrf!(A)
901-
LinearAlgebra.LAPACK.getrf!(A::StridedCuMatrix{$elty}, ipiv::CuVector{Cint}) = CUSOLVER.getrf!(A, ipiv)
902-
LinearAlgebra.LAPACK.geqrf!(A::StridedCuMatrix{$elty}) = CUSOLVER.geqrf!(A)
903-
LinearAlgebra.LAPACK.geqrf!(A::StridedCuMatrix{$elty}, tau::CuVector{$elty}) = CUSOLVER.geqrf!(A, tau)
904-
LinearAlgebra.LAPACK.sytrf!(uplo::Char, A::StridedCuMatrix{$elty}) = sytrf!(uplo, A)
905-
LinearAlgebra.LAPACK.sytrf!(uplo::Char, A::StridedCuMatrix{$elty}, ipiv::CuVector{Cint}) = CUSOLVER.sytrf!(uplo, A, ipiv)
906-
LinearAlgebra.LAPACK.getrs!(trans::Char, A::StridedCuMatrix{$elty}, ipiv::CuVector{Cint}, B::StridedCuVecOrMat{$elty}) = CUSOLVER.getrs!(trans, A, ipiv, B)
907-
LinearAlgebra.LAPACK.ormqr!(side::Char, trans::Char, A::CuMatrix{$elty}, tau::CuVector{$elty}, C::CuVecOrMat{$elty}) = CUSOLVER.ormqr!(side, trans, A, tau, C)
908-
LinearAlgebra.LAPACK.orgqr!(A::StridedCuMatrix{$elty}, tau::CuVector{$elty}) = CUSOLVER.orgqr!(A, tau)
909-
LinearAlgebra.LAPACK.gebrd!(A::StridedCuMatrix{$elty}) = CUSOLVER.gebrd!(A)
910-
LinearAlgebra.LAPACK.gesvd!(jobu::Char, jobvt::Char, A::StridedCuMatrix{$elty}) = CUSOLVER.gesvd!(jobu, jobvt, A)
903+
LAPACK.potrf!(uplo::Char, A::StridedCuMatrix{$elty}) = potrf!(uplo, A)
904+
LAPACK.potrs!(uplo::Char, A::StridedCuMatrix{$elty}, B::StridedCuVecOrMat{$elty}) = potrs!(uplo, A, B)
905+
LAPACK.potri!(uplo::Char, A::StridedCuMatrix{$elty}) = potri!(uplo, A)
906+
LAPACK.getrf!(A::StridedCuMatrix{$elty}) = getrf!(A)
907+
LAPACK.getrf!(A::StridedCuMatrix{$elty}, ipiv::CuVector{Cint}) = getrf!(A, ipiv)
908+
LAPACK.geqrf!(A::StridedCuMatrix{$elty}) = geqrf!(A)
909+
LAPACK.geqrf!(A::StridedCuMatrix{$elty}, tau::CuVector{$elty}) = geqrf!(A, tau)
910+
LAPACK.sytrf!(uplo::Char, A::StridedCuMatrix{$elty}) = sytrf!(uplo, A)
911+
LAPACK.sytrf!(uplo::Char, A::StridedCuMatrix{$elty}, ipiv::CuVector{Cint}) = sytrf!(uplo, A, ipiv)
912+
LAPACK.getrs!(trans::Char, A::StridedCuMatrix{$elty}, ipiv::CuVector{Cint}, B::StridedCuVecOrMat{$elty}) = getrs!(trans, A, ipiv, B)
913+
LAPACK.ormqr!(side::Char, trans::Char, A::CuMatrix{$elty}, tau::CuVector{$elty}, C::CuVecOrMat{$elty}) = ormqr!(side, trans, A, tau, C)
914+
LAPACK.orgqr!(A::StridedCuMatrix{$elty}, tau::CuVector{$elty}) = orgqr!(A, tau)
915+
LAPACK.gebrd!(A::StridedCuMatrix{$elty}) = gebrd!(A)
916+
LAPACK.gesvd!(jobu::Char, jobvt::Char, A::StridedCuMatrix{$elty}) = gesvd!(jobu, jobvt, A)
911917
end
912918
end
913919

914920
for elty in (:Float32, :Float64)
915921
@eval begin
916-
LinearAlgebra.LAPACK.syev!(jobz::Char, uplo::Char, A::StridedCuMatrix{$elty}) = CUSOLVER.syevd!(jobz, uplo, A)
917-
LinearAlgebra.LAPACK.sygvd!(itype::Int, jobz::Char, uplo::Char, A::StridedCuMatrix{$elty}, B::StridedCuMatrix{$elty}) = CUSOLVER.sygvd!(itype, jobz, uplo, A, B)
922+
LAPACK.syev!(jobz::Char, uplo::Char, A::StridedCuMatrix{$elty}) = syevd!(jobz, uplo, A)
923+
LAPACK.sygvd!(itype::Int, jobz::Char, uplo::Char, A::StridedCuMatrix{$elty}, B::StridedCuMatrix{$elty}) = sygvd!(itype, jobz, uplo, A, B)
918924
end
919925
end
920926

921927
for elty in (:ComplexF32, :ComplexF64)
922928
@eval begin
923-
LinearAlgebra.LAPACK.syev!(jobz::Char, uplo::Char, A::StridedCuMatrix{$elty}) = CUSOLVER.heevd!(jobz, uplo, A)
924-
LinearAlgebra.LAPACK.sygvd!(itype::Int, jobz::Char, uplo::Char, A::StridedCuMatrix{$elty}, B::StridedCuMatrix{$elty}) = CUSOLVER.hegvd!(itype, jobz, uplo, A, B)
929+
LAPACK.syev!(jobz::Char, uplo::Char, A::StridedCuMatrix{$elty}) = heevd!(jobz, uplo, A)
930+
LAPACK.sygvd!(itype::Int, jobz::Char, uplo::Char, A::StridedCuMatrix{$elty}, B::StridedCuMatrix{$elty}) = hegvd!(itype, jobz, uplo, A, B)
925931
end
926932
end
927933

928934
if VERSION >= v"1.10"
929935
for elty in (:Float32, :Float64)
930936
@eval begin
931-
LinearAlgebra.LAPACK.syevd!(jobz::Char, uplo::Char, A::StridedCuMatrix{$elty}) = CUSOLVER.syevd!(jobz, uplo, A)
937+
LAPACK.syevd!(jobz::Char, uplo::Char, A::StridedCuMatrix{$elty}) = syevd!(jobz, uplo, A)
932938
end
933939
end
934940

935941
for elty in (:ComplexF32, :ComplexF64)
936942
@eval begin
937-
LinearAlgebra.LAPACK.syevd!(jobz::Char, uplo::Char, A::StridedCuMatrix{$elty}) = CUSOLVER.heevd!(jobz, uplo, A)
943+
LAPACK.syevd!(jobz::Char, uplo::Char, A::StridedCuMatrix{$elty}) = heevd!(jobz, uplo, A)
938944
end
939945
end
940946
end

lib/custatevec/src/statevec.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ function applyMatrix!(sv::CuStateVec, matrix::Union{Matrix, CuMatrix}, adjoint::
1111
out[]
1212
end
1313
with_workspace(handle().cache, bufferSize) do buffer
14-
custatevecApplyMatrix(handle(), sv.data, eltype(sv), sv.nbits, matrix, eltype(matrix), CUSTATEVEC_MATRIX_LAYOUT_COL, Int32(adjoint), convert(Vector{Int32}, targets), length(targets), convert(Vector{Int32}, controls), convert(Vector{Int32}, controlValues), length(controls), compute_type(eltype(sv), eltype(matrix)), buffer, length(buffer))
14+
custatevecApplyMatrix(handle(), sv.data, eltype(sv), sv.nbits, matrix, eltype(matrix), CUSTATEVEC_MATRIX_LAYOUT_COL, Int32(adjoint), convert(Vector{Int32}, targets), length(targets), convert(Vector{Int32}, controls), convert(Vector{Int32}, controlValues), length(controls), compute_type(eltype(sv), eltype(matrix)), buffer, sizeof(buffer))
1515
end
1616
sv
1717
end
@@ -25,7 +25,7 @@ function applyMatrixBatched!(sv::CuStateVec, n_svs::Int, map_type::custatevecMat
2525
out[]
2626
end
2727
with_workspace(handle().cache, bufferSize) do buffer
28-
custatevecApplyMatrixBatched(handle(), sv.data, eltype(sv), n_index_bits, n_svs, sv_stride, map_type, matrix_inds, matrix, eltype(matrix), CUSTATEVEC_MATRIX_LAYOUT_COL, Int32(adjoint), n_matrices, convert(Vector{Int32}, targets), length(targets), convert(Vector{Int32}, controls), convert(Vector{Int32}, controlValues), length(controls), compute_type(eltype(sv), eltype(matrix)), buffer, length(buffer))
28+
custatevecApplyMatrixBatched(handle(), sv.data, eltype(sv), n_index_bits, n_svs, sv_stride, map_type, matrix_inds, matrix, eltype(matrix), CUSTATEVEC_MATRIX_LAYOUT_COL, Int32(adjoint), n_matrices, convert(Vector{Int32}, targets), length(targets), convert(Vector{Int32}, controls), convert(Vector{Int32}, controlValues), length(controls), compute_type(eltype(sv), eltype(matrix)), buffer, sizeof(buffer))
2929
end
3030
sv
3131
end
@@ -37,7 +37,7 @@ function applyGeneralizedPermutationMatrix!(sv::CuStateVec, permutation::Union{V
3737
out[]
3838
end
3939
with_workspace(handle().cache, bufferSize) do buffer
40-
custatevecApplyGeneralizedPermutationMatrix(handle(), sv.data, eltype(sv), sv.nbits, permutation, diagonals, eltype(diagonals), Int32(adjoint), convert(Vector{Int32}, targets), length(targets), convert(Vector{Int32}, controls), convert(Vector{Int32}, controlValues), length(controls), buffer, length(buffer))
40+
custatevecApplyGeneralizedPermutationMatrix(handle(), sv.data, eltype(sv), sv.nbits, permutation, diagonals, eltype(diagonals), Int32(adjoint), convert(Vector{Int32}, targets), length(targets), convert(Vector{Int32}, controls), convert(Vector{Int32}, controlValues), length(controls), buffer, sizeof(buffer))
4141
end
4242
sv
4343
end
@@ -75,7 +75,7 @@ function collapseByBitStringBatched!(sv::CuStateVec, n_svs::Int, bitstrings::Vec
7575
sv_stride = div(length(sv.data), n_svs)
7676
n_index_bits = Int(log2(div(length(sv.data), n_svs)))
7777
with_workspace(handle().cache, bufferSize) do buffer
78-
custatevecCollapseByBitStringBatched(handle(), sv.data, eltype(sv), n_index_bits, n_svs, sv_stride, convert(Vector{custatevecIndex_t}, bitstrings), convert(Vector{Int32}, bitordering), n_index_bits, norms, buffer, length(buffer))
78+
custatevecCollapseByBitStringBatched(handle(), sv.data, eltype(sv), n_index_bits, n_svs, sv_stride, convert(Vector{custatevecIndex_t}, bitstrings), convert(Vector{Int32}, bitordering), n_index_bits, norms, buffer, sizeof(buffer))
7979
end
8080
sv
8181
end
@@ -118,7 +118,7 @@ function expectation(sv::CuStateVec, matrix::Union{Matrix, CuMatrix}, basis_bits
118118
expVal = Ref{Float64}()
119119
residualNorm = Ref{Float64}()
120120
with_workspace(handle().cache, bufferSize) do buffer
121-
custatevecComputeExpectation(handle(), sv.data, eltype(sv), sv.nbits, expVal, Float64, residualNorm, matrix, eltype(matrix), CUSTATEVEC_MATRIX_LAYOUT_COL, convert(Vector{Int32}, basis_bits), length(basis_bits), compute_type(eltype(sv), eltype(matrix)), buffer, length(buffer))
121+
custatevecComputeExpectation(handle(), sv.data, eltype(sv), sv.nbits, expVal, Float64, residualNorm, matrix, eltype(matrix), CUSTATEVEC_MATRIX_LAYOUT_COL, convert(Vector{Int32}, basis_bits), length(basis_bits), compute_type(eltype(sv), eltype(matrix)), buffer, sizeof(buffer))
122122
end
123123
return expVal[], residualNorm[]
124124
end
@@ -134,7 +134,7 @@ function sample(sv::CuStateVec, sampled_bits::Vector{<:Integer}, shot_count)
134134
sampler = CuStateVecSampler(sv, UInt32(shot_count))
135135
bitstrings = Vector{custatevecIndex_t}(undef, shot_count)
136136
with_workspace(handle().cache, sampler.ws_size) do buffer
137-
custatevecSamplerPreprocess(handle(), sampler.handle, buffer, length(buffer))
137+
custatevecSamplerPreprocess(handle(), sampler.handle, buffer, sizeof(buffer))
138138
custatevecSamplerSample(handle(), sampler.handle, bitstrings, convert(Vector{Int32}, sampled_bits), length(sampled_bits), rand(shot_count), shot_count, CUSTATEVEC_SAMPLER_OUTPUT_RANDNUM_ORDER)
139139
end
140140
return bitstrings
@@ -173,7 +173,7 @@ function testMatrixType(matrix::Union{Matrix, CuMatrix}, adjoint::Bool, matrix_t
173173
out[]
174174
end
175175
with_workspace(handle().cache, bufferSize) do buffer
176-
custatevecTestMatrixType(handle(), residualNorm, matrix_type, matrix, eltype(matrix), CUSTATEVEC_MATRIX_LAYOUT_COL, n_targets, Int32(adjoint), compute_type, buffer, length(buffer))
176+
custatevecTestMatrixType(handle(), residualNorm, matrix_type, matrix, eltype(matrix), CUSTATEVEC_MATRIX_LAYOUT_COL, n_targets, Int32(adjoint), compute_type, buffer, sizeof(buffer))
177177
end
178178
return residualNorm[]
179179
end

0 commit comments

Comments
 (0)