Skip to content

Commit 0d6ba42

Browse files
charleskawczynskiCharlie Kawczynski
andauthored
Fix, simplify and generalize shmem for FD stencils (#2282)
Co-authored-by: Charlie Kawczynski <charliek@clima.gps.caltech.edu>
1 parent 456e3dc commit 0d6ba42

File tree

7 files changed

+598
-329
lines changed

7 files changed

+598
-329
lines changed

ext/cuda/operators_fd_shmem.jl

Lines changed: 102 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -12,85 +12,83 @@ Base.@propagate_inbounds function fd_operator_shmem(
1212
# allocate temp output
1313
RT = return_eltype(op, args...)
1414
Ju³ = CUDA.CuStaticSharedArray(RT, (Nvt,))
15-
return Ju³
15+
lJu³ = CUDA.CuStaticSharedArray(RT, (1,))
16+
rJu³ = CUDA.CuStaticSharedArray(RT, (1,))
17+
return (Ju³, lJu³, rJu³)
1618
end
1719

18-
Base.@propagate_inbounds function fd_operator_fill_shmem_interior!(
20+
Base.@propagate_inbounds function fd_operator_fill_shmem!(
1921
op::Operators.DivergenceF2C,
20-
Ju³,
21-
loc, # can be any location
22-
space,
23-
idx::Utilities.PlusHalf,
24-
hidx,
25-
arg,
26-
)
27-
@inbounds begin
28-
vt = threadIdx().x
29-
lg = Geometry.LocalGeometry(space, idx, hidx)
30-
= Operators.getidx(space, arg, loc, idx, hidx)
31-
Ju³[vt] = Geometry.Jcontravariant3(u³, lg)
32-
end
33-
return nothing
34-
end
35-
36-
Base.@propagate_inbounds function fd_operator_fill_shmem_left_boundary!(
37-
op::Operators.DivergenceF2C,
38-
bc::Operators.SetValue,
39-
Ju³,
22+
(Ju³, lJu³, rJu³),
4023
loc,
24+
bc_bds,
25+
arg_space,
4126
space,
4227
idx::Utilities.PlusHalf,
4328
hidx,
4429
arg,
4530
)
46-
idx == Operators.left_face_boundary_idx(space) ||
47-
error("Incorrect left idx")
4831
@inbounds begin
4932
vt = threadIdx().x
5033
lg = Geometry.LocalGeometry(space, idx, hidx)
51-
= Operators.getidx(space, bc.val, loc, nothing, hidx)
52-
Ju³[vt] = Geometry.Jcontravariant3(u³, lg)
53-
end
54-
return nothing
55-
end
56-
57-
Base.@propagate_inbounds function fd_operator_fill_shmem_right_boundary!(
58-
op::Operators.DivergenceF2C,
59-
bc::Operators.SetValue,
60-
Ju³,
61-
loc,
62-
space,
63-
idx::Utilities.PlusHalf,
64-
hidx,
65-
arg,
66-
)
67-
# The right boundary is called at `idx + 1`, so we need to subtract 1 from idx (shmem is loaded at vt+1)
68-
idx == Operators.right_face_boundary_idx(space) ||
69-
error("Incorrect right idx")
70-
@inbounds begin
71-
vt = threadIdx().x
72-
lg = Geometry.LocalGeometry(space, idx, hidx)
73-
= Operators.getidx(space, bc.val, loc, nothing, hidx)
74-
Ju³[vt] = Geometry.Jcontravariant3(u³, lg)
34+
if !on_boundary(space, op, loc, idx)
35+
= Operators.getidx(space, arg, loc, idx, hidx)
36+
Ju³[vt] = Geometry.Jcontravariant3(u³, lg)
37+
else
38+
bc = Operators.get_boundary(op, loc)
39+
ub = Operators.getidx(space, bc.val, loc, nothing, hidx)
40+
bJu³ = on_left_boundary(idx, space) ? lJu³ : rJu³
41+
if bc isa Operators.SetValue
42+
bJu³[1] = Geometry.Jcontravariant3(ub, lg)
43+
elseif bc isa Operators.SetDivergence
44+
bJu³[1] = ub
45+
elseif bc isa Operators.Extrapolate # no shmem needed
46+
end
47+
end
7548
end
7649
return nothing
7750
end
7851

7952
Base.@propagate_inbounds function fd_operator_evaluate(
8053
op::Operators.DivergenceF2C,
81-
Ju³,
54+
(Ju³, lJu³, rJu³),
8255
loc,
8356
space,
8457
idx::Integer,
8558
hidx,
86-
args...,
59+
arg,
8760
)
8861
@inbounds begin
8962
vt = threadIdx().x
90-
local_geometry = Geometry.LocalGeometry(space, idx, hidx)
91-
Ju³₋ = Ju³[vt] # corresponds to idx - half
92-
Ju³₊ = Ju³[vt + 1] # corresponds to idx + half
93-
return (Ju³₊ Ju³₋) local_geometry.invJ
63+
lg = Geometry.LocalGeometry(space, idx, hidx)
64+
if !on_boundary(space, op, loc, idx)
65+
Ju³₋ = Ju³[vt] # corresponds to idx - half
66+
Ju³₊ = Ju³[vt + 1] # corresponds to idx + half
67+
return (Ju³₊ Ju³₋) lg.invJ
68+
else
69+
bc = Operators.get_boundary(op, loc)
70+
@assert bc isa Operators.SetValue || bc isa Operators.SetDivergence
71+
if on_left_boundary(idx, space)
72+
if bc isa Operators.SetValue
73+
Ju³₋ = lJu³[1] # corresponds to idx - half
74+
Ju³₊ = Ju³[vt + 1] # corresponds to idx + half
75+
return (Ju³₊ Ju³₋) lg.invJ
76+
else
77+
# @assert bc isa Operators.SetDivergence
78+
return lJu³[1]
79+
end
80+
else
81+
@assert on_right_boundary(idx, space)
82+
if bc isa Operators.SetValue
83+
Ju³₋ = Ju³[vt] # corresponds to idx - half
84+
Ju³₊ = rJu³[1] # corresponds to idx + half
85+
return (Ju³₊ Ju³₋) lg.invJ
86+
else
87+
@assert bc isa Operators.SetDivergence
88+
return rJu³[1]
89+
end
90+
end
91+
end
9492
end
9593
end
9694

@@ -108,10 +106,12 @@ Base.@propagate_inbounds function fd_operator_shmem(
108106
return (u, lb, rb)
109107
end
110108

111-
Base.@propagate_inbounds function fd_operator_fill_shmem_interior!(
109+
Base.@propagate_inbounds function fd_operator_fill_shmem!(
112110
op::Operators.GradientC2F,
113111
(u, lb, rb),
114112
loc, # can be any location
113+
bc_bds,
114+
arg_space,
115115
space,
116116
idx::Integer,
117117
hidx,
@@ -120,50 +120,33 @@ Base.@propagate_inbounds function fd_operator_fill_shmem_interior!(
120120
@inbounds begin
121121
vt = threadIdx().x
122122
cov3 = Geometry.Covariant3Vector(1)
123-
u[vt] = cov3 Operators.getidx(space, arg, loc, idx, hidx)
124-
end
125-
return nothing
126-
end
127-
128-
Base.@propagate_inbounds function fd_operator_fill_shmem_left_boundary!(
129-
op::Operators.GradientC2F,
130-
bc::Operators.SetValue,
131-
(u, lb, rb),
132-
loc,
133-
space,
134-
idx::Integer,
135-
hidx,
136-
arg,
137-
)
138-
idx == Operators.left_center_boundary_idx(space) ||
139-
error("Incorrect left idx")
140-
@inbounds begin
141-
vt = threadIdx().x
142-
cov3 = Geometry.Covariant3Vector(1)
143-
u[vt] = cov3 Operators.getidx(space, arg, loc, idx, hidx)
144-
lb[1] = cov3 Operators.getidx(space, bc.val, loc, nothing, hidx)
145-
end
146-
return nothing
147-
end
148-
149-
Base.@propagate_inbounds function fd_operator_fill_shmem_right_boundary!(
150-
op::Operators.GradientC2F,
151-
bc::Operators.SetValue,
152-
(u, lb, rb),
153-
loc,
154-
space,
155-
idx::Integer,
156-
hidx,
157-
arg,
158-
)
159-
# The right boundary is called at `idx + 1`, so we need to subtract 1 from idx (shmem is loaded at vt+1)
160-
idx == Operators.right_center_boundary_idx(space) ||
161-
error("Incorrect right idx")
162-
@inbounds begin
163-
vt = threadIdx().x
164-
cov3 = Geometry.Covariant3Vector(1)
165-
u[vt] = cov3 Operators.getidx(space, arg, loc, idx, hidx)
166-
rb[1] = cov3 Operators.getidx(space, bc.val, loc, nothing, hidx)
123+
if in_domain(idx, arg_space)
124+
u[vt] = cov3 Operators.getidx(space, arg, loc, idx, hidx)
125+
else # idx can be Spaces.nlevels(ᶜspace)+1 because threads must extend to faces
126+
ᶜspace = Spaces.center_space(arg_space)
127+
@assert idx == Spaces.nlevels(ᶜspace) + 1
128+
end
129+
if on_any_boundary(idx, space, op)
130+
lloc =
131+
Operators.LeftBoundaryWindow{Spaces.left_boundary_name(space)}()
132+
rloc = Operators.RightBoundaryWindow{
133+
Spaces.right_boundary_name(space),
134+
}()
135+
bloc = on_left_boundary(idx, space, op) ? lloc : rloc
136+
@assert bloc isa typeof(lloc) && on_left_boundary(idx, space, op) ||
137+
bloc isa typeof(rloc) && on_right_boundary(idx, space, op)
138+
bc = Operators.get_boundary(op, bloc)
139+
@assert bc isa Operators.SetValue || bc isa Operators.SetGradient
140+
ub = Operators.getidx(space, bc.val, bloc, nothing, hidx)
141+
bu = on_left_boundary(idx, space) ? lb : rb
142+
if bc isa Operators.SetValue
143+
bu[1] = cov3 ub
144+
elseif bc isa Operators.SetGradient
145+
lg = Geometry.LocalGeometry(space, idx, hidx)
146+
bu[1] = Geometry.project(Geometry.Covariant3Axis(), ub, lg)
147+
elseif bc isa Operators.Extrapolate # no shmem needed
148+
end
149+
end
167150
end
168151
return nothing
169152
end
@@ -179,17 +162,28 @@ Base.@propagate_inbounds function fd_operator_evaluate(
179162
)
180163
@inbounds begin
181164
vt = threadIdx().x
182-
# @assert idx.i == vt-1 # assertion passes, but commented to remove potential thrown exception in llvm output
183-
if idx == Operators.right_face_boundary_idx(space)
184-
u₋ = 2 * u[vt - 1] # corresponds to idx - half
185-
u₊ = 2 * rb[1] # corresponds to idx + half
186-
elseif idx == Operators.left_face_boundary_idx(space)
187-
u₋ = 2 * lb[1] # corresponds to idx - half
188-
u₊ = 2 * u[vt] # corresponds to idx + half
189-
else
165+
lg = Geometry.LocalGeometry(space, idx, hidx)
166+
if !on_boundary(space, op, loc, idx)
190167
u₋ = u[vt - 1] # corresponds to idx - half
191168
u₊ = u[vt] # corresponds to idx + half
169+
return u₊ u₋
170+
else
171+
bc = Operators.get_boundary(op, loc)
172+
@assert bc isa Operators.SetValue
173+
if on_left_boundary(idx, space)
174+
if bc isa Operators.SetValue
175+
u₋ = 2 * lb[1] # corresponds to idx - half
176+
u₊ = 2 * u[vt] # corresponds to idx + half
177+
return u₊ u₋
178+
end
179+
else
180+
@assert on_right_boundary(idx, space)
181+
if bc isa Operators.SetValue
182+
u₋ = 2 * u[vt - 1] # corresponds to idx - half
183+
u₊ = 2 * rb[1] # corresponds to idx + half
184+
return u₊ u₋
185+
end
186+
end
192187
end
193-
return u₊ u₋
194188
end
195189
end

0 commit comments

Comments
 (0)