Skip to content

Commit 0c30ae5

Browse files
Merge pull request #1663 from CliMA/ck/try_cuda_launch_config
Define and use `auto_launch!`
2 parents ca24ff9 + 33ce411 commit 0c30ae5

13 files changed

+328
-78
lines changed

ext/ClimaCoreCUDAExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import ClimaCore.Utilities: half
1313
import ClimaCore.RecursiveApply:
1414
, , , radd, rmul, rsub, rdiv, rmap, rzero, rmin, rmax
1515

16+
include(joinpath("cuda", "cuda_utils.jl"))
1617
include(joinpath("cuda", "data_layouts.jl"))
1718
include(joinpath("cuda", "fields.jl"))
1819
include(joinpath("cuda", "topologies_dss.jl"))
@@ -23,8 +24,8 @@ include(joinpath("cuda", "remapping_interpolate_array.jl"))
2324
include(joinpath("cuda", "limiters.jl"))
2425
include(joinpath("cuda", "operators_sem_shmem.jl"))
2526
include(joinpath("cuda", "operators_thomas_algorithm.jl"))
27+
include(joinpath("cuda", "matrix_fields_single_field_solve.jl"))
2628
include(joinpath("cuda", "matrix_fields_multiple_field_solve.jl"))
2729
include(joinpath("cuda", "operators_spectral_element.jl"))
28-
include(joinpath("cuda", "matrix_fields_single_field_solve.jl"))
2930

3031
end

ext/cuda/cuda_utils.jl

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import CUDA
2+
import ClimaCore.Fields
3+
import ClimaCore.DataLayouts
4+
5+
get_n_items(field::Fields.Field) =
6+
prod(size(parent(Fields.field_values(field))))
7+
get_n_items(data::DataLayouts.AbstractData) = prod(size(parent(data)))
8+
get_n_items(arr::AbstractArray) = prod(size(parent(arr)))
9+
get_n_items(tup::Tuple) = prod(tup)
10+
11+
"""
12+
auto_launch!(f!::F!, args,
13+
::Union{
14+
Int,
15+
NTuple{N, <:Int},
16+
AbstractArray,
17+
AbstractData,
18+
Field,
19+
};
20+
threads_s,
21+
blocks_s,
22+
always_inline = true
23+
)
24+
25+
Launch a cuda kernel, using `CUDA.launch_configuration`
26+
to determine the number of threads/blocks.
27+
28+
Suggested threads and blocks (`threads_s`, `blocks_s`) can be given
29+
to benchmark compare against auto-determined threads/blocks.
30+
"""
31+
function auto_launch!(
32+
f!::F!,
33+
args,
34+
data;
35+
threads_s,
36+
blocks_s,
37+
always_inline = true,
38+
) where {F!}
39+
nitems = get_n_items(data)
40+
# For now, we'll simply use the
41+
# suggested threads and blocks:
42+
CUDA.@cuda always_inline = always_inline threads = threads_s blocks =
43+
blocks_s f!(args...)
44+
45+
# Soon, we'll experiment with `CUDA.launch_configuration`
46+
# kernel = CUDA.@cuda always_inline = true launch = false f!(args...)
47+
# config = CUDA.launch_configuration(kernel.fun)
48+
# threads = min(nitems, config.threads)
49+
# blocks = cld(nitems, threads)
50+
# s = ""
51+
# s *= "Launching kernel $f! with following config:\n"
52+
# s *= " nitems: $nitems\n"
53+
# s *= " threads: $threads\n"
54+
# s *= " blocks: $blocks\n"
55+
# @info s
56+
# kernel(args...; threads, blocks) # This knows to use always_inline from above.
57+
end
58+
59+
"""
60+
kernel_indexes(n)
61+
Return a tuple of indexes from the kernel,
62+
where `n` is a tuple of max lengths along each
63+
dimension of the accessed data.
64+
"""
65+
function kernel_indexes(n::Tuple)
66+
tidx = (CUDA.blockIdx().x - 1) * CUDA.blockDim().x + CUDA.threadIdx().x
67+
inds = if 1 tidx prod(n)
68+
CartesianIndices(map(x -> Base.OneTo(x), n))[tidx].I
69+
else
70+
ntuple(x -> -1, length(n))
71+
end
72+
return inds
73+
end
74+
75+
"""
76+
valid_range(inds, n)
77+
Returns a `Bool` indicating if the thread index
78+
is in the valid range, based on `inds` (the result
79+
of `kernel_indexes`) and `n`, a tuple of max lengths
80+
along each dimension of the accessed data.
81+
```julia
82+
function kernel!(data, n)
83+
inds = kernel_indexes(n)
84+
if valid_range(inds, n)
85+
do_work!(data[inds...])
86+
end
87+
end
88+
```
89+
"""
90+
valid_range(inds::NTuple, n::Tuple) = all(i -> 1 inds[i] n[i], 1:length(n))
91+
function valid_range(n::Tuple)
92+
inds = kernel_indexes(n)
93+
return all(i -> 1 inds[i] n[i], 1:length(n))
94+
end

