Skip to content

Commit e506eb4

Browse files
Add shmem for GradientC2F (#2268)
1 parent 757bf01 commit e506eb4

File tree

5 files changed

+211
-19
lines changed

5 files changed

+211
-19
lines changed

ext/cuda/operators_fd_shmem.jl

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import ClimaCore: DataLayouts, Spaces, Geometry, RecursiveApply, DataLayouts
22
import CUDA
33
import ClimaCore.Operators: return_eltype, get_local_geometry
4+
import ClimaCore.Geometry:
45

56
Base.@propagate_inbounds function fd_operator_shmem(
67
space,
@@ -92,3 +93,103 @@ Base.@propagate_inbounds function fd_operator_evaluate(
9293
return (Ju³₊ Ju³₋) local_geometry.invJ
9394
end
9495
end
96+
97+
Base.@propagate_inbounds function fd_operator_shmem(
98+
space,
99+
::Val{Nvt},
100+
op::Operators.GradientC2F,
101+
args...,
102+
) where {Nvt}
103+
# allocate temp output
104+
RT = return_eltype(op, args...)
105+
u = CUDA.CuStaticSharedArray(RT, (Nvt,)) # cell centers
106+
lb = CUDA.CuStaticSharedArray(RT, (1,)) # left boundary
107+
rb = CUDA.CuStaticSharedArray(RT, (1,)) # right boundary
108+
return (u, lb, rb)
109+
end
110+
111+
Base.@propagate_inbounds function fd_operator_fill_shmem_interior!(
112+
op::Operators.GradientC2F,
113+
(u, lb, rb),
114+
loc, # can be any location
115+
space,
116+
idx::Integer,
117+
hidx,
118+
arg,
119+
)
120+
@inbounds begin
121+
vt = threadIdx().x
122+
cov3 = Geometry.Covariant3Vector(1)
123+
u[vt] = cov3 Operators.getidx(space, arg, loc, idx, hidx)
124+
end
125+
return nothing
126+
end
127+
128+
Base.@propagate_inbounds function fd_operator_fill_shmem_left_boundary!(
129+
op::Operators.GradientC2F,
130+
bc::Operators.SetValue,
131+
(u, lb, rb),
132+
loc,
133+
space,
134+
idx::Integer,
135+
hidx,
136+
arg,
137+
)
138+
idx == Operators.left_center_boundary_idx(space) ||
139+
error("Incorrect left idx")
140+
@inbounds begin
141+
vt = threadIdx().x
142+
cov3 = Geometry.Covariant3Vector(1)
143+
u[vt] = cov3 Operators.getidx(space, arg, loc, idx, hidx)
144+
lb[1] = cov3 Operators.getidx(space, bc.val, loc, nothing, hidx)
145+
end
146+
return nothing
147+
end
148+
149+
Base.@propagate_inbounds function fd_operator_fill_shmem_right_boundary!(
150+
op::Operators.GradientC2F,
151+
bc::Operators.SetValue,
152+
(u, lb, rb),
153+
loc,
154+
space,
155+
idx::Integer,
156+
hidx,
157+
arg,
158+
)
159+
# The right boundary is called at `idx + 1`, so we need to subtract 1 from idx (shmem is loaded at vt+1)
160+
idx == Operators.right_center_boundary_idx(space) ||
161+
error("Incorrect right idx")
162+
@inbounds begin
163+
vt = threadIdx().x
164+
cov3 = Geometry.Covariant3Vector(1)
165+
u[vt] = cov3 Operators.getidx(space, arg, loc, idx, hidx)
166+
rb[1] = cov3 Operators.getidx(space, bc.val, loc, nothing, hidx)
167+
end
168+
return nothing
169+
end
170+
171+
Base.@propagate_inbounds function fd_operator_evaluate(
172+
op::Operators.GradientC2F,
173+
(u, lb, rb),
174+
loc,
175+
space,
176+
idx::PlusHalf,
177+
hidx,
178+
args...,
179+
)
180+
@inbounds begin
181+
vt = threadIdx().x
182+
# @assert idx.i == vt-1 # assertion passes, but commented to remove potential thrown exception in llvm output
183+
if idx == Operators.right_face_boundary_idx(space)
184+
u₋ = 2 * u[vt - 1] # corresponds to idx - half
185+
u₊ = 2 * rb[1] # corresponds to idx + half
186+
elseif idx == Operators.left_face_boundary_idx(space)
187+
u₋ = 2 * lb[1] # corresponds to idx - half
188+
u₊ = 2 * u[vt] # corresponds to idx + half
189+
else
190+
u₋ = u[vt - 1] # corresponds to idx - half
191+
u₊ = u[vt] # corresponds to idx + half
192+
end
193+
return u₊ u₋
194+
end
195+
end

ext/cuda/operators_fd_shmem_common.jl

Lines changed: 67 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,11 @@ get_arg_space(
192192
args::Tuple,
193193
) = axes(args[1])
194194

195-
get_cent_idx(idx::Integer) = idx
196-
get_face_idx(idx::PlusHalf) = idx
197-
get_cent_idx(idx::PlusHalf) = idx + half
198-
get_face_idx(idx::Integer) = idx - half
195+
get_cent_idx(idx::Integer) = idx # center when traversing centers (trivial)
196+
get_face_idx(idx::PlusHalf) = idx # face when traversing faces (trivial)
197+
198+
get_cent_idx(idx::PlusHalf) = idx + half # center when traversing faces. Convention: use center right of face
199+
get_face_idx(idx::Integer) = idx - half # face when traversing centers. Convention: use face left of center
199200

200201
"""
201202
fd_resolve_shmem!(
@@ -301,21 +302,69 @@ Base.@propagate_inbounds function fd_resolve_shmem!(
301302
else # this else should never run
302303
end
303304
else # populate shmem on centers
304-
#! format: off
305-
# TODO: this needs exercised
306-
case1 = IP || get_cent_idx(bc_lw) ᶜidx get_cent_idx(bc_rw) + 1 # interior
307-
case2 = get_cent_idx(bc_li) < ᶜidx < get_cent_idx(bc_lw) && Operators.has_boundary(op, lloc) # left
308-
case3 = get_cent_idx(bc_rw) < ᶜidx < get_cent_idx(bc_ri) && Operators.has_boundary(op, rloc) # right
309-
case4 = get_cent_idx(bc_li) < ᶜidx < get_cent_idx(bc_lw) && !Operators.has_boundary(op, lloc) # left
310-
case5 = get_cent_idx(bc_rw) < ᶜidx < get_cent_idx(bc_ri) && !Operators.has_boundary(op, rloc) # right
311-
if case1; fd_operator_fill_shmem_interior!(sbc.op,sbc.work,iloc,space,ᶜidx,hidx,sbc.args...)
312-
elseif case2; fd_operator_fill_shmem_left_boundary!(sbc.op,Operators.get_boundary(op, lloc),sbc.work,lloc,space,ᶜidx,hidx,sbc.args...)
313-
elseif case3; fd_operator_fill_shmem_right_boundary!(sbc.op,Operators.get_boundary(op, rloc),sbc.work,rloc,space,ᶜidx,hidx,sbc.args...)
314-
elseif case4; fd_operator_fill_shmem_interior!(sbc.op,sbc.work,lloc,space,ᶜidx,hidx,sbc.args...)
315-
elseif case5; fd_operator_fill_shmem_interior!(sbc.op,sbc.work,rloc,space,ᶜidx,hidx,sbc.args...)
316-
else # this else should never run
305+
if IP || get_cent_idx(bc_lw) ᶜidx < get_cent_idx(bc_rw) # interior
306+
fd_operator_fill_shmem_interior!(
307+
sbc.op,
308+
sbc.work,
309+
iloc,
310+
space,
311+
ᶜidx,
312+
hidx,
313+
sbc.args...,
314+
)
315+
elseif get_cent_idx(bc_li) ᶜidx < get_cent_idx(bc_lw) &&
316+
Operators.has_boundary(op, lloc) # left
317+
fd_operator_fill_shmem_left_boundary!(
318+
sbc.op,
319+
Operators.get_boundary(op, lloc),
320+
sbc.work,
321+
lloc,
322+
space,
323+
ᶜidx,
324+
hidx,
325+
sbc.args...,
326+
)
327+
elseif get_cent_idx(bc_rw) ᶜidx < get_cent_idx(bc_ri) &&
328+
Operators.has_boundary(op, rloc) # right
329+
fd_operator_fill_shmem_right_boundary!(
330+
sbc.op,
331+
Operators.get_boundary(op, rloc),
332+
sbc.work,
333+
rloc,
334+
space,
335+
ᶜidx,
336+
hidx,
337+
sbc.args...,
338+
)
339+
elseif get_cent_idx(bc_li) < ᶜidx < get_cent_idx(bc_lw) &&
340+
!Operators.has_boundary(op, lloc) # left
341+
fd_operator_fill_shmem_interior!(
342+
sbc.op,
343+
sbc.work,
344+
lloc,
345+
space,
346+
ᶜidx,
347+
hidx,
348+
sbc.args...,
349+
)
350+
elseif get_cent_idx(bc_rw) < ᶜidx < get_cent_idx(bc_ri) &&
351+
!Operators.has_boundary(op, rloc) # right
352+
fd_operator_fill_shmem_interior!(
353+
sbc.op,
354+
sbc.work,
355+
rloc,
356+
space,
357+
ᶜidx,
358+
hidx,
359+
sbc.args...,
360+
)
361+
else # this should only ever be exercised at Spaces.nlevels(ᶜspace)+1
362+
# We don't have or need to fill shmem at `Spaces.nlevels
363+
# (ᶜspace)+1`, but threads may have this ᶜidx because they may be
364+
# filling shmem for an operator whose shmem exists on the face
365+
# space, which extends beyond the center space.
366+
# @assert ᶜidx == Spaces.nlevels(ᶜspace) + 1 # assertion passes, but commented to remove potential thrown exception in llvm output
317367
end
318-
#! format: on
319368
end
320369
return nothing
321370
end

ext/cuda/operators_fd_shmem_is_supported.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ end
149149
) = false
150150

151151
# Add cases here where shmem is supported:
152+
153+
##### DivergenceF2C
152154
@inline Operators.fd_shmem_is_supported(op::Operators.DivergenceF2C) =
153155
Operators.fd_shmem_is_supported(op, op.bcs)
154156
@inline Operators.fd_shmem_is_supported(
@@ -162,3 +164,18 @@ end
162164
all(values(bcs)) do bc
163165
all(supported_bc -> bc isa supported_bc, (Operators.SetValue,))
164166
end
167+
168+
##### GradientC2F
169+
@inline Operators.fd_shmem_is_supported(op::Operators.GradientC2F) =
170+
Operators.fd_shmem_is_supported(op, op.bcs)
171+
@inline Operators.fd_shmem_is_supported(
172+
op::Operators.GradientC2F,
173+
::@NamedTuple{},
174+
) = false
175+
@inline Operators.fd_shmem_is_supported(
176+
op::Operators.GradientC2F,
177+
bcs::NamedTuple,
178+
) =
179+
all(values(bcs)) do bc
180+
all(supported_bc -> bc isa supported_bc, (Operators.SetValue,))
181+
end

test/Operators/finitedifference/unit_fd_ops_shared_memory.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ end
3636
@test Operators.any_fd_shmem_supported(bc)
3737
@test !ext.any_fd_shmem_style(ext.disable_shmem_style(bc))
3838
@. c = div(Geometry.WVector(f))
39+
ᶠgrad = Operators.GradientC2F(;
40+
bottom = Operators.SetValue(FT(0)),
41+
top = Operators.SetValue(FT(0)),
42+
)
43+
bc = @. lazy(ᶠgrad(c))
44+
@test Operators.any_fd_shmem_supported(bc)
45+
@test Operators.fd_shmem_is_supported(bc)
3946
end
4047

4148
#! format: off
@@ -66,6 +73,7 @@ end
6673
@test compare_cpu_gpu(fields_cpu.ᶜout9, fields.ᶜout9); @test !is_trivial(fields_cpu.ᶜout9)
6774
@test compare_cpu_gpu(fields_cpu.ᶜout10, fields.ᶜout10); @test !is_trivial(fields_cpu.ᶜout10)
6875
@test compare_cpu_gpu(fields_cpu.ᶜout_uₕ, fields.ᶜout_uₕ); @test !is_trivial(fields_cpu.ᶜout_uₕ)
76+
@test compare_cpu_gpu(fields_cpu.ᶠout3_cov, fields.ᶠout3_cov); @test !is_trivial(fields_cpu.ᶠout3_cov)
6977
end
7078

7179
@testset "Correctness plane" begin

test/Operators/finitedifference/utils_fd_ops_shared_memory.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ function kernels!(fields)
4949
(; ᶜout1, ᶜout2, ᶜout3, ᶜout4, ᶜout5, ᶜout6, ᶜout7, ᶜout8, ᶜout9) = fields
5050
(; ᶜout10) = fields
5151
(; ᶠout1_contra, ᶠout2_contra) = fields
52+
(; ᶠout3_cov) = fields
5253
(; w_cov) = fields
5354
(; ᶜout_uₕ, ᶜuₕ) = fields
5455
FT = Spaces.undertype(axes(ϕ))
@@ -125,6 +126,12 @@ function kernels!(fields)
125126
div_uh = Operators.DivergenceF2C(outer)
126127
@. ᶜout_uₕ = div_uh(f * grad(ᶜuₕ))
127128

129+
ᶠgrad = Operators.GradientC2F(;
130+
bottom = Operators.SetValue(FT(10)),
131+
top = Operators.SetValue(FT(10)),
132+
)
133+
@. ᶠout3_cov = ᶠgrad(ϕ)
134+
128135
return nothing
129136
end;
130137

@@ -138,12 +145,22 @@ function get_fields(space::Operators.AllFaceFiniteDifferenceSpace)
138145
i -> Fields.Field(Geometry.Contravariant3Vector{FT}, space),
139146
length(K_contra),
140147
)
148+
K_cov_out = (ntuple(i -> Symbol("ᶠout$(i)_cov"), 8)...,)
149+
V_cov_out = ntuple(
150+
i -> Fields.zeros(Geometry.Covariant3Vector{FT}, space),
151+
length(K_cov_out),
152+
)
141153
K_cov = (:w_cov,)
142154
V_cov = ntuple(
143155
i -> Fields.Field(Geometry.Covariant3Vector{FT}, space),
144156
length(K_cov),
145157
)
146-
nt = (; zip(K, V)..., zip(K_contra, V_contra)..., zip(K_cov, V_cov)...)
158+
nt = (;
159+
zip(K, V)...,
160+
zip(K_contra, V_contra)...,
161+
zip(K_cov, V_cov)...,
162+
zip(K_cov_out, V_cov_out)...,
163+
)
147164
@. nt.f = sin(z)
148165
@. nt.w_cov.components.data.:1 = sin(z)
149166
return nt

0 commit comments

Comments
 (0)