Skip to content

Bringing GPU programming to DFTK #697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
Brillouin = "23470ee3-d0df-4052-8b1a-8cbd6363e7f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DftFunctionals = "6bd331d2-b28d-4fd3-880e-1a1c7f37947f"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
InteratomicPotentials = "a9efe35a-c65d-452d-b8a8-82646cd5cb04"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
Expand All @@ -33,6 +35,7 @@ Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
26 changes: 26 additions & 0 deletions examples/gpu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using DFTK
using CUDA
using MKL
setup_threading(n_blas=1)

a = 10.263141334305942 # Lattice constant in Bohr
lattice = a / 2 .* [[0 1 1.]; [1 0 1.]; [1 1 0.]]
Si = ElementPsp(:Si, psp=load_psp("hgh/lda/Si-q4"))
atoms = [Si, Si]
positions = [ones(3)/8, -ones(3)/8];
terms_LDA = [Kinetic(), AtomicLocal(), AtomicNonlocal()]

# Setup an LDA model and discretize using
# a single k-point and a small `Ecut` of 5 Hartree.
mod = Model(lattice, atoms, positions; terms=terms_LDA,symmetries=false)
basis = PlaneWaveBasis(mod; Ecut=30, kgrid=(1, 1, 1))
basis_gpu = PlaneWaveBasis(mod; Ecut=30, kgrid=(1, 1, 1), array_type = CuArray)


DFTK.reset_timer!(DFTK.timer)
scfres = self_consistent_field(basis; solver=scf_damping_solver(1.0), is_converged=DFTK.ScfConvergenceDensity(1e-3))
println(DFTK.timer)

DFTK.reset_timer!(DFTK.timer)
scfres_gpu = self_consistent_field(basis_gpu; solver=scf_damping_solver(1.0), is_converged=DFTK.ScfConvergenceDensity(1e-3))
println(DFTK.timer)
4 changes: 4 additions & 0 deletions src/DFTK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ using spglib_jll
using Unitful
using UnitfulAtomic
using ForwardDiff
using AbstractFFTs
using GPUArrays
using CUDA
using Random
Comment on lines +16 to +19
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure they should be here (and a hard dependency of DFTK) long-term.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we will need to discuss dependencies (especially if we want to move LOBPCG out of DFTK, that can take some work): I also didn't really know where to put my imports and how they were managed in a big package, so there is room for improvement.

using ChainRulesCore

export Vec3
Expand Down
27 changes: 16 additions & 11 deletions src/PlaneWaveBasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ struct PlaneWaveBasis{T, VT} <: AbstractBasis{T} where {VT <: Real}
G_to_r_normalization::T # G_to_r = G_to_r_normalization * BFFT

# "cubic" basis in reciprocal and real space, on which potentials and densities are stored
G_vectors::Array{Vec3{Int}, 3}
r_vectors::Array{Vec3{VT }, 3}
G_vectors::AbstractArray{Vec3{Int}, 3}
r_vectors::AbstractArray{Vec3{VT }, 3}

## MPI-local information of the kpoints this processor treats
# Irreducible kpoints. In the case of collinear spin,
Expand Down Expand Up @@ -148,7 +148,7 @@ end
# and are stored in PlaneWaveBasis for easy reconstruction.
function PlaneWaveBasis(model::Model{T}, Ecut::Number, fft_size, variational,
kcoords, kweights, kgrid, kshift,
symmetries_respect_rgrid, comm_kpts) where {T <: Real}
symmetries_respect_rgrid, comm_kpts, array_type = Array) where {T <: Real}
# Validate fft_size
if variational
max_E = sum(abs2, model.recip_lattice * floor.(Int, Vec3(fft_size) ./ 2)) / 2
Expand Down Expand Up @@ -191,7 +191,8 @@ function PlaneWaveBasis(model::Model{T}, Ecut::Number, fft_size, variational,
kweights_global = kweights

# Setup FFT plans
(ipFFT, opFFT, ipBFFT, opBFFT) = build_fft_plans(T, fft_size)
G_vects = G_vectors(fft_size, array_type)
(ipFFT, opFFT, ipBFFT, opBFFT) = build_fft_plans(similar(G_vects,T), fft_size)

# Normalization constants
# r_to_G = r_to_G_normalization * FFT
Expand Down Expand Up @@ -255,15 +256,14 @@ function PlaneWaveBasis(model::Model{T}, Ecut::Number, fft_size, variational,
Ecut, variational,
opFFT, ipFFT, opBFFT, ipBFFT,
r_to_G_normalization, G_to_r_normalization,
G_vectors(fft_size), r_vectors,
G_vects, r_vectors,
kpoints, kweights_thisproc, kgrid, kshift,
kcoords_global, kweights_global, comm_kpts, krange_thisproc, krange_allprocs,
symmetries, symmetries_respect_rgrid, terms)

# Instantiate the terms with the basis
for (it, t) in enumerate(model.term_types)
term_name = string(nameof(typeof(t)))
@timing "Instantiation $term_name" basis.terms[it] = t(basis)
@timing "Instantiation $term_name" basis.terms[it] = t(basis, array_type = array_type)
end
basis
end
Expand All @@ -277,7 +277,7 @@ end
variational=true, fft_size=nothing,
kgrid=nothing, kshift=nothing,
symmetries_respect_rgrid=isnothing(fft_size),
comm_kpts=MPI.COMM_WORLD) where {T <: Real}
comm_kpts=MPI.COMM_WORLD, array_type = Array) where {T <: Real}
if isnothing(fft_size)
@assert variational
if symmetries_respect_rgrid
Expand All @@ -295,7 +295,7 @@ end
fft_size = compute_fft_size(model, Ecut, kcoords; factors=factors)
end
PlaneWaveBasis(model, Ecut, fft_size, variational, kcoords, kweights,
kgrid, kshift, symmetries_respect_rgrid, comm_kpts)
kgrid, kshift, symmetries_respect_rgrid, comm_kpts, array_type)
end