ext/cuda/data_layouts.jl

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,12 @@ function Base.copyto!(
5656
) where {S, Nij, A <: CUDA.CuArray}
5757
_, _, _, _, Nh = size(bc)
5858
if Nh > 0
59-
CUDA.@cuda always_inline = true threads = (Nij, Nij) blocks = (Nh, 1) knl_copyto!(
60-
dest,
61-
bc,
59+
auto_launch!(
60+
knl_copyto!,
61+
(dest, bc),
62+
dest;
63+
threads_s = (Nij, Nij),
64+
blocks_s = (Nh, 1),
6265
)
6366
end
6467
return dest
@@ -73,9 +76,12 @@ function Base.fill!(
7376
}
7477
_, _, _, _, Nh = size(dest)
7578
if Nh > 0
76-
CUDA.@cuda always_inline = true threads = (Nij, Nij) blocks = (Nh, 1) knl_fill!(
77-
dest,
78-
val,
79+
auto_launch!(
80+
knl_fill!,
81+
(dest, val),
82+
dest;
83+
threads_s = (Nij, Nij),
84+
blocks_s = (Nh, 1),
7985
)
8086
end
8187
return dest
@@ -91,8 +97,13 @@ function Base.copyto!(
9197
if Nv > 0 && Nh > 0
9298
Nv_per_block = min(Nv, fld(256, Nij * Nij))
9399
Nv_blocks = cld(Nv, Nv_per_block)
94-
CUDA.@cuda always_inline = true threads = (Nij, Nij, Nv_per_block) blocks =
95-
(Nh, Nv_blocks) knl_copyto!(dest, bc)
100+
auto_launch!(
101+
knl_copyto!,
102+
(dest, bc),
103+
dest;
104+
threads_s = (Nij, Nij, Nv_per_block),
105+
blocks_s = (Nh, Nv_blocks),
106+
)
96107
end
97108
return dest
98109
end
@@ -104,8 +115,13 @@ function Base.fill!(
104115
if Nv > 0 && Nh > 0
105116
Nv_per_block = min(Nv, fld(256, Nij * Nij))
106117
Nv_blocks = cld(Nv, Nv_per_block)
107-
CUDA.@cuda always_inline = true threads = (Nij, Nij, Nv_per_block) blocks =
108-
(Nh, Nv_blocks) knl_fill!(dest, val)
118+
auto_launch!(
119+
knl_fill!,
120+
(dest, val),
121+
dest;
122+
threads_s = (Nij, Nij, Nv_per_block),
123+
blocks_s = (Nh, Nv_blocks),
124+
)
109125
end
110126
return dest
111127
end
@@ -117,19 +133,25 @@ function Base.copyto!(
117133
) where {S, A <: CUDA.CuArray}
118134
_, _, _, Nv, Nh = size(bc)
119135
if Nv > 0 && Nh > 0
120-
CUDA.@cuda always_inline = true threads = (1, 1) blocks = (Nh, Nv) knl_copyto!(
121-
dest,
122-
bc,
136+
auto_launch!(
137+
knl_copyto!,
138+
(dest, bc),
139+
dest;
140+
threads_s = (1, 1),
141+
blocks_s = (Nh, Nv),
123142
)
124143
end
125144
return dest
126145
end
127146
function Base.fill!(dest::VF{S, A}, val) where {S, A <: CUDA.CuArray}
128147
_, _, _, Nv, Nh = size(dest)
129148
if Nv > 0 && Nh > 0
130-
CUDA.@cuda always_inline = true threads = (1, 1) blocks = (Nh, Nv) knl_fill!(
131-
dest,
132-
val,
149+
auto_launch!(
150+
knl_fill!,
151+
(dest, val),
152+
dest;
153+
threads_s = (1, 1),
154+
blocks_s = (Nh, Nv),
133155
)
134156
end
135157
return dest
@@ -139,16 +161,22 @@ function Base.copyto!(
139161
dest::DataF{S},
140162
bc::Union{DataF{S, A}, Base.Broadcast.Broadcasted{DataFStyle{A}}},
141163
) where {S, A <: CUDA.CuArray}
142-
CUDA.@cuda always_inline = true threads = (1, 1) blocks = (1, 1) knl_copyto!(
143-
dest,
144-
bc,
164+
auto_launch!(
165+
knl_copyto!,
166+
(dest, bc),
167+
dest;
168+
threads_s = (1, 1),
169+
blocks_s = (1, 1),
145170
)
146171
return dest
147172
end
148173
function Base.fill!(dest::DataF{S, A}, val) where {S, A <: CUDA.CuArray}
149-
CUDA.@cuda always_inline = true threads = (1, 1) blocks = (1, 1) knl_fill!(
150-
dest,
151-
val,
174+
auto_launch!(
175+
knl_fill!,
176+
(dest, val),
177+
dest;
178+
threads_s = (1, 1),
179+
blocks_s = (1, 1),
152180
)
153181
return dest
154182
end

ext/cuda/limiters.jl

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function compute_element_bounds!(
3131
(Ni, Nj, _, Nv, Nh) = S
3232
nthreads, nblocks = config_threadblock(Nv, Nh)
3333

34-
CUDA.@cuda always_inline = true threads = nthreads blocks = nblocks compute_element_bounds_kernel!(
34+
args = (
3535
limiter,
3636
Fields.field_values(Operators.strip_space(ρq, axes(ρq))),
3737
Fields.field_values(Operators.strip_space(ρ, axes(ρ))),
@@ -40,6 +40,13 @@ function compute_element_bounds!(
4040
Val(Ni),
4141
Val(Nj),
4242
)
43+
auto_launch!(
44+
compute_element_bounds_kernel!,
45+
args,
46+
ρ;
47+
threads_s = nthreads,
48+
blocks_s = nblocks,
49+
)
4350
return nothing
4451
end
4552

@@ -87,7 +94,7 @@ function compute_neighbor_bounds_local!(
8794
topology = Spaces.topology(axes(ρ))
8895
Ni, Nj, _, Nv, Nh = size(Fields.field_values(ρ))
8996
nthreads, nblocks = config_threadblock(Nv, Nh)
90-
CUDA.@cuda always_inline = true threads = nthreads blocks = nblocks compute_neighbor_bounds_local_kernel!(
97+
args = (
9198
limiter,
9299
topology.local_neighbor_elem,
93100
topology.local_neighbor_elem_offset,
@@ -96,6 +103,13 @@ function compute_neighbor_bounds_local!(
96103
Val(Ni),
97104
Val(Nj),
98105
)
106+
auto_launch!(
107+
compute_neighbor_bounds_local_kernel!,
108+
args,
109+
ρ;
110+
threads_s = nthreads,
111+
blocks_s = nblocks,
112+
)
99113
end
100114

101115
function compute_neighbor_bounds_local_kernel!(
@@ -140,7 +154,7 @@ function apply_limiter!(
140154
maxiter = Ni * Nj
141155
WJ = Spaces.local_geometry_data(axes(ρq)).WJ
142156
nthreads, nblocks = config_threadblock(Nv, Nh)
143-
CUDA.@cuda always_inline = true threads = nthreads blocks = nblocks apply_limiter_kernel!(
157+
args = (
144158
limiter,
145159
Fields.field_values(Operators.strip_space(ρq, axes(ρq))),
146160
Fields.field_values(Operators.strip_space(ρ, axes(ρ))),
@@ -152,6 +166,13 @@ function apply_limiter!(
152166
Val(Nj),
153167
Val(maxiter),
154168
)
169+
auto_launch!(
170+
apply_limiter_kernel!,
171+
args,
172+
ρ;
173+
threads_s = nthreads,
174+
blocks_s = nblocks,
175+
)
155176
return nothing
156177
end
157178

ext/cuda/matrix_fields_multiple_field_solve.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import ClimaComms
33
import LinearAlgebra: UniformScaling
44
import ClimaCore.Operators
55
import ClimaCore.MatrixFields
6+
import ClimaCore.MatrixFields: _single_field_solve!
67
import ClimaCore.MatrixFields: multiple_field_solve!
78
import ClimaCore.MatrixFields: is_CuArray_type
89
import ClimaCore.MatrixFields: allow_scalar_func
@@ -30,11 +31,16 @@ function multiple_field_solve!(::ClimaComms.CUDADevice, cache, x, A, b, x1)
3031
tups = (cache_tup, x_tup, A_tup, b_tup)
3132

3233
device = ClimaComms.device(x[first(names)])
33-
CUDA.@cuda threads = nthreads blocks = nblocks multiple_field_solve_kernel!(
34-
device,
35-
tups,
36-
x1,
37-
Val(Nnames),
34+
35+
args = (device, tups, x1, Val(Nnames))
36+
# TODO: use always_inline=true
37+
auto_launch!(
38+
multiple_field_solve_kernel!,
39+
args,
40+
x1;
41+
threads_s = nthreads,
42+
blocks_s = nblocks,
43+
always_inline = false,
3844
)
3945
end
4046

ext/cuda/operators_finite_difference.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function Base.copyto!(
3131
max_threads = 256
3232
nitems = Nv * Nq * Nq * Nh # # of independent items
3333
(nthreads, nblocks) = _configure_threadblock(max_threads, nitems)
34-
@cuda always_inline = true threads = (nthreads,) blocks = (nblocks,) copyto_stencil_kernel!(
34+
args = (
3535
strip_space(out, space),
3636
strip_space(bc, space),
3737
axes(out),
@@ -40,6 +40,13 @@ function Base.copyto!(
4040
Nh,
4141
Nv,
4242
)
43+
auto_launch!(
44+
copyto_stencil_kernel!,
45+
args,
46+
out;
47+
threads_s = (nthreads,),
48+
blocks_s = (nblocks,),
49+
)
4350
return out
4451
end
4552

0 commit comments

Comments
 (0)