Skip to content

Introduce QuantumToolboxMetalExt #233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions .buildkite/Metal_Ext.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
steps:
- label: "Metal Julia {{matrix.version}}"
matrix:
setup:
version:
- "1.10" # oldest
- "1" # latest
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.version}}"
- JuliaCI/julia-test#v1:
test_args: "--quickfail"
- JuliaCI/julia-coverage#v1:
codecov: true
dirs:
- src
- ext
agents:
queue: "juliaecosystem"
os: "macos"
arch: "aarch64"
env:
GROUP: "Metal_Ext"
SECRET_CODECOV_TOKEN: "ZfhQu/IcRLqNyZ//ZNs5sjBPaV76IHfU5gui52Qn+Rp8tOurukqgScuyDt+3HQ4R0hJYBw1/Nqg6jmBsvWSc9NEUx8kGsUJFHfN3no0+b+PFxA8oJkWc9EpyIsjht5ZIjlsFWR3f0DpPqMEle/QyWOPcal63CChXR8oAoR+Fz1Bh8GkokLlnC8F9Ugp9xBlu401GCbyZhvLTZnNIgK5yy9q8HBJnBg1cPOhI81J6JvYpEmcIofEzFV/qkfpTUPclu43WNoFX2DZPzbxilf3fsAd5/+nRkRfkNML8KiN4mnmjHxPPbuY8F5zC/PS5ybXtDpfvaMQc01WApXCkZk0ZAQ==;U2FsdGVkX1+eDT7dqCME5+Ox5i8GvWRTQbwiP/VYjapThDbxXFDeSSIC6Opmon+M8go22Bun3bat6Fzie65ang=="
timeout_in_minutes: 60
if: |
// Don't run Buildkite if the commit message includes the text [skip ci], [ci skip], or [no ci]
// Don't run Buildkite for PR draft
// Only run Buildkite when new commits and PR are made to main branch
build.message !~ /\[skip ci\]/ &&
build.message !~ /\[ci skip\]/ &&
build.message !~ /\[no ci\]/ &&
!build.pull_request.draft &&
(build.branch =~ /main/ || build.pull_request.base_branch =~ /main/)
12 changes: 11 additions & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,18 @@ steps:
- "src/**"
- "ext/QuantumToolboxCUDAExt.jl"
- "test/runtests.jl"
- "test/ext-test/gpu/**"
- "test/ext-test/cuda/**"
- "Project.toml"
target: ".buildkite/CUDA_Ext.yml"
- staticfloat/forerunner: # Metal.jl tests
watch:
- ".buildkite/pipeline.yml"
- ".buildkite/Metal_Ext.yml"
- "src/**"
- "ext/QuantumToolboxMetalExt.jl"
- "test/runtests.jl"
- "test/ext-test/metal/**"
- "Project.toml"
target: ".buildkite/Metal_Ext.yml"
agents:
queue: "juliagpu"
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"

[extensions]
QuantumToolboxCUDAExt = "CUDA"
QuantumToolboxCairoMakieExt = "CairoMakie"
QuantumToolboxGPUArraysExt = ["GPUArrays", "KernelAbstractions"]
QuantumToolboxMetalExt = "Metal"

[compat]
Aqua = "0.8"
Expand All @@ -47,13 +49,14 @@ DiffEqCallbacks = "4.2.1 - 4"
DiffEqNoiseProcess = "5"
Distributed = "1"
FFTW = "1.5"
GPUArrays = "10, 11"
GPUArrays = "11.2 - 11"
Graphs = "1.7"
IncompleteLU = "0.2"
JET = "0.9"
KernelAbstractions = "0.9.2"
LinearAlgebra = "1"
LinearSolve = "2"
Metal = "1.5"
OrdinaryDiffEqCore = "1"
OrdinaryDiffEqTsit5 = "1"
Pkg = "1"
Expand Down
55 changes: 55 additions & 0 deletions ext/QuantumToolboxMetalExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
module QuantumToolboxMetalExt

using QuantumToolbox
import Metal: mtl, MtlArray

@doc raw"""
MtlArray(A::QuantumObject)
If `A.data` is an arbitrary array, return a new [`QuantumObject`](@ref) where `A.data` is in the type of `Metal.MtlArray` for gpu calculations.
Note that this function will always change element type into `32`-bit (`Int32`, `Float32`, and `ComplexF32`).
"""
MtlArray(A::QuantumObject{<:AbstractArray{T}}) where {T<:Number} = QuantumObject(MtlArray(A.data), A.type, A.dims)
MtlArray(A::QuantumObject{<:AbstractArray{T}}) where {T<:Int64} = QuantumObject(MtlArray{Int32}(A.data), A.type, A.dims)
MtlArray(A::QuantumObject{<:AbstractArray{T}}) where {T<:Float64} =
QuantumObject(MtlArray{Float32}(A.data), A.type, A.dims)
MtlArray(A::QuantumObject{<:AbstractArray{T}}) where {T<:ComplexF64} =
QuantumObject(MtlArray{ComplexF32}(A.data), A.type, A.dims)

