Skip to content

Commit 9259c96

Browse files
GVignemfherbst
andauthored
Make some computations in DFTK GPU-compatible (#712)
Co-authored-by: Michael F. Herbst <info@michael-herbst.com>
1 parent 703e0d0 commit 9259c96

38 files changed

+324
-145
lines changed

.github/workflows/documentation.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ on:
66
tags:
77
- 'v*'
88
pull_request:
9+
concurrency:
10+
# Skip intermediate builds: always.
11+
# Cancel intermediate builds: only if it is a pull request build.
12+
group: ${{ github.workflow }}-${{ github.ref }}
13+
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
914

1015
jobs:
1116
docs:

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
3535
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
3636
PseudoPotentialIO = "cb339c56-07fa-4cb2-923a-142469552264"
3737
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
38+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3839
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
3940
Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
4041
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

examples/gpu.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using DFTK
2+
using CUDA
3+
4+
a = 10.26 # Silicon lattice constant in Bohr
5+
lattice = a / 2 * [[0 1 1.];
6+
[1 0 1.];
7+
[1 1 0.]]
8+
Si = ElementPsp(:Si, psp=load_psp("hgh/lda/Si-q4"))
9+
atoms = [Si, Si]
10+
positions = [ones(3)/8, -ones(3)/8]
11+
model = model_DFT(lattice, atoms, positions, []; temperature=1e-3)
12+
13+
# If available, use CUDA to store DFT quantities and perform main computations
14+
architecture = has_cuda() ? DFTK.GPU(CuArray) : DFTK.CPU()
15+
16+
basis = PlaneWaveBasis(model; Ecut=30, kgrid=(1, 1, 1), architecture)
17+
scfres = self_consistent_field(basis; tol=1e-3,
18+
solver=scf_damping_solver(),
19+
mixing=KerkerMixing())

src/DFTK.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ using spglib_jll
1313
using Unitful
1414
using UnitfulAtomic
1515
using ForwardDiff
16+
using AbstractFFTs
17+
using GPUArraysCore
18+
using Random
1619
using ChainRulesCore
1720

1821
export Vec3
@@ -31,6 +34,7 @@ include("common/mpi.jl")
3134
include("common/threading.jl")
3235
include("common/printing.jl")
3336
include("common/cis2pi.jl")
37+
include("architecture.jl")
3438
include("common/zeros_like.jl")
3539
include("common/norm.jl")
3640

src/Model.jl

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -279,13 +279,25 @@ Examples of covectors are forces.
279279
Reciprocal vectors are a special case: they are covectors, but conventionally have an
280280
additional factor of 2π in their definition, so they transform rather with 2π times the
281281
inverse lattice transpose: q_cart = 2π lattice' \ q_red = recip_lattice * q_red.
282+
283+
For each of the function there is a one-argument version (returning a function to do the
284+
transformation) and a two-argument version applying the transformation to a passed vector.
282285
=#
283-
vector_red_to_cart(model::Model, rred) = model.lattice * rred
284-
vector_cart_to_red(model::Model, rcart) = model.inv_lattice * rcart
285-
covector_red_to_cart(model::Model, fred) = model.inv_lattice' * fred
286-
covector_cart_to_red(model::Model, fcart) = model.lattice' * fcart
287-
recip_vector_red_to_cart(model::Model, qred) = model.recip_lattice * qred
288-
recip_vector_cart_to_red(model::Model, qcart) = model.inv_recip_lattice * qcart
286+
@inline _gen_matmul(mat) = vec -> mat * vec
287+
288+
vector_red_to_cart(model::Model) = _gen_matmul(model.lattice)
289+
vector_cart_to_red(model::Model) = _gen_matmul(model.inv_lattice)
290+
covector_red_to_cart(model::Model) = _gen_matmul(model.inv_lattice')
291+
covector_cart_to_red(model::Model) = _gen_matmul(model.lattice')
292+
recip_vector_red_to_cart(model::Model) = _gen_matmul(model.recip_lattice)
293+
recip_vector_cart_to_red(model::Model) = _gen_matmul(model.inv_recip_lattice)
294+
295+
vector_red_to_cart(model::Model, vec) = vector_red_to_cart(model)(vec)
296+
vector_cart_to_red(model::Model, vec) = vector_cart_to_red(model)(vec)
297+
covector_red_to_cart(model::Model, vec) = covector_red_to_cart(model)(vec)
298+
covector_cart_to_red(model::Model, vec) = covector_cart_to_red(model)(vec)
299+
recip_vector_red_to_cart(model::Model, vec) = recip_vector_red_to_cart(model)(vec)
300+
recip_vector_cart_to_red(model::Model, vec) = recip_vector_cart_to_red(model)(vec)
289301

290302
#=
291303
Transformations on vectors and covectors are matrices and comatrices.
@@ -300,7 +312,14 @@ s_cart = L s_red = L A_red r_red = L A_red L⁻¹ r_cart, thus A_cart = L A_red
300312
Examples of matrices are the symmetries in real space (W)
301313
Examples of comatrices are the symmetries in reciprocal space (S)
302314
=#
303-
matrix_red_to_cart(model::Model, Ared) = model.lattice * Ared * model.inv_lattice
304-
matrix_cart_to_red(model::Model, Acart) = model.inv_lattice * Acart * model.lattice
305-
comatrix_red_to_cart(model::Model, Bred) = model.inv_lattice' * Bred * model.lattice'
306-
comatrix_cart_to_red(model::Model, Bcart) = model.lattice' * Bcart * model.inv_lattice'
315+
@inline _gen_matmatmul(M, Minv) = mat -> M * mat * Minv
316+
317+
matrix_red_to_cart(model::Model) = _gen_matmatmul(model.lattice, model.inv_lattice)
318+
matrix_cart_to_red(model::Model) = _gen_matmatmul(model.inv_lattice, model.lattice)
319+
comatrix_red_to_cart(model::Model) = _gen_matmatmul(model.inv_lattice', model.lattice')
320+
comatrix_cart_to_red(model::Model) = _gen_matmatmul(model.lattice', model.inv_lattice')
321+
322+
matrix_red_to_cart(model::Model, Ared) = matrix_red_to_cart(model)(Ared)
323+
matrix_cart_to_red(model::Model, Acart) = matrix_cart_to_red(model)(Acart)
324+
comatrix_red_to_cart(model::Model, Bred) = comatrix_red_to_cart(model)(Bred)
325+
comatrix_cart_to_red(model::Model, Bcart) = comatrix_cart_to_red(model)(Bcart)

0 commit comments

Comments
 (0)