From c746c194ff5d856e7fa1e8101d975fc3912a7979 Mon Sep 17 00:00:00 2001 From: Charlie Kawczynski Date: Fri, 18 Apr 2025 12:57:45 -0700 Subject: [PATCH] Refactor FD shmem index management --- ext/cuda/operators_fd_shmem.jl | 89 +++++++++++++------------ ext/cuda/operators_finite_difference.jl | 25 +++++++ 2 files changed, 73 insertions(+), 41 deletions(-) diff --git a/ext/cuda/operators_fd_shmem.jl b/ext/cuda/operators_fd_shmem.jl index da1d3ffdd9..a86169f976 100644 --- a/ext/cuda/operators_fd_shmem.jl +++ b/ext/cuda/operators_fd_shmem.jl @@ -29,20 +29,21 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!( arg, ) @inbounds begin - vt = threadIdx().x + si = FDShmemIndex() + bi = FDShmemBoundaryIndex() lg = Geometry.LocalGeometry(space, idx, hidx) if !on_boundary(idx, space, op) u³ = Operators.getidx(space, arg, idx, hidx) - Ju³[vt] = Geometry.Jcontravariant3(u³, lg) + Ju³[si] = Geometry.Jcontravariant3(u³, lg) elseif on_left_boundary(idx, space, op) bloc = Operators.left_boundary_window(space) bc = Operators.get_boundary(op, bloc) ub = Operators.getidx(space, bc.val, nothing, hidx) bJu³ = on_left_boundary(idx, space) ? lJu³ : rJu³ if bc isa Operators.SetValue - bJu³[1] = Geometry.Jcontravariant3(ub, lg) + bJu³[bi] = Geometry.Jcontravariant3(ub, lg) elseif bc isa Operators.SetDivergence - bJu³[1] = ub + bJu³[bi] = ub elseif bc isa Operators.Extrapolate # no shmem needed end elseif on_right_boundary(idx, space, op) @@ -51,9 +52,9 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!( ub = Operators.getidx(space, bc.val, nothing, hidx) bJu³ = on_left_boundary(idx, space) ? lJu³ : rJu³ if bc isa Operators.SetValue - bJu³[1] = Geometry.Jcontravariant3(ub, lg) + bJu³[bi] = Geometry.Jcontravariant3(ub, lg) elseif bc isa Operators.SetDivergence - bJu³[1] = ub + bJu³[bi] = ub elseif bc isa Operators.Extrapolate # no shmem needed end end @@ -70,11 +71,12 @@ Base.@propagate_inbounds function fd_operator_evaluate( arg, ) @inbounds begin - vt = threadIdx().x + si = FDShmemIndex() + bi = FDShmemBoundaryIndex() lg = Geometry.LocalGeometry(space, idx, hidx) if !on_boundary(idx, space, op) - Ju³₋ = Ju³[vt] # corresponds to idx - half - Ju³₊ = Ju³[vt + 1] # corresponds to idx + half + Ju³₋ = Ju³[si] # corresponds to idx - half + Ju³₊ = Ju³[si + 1] # corresponds to idx + half return (Ju³₊ ⊟ Ju³₋) ⊠ lg.invJ else bloc = @@ -85,22 +87,22 @@ Base.@propagate_inbounds function fd_operator_evaluate( @assert bc isa Operators.SetValue || bc isa Operators.SetDivergence if on_left_boundary(idx, space) if bc isa Operators.SetValue - Ju³₋ = lJu³[1] # corresponds to idx - half - Ju³₊ = Ju³[vt + 1] # corresponds to idx + half + Ju³₋ = lJu³[bi] # corresponds to idx - half + Ju³₊ = Ju³[si + 1] # corresponds to idx + half return (Ju³₊ ⊟ Ju³₋) ⊠ lg.invJ else # @assert bc isa Operators.SetDivergence - return lJu³[1] + return lJu³[bi] end else @assert on_right_boundary(idx, space) if bc isa Operators.SetValue - Ju³₋ = Ju³[vt] # corresponds to idx - half - Ju³₊ = rJu³[1] # corresponds to idx + half + Ju³₋ = Ju³[si] # corresponds to idx - half + Ju³₊ = rJu³[bi] # corresponds to idx + half return (Ju³₊ ⊟ Ju³₋) ⊠ lg.invJ else @assert bc isa Operators.SetDivergence - return rJu³[1] + return rJu³[bi] end end end @@ -133,10 +135,11 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!( ) @inbounds begin is_out_of_bounds(idx, space) && return nothing - vt = threadIdx().x + si = FDShmemIndex() + bi = FDShmemBoundaryIndex() cov3 = Geometry.Covariant3Vector(1) if in_domain(idx, arg_space) - u[vt] = cov3 ⊗ Operators.getidx(space, arg, idx, hidx) + u[si] = cov3 ⊗ Operators.getidx(space, arg, idx, hidx) end if on_any_boundary(idx, space, op) lloc = Operators.left_boundary_window(space) @@ -149,10 +152,10 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!( ub = Operators.getidx(space, bc.val, nothing, hidx) bu = on_left_boundary(idx, space) ? lb : rb if bc isa Operators.SetValue - bu[1] = cov3 ⊗ ub + bu[bi] = cov3 ⊗ ub elseif bc isa Operators.SetGradient lg = Geometry.LocalGeometry(space, idx, hidx) - bu[1] = Geometry.project(Geometry.Covariant3Axis(), ub, lg) + bu[bi] = Geometry.project(Geometry.Covariant3Axis(), ub, lg) elseif bc isa Operators.Extrapolate # no shmem needed end end @@ -169,11 +172,12 @@ Base.@propagate_inbounds function fd_operator_evaluate( args..., ) @inbounds begin - vt = threadIdx().x + si = FDShmemIndex() + bi = FDShmemBoundaryIndex() lg = Geometry.LocalGeometry(space, idx, hidx) if !on_boundary(idx, space, op) - u₋ = u[vt - 1] # corresponds to idx - half - u₊ = u[vt] # corresponds to idx + half + u₋ = u[si - 1] # corresponds to idx - half + u₊ = u[si] # corresponds to idx + half return u₊ ⊟ u₋ else bloc = @@ -184,15 +188,15 @@ Base.@propagate_inbounds function fd_operator_evaluate( @assert bc isa Operators.SetValue if on_left_boundary(idx, space) if bc isa Operators.SetValue - u₋ = 2 * lb[1] # corresponds to idx - half - u₊ = 2 * u[vt] # corresponds to idx + half + u₋ = 2 * lb[bi] # corresponds to idx - half + u₊ = 2 * u[si] # corresponds to idx + half return u₊ ⊟ u₋ end else @assert on_right_boundary(idx, space) if bc isa Operators.SetValue - u₋ = 2 * u[vt - 1] # corresponds to idx - half - u₊ = 2 * rb[1] # corresponds to idx + half + u₋ = 2 * u[si - 1] # corresponds to idx - half + u₊ = 2 * rb[bi] # corresponds to idx + half return u₊ ⊟ u₋ end end @@ -226,9 +230,10 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!( ) @inbounds begin is_out_of_bounds(idx, space) && return nothing - ᶜidx = get_cent_idx(idx) + si = FDShmemIndex(idx) + bi = FDShmemBoundaryIndex() if in_domain(idx, arg_space) - u[idx] = Operators.getidx(space, arg, idx, hidx) + u[si] = Operators.getidx(space, arg, idx, hidx) else lloc = Operators.left_boundary_window(space) rloc = Operators.right_boundary_window(space) @@ -242,16 +247,16 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!( bc isa Operators.NullBoundaryCondition if bc isa Operators.NullBoundaryCondition || bc isa Operators.Extrapolate - u[idx] = Operators.getidx(space, arg, idx, hidx) + u[si] = Operators.getidx(space, arg, idx, hidx) return nothing end bu = on_left_boundary(idx, space) ? lb : rb ub = Operators.getidx(space, bc.val, nothing, hidx) if bc isa Operators.SetValue - bu[1] = ub + bu[bi] = ub elseif bc isa Operators.SetGradient lg = Geometry.LocalGeometry(space, idx, hidx) - bu[1] = Geometry.covariant3(ub, lg) + bu[bi] = Geometry.covariant3(ub, lg) end end end @@ -270,9 +275,11 @@ Base.@propagate_inbounds function fd_operator_evaluate( vt = threadIdx().x lg = Geometry.LocalGeometry(space, idx, hidx) ᶜidx = get_cent_idx(idx) + si = FDShmemIndex(ᶜidx) + bi = FDShmemBoundaryIndex() if !on_boundary(idx, space, op) - u₋ = u[ᶜidx - 1] # corresponds to idx - half - u₊ = u[ᶜidx] # corresponds to idx + half + u₋ = u[si - 1] # corresponds to idx - half + u₊ = u[si] # corresponds to idx + half return RecursiveApply.rdiv(u₊ ⊞ u₋, 2) else bloc = @@ -285,26 +292,26 @@ Base.@propagate_inbounds function fd_operator_evaluate( bc isa Operators.Extrapolate if on_left_boundary(idx, space) if bc isa Operators.SetValue - return lb[1] + return lb[bi] elseif bc isa Operators.SetGradient - u₋ = lb[1] # corresponds to idx - half - u₊ = u[ᶜidx] # corresponds to idx + half + u₋ = lb[bi] # corresponds to idx - half + u₊ = u[si] # corresponds to idx + half return u₊ ⊟ RecursiveApply.rdiv(u₋, 2) else @assert bc isa Operators.Extrapolate - return u[ᶜidx] + return u[si] end else @assert on_right_boundary(idx, space) if bc isa Operators.SetValue - return rb[1] + return rb[bi] elseif bc isa Operators.SetGradient - u₋ = u[ᶜidx - 1] # corresponds to idx - half - u₊ = rb[1] # corresponds to idx + half + u₋ = u[si - 1] # corresponds to idx - half + u₊ = rb[bi] # corresponds to idx + half return u₋ ⊞ RecursiveApply.rdiv(u₊, 2) else @assert bc isa Operators.Extrapolate - return u[ᶜidx - 1] + return u[si - 1] end end end diff --git a/ext/cuda/operators_finite_difference.jl b/ext/cuda/operators_finite_difference.jl index cd2f8ad09a..fcacc1c3f0 100644 --- a/ext/cuda/operators_finite_difference.jl +++ b/ext/cuda/operators_finite_difference.jl @@ -26,6 +26,31 @@ struct ShmemParams{Nv} end interior_size(::ShmemParams{Nv}) where {Nv} = (Nv,) boundary_size(::ShmemParams{Nv}) where {Nv} = (1,) +struct ShmemIndex{T} + v::T + col_id::T +end +@inline function FDShmemIndex() + v = threadIdx().x + return ShmemIndex(v, typeof(v)(1)) +end +@inline function FDShmemIndex(v) + return ShmemIndex(v, typeof(v)(1)) +end +@inline FDShmemBoundaryIndex() = ShmemIndex(1, 1) + +# Base.getindex(a::AbstractArray, si::ShmemIndex) = Base.getindex(a, si.v, si.col_id) +# Base.setindex!(a::AbstractArray, val, si::ShmemIndex) = Base.setindex!(a, val, si.v, si.col_id) +Base.@propagate_inbounds Base.getindex(a::AbstractArray, si::ShmemIndex) = + Base.getindex(a, si.v) +Base.@propagate_inbounds Base.setindex!(a::AbstractArray, val, si::ShmemIndex) = + Base.setindex!(a, val, si.v) + +@inline Base.:+(si::ShmemIndex{T}, i::Integer) where {T} = + ShmemIndex{T}(si.v + T(i), si.col_id) +@inline Base.:-(si::ShmemIndex{T}, i::Integer) where {T} = + ShmemIndex{T}(si.v - T(i), si.col_id) + function Base.copyto!( out::Field, bc::Union{