@doc raw"""
MtlArray{T}(A::QuantumObject)
If `A.data` is an arbitrary array, return a new [`QuantumObject`](@ref) where `A.data` is in the type of `Metal.MtlArray` with element type `T` for gpu calculations.
"""
MtlArray{T}(A::QuantumObject{<:AbstractArray{Tq}}) where {T,Tq<:Number} =
QuantumObject(MtlArray{T}(A.data), A.type, A.dims)

@doc raw"""
mtl(A::QuantumObject)
Return a new [`QuantumObject`](@ref) where `A.data` is in the type of `Metal` arrays for gpu calculations.
Note that this function will always change element type into `32`-bit (`Int32`, `Float32`, and `ComplexF32`).
"""
mtl(A::QuantumObject{<:AbstractArray{T}}) where {T<:Int64} = QuantumObject(MtlArray{Int32}(A.data), A.type, A.dims)
mtl(A::QuantumObject{<:AbstractArray{T}}) where {T<:Float64} = QuantumObject(MtlArray{Float32}(A.data), A.type, A.dims)
mtl(A::QuantumObject{<:AbstractArray{T}}) where {T<:ComplexF64} =
QuantumObject(MtlArray{ComplexF32}(A.data), A.type, A.dims)

## TODO: Remove the following part if Metal.jl support `sparse`
import LinearAlgebra: Transpose, Adjoint
import QuantumToolbox: _spre, _spost, _sprepost
_spre(A::MtlArray, Id::AbstractMatrix) = kron(Id, A)
_spre(A::Transpose{T,<:MtlArray}, Id::AbstractMatrix) where {T<:Number} = kron(Id, A)
_spre(A::Adjoint{T,<:MtlArray}, Id::AbstractMatrix) where {T<:Number} = kron(Id, A)
_spost(B::MtlArray, Id::AbstractMatrix) = kron(transpose(B), Id)
_spost(B::Transpose{T,<:MtlArray}, Id::AbstractMatrix) where {T<:Number} = kron(transpose(B), Id)
_spost(B::Adjoint{T,<:MtlArray}, Id::AbstractMatrix) where {T<:Number} = kron(transpose(B), Id)
_sprepost(A::MtlArray, B::MtlArray) = kron(transpose(B), A)
_sprepost(A::MtlArray, B::Transpose{T,<:MtlArray}) where {T<:Number} = kron(transpose(B), A)
_sprepost(A::MtlArray, B::Adjoint{T,<:MtlArray}) where {T<:Number} = kron(transpose(B), A)
_sprepost(A::Transpose{T,<:MtlArray}, B::MtlArray) where {T<:Number} = kron(transpose(B), A)
_sprepost(A::Transpose{T1,<:MtlArray}, B::Transpose{T2,<:MtlArray}) where {T1<:Number,T2<:Number} =
kron(transpose(B), A)
_sprepost(A::Transpose{T1,<:MtlArray}, B::Adjoint{T2,<:MtlArray}) where {T1<:Number,T2<:Number} = kron(transpose(B), A)
_sprepost(A::Adjoint{T,<:MtlArray}, B::MtlArray) where {T<:Number} = kron(transpose(B), A)
_sprepost(A::Adjoint{T1,<:MtlArray}, B::Transpose{T2,<:MtlArray}) where {T1<:Number,T2<:Number} = kron(transpose(B), A)
_sprepost(A::Adjoint{T1,<:MtlArray}, B::Adjoint{T2,<:MtlArray}) where {T1<:Number,T2<:Number} = kron(transpose(B), A)

end
File renamed without changes.
File renamed without changes.
6 changes: 6 additions & 0 deletions test/ext-test/metal/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[deps]
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
QuantumToolbox = "6c2fb7c5-b903-41d2-bc5e-5a7c320b9fab"

[compat]
Metal = "1"
75 changes: 75 additions & 0 deletions test/ext-test/metal/metal_ext.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
@testset "Metal Extension" verbose = true begin
ψdi = Qobj(Int64[1, 0])
ψdf = Qobj(Float64[1, 0])
ψdc = Qobj(ComplexF64[1, 0])
ψsi = dense_to_sparse(ψdi)
ψsf = dense_to_sparse(ψdf)
ψsc = dense_to_sparse(ψdc)

