Skip to content

Define a linear partition, and use in FD stencils #2002

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 58 additions & 8 deletions ext/cuda/data_layouts_threadblock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,25 @@ end
##### Custom partitions
#####

##### linear partition
@inline function linear_partition(
us::DataLayouts.UniversalSize,
n_max_threads::Integer,
)
nitems = prod(DataLayouts.universal_size(us))
threads = min(nitems, n_max_threads)
blocks = cld(nitems, threads)
return (; threads, blocks)
end
@inline function linear_universal_index(us::UniversalSize)
i = (CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x + CUDA.threadIdx().x
inds = DataLayouts.universal_size(us)
CI = CartesianIndices(map(x -> Base.OneTo(x), inds))
return (CI, i)
end
@inline linear_is_valid_index(i::Integer, us::UniversalSize) =
1 ≤ i ≤ DataLayouts.get_N(us)

##### Column-wise
@inline function columnwise_partition(
us::DataLayouts.UniversalSize,
Expand All @@ -194,6 +213,26 @@ end
@inline columnwise_is_valid_index(I::CI5, us::UniversalSize) =
1 ≤ I[5] ≤ DataLayouts.get_Nh(us)

@inline function columnwise_linear_partition(
us::DataLayouts.UniversalSize,
n_max_threads::Integer,
)
(Nij, _, _, _, Nh) = DataLayouts.universal_size(us)
nitems = prod((Nij, Nij, Nh))
threads = min(nitems, n_max_threads)
blocks = cld(nitems, threads)
return (; threads, blocks)
end
@inline function columnwise_linear_universal_index(us)
i = (CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x + CUDA.threadIdx().x
(Nij, _, _, _, Nh) = DataLayouts.universal_size(us)
n = (Nij, Nij, Nh)
CI = CartesianIndices(map(x -> Base.OneTo(x), n))
return (CI, i)
end
@inline columnwise_linear_is_valid_index(i_linear::Integer, N::Integer) =
1 ≤ i_linear ≤ N

##### Element-wise (e.g., limiters)
# TODO

Expand All @@ -204,16 +243,27 @@ end
Nnames,
)
(Nij, _, _, _, Nh) = DataLayouts.universal_size(us)
@assert prod((Nij, Nij, Nnames)) ≤ n_max_threads "threads,n_max_threads=($(prod((Nij, Nij, Nnames))),$n_max_threads)"
return (; threads = (Nij, Nij, Nnames), blocks = (Nh,))
# @assert prod((Nij, Nij, Nnames)) ≤ n_max_threads "threads,n_max_threads=($(prod((Nij, Nij, Nnames))),$n_max_threads)"
# return (; threads = (Nij, Nij, Nnames), blocks = (Nh,))
nitems = prod((Nh, Nij, Nij, Nnames))
threads = min(nitems, n_max_threads)
blocks = cld(nitems, threads)
return (; threads, blocks)
end
@inline function multiple_field_solve_universal_index(us::UniversalSize)
(i, j, iname) = CUDA.threadIdx()
(h,) = CUDA.blockIdx()
return (CartesianIndex((i, j, 1, 1, h)), iname)
@inline function multiple_field_solve_universal_index(us::DataLayouts.UniversalSize, ::Val{Nnames}) where {Nnames}
# (i, j, iname) = CUDA.threadIdx()
# (h,) = CUDA.blockIdx()
# return (CartesianIndex((i, j, 1, 1, h)), iname)
i = (CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x + CUDA.threadIdx().x
(Nij, _, _, _, Nh) = DataLayouts.universal_size(us)
n = (Nij, Nij, Nh, Nnames)
CI = CartesianIndices(n)
return (CI, i)
end
@inline multiple_field_solve_is_valid_index(I::CI5, us::UniversalSize) =
1 ≤ I[5] ≤ DataLayouts.get_Nh(us)
# @inline multiple_field_solve_is_valid_index(I::CI5, us::UniversalSize) =
# 1 ≤ I[5] ≤ DataLayouts.get_Nh(us)
@inline multiple_field_solve_is_valid_index(i_linear::Integer, N::Integer) =
1 ≤ i_linear ≤ N

##### spectral kernel partition
@inline function spectral_partition(
Expand Down
6 changes: 3 additions & 3 deletions ext/cuda/matrix_fields_multiple_field_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ function multiple_field_solve_kernel!(
::Val{Nnames},
) where {Nnames}
@inbounds begin
(I, iname) = multiple_field_solve_universal_index(us)
if multiple_field_solve_is_valid_index(I, us)
(i, j, _, _, h) = I.I
(CI, i_linear) = multiple_field_solve_universal_index(us, Val(Nnames))
if multiple_field_solve_is_valid_index(i_linear, prod(CI.I))
(i, j, _, _, h, iname) = CI.I
generated_single_field_solve!(
device,
caches,
Expand Down
11 changes: 6 additions & 5 deletions ext/cuda/operators_finite_difference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ function Base.copyto!(
bounds = Operators.window_bounds(space, bc)
out_fv = Fields.field_values(out)
us = DataLayouts.UniversalSize(out_fv)
nitems = prod(DataLayouts.universal_size(us))
args =
(strip_space(out, space), strip_space(bc, space), axes(out), bounds, us)

threads = threads_via_occupancy(copyto_stencil_kernel!, args)
n_max_threads = min(threads, get_N(us))
p = partition(out_fv, n_max_threads)
n_max_threads = min(threads, nitems)
p = linear_partition(us, n_max_threads)

auto_launch!(
copyto_stencil_kernel!,
Expand All @@ -40,9 +41,9 @@ import ClimaCore.DataLayouts: get_N, get_Nv, get_Nij, get_Nij, get_Nh

function copyto_stencil_kernel!(out, bc, space, bds, us)
@inbounds begin
out_fv = Fields.field_values(out)
I = universal_index(out_fv)
if is_valid_index(out_fv, I, us)
(CI, i_linear) = linear_universal_index(us)
if linear_is_valid_index(i_linear, us)
I = CI[i_linear]
(li, lw, rw, ri) = bds
(i, j, _, v, h) = I.I
hidx = (i, j, h)
Expand Down
3 changes: 1 addition & 2 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ is excluded and is returned as 1.

Statically returns `prod((Ni, Nj, Nv, Nh))`
"""
@inline get_N(::UniversalSize{Ni, Nj, Nv, Nh}) where {Ni, Nj, Nv, Nh} =
prod((Ni, Nj, Nv, Nh))
@inline get_N(us::UniversalSize) = prod(universal_size(us))

"""
get_Nv(::UniversalSize)
Expand Down
Loading