Skip to content

Commit 2c12dd3

Browse files
Merge pull request #2001 from CliMA/ck/high_res_launch_config
Fix high res threading config for SEM kernels
2 parents bd20629 + a7a68f3 commit 2c12dd3

File tree

3 files changed

+46
-32
lines changed

3 files changed

+46
-32
lines changed

ext/cuda/data_layouts_threadblock.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,40 @@ end
214214
end
215215
@inline multiple_field_solve_is_valid_index(I::CI5, us::UniversalSize) =
216216
1 I[5] DataLayouts.get_Nh(us)
217+
218+
##### spectral kernel partition
219+
@inline function spectral_partition(
220+
us::DataLayouts.UniversalSize,
221+
n_max_threads::Integer = 256;
222+
)
223+
(Nq, _, _, Nv, Nh) = DataLayouts.universal_size(us)
224+
Nvthreads = min(fld(n_max_threads, Nq * Nq), maximum_allowable_threads()[3])
225+
Nvblocks = cld(Nv, Nvthreads)
226+
@assert prod((Nq, Nq, Nvthreads)) n_max_threads "threads,n_max_threads=($(prod((Nq, Nq, Nvthreads))),$n_max_threads)"
227+
@assert Nq * Nq n_max_threads
228+
return (; threads = (Nq, Nq, Nvthreads), blocks = (Nh, Nvblocks), Nvthreads)
229+
end
230+
@inline function spectral_universal_index(space::Spaces.AbstractSpace)
231+
i = threadIdx().x
232+
j = threadIdx().y
233+
k = threadIdx().z
234+
h = blockIdx().x
235+
vid = k + (blockIdx().y - 1) * blockDim().z
236+
if space isa Spaces.AbstractSpectralElementSpace
237+
v = nothing
238+
elseif space isa Spaces.FaceExtrudedFiniteDifferenceSpace
239+
v = vid - half
240+
elseif space isa Spaces.CenterExtrudedFiniteDifferenceSpace
241+
v = vid
242+
else
243+
error("Invalid space")
244+
end
245+
ij = CartesianIndex((i, j))
246+
slabidx = Fields.SlabIndex(v, h)
247+
return (ij, slabidx)
248+
end
249+
@inline spectral_is_valid_index(
250+
space::Spaces.AbstractSpectralElementSpace,
251+
ij,
252+
slabidx,
253+
) = Operators.is_valid_index(space, ij, slabidx)

ext/cuda/operators_spectral_element.jl

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,26 +34,20 @@ function Base.copyto!(
3434
},
3535
)
3636
space = axes(out)
37-
QS = Spaces.quadrature_style(space)
38-
Nq = Quadratures.degrees_of_freedom(QS)
39-
Nh = Topologies.nlocalelems(Spaces.topology(space))
40-
Nv = Spaces.nlevels(space)
41-
max_threads = 256
42-
@assert Nq * Nq max_threads
43-
Nvthreads = fld(max_threads, Nq * Nq)
44-
Nvblocks = cld(Nv, Nvthreads)
37+
us = UniversalSize(Fields.field_values(out))
4538
# executed
39+
p = spectral_partition(us)
4640
args = (
4741
strip_space(out, space),
4842
strip_space(sbc, space),
4943
space,
50-
Val(Nvthreads),
44+
Val(p.Nvthreads),
5145
)
5246
auto_launch!(
5347
copyto_spectral_kernel!,
5448
args;
55-
threads_s = (Nq, Nq, Nvthreads),
56-
blocks_s = (Nh, Nvblocks),
49+
threads_s = p.threads,
50+
blocks_s = p.blocks,
5751
)
5852
return out
5953
end
@@ -66,32 +60,15 @@ function copyto_spectral_kernel!(
6660
::Val{Nvt},
6761
) where {Nvt}
6862
@inbounds begin
69-
i = threadIdx().x
70-
j = threadIdx().y
71-
k = threadIdx().z
72-
h = blockIdx().x
73-
vid = k + (blockIdx().y - 1) * blockDim().z
7463
# allocate required shmem
75-
7664
sbc_reconstructed =
7765
Operators.reconstruct_placeholder_broadcasted(space, sbc)
7866
sbc_shmem = allocate_shmem(Val(Nvt), sbc_reconstructed)
7967

80-
8168
# can loop over blocks instead?
82-
if space isa Spaces.AbstractSpectralElementSpace
83-
v = nothing
84-
elseif space isa Spaces.FaceExtrudedFiniteDifferenceSpace
85-
v = vid - half
86-
elseif space isa Spaces.CenterExtrudedFiniteDifferenceSpace
87-
v = vid
88-
else
89-
error("Invalid space")
90-
end
91-
ij = CartesianIndex((i, j))
92-
slabidx = Fields.SlabIndex(v, h)
93-
# v may potentially be out-of-range: any time memory is accessed, it
94-
# should be checked by a call to is_valid_index(space, ij, slabidx)
69+
(ij, slabidx) = spectral_universal_index(space)
70+
# v in `slabidx` may potentially be out-of-range: any time memory is
71+
# accessed, it should be checked by a call to is_valid_index(space, ij, slabidx)
9572

9673
# resolve_shmem! needs to be called even when out of range, so that
9774
# sync_threads() is invoked collectively

test/Spaces/distributed_cuda/ddss4.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ pid, nprocs = ClimaComms.init(context)
100100
end
101101
p = @allocated Spaces.weighted_dss!(y0, dss_buffer)
102102
if pid == 1
103-
@test p 7776
103+
@test p 410296
104104
end
105105

106106
end

0 commit comments

Comments
 (0)