Skip to content

Commit b1cabe8

Browse files
authored
Fix and test for mgpu batch measure (#2671)
1 parent 3974de5 commit b1cabe8

File tree

2 files changed

+76
-11
lines changed

2 files changed

+76
-11
lines changed

lib/custatevec/src/statevec.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
function initialize!(sv::CuStateVec, sv_type::custatevecStateVectorType_t)
22
custatevecInitializeStateVector(handle(), sv.data, eltype(sv), sv.nbits, sv_type)
3-
sv
3+
return sv
44
end
55

66
function applyPauliExp!(sv::CuStateVec, theta::Float64, paulis::Vector{<:Pauli}, targets::Vector{Int32}, controls::Vector{Int32}, controlValues::Vector{Int32}=fill(one(Int32), length(controls)))
77
cupaulis = CuStateVecPauli.(paulis)
88
custatevecApplyPauliRotation(handle(), sv.data, eltype(sv), sv.nbits, theta, cupaulis, targets, length(targets), controls, controlValues, length(controls))
9-
sv
9+
return sv
1010
end
1111

1212
function applyMatrix!(sv::CuStateVec, matrix::Union{Matrix, CuMatrix}, adjoint::Bool, targets::Vector{<:Integer}, controls::Vector{<:Integer}, controlValues::Vector{<:Integer}=fill(one(Int32), length(controls)))
@@ -18,7 +18,7 @@ function applyMatrix!(sv::CuStateVec, matrix::Union{Matrix, CuMatrix}, adjoint::
1818
with_workspace(handle().cache, bufferSize) do buffer
1919
custatevecApplyMatrix(handle(), sv.data, eltype(sv), sv.nbits, matrix, eltype(matrix), CUSTATEVEC_MATRIX_LAYOUT_COL, Int32(adjoint), convert(Vector{Int32}, targets), length(targets), convert(Vector{Int32}, controls), convert(Vector{Int32}, controlValues), length(controls), compute_type(eltype(sv), eltype(matrix)), buffer, sizeof(buffer))
2020
end
21-
sv
21+
return sv
2222
end
2323

2424
function applyMatrixBatched!(sv::CuStateVec, n_svs::Int, map_type::custatevecMatrixMapType_t, matrix_inds::Vector{Int}, matrix::Union{Vector, CuVector}, n_matrices::Int, adjoint::Bool, targets::Vector{<:Integer}, controls::Vector{<:Integer}, controlValues::Vector{<:Integer}=fill(one(Int32), length(controls)))
@@ -32,7 +32,7 @@ function applyMatrixBatched!(sv::CuStateVec, n_svs::Int, map_type::custatevecMat
3232
with_workspace(handle().cache, bufferSize) do buffer
3333
custatevecApplyMatrixBatched(handle(), sv.data, eltype(sv), n_index_bits, n_svs, sv_stride, map_type, matrix_inds, matrix, eltype(matrix), CUSTATEVEC_MATRIX_LAYOUT_COL, Int32(adjoint), n_matrices, convert(Vector{Int32}, targets), length(targets), convert(Vector{Int32}, controls), convert(Vector{Int32}, controlValues), length(controls), compute_type(eltype(sv), eltype(matrix)), buffer, sizeof(buffer))
3434
end
35-
sv
35+
return sv
3636
end
3737

3838
function applyGeneralizedPermutationMatrix!(sv::CuStateVec, permutation::Union{Vector{<:Integer}, CuVector{<:Integer}}, diagonals::Union{Vector, CuVector}, adjoint::Bool, targets::Vector{<:Integer}, controls::Vector{<:Integer}, controlValues::Vector{<:Integer}=fill(one(Int32), length(controls)))
@@ -44,7 +44,7 @@ function applyGeneralizedPermutationMatrix!(sv::CuStateVec, permutation::Union{V
4444
with_workspace(handle().cache, bufferSize) do buffer
4545
custatevecApplyGeneralizedPermutationMatrix(handle(), sv.data, eltype(sv), sv.nbits, permutation, diagonals, eltype(diagonals), Int32(adjoint), convert(Vector{Int32}, targets), length(targets), convert(Vector{Int32}, controls), convert(Vector{Int32}, controlValues), length(controls), buffer, sizeof(buffer))
4646
end
47-
sv
47+
return sv
4848
end
4949

5050
function abs2SumOnZBasis(sv::CuStateVec, basisInds::Vector{<:Integer})
@@ -56,7 +56,7 @@ end
5656

5757
function collapseOnZBasis!(sv::CuStateVec, parity::Int, basisInds::Vector{<:Integer}, norm::Float64)
5858
custatevecCollapseOnZBasis(handle(), sv.data, eltype(sv), sv.nbits, parity, convert(Vector{Int32}, basisInds), length(basisInds), norm)
59-
sv
59+
return sv
6060
end
6161

6262
function measureOnZBasis!(sv::CuStateVec, basisInds::Vector{<:Integer}, randnum::Float64, collapse::custatevecCollapseOp_t=CUSTATEVEC_COLLAPSE_NONE)
@@ -68,7 +68,7 @@ end
6868

6969
function collapseByBitString!(sv::CuStateVec, bitstring::Union{Vector{<:Integer}, BitVector, Vector{Bool}}, bitordering::Vector{<:Integer}, norm::Float64)
7070
custatevecCollapseByBitString(handle(), sv.data, eltype(sv), sv.nbits, convert(Vector{Int32}, bitstring), convert(Vector{Int32}, bitordering), length(bitstring), norm)
71-
sv
71+
return sv
7272
end
7373

7474
function collapseByBitStringBatched!(sv::CuStateVec, n_svs::Int, bitstrings::Vector{<:Integer}, bitordering::Vector{<:Integer}, norms::Vector{Float64})
@@ -82,7 +82,7 @@ function collapseByBitStringBatched!(sv::CuStateVec, n_svs::Int, bitstrings::Vec
8282
with_workspace(handle().cache, bufferSize) do buffer
8383
custatevecCollapseByBitStringBatched(handle(), sv.data, eltype(sv), n_index_bits, n_svs, sv_stride, convert(Vector{custatevecIndex_t}, bitstrings), convert(Vector{Int32}, bitordering), n_index_bits, norms, buffer, sizeof(buffer))
8484
end
85-
sv
85+
return sv
8686
end
8787

8888
function abs2SumArray(sv::CuStateVec, bitordering::Vector{<:Integer}, maskBitString::Vector{<:Integer}, maskOrdering::Vector{<:Integer})
@@ -110,7 +110,7 @@ end
110110
function batchMeasureWithOffset!(sv::CuStateVec, bitordering::Vector{<:Integer}, randnum::Float64, offset::Float64, abs2sum::Float64, collapse::custatevecCollapseOp_t=CUSTATEVEC_COLLAPSE_NONE)
111111
0.0 <= randnum < 1.0 || throw(ArgumentError("randnum $randnum must be in the interval [0, 1)."))
112112
bitstring = zeros(Int32, length(bitordering))
113-
custatevecBatchMeasure(handle(), sv.data, eltype(sv), sv.nbits, convert(Vector{Int32}, bitstring), convert(Vector{Int32}, bitordering), length(bitstring), randnum, collapse, offset, abs2sum)
113+
custatevecBatchMeasureWithOffset(handle(), sv.data, eltype(sv), sv.nbits, convert(Vector{Int32}, bitstring), convert(Vector{Int32}, bitordering), length(bitstring), randnum, collapse, offset, abs2sum)
114114
return sv, bitstring
115115
end
116116

@@ -147,7 +147,7 @@ end
147147

148148
function swapIndexBits!(sv::CuStateVec, bitSwaps::Vector{Pair{T, T}}, maskBitString::Vector{<:Integer}, maskOrdering::Vector{<:Integer}) where {T<:Integer}
149149
custatevecSwapIndexBits(handle(), sv.data, eltype(sv), sv.nbits, convert(Vector{Pair{Int32, Int32}}, bitSwaps), length(bitSwaps), convert(Vector{Int32}, maskBitString), convert(Vector{Int32}, maskOrdering), length(maskOrdering))
150-
sv
150+
return sv
151151
end
152152

153153
function swapIndexBitsMultiDevice!(sub_svs::Vector{CuStateVec}, devices::Vector{CuDevice}, indexBitSwaps::Vector{Pair{T, T}}, maskBitString::Vector{<:Integer}, maskOrdering::Vector{<:Integer}, device_network_type::custatevecDeviceNetworkType_t) where {T<:Integer}

lib/custatevec/test/runtests.jl

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using cuStateVec
88
@info "cuStateVec version: $(cuStateVec.version())"
99

1010
@testset "cuStateVec" begin
11-
import cuStateVec: CuStateVec, applyMatrix!, applyMatrixBatched!, applyPauliExp!, applyGeneralizedPermutationMatrix!, expectation, expectationsOnPauliBasis, sample, testMatrixType, Pauli, PauliX, PauliY, PauliZ, PauliI, measureOnZBasis!, swapIndexBits!, abs2SumOnZBasis, collapseOnZBasis!, batchMeasure!, abs2SumArray, collapseByBitString!, abs2SumArrayBatched, collapseByBitStringBatched!, accessorSet!, accessorGet, CuStateVecAccessor
11+
import cuStateVec: CuStateVec, applyMatrix!, applyMatrixBatched!, applyPauliExp!, applyGeneralizedPermutationMatrix!, expectation, expectationsOnPauliBasis, sample, testMatrixType, Pauli, PauliX, PauliY, PauliZ, PauliI, measureOnZBasis!, swapIndexBits!, abs2SumOnZBasis, collapseOnZBasis!, batchMeasure!, batchMeasureWithOffset!, abs2SumArray, collapseByBitString!, abs2SumArrayBatched, collapseByBitStringBatched!, accessorSet!, accessorGet, CuStateVecAccessor
1212

1313
@testset "applyMatrix! and expectation" begin
1414
# build a simple state and compute expectations
@@ -273,3 +273,68 @@ using cuStateVec
273273
end
274274
end
275275
end
276+
277+
@testset "cuStateVec multiGPU" begin
278+
279+
nGlobalBits = 2;
280+
nLocalBits = 2;
281+
nSubSvs = 2^nGlobalBits
282+
subSvSize = 2^nLocalBits
283+
bitStringLen = 2
284+
bitOrdering = [1, 0]
285+
286+
bitString = Vector{Int}(undef, bitStringLen)
287+
bitString_result = zeros(Int, bitStringLen)
288+
# the most random of all numbers
289+
randnum = 0.71
290+
291+
h_sv = Vector{ComplexF64}[]
292+
push!(h_sv, [0.0; 0.125im; 0.250im; 0.375im])
293+
push!(h_sv, [0.0; -0.125im; -0.250im; -0.375im])
294+
push!(h_sv, [0.125; 0.125-0.125im; 0.125-0.250im; 0.125-0.375im])
295+
push!(h_sv, [-0.125; -0.125-0.125im; -0.125-0.250im; -0.125-0.375im])
296+
297+
h_sv_result = Vector{ComplexF64}[]
298+
push!(h_sv_result, zeros(ComplexF64, subSvSize))
299+
push!(h_sv_result, zeros(ComplexF64, subSvSize))
300+
push!(h_sv_result, ComplexF64[1/√2; 0; 0; 0])
301+
push!(h_sv_result, ComplexF64[-1/√2; 0; 0; 0])
302+
303+
n_devices = 4;
304+
# on CI, if we only have a single device, set up multiple devices
305+
# so that we properly cover the multigpu code paths.
306+
if ndevices() < n_devices
307+
sv_devices = fill(device(), n_devices)
308+
else
309+
sv_devices = collect(devices())[1:n_devices]
310+
end
311+
initial_dev = device()
312+
d_sv = similar(h_sv, CuStateVec{ComplexF64})
313+
normArray = similar(d_sv, Float64)
314+
try
315+
for sv_i in 1:length(d_sv)
316+
device!(sv_devices[sv_i])
317+
d_sv[sv_i] = CuStateVec(h_sv[sv_i])
318+
normArray[sv_i] = abs2SumArray(d_sv[sv_i], Int[], Int[], Int[])[]
319+
end
320+
finally
321+
device!(initial_dev)
322+
end
323+
cumulativeArray = zeros(Float64, length(normArray) + 1)
324+
for sv_i in 1:length(normArray)
325+
cumulativeArray[sv_i+1] = cumulativeArray[sv_i] + normArray[sv_i]
326+
end
327+
try
328+
for sv_i in 1:length(d_sv)
329+
if cumulativeArray[sv_i] <= randnum && randnum < cumulativeArray[sv_i + 1]
330+
norm = cumulativeArray[end]
331+
offset = cumulativeArray[sv_i]
332+
device!(sv_devices[sv_i])
333+
new_sv, bitstring = batchMeasureWithOffset!(d_sv[sv_i], bitOrdering, randnum, offset, norm)
334+
@test length(bitstring) == nLocalBits
335+
end
336+
end
337+
finally
338+
device!(initial_dev)
339+
end
340+
end

0 commit comments

Comments
 (0)