Skip to content

Commit 869cc54

Browse files
Implement FD shmem (#2184)
Try dont_limit on recursive resolve_shmem methods Fixes + more dont limit Matrix field fixes Matrix field fixes DivergenceF2C fix MatrixField fixes Qualify DivergenceF2C wip Refactor + fixed space bug. All seems good. More tests.. Fixes Test updates Fixes Allow disabling shmem using broadcast style Fix fused cuda operations in LG Revert some unwanted pieces More fixes Format wip, adding docs Fixes Fixes Apply formatter + docs Always call disable_shmem_style in else-branch Fix git conflict
1 parent 19cecdd commit 869cc54

16 files changed

+1387
-31
lines changed

.buildkite/pipeline.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,16 @@ steps:
610610
key: unit_spec_ops_plane
611611
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/Operators/spectralelement/plane.jl"
612612

613+
- label: "Unit: FD operator (shmem)"
614+
key: unit_fd_operator_shmem
615+
command:
616+
- "julia --color=yes --check-bounds=yes --project=.buildkite test/Operators/finitedifference/unit_fd_ops_shared_memory.jl"
617+
- "julia --color=yes --project=.buildkite test/Operators/finitedifference/benchmark_fd_ops_shared_memory.jl"
618+
env:
619+
CLIMACOMMS_DEVICE: "CUDA"
620+
agents:
621+
slurm_gpus: 1
622+
613623
- label: "Unit: column"
614624
key: unit_column
615625
command:

docs/make.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ withenv("GKSwstype" => "nul") do
8080
"Remapping" => "remapping.md",
8181
"MatrixFields" => "matrix_fields.md",
8282
"API" => "api.md",
83-
"Developer docs" => ["Performance tips" => "performance_tips.md"],
83+
"Developer docs" => [
84+
"Performance tips" => "performance_tips.md"
85+
"Shared memory design" => "shmem_design.md"
86+
],
8487
"Tutorials" => [
8588
joinpath("tutorials", tutorial * ".md") for
8689
tutorial in TUTORIALS

docs/src/shmem_design.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Shared memory design
2+
3+
ClimaCore stencil operators support staggered (or collocated) finite difference
4+
and interpolation operations. For example, the `DivergenceF2C` operator takes
5+
an argument that lives on the cell faces and the resulting divergence
6+
calculation lives on the cell centers. Such operations are effectively
7+
matrix-vector multiplication and are often a significant portion of the runtime
8+
cost for users.
9+
10+
Here, we outline an optimization, shared memory (or, "shmem" for short), that we
11+
use to improve the performance of these operations.
12+
13+
## Motivation
14+
15+
A naive and simplified implementation of this operation looks like `div[i] = (f
16+
[i+1] - f[i]) / dz[i]`. Such a calculation on the gpu (or cpu) requires `f[i]`
17+
be read from global memory to compute the result of `div[i]` and `div[i-1]`. Not
18+
to mention, if `f` is a `Broadcasted` object (`Broadcasted` objects behave like
19+
arrays, and support `f[i]` behavior), then `f[i]` may require several reads and
20+
or computations.
21+
22+
Reading data from global memory is often the main bottleneck for
23+
bandwidth-limited cuda kernels. As such, we use shmem to reduce the number of global memory reads (and compute) in our kernels.
24+
25+
## High-level design
26+
27+
The high-level view of the design is:
28+
29+
- The `bc::StencilBroadcasted` type has a `work` field, which is used to store
30+
shmem for the `bc.op` operator. The element type of the `work`
31+
(or parts of `work` if there are multiple parts) is the type returned by the
32+
`bc.op`'s `Operator.return_eltype`.
33+
- Recursively reconstruct the broadcasted object, allocating shmem for
34+
each `StencilBroadcasted` along the way that supports shmem
35+
(different operators require different arguments, and therefore different
36+
types and amounts of shmem).
37+
- Recursively fill the shmem for all `StencilBroadcasted`. This is done
38+
by reading the argument data from `getidx`
39+
- The destination field is filled with the result of `getidx` (as it is without
40+
shmem), except that we overload `getidx` (for supported `StencilBroadcasted`
41+
types) to retrieve the result of `getidx` via `fd_operator_evaluate`, which
42+
retrieves the result from the shmem, instead of global memory.
43+
44+
45+
46+

ext/ClimaCoreCUDAExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ include(joinpath("cuda", "operators_integral.jl"))
3333
include(joinpath("cuda", "remapping_interpolate_array.jl"))
3434
include(joinpath("cuda", "limiters.jl"))
3535
include(joinpath("cuda", "operators_sem_shmem.jl"))
36+
include(joinpath("cuda", "operators_fd_shmem_common.jl"))
37+
include(joinpath("cuda", "operators_fd_shmem.jl"))
3638
include(joinpath("cuda", "matrix_fields_single_field_solve.jl"))
3739
include(joinpath("cuda", "matrix_fields_multiple_field_solve.jl"))
3840
include(joinpath("cuda", "operators_spectral_element.jl"))

ext/cuda/data_layouts_threadblock.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,3 +313,33 @@ end
313313
ij,
314314
slabidx,
315315
) = Operators.is_valid_index(space, ij, slabidx)
316+
317+
##### shmem fd kernel partition
318+
@inline function fd_stencil_partition(
319+
us::DataLayouts.UniversalSize,
320+
n_face_levels::Integer,
321+
n_max_threads::Integer = 256;
322+
)
323+
(Nq, _, _, Nv, Nh) = DataLayouts.universal_size(us)
324+
Nvthreads = n_face_levels
325+
@assert Nvthreads <= maximum_allowable_threads()[1] "Number of vertical face levels cannot exceed $(maximum_allowable_threads()[1])"
326+
Nvblocks = cld(Nv, Nvthreads) # +1 may be needed to guarantee that shared memory is populated at the last cell face
327+
return (;
328+
threads = (Nvthreads,),
329+
blocks = (Nh, Nvblocks, Nq * Nq),
330+
Nvthreads,
331+
)
332+
end
333+
@inline function fd_stencil_universal_index(space::Spaces.AbstractSpace, us)
334+
(tv,) = CUDA.threadIdx()
335+
(h, bv, ij) = CUDA.blockIdx()
336+
v = tv + (bv - 1) * CUDA.blockDim().x
337+
(Nq, _, _, _, _) = DataLayouts.universal_size(us)
338+
if Nq * Nq < ij
339+
return CartesianIndex((-1, -1, 1, -1, -1))
340+
end
341+
@inbounds (i, j) = CartesianIndices((Nq, Nq))[ij].I
342+
return CartesianIndex((i, j, 1, v, h))
343+
end
344+
@inline fd_stencil_is_valid_index(I::CI5, us::UniversalSize) =
345+
1 I[5] DataLayouts.get_Nh(us)

ext/cuda/operators_fd_shmem.jl

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import ClimaCore: DataLayouts, Spaces, Geometry, RecursiveApply, DataLayouts
2+
import CUDA
3+
import ClimaCore.Operators: return_eltype, get_local_geometry
4+
5+
Base.@propagate_inbounds function fd_operator_shmem(
6+
space,
7+
::Val{Nvt},
8+
op::Operators.DivergenceF2C,
9+
args...,
10+
) where {Nvt}
11+
# allocate temp output
12+
RT = return_eltype(op, args...)
13+
Ju³ = CUDA.CuStaticSharedArray(RT, (Nvt,))
14+
return Ju³
15+
end
16+
17+
Base.@propagate_inbounds function fd_operator_fill_shmem_interior!(
18+
op::Operators.DivergenceF2C,
19+
Ju³,
20+
loc, # can be any location
21+
space,
22+
idx::Utilities.PlusHalf,
23+
hidx,
24+
arg,
25+
)
26+
@inbounds begin
27+
vt = threadIdx().x
28+
lg = Geometry.LocalGeometry(space, idx, hidx)
29+
= Operators.getidx(space, arg, loc, idx, hidx)
30+
Ju³[vt] = Geometry.Jcontravariant3(u³, lg)
31+
end
32+
return nothing
33+
end
34+
35+
Base.@propagate_inbounds function fd_operator_fill_shmem_left_boundary!(
36+
op::Operators.DivergenceF2C,
37+
bc::Operators.SetValue,
38+
Ju³,
39+
loc,
40+
space,
41+
idx::Utilities.PlusHalf,
42+
hidx,
43+
arg,
44+
)
45+
idx == Operators.left_face_boundary_idx(space) ||
46+
error("Incorrect left idx")
47+
@inbounds begin
48+
vt = threadIdx().x
49+
lg = Geometry.LocalGeometry(space, idx, hidx)
50+
= Operators.getidx(space, bc.val, loc, nothing, hidx)
51+
Ju³[vt] = Geometry.Jcontravariant3(u³, lg)
52+
end
53+
return nothing
54+
end
55+
56+
Base.@propagate_inbounds function fd_operator_fill_shmem_right_boundary!(
57+
op::Operators.DivergenceF2C,
58+
bc::Operators.SetValue,
59+
Ju³,
60+
loc,
61+
space,
62+
idx::Utilities.PlusHalf,
63+
hidx,
64+
arg,
65+
)
66+
# The right boundary is called at `idx + 1`, so we need to subtract 1 from idx (shmem is loaded at vt+1)
67+
idx == Operators.right_face_boundary_idx(space) ||
68+
error("Incorrect right idx")
69+
@inbounds begin
70+
vt = threadIdx().x
71+
lg = Geometry.LocalGeometry(space, idx, hidx)
72+
= Operators.getidx(space, bc.val, loc, nothing, hidx)
73+
Ju³[vt] = Geometry.Jcontravariant3(u³, lg)
74+
end
75+
return nothing
76+
end
77+
78+
Base.@propagate_inbounds function fd_operator_evaluate(
79+
op::Operators.DivergenceF2C,
80+
Ju³,
81+
loc,
82+
space,
83+
idx::Integer,
84+
hidx,
85+
args...,
86+
)
87+
@inbounds begin
88+
vt = threadIdx().x
89+
local_geometry = Geometry.LocalGeometry(space, idx, hidx)
90+
Ju³₋ = Ju³[vt] # corresponds to idx - half
91+
Ju³₊ = Ju³[vt + 1] # corresponds to idx + half
92+
return (Ju³₊ Ju³₋) local_geometry.invJ
93+
end
94+
end

0 commit comments

Comments
 (0)