diff --git a/ext/cuda/data_layouts_threadblock.jl b/ext/cuda/data_layouts_threadblock.jl index 02a0aeff6c..5786aa27ed 100644 --- a/ext/cuda/data_layouts_threadblock.jl +++ b/ext/cuda/data_layouts_threadblock.jl @@ -170,6 +170,25 @@ end ##### Custom partitions ##### +##### linear partition +@inline function linear_partition( + us::DataLayouts.UniversalSize, + n_max_threads::Integer, +) + nitems = prod(DataLayouts.universal_size(us)) + threads = min(nitems, n_max_threads) + blocks = cld(nitems, threads) + return (; threads, blocks) +end +@inline function linear_universal_index(us::UniversalSize) + i = (CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x + CUDA.threadIdx().x + inds = DataLayouts.universal_size(us) + CI = CartesianIndices(map(x -> Base.OneTo(x), inds)) + return (CI, i) +end +@inline linear_is_valid_index(i::Integer, us::UniversalSize) = + 1 ≤ i ≤ DataLayouts.get_N(us) + ##### Column-wise @inline function columnwise_partition( us::DataLayouts.UniversalSize, @@ -194,6 +213,26 @@ end @inline columnwise_is_valid_index(I::CI5, us::UniversalSize) = 1 ≤ I[5] ≤ DataLayouts.get_Nh(us) +@inline function columnwise_linear_partition( + us::DataLayouts.UniversalSize, + n_max_threads::Integer, +) + (Nij, _, _, _, Nh) = DataLayouts.universal_size(us) + nitems = prod((Nij, Nij, Nh)) + threads = min(nitems, n_max_threads) + blocks = cld(nitems, threads) + return (; threads, blocks) +end +@inline function columnwise_linear_universal_index(us) + i = (CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x + CUDA.threadIdx().x + (Nij, _, _, _, Nh) = DataLayouts.universal_size(us) + n = (Nij, Nij, Nh) + CI = CartesianIndices(map(x -> Base.OneTo(x), n)) + return (CI, i) +end +@inline columnwise_linear_is_valid_index(i_linear::Integer, N::Integer) = + 1 ≤ i_linear ≤ N + ##### Element-wise (e.g., limiters) # TODO @@ -204,16 +243,27 @@ end Nnames, ) (Nij, _, _, _, Nh) = DataLayouts.universal_size(us) - @assert prod((Nij, Nij, Nnames)) ≤ n_max_threads "threads,n_max_threads=($(prod((Nij, Nij, Nnames))),$n_max_threads)" - return (; threads = (Nij, Nij, Nnames), blocks = (Nh,)) + # @assert prod((Nij, Nij, Nnames)) ≤ n_max_threads "threads,n_max_threads=($(prod((Nij, Nij, Nnames))),$n_max_threads)" + # return (; threads = (Nij, Nij, Nnames), blocks = (Nh,)) + nitems = prod((Nh, Nij, Nij, Nnames)) + threads = min(nitems, n_max_threads) + blocks = cld(nitems, threads) + return (; threads, blocks) end -@inline function multiple_field_solve_universal_index(us::UniversalSize) - (i, j, iname) = CUDA.threadIdx() - (h,) = CUDA.blockIdx() - return (CartesianIndex((i, j, 1, 1, h)), iname) +@inline function multiple_field_solve_universal_index(us::DataLayouts.UniversalSize, ::Val{Nnames}) where {Nnames} + # (i, j, iname) = CUDA.threadIdx() + # (h,) = CUDA.blockIdx() + # return (CartesianIndex((i, j, 1, 1, h)), iname) + i = (CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x + CUDA.threadIdx().x + (Nij, _, _, _, Nh) = DataLayouts.universal_size(us) + n = (Nij, Nij, Nh, Nnames) + CI = CartesianIndices(n) + return (CI, i) end -@inline multiple_field_solve_is_valid_index(I::CI5, us::UniversalSize) = - 1 ≤ I[5] ≤ DataLayouts.get_Nh(us) +# @inline multiple_field_solve_is_valid_index(I::CI5, us::UniversalSize) = +# 1 ≤ I[5] ≤ DataLayouts.get_Nh(us) +@inline multiple_field_solve_is_valid_index(i_linear::Integer, N::Integer) = + 1 ≤ i_linear ≤ N ##### spectral kernel partition @inline function spectral_partition( diff --git a/ext/cuda/matrix_fields_multiple_field_solve.jl b/ext/cuda/matrix_fields_multiple_field_solve.jl index 3955aabaa7..3af9e9bb04 100644 --- a/ext/cuda/matrix_fields_multiple_field_solve.jl +++ b/ext/cuda/matrix_fields_multiple_field_solve.jl @@ -89,9 +89,9 @@ function multiple_field_solve_kernel!( ::Val{Nnames}, ) where {Nnames} @inbounds begin - (I, iname) = multiple_field_solve_universal_index(us) - if multiple_field_solve_is_valid_index(I, us) - (i, j, _, _, h) = I.I + (CI, i_linear) = multiple_field_solve_universal_index(us, Val(Nnames)) + if multiple_field_solve_is_valid_index(i_linear, prod(CI.I)) + (i, j, _, _, h, iname) = CI.I generated_single_field_solve!( device, caches, diff --git a/ext/cuda/operators_finite_difference.jl b/ext/cuda/operators_finite_difference.jl index c93b4e0797..fdb176a4b8 100644 --- a/ext/cuda/operators_finite_difference.jl +++ b/ext/cuda/operators_finite_difference.jl @@ -21,12 +21,13 @@ function Base.copyto!( bounds = Operators.window_bounds(space, bc) out_fv = Fields.field_values(out) us = DataLayouts.UniversalSize(out_fv) + nitems = prod(DataLayouts.universal_size(us)) args = (strip_space(out, space), strip_space(bc, space), axes(out), bounds, us) threads = threads_via_occupancy(copyto_stencil_kernel!, args) - n_max_threads = min(threads, get_N(us)) - p = partition(out_fv, n_max_threads) + n_max_threads = min(threads, nitems) + p = linear_partition(us, n_max_threads) auto_launch!( copyto_stencil_kernel!, @@ -40,9 +41,9 @@ import ClimaCore.DataLayouts: get_N, get_Nv, get_Nij, get_Nij, get_Nh function copyto_stencil_kernel!(out, bc, space, bds, us) @inbounds begin - out_fv = Fields.field_values(out) - I = universal_index(out_fv) - if is_valid_index(out_fv, I, us) + (CI, i_linear) = linear_universal_index(us) + if linear_is_valid_index(i_linear, us) + I = CI[i_linear] (li, lw, rw, ri) = bds (i, j, _, v, h) = I.I hidx = (i, j, h) diff --git a/src/DataLayouts/DataLayouts.jl b/src/DataLayouts/DataLayouts.jl index 5d45024e5c..c8392826ee 100644 --- a/src/DataLayouts/DataLayouts.jl +++ b/src/DataLayouts/DataLayouts.jl @@ -87,8 +87,7 @@ is excluded and is returned as 1. Statically returns `prod((Ni, Nj, Nv, Nh))` """ -@inline get_N(::UniversalSize{Ni, Nj, Nv, Nh}) where {Ni, Nj, Nv, Nh} = - prod((Ni, Nj, Nv, Nh)) +@inline get_N(us::UniversalSize) = prod(universal_size(us)) """ get_Nv(::UniversalSize)