Skip to content

Commit 8c96331

Browse files
Reduce use of DataLayouts internals
1 parent cfd8901 commit 8c96331

20 files changed

+201
-207
lines changed

ext/cuda/cuda_utils.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,6 @@ import ClimaCore.Fields
33
import ClimaCore.DataLayouts
44
import ClimaCore.DataLayouts: empty_kernel_stats
55

6-
get_n_items(field::Fields.Field) = get_n_items(Fields.field_values(field))
7-
get_n_items(data::DataLayouts.AbstractData) = get_n_items(size(data))
8-
get_n_items(arr::AbstractArray) = get_n_items(size(parent(arr)))
9-
get_n_items(tup::Tuple) = prod(tup)
10-
116
const reported_stats = Dict()
127
# Call via ClimaCore.DataLayouts.empty_kernel_stats()
138
empty_kernel_stats(::ClimaComms.CUDADevice) = empty!(reported_stats)
@@ -37,15 +32,15 @@ to benchmark compare against auto-determined threads/blocks (if `auto=false`).
3732
function auto_launch!(
3833
f!::F!,
3934
args,
40-
data;
35+
nitems::Union{Integer, Nothing} = nothing;
4136
auto = false,
4237
threads_s = nothing,
4338
blocks_s = nothing,
4439
always_inline = true,
4540
caller = :unknown,
4641
) where {F!}
4742
if auto
48-
nitems = get_n_items(data)
43+
@assert !isnothing(nitems)
4944
if nitems 0
5045
kernel = CUDA.@cuda always_inline = true launch = false f!(args...)
5146
config = CUDA.launch_configuration(kernel.fun)
@@ -64,7 +59,7 @@ function auto_launch!(
6459
# CUDA.registers(kernel) > 50 || return nothing # for debugging
6560
# occursin("single_field_solve_kernel", string(nameof(F!))) || return nothing
6661
if !haskey(reported_stats, key)
67-
nitems = get_n_items(data)
62+
@assert !isnothing(nitems)
6863
kernel = CUDA.@cuda always_inline = true launch = false f!(args...)
6964
config = CUDA.launch_configuration(kernel.fun)
7065
threads = min(nitems, config.threads)

ext/cuda/data_layouts_copyto.jl

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ function Base.copyto!(
2323
if Nh > 0
2424
auto_launch!(
2525
knl_copyto!,
26-
(dest, bc),
27-
dest;
26+
(dest, bc);
2827
threads_s = (Nij, Nij),
2928
blocks_s = (Nh, 1),
3029
)
@@ -42,8 +41,7 @@ function Base.copyto!(
4241
Nv_blocks = cld(Nv, Nv_per_block)
4342
auto_launch!(
4443
knl_copyto!,
45-
(dest, bc),
46-
dest;
44+
(dest, bc);
4745
threads_s = (Nij, Nij, Nv_per_block),
4846
blocks_s = (Nh, Nv_blocks),
4947
)
@@ -59,8 +57,7 @@ function Base.copyto!(
5957
if Nv > 0
6058
auto_launch!(
6159
knl_copyto!,
62-
(dest, bc),
63-
dest;
60+
(dest, bc);
6461
threads_s = (1, 1),
6562
blocks_s = (1, Nv),
6663
)
@@ -73,13 +70,7 @@ function Base.copyto!(
7370
bc::DataLayouts.BroadcastedUnionDataF{S},
7471
::ToCUDA,
7572
) where {S}
76-
auto_launch!(
77-
knl_copyto!,
78-
(dest, bc),
79-
dest;
80-
threads_s = (1, 1),
81-
blocks_s = (1, 1),
82-
)
73+
auto_launch!(knl_copyto!, (dest, bc); threads_s = (1, 1), blocks_s = (1, 1))
8374
return dest
8475
end
8576

@@ -100,7 +91,8 @@ function cuda_copyto!(dest::AbstractData, bc)
10091
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
10192
us = DataLayouts.UniversalSize(dest)
10293
if Nv > 0 && Nh > 0
103-
auto_launch!(knl_copyto_flat!, (dest, bc, us), dest; auto = true)
94+
nitems = prod(DataLayouts.universal_size(dest))
95+
auto_launch!(knl_copyto_flat!, (dest, bc, us), nitems; auto = true)
10496
end
10597
return dest
10698
end

ext/cuda/data_layouts_fill.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ function cuda_fill!(dest::AbstractData, val)
1414
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
1515
us = DataLayouts.UniversalSize(dest)
1616
if Nv > 0 && Nh > 0
17-
auto_launch!(knl_fill_flat!, (dest, val, us), dest; auto = true)
17+
nitems = prod(DataLayouts.universal_size(dest))
18+
auto_launch!(knl_fill_flat!, (dest, val, us), nitems; auto = true)
1819
end
1920
return dest
2021
end

ext/cuda/data_layouts_fused_copyto.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ function fused_copyto!(
5050
Nv_blocks = cld(Nv, Nv_per_block)
5151
auto_launch!(
5252
knl_fused_copyto!,
53-
(fmbc,),
54-
dest1;
53+
(fmbc,);
5554
threads_s = (Nij, Nij, Nv_per_block),
5655
blocks_s = (Nh, Nv_blocks),
5756
)
@@ -68,8 +67,7 @@ function fused_copyto!(
6867
if Nh > 0
6968
auto_launch!(
7069
knl_fused_copyto!,
71-
(fmbc,),
72-
dest1;
70+
(fmbc,);
7371
threads_s = (Nij, Nij),
7472
blocks_s = (Nh, 1),
7573
)
@@ -85,8 +83,7 @@ function fused_copyto!(
8583
if Nv > 0 && Nh > 0
8684
auto_launch!(
8785
knl_fused_copyto!,
88-
(fmbc,),
89-
dest1;
86+
(fmbc,);
9087
threads_s = (1, 1),
9188
blocks_s = (Nh, Nv),
9289
)
@@ -101,8 +98,7 @@ function fused_copyto!(
10198
) where {S}
10299
auto_launch!(
103100
knl_fused_copyto!,
104-
(fmbc,),
105-
dest1;
101+
(fmbc,);
106102
threads_s = (1, 1),
107103
blocks_s = (1, 1),
108104
)

ext/cuda/data_layouts_mapreduce.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function mapreduce_cuda(
2828
pdata = parent(data)
2929
T = eltype(pdata)
3030
(Ni, Nj, Nk, Nv, Nh) = size(data)
31-
Nf = div(length(pdata), prod(size(data))) # length of field dimension
31+
Nf = DataLayouts.ncomponents(data) # length of field dimension
3232
pwt = parent(weighted_jacobian)
3333

3434
nitems = Nv * Ni * Nj * Nk * Nh

ext/cuda/limiters.jl

Lines changed: 22 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -21,39 +21,24 @@ function compute_element_bounds!(
2121
ρ,
2222
::ClimaComms.CUDADevice,
2323
)
24-
S = size(Fields.field_values(ρ))
25-
(Ni, Nj, _, Nv, Nh) = S
24+
ρ_values = Fields.field_values(Operators.strip_space(ρ, axes(ρ)))
25+
ρq_values = Fields.field_values(Operators.strip_space(ρq, axes(ρq)))
26+
(_, _, _, Nv, Nh) = DataLayouts.universal_size(ρ_values)
2627
nthreads, nblocks = config_threadblock(Nv, Nh)
2728

28-
args = (
29-
limiter,
30-
Fields.field_values(Operators.strip_space(ρq, axes(ρq))),
31-
Fields.field_values(Operators.strip_space(ρ, axes(ρ))),
32-
Nv,
33-
Nh,
34-
Val(Ni),
35-
Val(Nj),
36-
)
29+
args = (limiter, ρq_values, ρ_values)
3730
auto_launch!(
3831
compute_element_bounds_kernel!,
39-
args,
40-
ρ;
32+
args;
4133
threads_s = nthreads,
4234
blocks_s = nblocks,
4335
)
4436
return nothing
4537
end
4638

4739

48-
function compute_element_bounds_kernel!(
49-
limiter,
50-
ρq,
51-
ρ,
52-
Nv,
53-
Nh,
54-
::Val{Ni},
55-
::Val{Nj},
56-
) where {Ni, Nj}
40+
function compute_element_bounds_kernel!(limiter, ρq, ρ)
41+
(Ni, Nj, _, Nv, Nh) = DataLayouts.universal_size(ρ)
5742
n = (Nv, Nh)
5843
tidx = thread_index()
5944
@inbounds if valid_range(tidx, prod(n))
@@ -88,21 +73,18 @@ function compute_neighbor_bounds_local!(
8873
::ClimaComms.CUDADevice,
8974
)
9075
topology = Spaces.topology(axes(ρ))
91-
Ni, Nj, _, Nv, Nh = size(Fields.field_values(ρ))
76+
us = DataLayouts.UniversalSize(Fields.field_values(ρ))
77+
(_, _, _, Nv, Nh) = DataLayouts.universal_size(us)
9278
nthreads, nblocks = config_threadblock(Nv, Nh)
9379
args = (
9480
limiter,
9581
topology.local_neighbor_elem,
9682
topology.local_neighbor_elem_offset,
97-
Nv,
98-
Nh,
99-
Val(Ni),
100-
Val(Nj),
83+
us,
10184
)
10285
auto_launch!(
10386
compute_neighbor_bounds_local_kernel!,
104-
args,
105-
ρ;
87+
args;
10688
threads_s = nthreads,
10789
blocks_s = nblocks,
10890
)
@@ -112,12 +94,9 @@ function compute_neighbor_bounds_local_kernel!(
11294
limiter,
11395
local_neighbor_elem,
11496
local_neighbor_elem_offset,
115-
Nv,
116-
Nh,
117-
::Val{Ni},
118-
::Val{Nj},
119-
) where {Ni, Nj}
120-
97+
us::DataLayouts.UniversalSize,
98+
)
99+
(_, _, _, Nv, Nh) = DataLayouts.universal_size(us)
121100
n = (Nv, Nh)
122101
tidx = thread_index()
123102
@inbounds if valid_range(tidx, prod(n))
@@ -147,27 +126,24 @@ function apply_limiter!(
147126
::ClimaComms.CUDADevice,
148127
)
149128
ρq_data = Fields.field_values(ρq)
150-
(Ni, Nj, _, Nv, Nh) = size(ρq_data)
151-
Nf = DataLayouts.ncomponents(ρq_data)
129+
us = DataLayouts.UniversalSize(ρq_data)
130+
(Ni, Nj, _, Nv, Nh) = DataLayouts.universal_size(us)
152131
maxiter = Ni * Nj
132+
Nf = DataLayouts.ncomponents(ρq_data)
153133
WJ = Spaces.local_geometry_data(axes(ρq)).WJ
154134
nthreads, nblocks = config_threadblock(Nv, Nh)
155135
args = (
156136
limiter,
157137
Fields.field_values(Operators.strip_space(ρq, axes(ρq))),
158138
Fields.field_values(Operators.strip_space(ρ, axes(ρ))),
159139
WJ,
160-
Nv,
161-
Nh,
140+
us,
162141
Val(Nf),
163-
Val(Ni),
164-
Val(Nj),
165142
Val(maxiter),
166143
)
167144
auto_launch!(
168145
apply_limiter_kernel!,
169-
args,
170-
ρ;
146+
args;
171147
threads_s = nthreads,
172148
blocks_s = nblocks,
173149
)
@@ -179,15 +155,13 @@ function apply_limiter_kernel!(
179155
ρq_data,
180156
ρ_data,
181157
WJ_data,
182-
Nv,
183-
Nh,
158+
us::DataLayouts.UniversalSize,
184159
::Val{Nf},
185-
::Val{Ni},
186-
::Val{Nj},
187160
::Val{maxiter},
188-
) where {Nf, Ni, Nj, maxiter}
161+
) where {Nf, maxiter}
189162
(; q_bounds_nbr, rtol) = limiter
190163
converged = true
164+
(Ni, Nj, _, Nv, Nh) = DataLayouts.universal_size(us)
191165
n = (Nv, Nh)
192166
tidx = thread_index()
193167
@inbounds if valid_range(tidx, prod(n))

ext/cuda/matrix_fields_multiple_field_solve.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ NVTX.@annotate function multiple_field_solve!(
3838

3939
auto_launch!(
4040
multiple_field_solve_kernel!,
41-
args,
42-
x1;
41+
args;
4342
threads_s = nthreads,
4443
blocks_s = nblocks,
4544
always_inline = true,

ext/cuda/matrix_fields_single_field_solve.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b)
2121
args = (device, cache, x, A, b)
2222
auto_launch!(
2323
single_field_solve_kernel!,
24-
args,
25-
x;
24+
args;
2625
threads_s = nthreads,
2726
blocks_s = nblocks,
2827
)

ext/cuda/operators_finite_difference.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ function Base.copyto!(
3636
(strip_space(out, space), strip_space(bc, space), axes(out), bounds, us)
3737
auto_launch!(
3838
copyto_stencil_kernel!,
39-
args,
40-
out;
39+
args;
4140
threads_s = (nthreads,),
4241
blocks_s = (nblocks,),
4342
)

ext/cuda/operators_integral.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ function column_reduce_device!(
2929
init,
3030
space,
3131
)
32-
auto_launch!(bycolumn_kernel!, args, (); threads_s, blocks_s)
32+
auto_launch!(bycolumn_kernel!, args; threads_s, blocks_s)
3333
end
3434

3535
function column_accumulate_device!(
@@ -52,7 +52,7 @@ function column_accumulate_device!(
5252
init,
5353
space,
5454
)
55-
auto_launch!(bycolumn_kernel!, args, (); threads_s, blocks_s)
55+
auto_launch!(bycolumn_kernel!, args; threads_s, blocks_s)
5656
end
5757

5858
bycolumn_kernel!(

0 commit comments

Comments
 (0)