Skip to content

Commit f9d3f53

Browse files
Make CUDA an extension
Apply formatter
1 parent 3faa8da commit f9d3f53

38 files changed

+1671
-1551
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ version = "0.13.4"
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
99
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
10-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1110
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
1211
CubedSphere = "7445602f-e544-4518-8976-18f8e8ae6cdb"
1312
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
@@ -32,9 +31,11 @@ Unrolled = "9602ed7d-8fef-5bc8-8597-8f21381861e8"
3231

3332
[weakdeps]
3433
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
34+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3535

3636
[extensions]
3737
KrylovExt = "Krylov"
38+
ClimaCoreCUDAExt = "CUDA"
3839

3940
[compat]
4041
Adapt = "3, 4"

ext/ClimaCoreCUDAExt.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
module ClimaCoreCUDAExt
2+
3+
import ClimaComms
4+
import ClimaCore: DataLayouts, Grids, Spaces, Fields
5+
import CUDA
6+
7+
include(joinpath("cuda", "data_layouts.jl"))
8+
include(joinpath("cuda", "fields.jl"))
9+
include(joinpath("cuda", "topologies_dss.jl"))
10+
include(joinpath("cuda", "operators_finite_difference.jl"))
11+
include(joinpath("cuda", "remapping_distributed.jl"))
12+
include(joinpath("cuda", "operators_integral.jl"))
13+
include(joinpath("cuda", "remapping_interpolate_array.jl"))
14+
include(joinpath("cuda", "limiters.jl"))
15+
include(joinpath("cuda", "operators_sem_shmem.jl"))
16+
include(joinpath("cuda", "operators_thomas_algorithm.jl"))
17+
include(joinpath("cuda", "matrix_fields_multiple_field_solve.jl"))
18+
include(joinpath("cuda", "operators_spectral_element.jl"))
19+
include(joinpath("cuda", "matrix_fields_single_field_solve.jl"))
20+
21+
end

src/DataLayouts/cuda.jl renamed to ext/cuda/data_layouts.jl

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,11 @@
1+
2+
import ClimaCore.DataLayouts: IJKFVH, IJFH, VIJFH, VIFH, IFH, IJF, IF, VF, DataF
3+
import ClimaCore.DataLayouts: IJFHStyle, VIJFHStyle, VFStyle, DataFStyle
4+
import ClimaCore.DataLayouts: promote_parent_array_type
5+
import ClimaCore.DataLayouts: parent_array_type
16
import Adapt
27
import CUDA
38

4-
Adapt.adapt_structure(to, data::IJKFVH{S, Nij, Nk}) where {S, Nij, Nk} =
5-
IJKFVH{S, Nij, Nk}(Adapt.adapt(to, parent(data)))
6-
7-
Adapt.adapt_structure(to, data::IJFH{S, Nij}) where {S, Nij} =
8-
IJFH{S, Nij}(Adapt.adapt(to, parent(data)))
9-
10-
Adapt.adapt_structure(to, data::VIJFH{S, Nij}) where {S, Nij} =
11-
VIJFH{S, Nij}(Adapt.adapt(to, parent(data)))
12-
13-
Adapt.adapt_structure(to, data::VIFH{S, Ni, A}) where {S, Ni, A} =
14-
VIFH{S, Ni}(Adapt.adapt(to, parent(data)))
15-
16-
Adapt.adapt_structure(to, data::IFH{S, Ni}) where {S, Ni} =
17-
IFH{S, Ni}(Adapt.adapt(to, parent(data)))
18-
19-
Adapt.adapt_structure(to, data::IJF{S, Nij}) where {S, Nij} =
20-
IJF{S, Nij}(Adapt.adapt(to, parent(data)))
21-
22-
Adapt.adapt_structure(to, data::IF{S, Ni}) where {S, Ni} =
23-
IF{S, Ni}(Adapt.adapt(to, parent(data)))
24-
25-
Adapt.adapt_structure(to, data::VF{S}) where {S} =
26-
VF{S}(Adapt.adapt(to, parent(data)))
27-
28-
Adapt.adapt_structure(to, data::DataF{S}) where {S} =
29-
DataF{S}(Adapt.adapt(to, parent(data)))
30-
319
parent_array_type(::Type{<:CUDA.CuArray{T, N, B} where {N}}) where {T, B} =
3210
CUDA.CuArray{T, N, B} where {N}
3311

