Skip to content

Commit 8628c64

Browse files
Reduce use of DataLayouts internals in ClimaCore
1 parent c2451ad commit 8628c64

File tree

9 files changed

+166
-162
lines changed

9 files changed

+166
-162
lines changed

ext/cuda/cuda_utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,15 @@ to benchmark compare against auto-determined threads/blocks (if `auto=false`).
3737
function auto_launch!(
3838
f!::F!,
3939
args,
40-
data;
40+
nitems::Union{Integer, Nothing} = nothing;
4141
auto = false,
4242
threads_s = nothing,
4343
blocks_s = nothing,
4444
always_inline = true,
4545
caller = :unknown,
4646
) where {F!}
4747
if auto
48-
nitems = get_n_items(data)
48+
@assert !isnothing(nitems)
4949
if nitems 0
5050
kernel = CUDA.@cuda always_inline = true launch = false f!(args...)
5151
config = CUDA.launch_configuration(kernel.fun)
@@ -64,7 +64,7 @@ function auto_launch!(
6464
# CUDA.registers(kernel) > 50 || return nothing # for debugging
6565
# occursin("single_field_solve_kernel", string(nameof(F!))) || return nothing
6666
if !haskey(reported_stats, key)
67-
nitems = get_n_items(data)
67+
@assert !isnothing(nitems)
6868
kernel = CUDA.@cuda always_inline = true launch = false f!(args...)
6969
config = CUDA.launch_configuration(kernel.fun)
7070
threads = min(nitems, config.threads)

ext/cuda/data_layouts_copyto.jl

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ function Base.copyto!(
2323
if Nh > 0
2424
auto_launch!(
2525
knl_copyto!,
26-
(dest, bc),
27-
dest;
26+
(dest, bc);
2827
threads_s = (Nij, Nij),
2928
blocks_s = (Nh, 1),
3029
)
@@ -42,8 +41,7 @@ function Base.copyto!(
4241
Nv_blocks = cld(Nv, Nv_per_block)
4342
auto_launch!(
4443
knl_copyto!,
45-
(dest, bc),
46-
dest;
44+
(dest, bc);
4745
threads_s = (Nij, Nij, Nv_per_block),
4846
blocks_s = (Nh, Nv_blocks),
4947
)
@@ -59,8 +57,7 @@ function Base.copyto!(
5957
if Nv > 0
6058
auto_launch!(
6159
knl_copyto!,
62-
(dest, bc),
63-
dest;
60+
(dest, bc);
6461
threads_s = (1, 1),
6562
blocks_s = (1, Nv),
6663
)
@@ -73,13 +70,7 @@ function Base.copyto!(
7370
bc::DataLayouts.BroadcastedUnionDataF{S},
7471
::ToCUDA,
7572
) where {S}
76-
auto_launch!(
77-
knl_copyto!,
78-
(dest, bc),
79-
dest;
80-
threads_s = (1, 1),
81-
blocks_s = (1, 1),
82-
)
73+
auto_launch!(knl_copyto!, (dest, bc); threads_s = (1, 1), blocks_s = (1, 1))
8374
return dest
8475
end
8576

@@ -100,7 +91,8 @@ function cuda_copyto!(dest::AbstractData, bc)
10091
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
10192
us = DataLayouts.UniversalSize(dest)
10293
if Nv > 0 && Nh > 0
103-
auto_launch!(knl_copyto_flat!, (dest, bc, us), dest; auto = true)
94+
nitems = prod(DataLayouts.universal_size(dest))
95+
auto_launch!(knl_copyto_flat!, (dest, bc, us), nitems; auto = true)
10496
end
10597
return dest
10698
end

ext/cuda/data_layouts_fill.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ function cuda_fill!(dest::AbstractData, val)
1414
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
1515
us = DataLayouts.UniversalSize(dest)
1616
if Nv > 0 && Nh > 0
17-
auto_launch!(knl_fill_flat!, (dest, val, us), dest; auto = true)
17+
nitems = prod(DataLayouts.universal_size(dest))
18+
auto_launch!(knl_fill_flat!, (dest, val, us), nitems; auto = true)
1819
end
1920
return dest
2021
end

ext/cuda/data_layouts_fused_copyto.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ function fused_copyto!(
5050
Nv_blocks = cld(Nv, Nv_per_block)
5151
auto_launch!(
5252
knl_fused_copyto!,
53-
(fmbc,),
54-
dest1;
53+
(fmbc,);
5554
threads_s = (Nij, Nij, Nv_per_block),
5655
blocks_s = (Nh, Nv_blocks),
5756
)
@@ -68,8 +67,7 @@ function fused_copyto!(
6867
if Nh > 0
6968
auto_launch!(
7069
knl_fused_copyto!,
71-
(fmbc,),
72-
dest1;
70+
(fmbc,);
7371
threads_s = (Nij, Nij),
7472
blocks_s = (Nh, 1),
7573
)
@@ -85,8 +83,7 @@ function fused_copyto!(
8583
if Nv > 0 && Nh > 0
8684
auto_launch!(
8785
knl_fused_copyto!,
88-
(fmbc,),
89-
dest1;
86+
(fmbc,);
9087
threads_s = (1, 1),
9188
blocks_s = (Nh, Nv),
9289
)
@@ -101,8 +98,7 @@ function fused_copyto!(
10198
) where {S}
10299
auto_launch!(
103100
knl_fused_copyto!,
104-
(fmbc,),
105-
dest1;
101+
(fmbc,);
106102
threads_s = (1, 1),
107103
blocks_s = (1, 1),
108104
)

ext/cuda/data_layouts_mapreduce.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function mapreduce_cuda(
2828
pdata = parent(data)
2929
T = eltype(pdata)
3030
(Ni, Nj, Nk, Nv, Nh) = size(data)
31-
Nf = div(length(pdata), prod(size(data))) # length of field dimension
31+
Nf = DataLayouts.ncomponents(data) # length of field dimension
3232
pwt = parent(weighted_jacobian)
3333

3434
nitems = Nv * Ni * Nj * Nk * Nh

ext/cuda/limiters.jl

Lines changed: 19 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -21,39 +21,26 @@ function compute_element_bounds!(
2121
ρ,
2222
::ClimaComms.CUDADevice,
2323
)
24-
S = size(Fields.field_values(ρ))
25-
(Ni, Nj, _, Nv, Nh) = S
24+
(_, _, Nv, _, Nh) = DataLayouts.universal_size(ρ)
2625
nthreads, nblocks = config_threadblock(Nv, Nh)
2726

2827
args = (
2928
limiter,
3029
Fields.field_values(Operators.strip_space(ρq, axes(ρq))),
3130
Fields.field_values(Operators.strip_space(ρ, axes(ρ))),
32-
Nv,
33-
Nh,
34-
Val(Ni),
35-
Val(Nj),
3631
)
3732
auto_launch!(
3833
compute_element_bounds_kernel!,
39-
args,
40-
ρ;
34+
args;
4135
threads_s = nthreads,
4236
blocks_s = nblocks,
4337
)
4438
return nothing
4539
end
4640

4741

48-
function compute_element_bounds_kernel!(
49-
limiter,
50-
ρq,
51-
ρ,
52-
Nv,
53-
Nh,
54-
::Val{Ni},
55-
::Val{Nj},
56-
) where {Ni, Nj}
42+
function compute_element_bounds_kernel!(limiter, ρq, ρ)
43+
(Ni, Nj, Nv, _, Nh) = DataLayouts.universal_size(ρ)
5744
n = (Nv, Nh)
5845
tidx = thread_index()
5946
@inbounds if valid_range(tidx, prod(n))
@@ -88,21 +75,18 @@ function compute_neighbor_bounds_local!(
8875
::ClimaComms.CUDADevice,
8976
)
9077
topology = Spaces.topology(axes(ρ))
91-
Ni, Nj, _, Nv, Nh = size(Fields.field_values(ρ))
78+
us = DataLayouts.UniversalSize(ρ)
79+
(Ni, Nj, Nv, _, Nh) = DataLayouts.universal_size(us)
9280
nthreads, nblocks = config_threadblock(Nv, Nh)
9381
args = (
9482
limiter,
9583
topology.local_neighbor_elem,
9684
topology.local_neighbor_elem_offset,
97-
Nv,
98-
Nh,
99-
Val(Ni),
100-
Val(Nj),
85+
us,
10186
)
10287
auto_launch!(
10388
compute_neighbor_bounds_local_kernel!,
104-
args,
105-
ρ;
89+
args;
10690
threads_s = nthreads,
10791
blocks_s = nblocks,
10892
)
@@ -112,12 +96,9 @@ function compute_neighbor_bounds_local_kernel!(
11296
limiter,
11397
local_neighbor_elem,
11498
local_neighbor_elem_offset,
115-
Nv,
116-
Nh,
117-
::Val{Ni},
118-
::Val{Nj},
119-
) where {Ni, Nj}
120-
99+
us::DataLayouts.UniversalSize,
100+
)
101+
(Ni, Nj, Nv, _, Nh) = DataLayouts.universal_size(us)
121102
n = (Nv, Nh)
122103
tidx = thread_index()
123104
@inbounds if valid_range(tidx, prod(n))
@@ -147,27 +128,24 @@ function apply_limiter!(
147128
::ClimaComms.CUDADevice,
148129
)
149130
ρq_data = Fields.field_values(ρq)
150-
(Ni, Nj, _, Nv, Nh) = size(ρq_data)
151-
Nf = DataLayouts.ncomponents(ρq_data)
131+
us = DataLayouts.UniversalSize(ρq_data)
132+
(Ni, Nj, _, Nv, Nh) = DataLayouts.universal_size(us)
152133
maxiter = Ni * Nj
134+
Nf = DataLayouts.ncomponents(ρq_data)
153135
WJ = Spaces.local_geometry_data(axes(ρq)).WJ
154136
nthreads, nblocks = config_threadblock(Nv, Nh)
155137
args = (
156138
limiter,
157139
Fields.field_values(Operators.strip_space(ρq, axes(ρq))),
158140
Fields.field_values(Operators.strip_space(ρ, axes(ρ))),
159141
WJ,
160-
Nv,
161-
Nh,
142+
us,
162143
Val(Nf),
163-
Val(Ni),
164-
Val(Nj),
165144
Val(maxiter),
166145
)
167146
auto_launch!(
168147
apply_limiter_kernel!,
169-
args,
170-
ρ;
148+
args;
171149
threads_s = nthreads,
172150
blocks_s = nblocks,
173151
)
@@ -179,15 +157,13 @@ function apply_limiter_kernel!(
179157
ρq_data,
180158
ρ_data,
181159
WJ_data,
182-
Nv,
183-
Nh,
160+
us::DataLayouts.UniversalSize,
184161
::Val{Nf},
185-
::Val{Ni},
186-
::Val{Nj},
187162
::Val{maxiter},
188-
) where {Nf, Ni, Nj, maxiter}
163+
) where {Nf, maxiter}
189164
(; q_bounds_nbr, rtol) = limiter
190165
converged = true
166+
(Ni, Nj, Nv, _, Nh) = DataLayouts.universal_size(us)
191167
n = (Nv, Nh)
192168
tidx = thread_index()
193169
@inbounds if valid_range(tidx, prod(n))

0 commit comments

Comments
 (0)