Skip to content

Commit 009a497

Browse files
Compute stencil location on the fly (#2284)
Make non-breaking with promote_bcs
1 parent 8fb42cf commit 009a497

File tree

11 files changed

+492
-706
lines changed

11 files changed

+492
-706
lines changed

examples/column/hydrostatic_ekman.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ function tendency!(dY, Y, _, t)
133133
# TODO!: Undesirable casting to vector required
134134
@. dρθ = -∂c(w * If(ρθ)) + ρ * ∂c* ∂f(ρθ / ρ))
135135

136-
uv_1 = Operators.getidx(axes(uv), uv, Operators.Interior(), 1)
136+
uv_1 = Operators.getidx(axes(uv), uv, 1)
137137
u_wind = LinearAlgebra.norm(uv_1)
138138

139139
A = Operators.AdvectionC2C(

ext/cuda/operators_fd_shmem.jl

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ end
2020
Base.@propagate_inbounds function fd_operator_fill_shmem!(
2121
op::Operators.DivergenceF2C,
2222
(Ju³, lJu³, rJu³),
23-
loc,
2423
bc_bds,
2524
arg_space,
2625
space,
@@ -31,12 +30,24 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
3130
@inbounds begin
3231
vt = threadIdx().x
3332
lg = Geometry.LocalGeometry(space, idx, hidx)
34-
if !on_boundary(space, op, loc, idx)
35-
= Operators.getidx(space, arg, loc, idx, hidx)
33+
if !on_boundary(idx, space, op)
34+
= Operators.getidx(space, arg, idx, hidx)
3635
Ju³[vt] = Geometry.Jcontravariant3(u³, lg)
37-
else
38-
bc = Operators.get_boundary(op, loc)
39-
ub = Operators.getidx(space, bc.val, loc, nothing, hidx)
36+
elseif on_left_boundary(idx, space, op)
37+
bloc = Operators.left_boundary_window(space)
38+
bc = Operators.get_boundary(op, bloc)
39+
ub = Operators.getidx(space, bc.val, nothing, hidx)
40+
bJu³ = on_left_boundary(idx, space) ? lJu³ : rJu³
41+
if bc isa Operators.SetValue
42+
bJu³[1] = Geometry.Jcontravariant3(ub, lg)
43+
elseif bc isa Operators.SetDivergence
44+
bJu³[1] = ub
45+
elseif bc isa Operators.Extrapolate # no shmem needed
46+
end
47+
elseif on_right_boundary(idx, space, op)
48+
bloc = Operators.right_boundary_window(space)
49+
bc = Operators.get_boundary(op, bloc)
50+
ub = Operators.getidx(space, bc.val, nothing, hidx)
4051
bJu³ = on_left_boundary(idx, space) ? lJu³ : rJu³
4152
if bc isa Operators.SetValue
4253
bJu³[1] = Geometry.Jcontravariant3(ub, lg)
@@ -52,7 +63,6 @@ end
5263
Base.@propagate_inbounds function fd_operator_evaluate(
5364
op::Operators.DivergenceF2C,
5465
(Ju³, lJu³, rJu³),
55-
loc,
5666
space,
5767
idx::Integer,
5868
hidx,
@@ -61,12 +71,16 @@ Base.@propagate_inbounds function fd_operator_evaluate(
6171
@inbounds begin
6272
vt = threadIdx().x
6373
lg = Geometry.LocalGeometry(space, idx, hidx)
64-
if !on_boundary(space, op, loc, idx)
74+
if !on_boundary(idx, space, op)
6575
Ju³₋ = Ju³[vt] # corresponds to idx - half
6676
Ju³₊ = Ju³[vt + 1] # corresponds to idx + half
6777
return (Ju³₊ Ju³₋) lg.invJ
6878
else
69-
bc = Operators.get_boundary(op, loc)
79+
bloc =
80+
on_left_boundary(idx, space, op) ?
81+
Operators.left_boundary_window(space) :
82+
Operators.right_boundary_window(space)
83+
bc = Operators.get_boundary(op, bloc)
7084
@assert bc isa Operators.SetValue || bc isa Operators.SetDivergence
7185
if on_left_boundary(idx, space)
7286
if bc isa Operators.SetValue
@@ -109,7 +123,6 @@ end
109123
Base.@propagate_inbounds function fd_operator_fill_shmem!(
110124
op::Operators.GradientC2F,
111125
(u, lb, rb),
112-
loc, # can be any location
113126
bc_bds,
114127
arg_space,
115128
space,
@@ -121,23 +134,20 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
121134
vt = threadIdx().x
122135
cov3 = Geometry.Covariant3Vector(1)
123136
if in_domain(idx, arg_space)
124-
u[vt] = cov3 Operators.getidx(space, arg, loc, idx, hidx)
137+
u[vt] = cov3 Operators.getidx(space, arg, idx, hidx)
125138
else # idx can be Spaces.nlevels(ᶜspace)+1 because threads must extend to faces
126139
ᶜspace = Spaces.center_space(arg_space)
127140
@assert idx == Spaces.nlevels(ᶜspace) + 1
128141
end
129142
if on_any_boundary(idx, space, op)
130-
lloc =
131-
Operators.LeftBoundaryWindow{Spaces.left_boundary_name(space)}()
132-
rloc = Operators.RightBoundaryWindow{
133-
Spaces.right_boundary_name(space),
134-
}()
143+
lloc = Operators.left_boundary_window(space)
144+
rloc = Operators.right_boundary_window(space)
135145
bloc = on_left_boundary(idx, space, op) ? lloc : rloc
136146
@assert bloc isa typeof(lloc) && on_left_boundary(idx, space, op) ||
137147
bloc isa typeof(rloc) && on_right_boundary(idx, space, op)
138148
bc = Operators.get_boundary(op, bloc)
139149
@assert bc isa Operators.SetValue || bc isa Operators.SetGradient
140-
ub = Operators.getidx(space, bc.val, bloc, nothing, hidx)
150+
ub = Operators.getidx(space, bc.val, nothing, hidx)
141151
bu = on_left_boundary(idx, space) ? lb : rb
142152
if bc isa Operators.SetValue
143153
bu[1] = cov3 ub
@@ -154,7 +164,6 @@ end
154164
Base.@propagate_inbounds function fd_operator_evaluate(
155165
op::Operators.GradientC2F,
156166
(u, lb, rb),
157-
loc,
158167
space,
159168
idx::PlusHalf,
160169
hidx,
@@ -163,12 +172,16 @@ Base.@propagate_inbounds function fd_operator_evaluate(
163172
@inbounds begin
164173
vt = threadIdx().x
165174
lg = Geometry.LocalGeometry(space, idx, hidx)
166-
if !on_boundary(space, op, loc, idx)
175+
if !on_boundary(idx, space, op)
167176
u₋ = u[vt - 1] # corresponds to idx - half
168177
u₊ = u[vt] # corresponds to idx + half
169178
return u₊ u₋
170179
else
171-
bc = Operators.get_boundary(op, loc)
180+
bloc =
181+
on_left_boundary(idx, space, op) ?
182+
Operators.left_boundary_window(space) :
183+
Operators.right_boundary_window(space)
184+
bc = Operators.get_boundary(op, bloc)
172185
@assert bc isa Operators.SetValue
173186
if on_left_boundary(idx, space)
174187
if bc isa Operators.SetValue

ext/cuda/operators_fd_shmem_common.jl

Lines changed: 22 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,13 @@ import ClimaCore.Utilities
99
##### Boundary helpers
1010
#####
1111

12-
@inline function has_left_boundary(space, op)
13-
lloc = Operators.LeftBoundaryWindow{Spaces.left_boundary_name(space)}()
14-
return Operators.has_boundary(op, lloc)
15-
end
16-
@inline function has_right_boundary(space, op)
17-
rloc = Operators.RightBoundaryWindow{Spaces.right_boundary_name(space)}()
18-
return Operators.has_boundary(op, rloc)
19-
end
12+
@inline has_left_boundary(space, op) =
13+
Operators.has_boundary(op, Operators.left_boundary_window(space))
14+
@inline has_right_boundary(space, op) =
15+
Operators.has_boundary(op, Operators.right_boundary_window(space))
2016

21-
@inline on_boundary(space, op, loc, idx) =
22-
Operators.has_boundary(op, loc) && on_boundary(idx, space)
17+
@inline on_boundary(idx, space, op) =
18+
on_left_boundary(idx, space, op) || on_right_boundary(idx, space, op)
2319

2420
@inline on_left_boundary(idx, space, op) =
2521
has_left_boundary(space, op) && on_left_boundary(idx, space)
@@ -92,8 +88,7 @@ end
9288
op,
9389
args...,
9490
)
95-
lloc = Operators.LeftBoundaryWindow{Spaces.left_boundary_name(space)}()
96-
Operators.should_call_left_boundary(idx, space, lloc, op, args...) ||
91+
Operators.should_call_left_boundary(idx, space, op, args...) ||
9792
in_left_boundary_window_range(idx, bc_bds)
9893
end
9994

@@ -104,8 +99,7 @@ end
10499
op,
105100
args...,
106101
)
107-
rloc = Operators.RightBoundaryWindow{Spaces.right_boundary_name(space)}()
108-
Operators.should_call_right_boundary(idx, space, rloc, op, args...) ||
102+
Operators.should_call_right_boundary(idx, space, op, args...) ||
109103
in_right_boundary_window_range(idx, bc_bds)
110104
end
111105

@@ -146,8 +140,7 @@ end
146140
op,
147141
args...,
148142
)
149-
lloc = Operators.LeftBoundaryWindow{Spaces.left_boundary_name(space)}()
150-
Operators.should_call_left_boundary(idx, space, lloc, op, args...) ||
143+
Operators.should_call_left_boundary(idx, space, op, args...) ||
151144
in_left_boundary_window_range(idx, bc_bds)
152145
end
153146

@@ -158,10 +151,9 @@ end
158151
op,
159152
args...,
160153
)
161-
rloc = Operators.RightBoundaryWindow{Spaces.right_boundary_name(space)}()
162154
ᶜspace = Spaces.center_space(space)
163155
idx > Spaces.nlevels(ᶜspace) && return false # short-circuit if
164-
Operators.should_call_right_boundary(idx, space, rloc, op, args...) ||
156+
Operators.should_call_right_boundary(idx, space, op, args...) ||
165157
in_right_boundary_window_range(idx, bc_bds)
166158
end
167159

@@ -172,30 +164,6 @@ end
172164
Base.@propagate_inbounds function getidx(
173165
parent_space,
174166
bc::StencilBroadcasted{CUDAWithShmemColumnStencilStyle},
175-
loc::Interior,
176-
idx,
177-
hidx,
178-
)
179-
space = axes(bc)
180-
if Operators.fd_shmem_is_supported(bc)
181-
return fd_operator_evaluate(
182-
bc.op,
183-
bc.work,
184-
loc,
185-
space,
186-
idx,
187-
hidx,
188-
bc.args...,
189-
)
190-
end
191-
Operators.stencil_interior(bc.op, loc, space, idx, hidx, bc.args...)
192-
end
193-
194-
195-
Base.@propagate_inbounds function getidx(
196-
parent_space,
197-
bc::StencilBroadcasted{CUDAWithShmemColumnStencilStyle},
198-
loc::Operators.LeftBoundaryWindow,
199167
idx,
200168
hidx,
201169
)
@@ -204,63 +172,34 @@ Base.@propagate_inbounds function getidx(
204172
return fd_operator_evaluate(
205173
bc.op,
206174
bc.work,
207-
loc,
208175
space,
209176
idx,
210177
hidx,
211178
bc.args...,
212179
)
213180
end
214181
op = bc.op
215-
if Operators.should_call_left_boundary(idx, space, loc, bc.op, bc.args...)
182+
if Operators.should_call_left_boundary(idx, space, bc.op, bc.args...)
216183
Operators.stencil_left_boundary(
217184
op,
218-
Operators.get_boundary(op, loc),
219-
loc,
220-
space,
221-
idx,
222-
hidx,
223-
bc.args...,
224-
)
225-
else
226-
# fallback to interior stencil
227-
Operators.stencil_interior(op, loc, space, idx, hidx, bc.args...)
228-
end
229-
end
230-
231-
Base.@propagate_inbounds function getidx(
232-
parent_space,
233-
bc::StencilBroadcasted{CUDAWithShmemColumnStencilStyle},
234-
loc::Operators.RightBoundaryWindow,
235-
idx,
236-
hidx,
237-
)
238-
space = axes(bc)
239-
if Operators.fd_shmem_is_supported(bc)
240-
return fd_operator_evaluate(
241-
bc.op,
242-
bc.work,
243-
loc,
185+
Operators.get_boundary(op, Operators.left_boundary_window(space)),
244186
space,
245187
idx,
246188
hidx,
247189
bc.args...,
248190
)
249-
end
250-
op = bc.op
251-
if Operators.should_call_right_boundary(idx, space, loc, bc.op, bc.args...)
191+
elseif Operators.should_call_right_boundary(idx, space, bc.op, bc.args...)
252192
Operators.stencil_right_boundary(
253193
op,
254-
Operators.get_boundary(op, loc),
255-
loc,
194+
Operators.get_boundary(op, Operators.right_boundary_window(space)),
256195
space,
257196
idx,
258197
hidx,
259198
bc.args...,
260199
)
261200
else
262201
# fallback to interior stencil
263-
Operators.stencil_interior(op, loc, space, idx, hidx, bc.args...)
202+
Operators.stencil_interior(op, space, idx, hidx, bc.args...)
264203
end
265204
end
266205

@@ -375,9 +314,6 @@ Base.@propagate_inbounds function fd_resolve_shmem!(
375314
)
376315
(li, lw, rw, ri) = bds
377316
space = axes(sbc)
378-
379-
ᶜspace = Spaces.center_space(space)
380-
ᶠspace = Spaces.face_space(space)
381317
arg_space = get_arg_space(sbc, sbc.args)
382318
ᶜidx = get_cent_idx(idx)
383319
ᶠidx = get_face_idx(idx)
@@ -387,13 +323,6 @@ Base.@propagate_inbounds function fd_resolve_shmem!(
387323
# After recursion, check if shmem is supported for this operator
388324
Operators.fd_shmem_is_supported(sbc) || return nothing
389325

390-
(; op) = sbc
391-
lloc = Operators.LeftBoundaryWindow{Spaces.left_boundary_name(space)}()
392-
rloc = Operators.RightBoundaryWindow{Spaces.right_boundary_name(space)}()
393-
iloc = Operators.Interior()
394-
395-
IP = Topologies.isperiodic(Spaces.vertical_topology(space))
396-
397326
# There are `Nf` threads, where `Nf` is the number of face levels. So,
398327
# each thread is responsible for filling shared memory at its cell center
399328
# (if the broadcasted argument lives on cell centers)
@@ -403,52 +332,18 @@ Base.@propagate_inbounds function fd_resolve_shmem!(
403332
# (the space of all broadcasted arguments must all match, so using the first is valid).
404333

405334
bc_bds = Operators.window_bounds(space, sbc)
406-
(bc_li, bc_lw, bc_rw, bc_ri) = bc_bds
407335
ᵃidx = arg_space isa Operators.AllFaceFiniteDifferenceSpace ? ᶠidx : ᶜidx
408336

409-
if in_interior(ᵃidx, arg_space, bc_bds, sbc.op, sbc.args...)
410-
fd_operator_fill_shmem!(
411-
sbc.op,
412-
sbc.work,
413-
iloc,
414-
bc_bds,
415-
arg_space,
416-
space,
417-
ᵃidx,
418-
hidx,
419-
sbc.args...,
420-
)
421-
elseif in_left_boundary_window(ᵃidx, arg_space, bc_bds, sbc.op, sbc.args...)
422-
fd_operator_fill_shmem!(
423-
sbc.op,
424-
sbc.work,
425-
lloc,
426-
bc_bds,
427-
arg_space,
428-
space,
429-
ᵃidx,
430-
hidx,
431-
sbc.args...,
432-
)
433-
elseif in_right_boundary_window(
434-
ᵃidx,
435-
arg_space,
436-
bc_bds,
337+
fd_operator_fill_shmem!(
437338
sbc.op,
339+
sbc.work,
340+
bc_bds,
341+
arg_space,
342+
space,
343+
ᵃidx,
344+
hidx,
438345
sbc.args...,
439346
)
440-
fd_operator_fill_shmem!(
441-
sbc.op,
442-
sbc.work,
443-
rloc,
444-
bc_bds,
445-
arg_space,
446-
space,
447-
ᵃidx,
448-
hidx,
449-
sbc.args...,
450-
)
451-
end
452347
CUDA.sync_threads()
453348
return nothing
454349
end

0 commit comments

Comments
 (0)