@doc raw"""
Expand All @@ -317,12 +317,12 @@ end
Creates a new basis identical to `basis`, but with a custom set of kpoints
"""
@timing function PlaneWaveBasis(basis::PlaneWaveBasis, kcoords::AbstractVector,
kweights::AbstractVector)
kweights::AbstractVector; array_type = Array)
kgrid = kshift = nothing
PlaneWaveBasis(basis.model, basis.Ecut,
basis.fft_size, basis.variational,
kcoords, kweights, kgrid, kshift,
basis.symmetries_respect_rgrid, basis.comm_kpts)
basis.symmetries_respect_rgrid, basis.comm_kpts, array_type)
end

"""
Expand All @@ -331,6 +331,11 @@ end
The wave vectors `G` in reduced (integer) coordinates for a cubic basis set
of given sizes.
"""
function G_vectors(fft_size::Union{Tuple,AbstractVector}, array_type::UnionAll)
#This functions allows to convert the G_vectors (currently being built on the CPU) to a GPU Array.
convert(array_type, G_vectors(fft_size))
end

function G_vectors(fft_size::Union{Tuple,AbstractVector})
# Note that a collect(G_vectors_generator(fft_size)) is 100-fold slower
# than this implementation, hence the code duplication.
Expand Down
3 changes: 2 additions & 1 deletion src/common/ortho.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# Orthonormalize
ortho_qr(φk) = Matrix(qr(φk).Q)
ortho_qr(φk::AbstractArray) = Matrix(qr(φk).Q) #LinearAlgebra.QRCompactWYQ -> Matrix
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this should be ::Array instead. Also it somehow feels wrong to need to put CuArray explicitly here. We should think of a way to generalise this (perhaps also with some "stripping off type arguments" construct as discussed on slack.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't we simply do this?

ortho_qr(φk::Array) = Matrix(qr(φk).Q) 
ortho_qr(φk::T) where T <: AbstractGPUArray = T(qr(φk).Q) 

Another way to do it would be to have only one function and to get the the "base type" of φk, then convert qr(φk).Q to this type: this can be done by calling T.name.wrapper (or maybe one day a dedicated function in Base). We would then have the following code:
ortho_qr(φk::T) where T <: AbstractArray = T.name.wrapper(qr(φk).Q)

ortho_qr(φk::CuArray) = CuArray(qr(φk).Q) #CUDA.CUSOLVER.CuQRPackedQ -> CuArray
20 changes: 15 additions & 5 deletions src/densities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,27 @@ grid `basis`, where the individual k-points are occupied according to `occupatio

@sync for (ichunk, chunk) in enumerate(Iterators.partition(ik_n, chunk_length))
Threads.@spawn for (ik, n) in chunk # spawn a task per chunk
ψnk_real = ψnk_real_chunklocal[ichunk]
ρ_loc = ρ_chunklocal[ichunk]

kpt = basis.kpoints[ik]
G_to_r!(ψnk_real, basis, kpt, ψ[ik][:, n])
ρ_loc[:, :, :, kpt.spin] .+= occupation[ik][n] .* basis.kweights[ik] .* abs2.(ψnk_real)
#TODO: is this the right way to got? Probably rewrite compute_density for GPUArrays
if typeof(basis.G_vectors) <:AbstractGPUArray
ψnk_real = similar(basis.G_vectors, complex(T), basis.fft_size)
G_to_r!(ψnk_real, basis, kpt, ψ[ik][:, n])
ρ_loc = ρ_chunklocal[ichunk]
ρ_loc[:, :, :, kpt.spin] .+= occupation[ik][n] .* basis.kweights[ik] .* Array(abs2.(ψnk_real))
else
ψnk_real = ψnk_real_chunklocal[ichunk]
ρ_loc = ρ_chunklocal[ichunk]

G_to_r!(ψnk_real, basis, kpt, ψ[ik][:, n])
ρ_loc[:, :, :, kpt.spin] .+= occupation[ik][n] .* basis.kweights[ik] .* abs2.(ψnk_real)
end
end
end

ρ = sum(ρ_chunklocal)
mpi_sum!(ρ, basis.comm_kpts)
array_type = typeof(similar(basis.G_vectors,complex(T), size(ρ)))
ρ = convert(array_type, ρ)
ρ = symmetrize_ρ(basis, ρ; do_lowpass=false)

_check_positive(ρ)
Expand Down
Loading