Skip to content

Commit 4a4261e

Browse files
Add more docs for SharedMemory (#2303)
1 parent b732f9f commit 4a4261e

File tree

3 files changed

+40
-3
lines changed

3 files changed

+40
-3
lines changed

docs/src/shmem_design.md

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,40 @@ The high-level view of the design is:
3535
(different operators require different arguments, and therefore different
3636
types and amounts of shmem).
3737
- Recursively fill the shmem for all `StencilBroadcasted`. This is done
38-
by reading the argument data from `getidx`
38+
by reading the argument data from `getidx`. See the section discussion below for more details.
3939
- The destination field is filled with the result of `getidx` (as it is without
4040
shmem), except that we overload `getidx` (for supported `StencilBroadcasted`
4141
types) to retrieve the result of `getidx` via `fd_operator_evaluate`, which
4242
retrieves the result from the shmem, instead of global memory.
4343

44+
### Populating shared memory, and memory access safety
4445

46+
We use tail-recursion when filling shared memory of the broadcast expressions.
47+
That is, we visit leaves of the broadcast expression, then work our way up.
48+
It's important to note that the `StencilBroadcasted` and `Broadcasted` can be
49+
interleaved.
4550

51+
Let's take `DivergenceF2C()(f*GradientC2F()(a*b)))` as an example (depicted in
52+
the image below).
4653

54+
Recursion must go through the entire expression in order to ensure that we've
55+
reached all of the leaves of the `StencilBroadcasted` objects (otherwise, we
56+
could introduce race conditions with memory access). The leaves of the
57+
`StencilBroadcasted` will call `getidx`, below which there are (by definition)
58+
no more `StencilBroadcasted`, and those `getidx` calls will read from global
59+
memory. All subsequent reads will be from shmem(as they will be caught by the
60+
`getidx(parent_space, bc::StencilBroadcasted
61+
{CUDAWithShmemColumnStencilStyle}, idx, hidx)` defined in the
62+
`ClimaCoreCUDAExt` module).
63+
64+
In the diagram below, we traverse and fill the yellow highlighted sections
65+
(bottom first and top last). The algorithmic impact of using shared memory is
66+
that the duplicate global memory reads (highlighted in red circles) become one
67+
global memory read (performed in `fd_operator_fill_shmem!`).
68+
69+
Finally, its important to note that threads must by syncrhonized after each node
70+
in the tree is filled, to avoid race conditions for subsequent `getidx
71+
(parent_space, bc::StencilBroadcasted{CUDAWithShmemColumnStencilStyle}, idx,
72+
hidx)` calls (which are retrieved via shmem).
73+
74+
![](shmem_diagram_example.png)

docs/src/shmem_diagram_example.png

215 KB
Loading

ext/cuda/operators_fd_shmem_common.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,9 +323,18 @@ Base.@propagate_inbounds function fd_resolve_shmem!(
323323
ᶜidx = get_cent_idx(idx)
324324
ᶠidx = get_face_idx(idx)
325325

326-
_fd_resolve_shmem!(idx, hidx, bds, sbc.args...) # propagate idx, not bc_idx recursively through broadcast expressions
326+
# Here, we use tail-recursion. We visit leaves of the broadcast expression,
327+
# then work our way up. The StencilBroadcasted and Broadcasted can be
328+
# interleaved (e.g., `DivergenceF2C()(f*GradientC2F()(a*b)))`. The leaves of
329+
# the StencilBroadcasted will call `getidx`, below which there are
330+
# (by definition) no more `StencilBroadcasted`, and those `getidx` calls
331+
# will read from global memory. Immediately above those reads, all
332+
# subsequent reads will be from shmem (as they will be caught by the
333+
# `getidx` defined above).
334+
_fd_resolve_shmem!(idx, hidx, bds, sbc.args...)
327335

328-
# After recursion, check if shmem is supported for this operator
336+
# Once we're about ready to fill the shmem, check if shmem is supported for
337+
# this operator
329338
Operators.fd_shmem_is_supported(sbc) || return nothing
330339

331340
# There are `Nf` threads, where `Nf` is the number of face levels. So,

0 commit comments

Comments
 (0)