src/Fields/mapreduce_cuda.jl renamed to ext/cuda/fields.jl

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
1+
import ClimaComms
2+
using CUDA: @cuda
3+
import LinearAlgebra, Statistics
4+
import ClimaCore: DataLayouts, Spaces, Grids, Fields
5+
import ClimaCore.Fields: Field, FieldStyle
6+
import ClimaCore.Fields: AbstractFieldStyle
7+
import ClimaCore.Spaces: cuda_synchronize
28
function Base.sum(
39
field::Union{Field, Base.Broadcast.Broadcasted{<:FieldStyle}},
410
::ClimaComms.CUDADevice,
@@ -280,3 +286,29 @@ end
280286
newsize = _cuda_reduce!(op, reduction, tidx, newsize, 1)
281287
return nothing
282288
end
289+
290+
291+
function Adapt.adapt_structure(
292+
to::CUDA.KernelAdaptor,
293+
bc::Base.Broadcast.Broadcasted{Style},
294+
) where {Style <: AbstractFieldStyle}
295+
Base.Broadcast.Broadcasted{Style}(
296+
Adapt.adapt(to, bc.f),
297+
Adapt.adapt(to, bc.args),
298+
Adapt.adapt(to, bc.axes),
299+
)
300+
end
301+
302+
function Adapt.adapt_structure(
303+
to::CUDA.KernelAdaptor,
304+
bc::Base.Broadcast.Broadcasted{Style, <:Any, Type{T}},
305+
) where {Style <: AbstractFieldStyle, T}
306+
Base.Broadcast.Broadcasted{Style}(
307+
(x...) -> T(x...),
308+
Adapt.adapt(to, bc.args),
309+
bc.axes,
310+
)
311+
end
312+
313+
cuda_synchronize(device::ClimaComms.CUDADevice; kwargs...) =
314+
CUDA.synchronize(; kwargs...)

ext/cuda/limiters.jl

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
import ClimaCore.Limiters: QuasiMonotoneLimiter
2+
import ClimaCore.Fields
3+
4+
function config_threadblock(Nv, Nh)
5+
nitems = Nv * Nh
6+
nthreads = min(256, nitems)
7+
nblocks = cld(nitems, nthreads)
8+
return (nthreads, nblocks)
9+
end
10+
11+
function get_hv(Nv, Nh, blockIdx, threadIdx, blockDim, gridDim)
12+
tidx = (blockIdx.x - 1) * blockDim.x + threadIdx.x
13+
(h, v) = CartesianIndices((1:Nh, 1:Nv))[tidx].I
14+
# @cuprintln("Nv,Nh,v,h=($Nv, $Nh,$v,$h)")
15+
return (h, v)
16+
end
17+
18+
function compute_element_bounds!(
19+
limiter::QuasiMonotoneLimiter,
20+
ρq,
21+
ρ,
22+
::ClimaComms.CUDADevice,
23+
)
24+
S = size(Fields.field_values(ρ))
25+
(Ni, Nj, _, Nv, Nh) = S
26+
nthreads, nblocks = config_threadblock(Nv, Nh)
27+
28+
CUDA.@cuda always_inline = true threads = nthreads blocks = nblocks compute_element_bounds_kernel!(
29+
limiter,
30+
Fields.field_values(Operators.strip_space(ρq, axes(ρq))),
31+
Fields.field_values(Operators.strip_space(ρ, axes(ρ))),
32+
Nv,
33+
Nh,
34+
Val(Ni),
35+
Val(Nj),
36+
)
37+
return nothing
38+
end
39+
40+
41+
function compute_element_bounds_kernel!(
42+
limiter,
43+
ρq,
44+
ρ,
45+
Nv,
46+
Nh,
47+
::Val{Ni},
48+
::Val{Nj},
49+
) where {Ni, Nj}
50+
(h, v) = get_hv(Nv, Nh, blockIdx(), threadIdx(), blockDim(), gridDim())
51+
if h Nh && v Nv
52+
(; q_bounds) = limiter
53+
local q_min, q_max
54+
slab_ρq = slab(ρq, v, h)
55+
slab_ρ = slab(ρ, v, h)
56+
for j in 1:Nj
57+
for i in 1:Ni
58+
q = rdiv(slab_ρq[i, j], slab_ρ[i, j])
59+
if i == 1 && j == 1
60+
q_min = q
61+
q_max = q
62+
else
63+
q_min = rmin(q_min, q)
64+
q_max = rmax(q_max, q)
65+
end
66+
end
67+
end
68+
slab_q_bounds = slab(q_bounds, v, h)
69+
slab_q_bounds[1] = q_min
70+
slab_q_bounds[2] = q_max
71+
end
72+
return nothing
73+
end
74+
75+
76+
function compute_neighbor_bounds_local!(
77+
limiter::QuasiMonotoneLimiter,
78+
ρ,
79+
::ClimaComms.CUDADevice,
80+
)
81+
topology = Spaces.topology(axes(ρ))
82+
Ni, Nj, _, Nv, Nh = size(Fields.field_values(ρ))
83+
nthreads, nblocks = config_threadblock(Nv, Nh)
84+
CUDA.@cuda always_inline = true threads = nthreads blocks = nblocks compute_neighbor_bounds_local_kernel!(
85+
limiter,
86+
topology.local_neighbor_elem,
87+
topology.local_neighbor_elem_offset,
88+
Nv,
89+
Nh,
90+
Val(Ni),
91+
Val(Nj),
92+
)
93+
end
94+
95+
function compute_neighbor_bounds_local_kernel!(
96+
limiter,
97+
local_neighbor_elem,
98+
local_neighbor_elem_offset,
99+
Nv,
100+
Nh,
101+
::Val{Ni},
102+
::Val{Nj},
103+
) where {Ni, Nj}
104+
105+
(h, v) = get_hv(Nv, Nh, blockIdx(), threadIdx(), blockDim(), gridDim())
106+
if h Nh && v Nv
107+
(; q_bounds, q_bounds_nbr, ghost_buffer, rtol) = limiter
108+
slab_q_bounds = slab(q_bounds, v, h)
109+
q_min = slab_q_bounds[1]
110+
q_max = slab_q_bounds[2]
111+
for lne in
112+
local_neighbor_elem_offset[h]:(local_neighbor_elem_offset[h + 1] - 1)
113+
h_nbr = local_neighbor_elem[lne]
114+
slab_q_bounds = slab(q_bounds, v, h_nbr)
115+
q_min = rmin(q_min, slab_q_bounds[1])
116+
q_max = rmax(q_max, slab_q_bounds[2])
117+
end
118+
slab_q_bounds_nbr = slab(q_bounds_nbr, v, h)
119+
slab_q_bounds_nbr[1] = q_min
120+
slab_q_bounds_nbr[2] = q_max
121+
end
122+
return nothing
123+
end
124+
125+
function apply_limiter!(
126+
ρq::Fields.Field,
127+
ρ::Fields.Field,
128+
limiter::QuasiMonotoneLimiter,
129+
::ClimaComms.CUDADevice,
130+
)
131+
ρq_data = Fields.field_values(ρq)
132+
(Ni, Nj, _, Nv, Nh) = size(ρq_data)
133+
Nf = DataLayouts.ncomponents(ρq_data)
134+
maxiter = Ni * Nj
135+
WJ = Spaces.local_geometry_data(axes(ρq)).WJ
136+
nthreads, nblocks = config_threadblock(Nv, Nh)
137+
CUDA.@cuda always_inline = true threads = nthreads blocks = nblocks apply_limiter_kernel!(
138+
limiter,
139+
Fields.field_values(Operators.strip_space(ρq, axes(ρq))),
140+
Fields.field_values(Operators.strip_space(ρ, axes(ρ))),
141+
WJ,
142+
Nv,
143+
Nh,
144+
Val(Nf),
145+
Val(Ni),
146+
Val(Nj),
147+
Val(maxiter),
148+
)
149+
return nothing
150+
end
151+
152+
function apply_limiter_kernel!(
153+
limiter::QuasiMonotoneLimiter,
154+
ρq_data,
155+
ρ_data,
156+
WJ_data,
157+
Nv,
158+
Nh,
159+
::Val{Nf},
160+
::Val{Ni},
161+
::Val{Nj},
162+
::Val{maxiter},
163+
) where {Nf, Ni, Nj, maxiter}
164+
(; q_bounds_nbr, rtol) = limiter
165+
converged = true
166+
(h, v) = get_hv(Nv, Nh, blockIdx(), threadIdx(), blockDim(), gridDim())
167+
if h Nh && v Nv
168+
169+
slab_ρ = slab(ρ_data, v, h)
170+
slab_ρq = slab(ρq_data, v, h)
171+
slab_WJ = slab(WJ_data, v, h)
172+
slab_q_bounds = slab(q_bounds_nbr, v, h)
173+
174+
array_ρq = parent(slab_ρq)
175+
array_ρ = parent(slab_ρ)
176+
array_w = parent(slab_WJ)
177+
array_q_bounds = parent(slab_q_bounds)
178+
179+
# 1) compute ∫ρ
180+
total_mass = zero(eltype(array_ρ))
181+
for j in 1:Nj, i in 1:Ni
182+
total_mass += array_ρ[i, j, 1] * array_w[i, j, 1]
183+
end
184+
185+
@assert total_mass > 0
186+
187+
converged = true
188+
for f in 1:Nf
189+
q_min = array_q_bounds[1, f]
190+
q_max = array_q_bounds[2, f]
191+
192+
# 2) compute ∫ρq
193+
tracer_mass = zero(eltype(array_ρq))
194+
for j in 1:Nj, i in 1:Ni
195+
tracer_mass += array_ρq[i, j, f] * array_w[i, j, 1]
196+
end
197+
198+
# TODO: Should this condition be enforced? (It isn't in HOMME.)
199+
# @assert tracer_mass >= 0
200+
201+
# 3) set bounds
202+
q_avg = tracer_mass / total_mass
203+
q_min = min(q_min, q_avg)
204+
q_max = max(q_max, q_avg)
205+
206+
# 3) modify ρq
207+
for iter in 1:maxiter
208+
Δtracer_mass = zero(eltype(array_ρq))
209+
for j in 1:Nj, i in 1:Ni
210+
ρ = array_ρ[i, j, 1]
211+
ρq = array_ρq[i, j, f]
212+
ρq_max = ρ * q_max
213+
ρq_min = ρ * q_min
214+
w = array_w[i, j]
215+
if ρq > ρq_max
216+
Δtracer_mass += (ρq - ρq_max) * w
217+
array_ρq[i, j, f] = ρq_max
218+
elseif ρq < ρq_min
219+
Δtracer_mass += (ρq - ρq_min) * w
220+
array_ρq[i, j, f] = ρq_min
221+
end
222+
end
223+
224+
if abs(Δtracer_mass) <= rtol * abs(tracer_mass)
225+
break
226+
end
227+
228+
if Δtracer_mass > 0 # add mass
229+
total_mass_at_Δ_points = zero(eltype(array_ρ))
230+
for j in 1:Nj, i in 1:Ni
231+
ρ = array_ρ[i, j, 1]
232+
ρq = array_ρq[i, j, f]
233+
w = array_w[i, j]
234+
if ρq < ρ * q_max
235+
total_mass_at_Δ_points += ρ * w
236+
end
237+
end
238+
Δq_at_Δ_points = Δtracer_mass / total_mass_at_Δ_points
239+
for j in 1:Nj, i in 1:Ni
240+
ρ = array_ρ[i, j, 1]
241+
ρq = array_ρq[i, j, f]
242+
if ρq < ρ * q_max
243+
array_ρq[i, j, f] += ρ * Δq_at_Δ_points
244+
end
245+
end
246+
else # remove mass
247+
total_mass_at_Δ_points = zero(eltype(array_ρ))
248+
for j in 1:Nj, i in 1:Ni
249+
ρ = array_ρ[i, j, 1]
250+
ρq = array_ρq[i, j, f]
251+
w = array_w[i, j]
252+
if ρq > ρ * q_min
253+
total_mass_at_Δ_points += ρ * w
254+
end
255+
end
256+
Δq_at_Δ_points = Δtracer_mass / total_mass_at_Δ_points
257+
for j in 1:Nj, i in 1:Ni
258+
ρ = array_ρ[i, j, 1]
259+
ρq = array_ρq[i, j, f]
260+
if ρq > ρ * q_min
261+
array_ρq[i, j, f] += ρ * Δq_at_Δ_points
262+
end
263+
end
264+
end
265+
266+
if iter == maxiter
267+
converged = false
268+
end
269+
end
270+
end
271+
272+
end
273+
# converged || @warn "Limiter failed to converge with rtol = $rtol"
274+
275+
return nothing
276+
end

0 commit comments

Comments
 (0)