Skip to content

Commit 39f76d6

Browse files
Merge pull request #2006 from CliMA/ck/thread_partition_interface
Hoist UniversalSize computation outside of kernels
2 parents 65d0e30 + a540321 commit 39f76d6

File tree

6 files changed

+19
-15
lines changed

6 files changed

+19
-15
lines changed

ext/cuda/data_layouts_threadblock.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ end
185185
@assert prod((Nij, Nij, Nh_thread)) n_max_threads "threads,n_max_threads=($(prod((Nij, Nij, Nh_thread))),$n_max_threads)"
186186
return (; threads = (Nij, Nij, Nh_thread), blocks = (Nh_blocks,))
187187
end
188-
@inline function columnwise_universal_index()
188+
@inline function columnwise_universal_index(us::UniversalSize)
189189
(i, j, th) = CUDA.threadIdx()
190190
(bh,) = CUDA.blockIdx()
191191
h = th + (bh - 1) * CUDA.blockDim().z
@@ -207,7 +207,7 @@ end
207207
@assert prod((Nij, Nij, Nnames)) n_max_threads "threads,n_max_threads=($(prod((Nij, Nij, Nnames))),$n_max_threads)"
208208
return (; threads = (Nij, Nij, Nnames), blocks = (Nh,))
209209
end
210-
@inline function multiple_field_solve_universal_index()
210+
@inline function multiple_field_solve_universal_index(us::UniversalSize)
211211
(i, j, iname) = CUDA.threadIdx()
212212
(h,) = CUDA.blockIdx()
213213
return (CartesianIndex((i, j, 1, 1, h)), iname)

ext/cuda/matrix_fields_multiple_field_solve.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ NVTX.@annotate function multiple_field_solve!(
3333

3434
device = ClimaComms.device(x[first(names)])
3535

36-
args = (device, caches, xs, As, bs, x1, Val(Nnames))
37-
3836
us = UniversalSize(Fields.field_values(x1))
37+
args = (device, caches, xs, As, bs, x1, us, Val(Nnames))
38+
3939
nitems = Ni * Nj * Nh * Nnames
4040
threads = threads_via_occupancy(multiple_field_solve_kernel!, args)
4141
n_max_threads = min(threads, nitems)
@@ -85,11 +85,11 @@ function multiple_field_solve_kernel!(
8585
As,
8686
bs,
8787
x1,
88+
us::UniversalSize,
8889
::Val{Nnames},
8990
) where {Nnames}
9091
@inbounds begin
91-
us = UniversalSize(Fields.field_values(x1))
92-
(I, iname) = multiple_field_solve_universal_index()
92+
(I, iname) = multiple_field_solve_universal_index(us)
9393
if multiple_field_solve_is_valid_index(I, us)
9494
(i, j, _, _, h) = I.I
9595
generated_single_field_solve!(

ext/cuda/matrix_fields_single_field_solve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b)
3030
end
3131

3232
function single_field_solve_kernel!(device, cache, x, A, b, us)
33-
I = columnwise_universal_index()
33+
I = columnwise_universal_index(us)
3434
if columnwise_is_valid_index(I, us)
3535
(i, j, _, _, h) = I.I
3636
_single_field_solve!(

ext/cuda/operators_integral.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ function column_reduce_device!(
1919
space,
2020
) where {F, T}
2121
Ni, Nj, _, _, Nh = size(Fields.field_values(output))
22+
us = UniversalSize(Fields.field_values(output))
2223
args = (
2324
single_column_reduce!,
2425
f,
@@ -27,8 +28,8 @@ function column_reduce_device!(
2728
strip_space(input, space),
2829
init,
2930
space,
31+
us,
3032
)
31-
us = UniversalSize(Fields.field_values(output))
3233
nitems = Ni * Nj * Nh
3334
threads = threads_via_occupancy(bycolumn_kernel!, args)
3435
n_max_threads = min(threads, nitems)
@@ -50,7 +51,8 @@ function column_accumulate_device!(
5051
init,
5152
space,
5253
) where {F, T}
53-
us = UniversalSize(Fields.field_values(output))
54+
out_fv = Fields.field_values(output)
55+
us = UniversalSize(out_fv)
5456
args = (
5557
single_column_accumulate!,
5658
f,
@@ -59,8 +61,9 @@ function column_accumulate_device!(
5961
strip_space(input, space),
6062
init,
6163
space,
64+
us,
6265
)
63-
Ni, Nj, _, _, Nh = size(Fields.field_values(output))
66+
(Ni, Nj, _, _, Nh) = DataLayouts.universal_size(us)
6467
nitems = Ni * Nj * Nh
6568
threads = threads_via_occupancy(bycolumn_kernel!, args)
6669
n_max_threads = min(threads, nitems)
@@ -81,12 +84,12 @@ bycolumn_kernel!(
8184
input,
8285
init,
8386
space,
87+
us::DataLayouts.UniversalSize,
8488
) where {S, F, T} =
8589
if space isa Spaces.FiniteDifferenceSpace
8690
single_column_function!(f, transform, output, input, init, space)
8791
else
88-
I = columnwise_universal_index()
89-
us = UniversalSize(Fields.field_values(output))
92+
I = columnwise_universal_index(us)
9093
if columnwise_is_valid_index(I, us)
9194
(i, j, _, _, h) = I.I
9295
single_column_function!(

ext/cuda/operators_thomas_algorithm.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import CUDA
66
using CUDA: @cuda
77
function column_thomas_solve!(::ClimaComms.CUDADevice, A, b)
88
us = UniversalSize(Fields.field_values(A))
9-
args = (A, b)
9+
args = (A, b, us)
1010
Ni, Nj, _, _, Nh = size(Fields.field_values(A))
1111
threads = threads_via_occupancy(thomas_algorithm_kernel!, args)
1212
nitems = Ni * Nj * Nh
@@ -23,9 +23,9 @@ end
2323
function thomas_algorithm_kernel!(
2424
A::Fields.ExtrudedFiniteDifferenceField,
2525
b::Fields.ExtrudedFiniteDifferenceField,
26+
us::DataLayouts.UniversalSize,
2627
)
27-
I = columnwise_universal_index()
28-
us = UniversalSize(Fields.field_values(A))
28+
I = columnwise_universal_index(us)
2929
if columnwise_is_valid_index(I, us)
3030
(i, j, _, _, h) = I.I
3131
thomas_algorithm!(Spaces.column(A, i, j, h), Spaces.column(b, i, j, h))

src/Operators/thomas_algorithm.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ column_thomas_solve!(::ClimaComms.AbstractCPUDevice, A, b) =
1717
thomas_algorithm_kernel!(
1818
A::Fields.FiniteDifferenceField,
1919
b::Fields.FiniteDifferenceField,
20+
us::DataLayouts.UniversalSize,
2021
) = thomas_algorithm!(A, b)
2122

2223
function thomas_algorithm!(

0 commit comments

Comments
 (0)