Xdi = Qobj(Int64[0 1; 1 0])
Xdf = Qobj(Float64[0 1; 1 0])
Xdc = Qobj(ComplexF64[0 1; 1 0])
Xsi = dense_to_sparse(Xdi)
Xsf = dense_to_sparse(Xdf)
Xsc = dense_to_sparse(Xdc)

# type conversion of dense arrays
@test typeof(mtl(ψdi).data) == typeof(MtlArray{Int32}(ψdi).data) <: MtlVector{Int32}
@test typeof(mtl(ψdf).data) ==
typeof(MtlArray(ψdf).data) ==
typeof(MtlArray{Float32}(ψdf).data) <:
MtlVector{Float32}
@test typeof(mtl(ψdc).data) ==
typeof(MtlArray(ψdc).data) ==
typeof(MtlArray{ComplexF32}(ψdc).data) <:
MtlVector{ComplexF32}
@test typeof(mtl(Xdi).data) == typeof(MtlArray{Int32}(Xdi).data) <: MtlMatrix{Int32}
@test typeof(mtl(Xdf).data) ==
typeof(MtlArray(Xdf).data) ==
typeof(MtlArray{Float32}(Xdf).data) <:
MtlMatrix{Float32}
@test typeof(mtl(Xdc).data) ==
typeof(MtlArray(Xdc).data) ==
typeof(MtlArray{ComplexF32}(Xdc).data) <:
MtlMatrix{ComplexF32}

# type conversion of sparse arrays
@test typeof(mtl(ψsi).data) == typeof(MtlArray{Int32}(ψsi).data) <: MtlVector{Int32}
@test typeof(mtl(ψsf).data) ==
typeof(MtlArray(ψsf).data) ==
typeof(MtlArray{Float32}(ψsf).data) <:
MtlVector{Float32}
@test typeof(mtl(ψsc).data) ==
typeof(MtlArray(ψsc).data) ==
typeof(MtlArray{ComplexF32}(ψsc).data) <:
MtlVector{ComplexF32}
@test typeof(mtl(Xsi).data) == typeof(MtlArray{Int32}(Xsi).data) <: MtlMatrix{Int32}
@test typeof(mtl(Xsf).data) ==
typeof(MtlArray(Xsf).data) ==
typeof(MtlArray{Float32}(Xsf).data) <:
MtlMatrix{Float32}
@test typeof(mtl(Xsc).data) ==
typeof(MtlArray(Xsc).data) ==
typeof(MtlArray{ComplexF32}(Xsc).data) <:
MtlMatrix{ComplexF32}

# brief example in README and documentation
N = 5 # cannot be too large since Metal.jl does not support sparse matrix
ω = 1.0f0 # Float32
γ = 0.1f0 # Float32
tlist = range(0, 10, 100)

## calculate by CPU
a_cpu = destroy(N)
ψ0_cpu = fock(N, 3)
H_cpu = ω * a_cpu' * a_cpu
sol_cpu = mesolve(H_cpu, ψ0_cpu, tlist, [sqrt(γ) * a_cpu], e_ops = [a_cpu' * a_cpu], progress_bar = Val(false))

## calculate by GPU
a_gpu = mtl(destroy(N))
ψ0_gpu = mtl(fock(N, 3))
H_gpu = ω * a_gpu' * a_gpu
sol_gpu = mesolve(H_gpu, ψ0_gpu, tlist, [sqrt(γ) * a_gpu], e_ops = [a_gpu' * a_gpu], progress_bar = Val(false))

@test all(isapprox.(sol_cpu.expect, sol_gpu.expect; atol = 1e-6))
end
19 changes: 17 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ if (GROUP == "CairoMakie_Ext")# || (GROUP == "All")
end

if (GROUP == "CUDA_Ext")# || (GROUP == "All")
Pkg.activate("ext-test/gpu")
Pkg.activate("ext-test/cuda")
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
Pkg.instantiate()

Expand All @@ -71,5 +71,20 @@ if (GROUP == "CUDA_Ext")# || (GROUP == "All")
QuantumToolbox.about()
CUDA.versioninfo()

include(joinpath(testdir, "ext-test", "gpu", "cuda_ext.jl"))
include(joinpath(testdir, "ext-test", "cuda", "cuda_ext.jl"))
end

if (GROUP == "Metal_Ext")# || (GROUP == "All")
Pkg.activate("ext-test/metal")
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
Pkg.instantiate()

using QuantumToolbox
using Metal
Metal.allowscalar(false) # Avoid unexpected scalar indexing

QuantumToolbox.about()
Metal.versioninfo()

include(joinpath(testdir, "ext-test", "metal", "metal_ext.jl"))
end
Loading