Skip to content

Commit 1f857be

Browse files
Add ptrace support for any GPUArray (#350)
* Add `ptrace` support for any `GPUArray` * Add changelog --------- Co-authored-by: Yi-Te Huang <44385685+ytdHuang@users.noreply.github.com>
1 parent 3ffdd3d commit 1f857be

File tree

5 files changed

+131
-2
lines changed

5 files changed

+131
-2
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
## [Unreleased](https://github.com/qutip/QuantumToolbox.jl/tree/main)
99

1010
- Change the structure of block diagonalization functions, using `BlockDiagonalForm` struct and changing the function name from `bdf` to `block_diagonal_form`. ([#349])
11-
11+
- Add **GPUArrays** compatibility for `ptrace` function, by using **KernelAbstractions.jl**. ([#350])
1212

1313
## [v0.24.0]
1414
Release date: 2024-12-13
@@ -70,3 +70,4 @@ Release date: 2024-11-13
7070
[#346]: https://github.com/qutip/QuantumToolbox.jl/issues/346
7171
[#347]: https://github.com/qutip/QuantumToolbox.jl/issues/347
7272
[#349]: https://github.com/qutip/QuantumToolbox.jl/issues/349
73+
[#350]: https://github.com/qutip/QuantumToolbox.jl/issues/350

Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,13 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
2929
[weakdeps]
3030
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3131
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
32+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
33+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
3234

3335
[extensions]
3436
QuantumToolboxCUDAExt = "CUDA"
3537
QuantumToolboxCairoMakieExt = "CairoMakie"
38+
QuantumToolboxGPUArraysExt = ["GPUArrays", "KernelAbstractions"]
3639

3740
[compat]
3841
Aqua = "0.8"
@@ -44,9 +47,11 @@ DiffEqCallbacks = "4.2.1 - 4"
4447
DiffEqNoiseProcess = "5"
4548
Distributed = "1"
4649
FFTW = "1.5"
50+
GPUArrays = "10"
4751
Graphs = "1.7"
4852
IncompleteLU = "0.2"
4953
JET = "0.9"
54+
KernelAbstractions = "0.9.2"
5055
LinearAlgebra = "1"
5156
LinearSolve = "2"
5257
OrdinaryDiffEqCore = "1"

ext/QuantumToolboxGPUArraysExt.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
module QuantumToolboxGPUArraysExt
2+
3+
using QuantumToolbox
4+
5+
import GPUArrays: AbstractGPUArray
6+
import KernelAbstractions
7+
import KernelAbstractions: @kernel, @Const, @index, get_backend, synchronize
8+
9+
@kernel function tr_kernel!(B, @Const(A))
10+
# i, j, k = @index(Global, NTuple)
11+
# Atomix.@atomic B[i, j] += A[i, j, k, k] # TODO: use Atomix when it will support Complex types
12+
13+
i, j = @index(Global, NTuple)
14+
@inbounds B[i, j] = 0
15+
@inbounds for k in 1:size(A, 3)
16+
B[i, j] += A[i, j, k, k]
17+
end
18+
end
19+
20+
function QuantumToolbox._map_trace(A::AbstractGPUArray{T,4}) where {T}
21+
B = similar(A, size(A, 1), size(A, 2))
22+
fill!(B, 0)
23+
24+
backend = get_backend(A)
25+
kernel! = tr_kernel!(backend)
26+
27+
kernel!(B, A, ndrange = size(A)[1:2])
28+
KernelAbstractions.synchronize(backend)
29+
30+
return B
31+
end
32+
33+
end

src/qobj/arithmetic_and_attributes.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,11 +631,13 @@ function _ptrace_oper(QO::AbstractArray, dims::Union{SVector,MVector}, sel)
631631
topermute = reverse(2 * n_d + 1 .- qtrace_sel)
632632
ρmat = permutedims(ρmat, topermute) # TODO: use PermutedDimsArray when Julia v1.11.0 is released
633633
ρmat = reshape(ρmat, prod(dkeep), prod(dkeep), prod(dtrace), prod(dtrace))
634-
res = map(tr, eachslice(ρmat, dims = (1, 2)))
634+
res = _map_trace(ρmat)
635635

636636
return res, dkeep
637637
end
638638

639+
_map_trace(A::AbstractArray{T,4}) where {T} = map(tr, eachslice(A, dims = (1, 2)))
640+
639641
@doc raw"""
640642
purity(ρ::QuantumObject)
641643

test/ext-test/gpu/cuda_ext.jl

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,91 @@
106106
@test all([isapprox(sol_cpu.expect[i], sol_gpu64.expect[i]) for i in 1:length(tlist)])
107107
@test all([isapprox(sol_cpu.expect[i], sol_gpu32.expect[i]; atol = 1e-6) for i in 1:length(tlist)])
108108
end
109+
110+
@testset "CUDA ptrace" begin
111+
g = fock(2, 1)
112+
e = fock(2, 0)
113+
α = sqrt(0.7)
114+
β = sqrt(0.3) * 1im
115+
ψ = α * kron(g, e) + β * kron(e, g) |> cu
116+
117+
ρ1 = ptrace(ψ, 1)
118+
ρ2 = ptrace(ψ, 2)
119+
@test ρ1.data isa CuArray
120+
@test ρ2.data isa CuArray
121+
@test Array(ρ1.data) [0.3 0.0; 0.0 0.7] atol = 1e-10
122+
@test Array(ρ2.data) [0.7 0.0; 0.0 0.3] atol = 1e-10
123+
124+
ψ_d = ψ'
125+
126+
ρ1 = ptrace(ψ_d, 1)
127+
ρ2 = ptrace(ψ_d, 2)
128+
@test ρ1.data isa CuArray
129+
@test ρ2.data isa CuArray
130+
@test Array(ρ1.data) [0.3 0.0; 0.0 0.7] atol = 1e-10
131+
@test Array(ρ2.data) [0.7 0.0; 0.0 0.3] atol = 1e-10
132+
133+
ρ = ket2dm(ψ)
134+
ρ1 = ptrace(ρ, 1)
135+
ρ2 = ptrace(ρ, 2)
136+
@test ρ.data isa CuArray
137+
@test ρ1.data isa CuArray
138+
@test ρ2.data isa CuArray
139+
@test Array(ρ1.data) [0.3 0.0; 0.0 0.7] atol = 1e-10
140+
@test Array(ρ2.data) [0.7 0.0; 0.0 0.3] atol = 1e-10
141+
142+
ψ1 = normalize(g + 1im * e)
143+
ψ2 = normalize(g + e)
144+
ρ1 = ket2dm(ψ1)
145+
ρ2 = ket2dm(ψ2)
146+
ρ = kron(ρ1, ρ2) |> cu
147+
ρ1_ptr = ptrace(ρ, 1)
148+
ρ2_ptr = ptrace(ρ, 2)
149+
@test ρ1_ptr.data isa CuArray
150+
@test ρ2_ptr.data isa CuArray
151+
@test ρ1.data Array(ρ1_ptr.data) atol = 1e-10
152+
@test ρ2.data Array(ρ2_ptr.data) atol = 1e-10
153+
154+
ψlist = [rand_ket(2), rand_ket(3), rand_ket(4), rand_ket(5)]
155+
ρlist = [rand_dm(2), rand_dm(3), rand_dm(4), rand_dm(5)]
156+
ψtotal = tensor(ψlist...) |> cu
157+
ρtotal = tensor(ρlist...) |> cu
158+
sel_tests = [
159+
SVector{0,Int}(),
160+
1,
161+
2,
162+
3,
163+
4,
164+
(1, 2),
165+
(1, 3),
166+
(1, 4),
167+
(2, 3),
168+
(2, 4),
169+
(3, 4),
170+
(1, 2, 3),
171+
(1, 2, 4),
172+
(1, 3, 4),
173+
(2, 3, 4),
174+
(1, 2, 3, 4),
175+
]
176+
for sel in sel_tests
177+
if length(sel) == 0
178+
@test ptrace(ψtotal, sel) 1.0
179+
@test ptrace(ρtotal, sel) 1.0
180+
else
181+
@test ptrace(ψtotal, sel) cu(tensor([ket2dm(ψlist[i]) for i in sel]...))
182+
@test ptrace(ρtotal, sel) cu(tensor([ρlist[i] for i in sel]...))
183+
end
184+
end
185+
@test ptrace(ψtotal, (1, 3, 4)) ptrace(ψtotal, (4, 3, 1)) # check sort of sel
186+
@test ptrace(ρtotal, (1, 3, 4)) ptrace(ρtotal, (3, 1, 4)) # check sort of sel
187+
188+
@testset "Type Inference (ptrace)" begin
189+
@inferred ptrace(ρ, 1)
190+
@inferred ptrace(ρ, 2)
191+
@inferred ptrace(ψ_d, 1)
192+
@inferred ptrace(ψ_d, 2)
193+
@inferred ptrace(ψ, 1)
194+
@inferred ptrace(ψ, 2)
195+
end
196+
end

0 commit comments

Comments
 (0)