Skip to content

Commit 3a8464b

Browse files
Define a linear partition, and use in FD stencils
1 parent bd20629 commit 3a8464b

File tree

3 files changed

+25
-7
lines changed

3 files changed

+25
-7
lines changed

ext/cuda/data_layouts_threadblock.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,25 @@ end
170170
##### Custom partitions
171171
#####
172172

173+
##### linear partition
174+
@inline function linear_partition(
175+
us::DataLayouts.UniversalSize,
176+
n_max_threads::Integer,
177+
)
178+
nitems = prod(DataLayouts.universal_size(us))
179+
threads = min(nitems, n_max_threads)
180+
blocks = cld(nitems, threads)
181+
return (; threads, blocks)
182+
end
183+
@inline function linear_universal_index(us::UniversalSize)
184+
i = (CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x + CUDA.threadIdx().x
185+
inds = DataLayouts.universal_size(us)
186+
CI = CartesianIndices(map(x -> Base.OneTo(x), inds))
187+
return (CI[i], i)
188+
end
189+
@inline linear_is_valid_index(i::Integer, us::UniversalSize) =
190+
1 i DataLayouts.get_N(us)
191+
173192
##### Column-wise
174193
@inline function columnwise_partition(
175194
us::DataLayouts.UniversalSize,

ext/cuda/operators_finite_difference.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@ function Base.copyto!(
2121
bounds = Operators.window_bounds(space, bc)
2222
out_fv = Fields.field_values(out)
2323
us = DataLayouts.UniversalSize(out_fv)
24+
nitems = prod(DataLayouts.universal_size(us))
2425
args =
2526
(strip_space(out, space), strip_space(bc, space), axes(out), bounds, us)
2627

2728
threads = threads_via_occupancy(copyto_stencil_kernel!, args)
28-
n_max_threads = min(threads, get_N(us))
29-
p = partition(out_fv, n_max_threads)
29+
n_max_threads = min(threads, nitems)
30+
p = linear_partition(us, n_max_threads)
3031

3132
auto_launch!(
3233
copyto_stencil_kernel!,
@@ -40,9 +41,8 @@ import ClimaCore.DataLayouts: get_N, get_Nv, get_Nij, get_Nij, get_Nh
4041

4142
function copyto_stencil_kernel!(out, bc, space, bds, us)
4243
@inbounds begin
43-
out_fv = Fields.field_values(out)
44-
I = universal_index(out_fv)
45-
if is_valid_index(out_fv, I, us)
44+
(I, i_linear) = linear_universal_index(us)
45+
if linear_is_valid_index(i_linear, us)
4646
(li, lw, rw, ri) = bds
4747
(i, j, _, v, h) = I.I
4848
hidx = (i, j, h)

src/DataLayouts/DataLayouts.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ is excluded and is returned as 1.
8787
8888
Statically returns `prod((Ni, Nj, Nv, Nh))`
8989
"""
90-
@inline get_N(::UniversalSize{Ni, Nj, Nv, Nh}) where {Ni, Nj, Nv, Nh} =
91-
prod((Ni, Nj, Nv, Nh))
90+
@inline get_N(us::UniversalSize) = prod(universal_size(us))
9291

9392
"""
9493
get_Nv(::UniversalSize)

0 commit comments

Comments
 (0)