diff --git a/ext/cuda/operators_columnwise.jl b/ext/cuda/operators_columnwise.jl index c02a5373df..8adfa7427b 100644 --- a/ext/cuda/operators_columnwise.jl +++ b/ext/cuda/operators_columnwise.jl @@ -33,6 +33,14 @@ function columnwise!( us = DataLayouts.UniversalSize(Fields.field_values(ᶜcf)) (Ni, Nj, _, _, Nh) = DataLayouts.universal_size(us) nitems = Ni * Nj * 1 * ᶠNv * Nh + + threads_per_column = ᶠNv + threads_per_block = 256 # figure out a better way to estimate this + ntotal_columns = Nh * Ni * Nj + n_columns_per_block = fld(threads_per_block, threads_per_column) + blocks = cld(ntotal_columns, n_columns_per_block) + threads = threads_per_block + kernel = CUDA.@cuda( always_inline = true, launch = false, @@ -51,8 +59,8 @@ function columnwise!( Val(localmem_state), ) ) - threads = (ᶠNv,) - blocks = (Nh, Ni * Nj) + # threads = (ᶠNv,) + # blocks = (Nh, Ni * Nj) kernel( device, ᶜf, @@ -71,15 +79,98 @@ function columnwise!( ) end -@inline function universal_index_columnwise( - device::ClimaComms.CUDADevice, - UI, - us, -) - (v,) = CUDA.threadIdx() - (h, ij) = CUDA.blockIdx() - (Ni, Nj, _, _, _) = DataLayouts.universal_size(us) - Ni * Nj < ij && return CartesianIndex((-1, -1, 1, -1, -1)) - @inbounds (i, j) = CartesianIndices((Ni, Nj))[ij].I - return CartesianIndex((i, j, 1, v, h)) +# @inline function universal_index_columnwise( +# device::ClimaComms.CUDADevice, +# UI, +# (v,) = CUDA.threadIdx() +# (h, ij) = CUDA.blockIdx() +# (Ni, Nj, _, _, _) = DataLayouts.universal_size(us) +# Ni * Nj < ij && return CartesianIndex((-1, -1, 1, -1, -1)) +# @inbounds (i, j) = CartesianIndices((Ni, Nj))[ij].I +# return CartesianIndex((i, j, 1, v, h)) +# end + +@inline function universal_index_columnwise(device::ClimaComms.CUDADevice, UI, us, ::Val{ᶠNv}) where {ᶠNv} + (Ni, Nj, _, _, Nh) = DataLayouts.universal_size(us) + (v, i, j, h) = get_indices(Ni, Nj, ᶠNv, Nh) + + v_safe = 1 ≤ v ≤ ᶠNv ? v : 1 + i_safe = 1 ≤ i ≤ Ni ? i : 1 + j_safe = 1 ≤ j ≤ Nj ? j : 1 + h_safe = 1 ≤ h ≤ Nh ? h : 1 + + # (1 ≤ v ≤ ᶠNv && oob(v, i, j, h)) || CUDA.@cuprint("bad v=$v, i=$i, j=$j, h=$h, ᶠNv=$ᶠNv\n") + # (1 ≤ i ≤ Ni && oob(v, i, j, h)) || CUDA.@cuprint("bad i = $i\n") + # (1 ≤ j ≤ Nj && oob(v, i, j, h)) || CUDA.@cuprint("bad j = $j\n") + # (1 ≤ h ≤ Nh && oob(v, i, j, h)) || CUDA.@cuprint("bad h = $h, Nh = $Nh\n") + + ui = CartesianIndex((i_safe, j_safe, 1, v_safe, h_safe)) + return ui +end + +oob(v, i, j, h) = v == -1 && i == -1 && j == -1 && h == -1 + +function get_indices(Ni, Nj, ᶠNv, Nh) + # static parameters + threads_per_column = ᶠNv # number of threads per column + threads_per_block = 256 # number of threads per block + ntotal_columns = Nh * Ni * Nj # total number of columns + n_columns_per_block = fld(threads_per_block, threads_per_column) # number of columns per block + n_blocks = cld(ntotal_columns, n_columns_per_block) # number of blocks (which have multiple columns) + @assert gridDim().x == n_blocks + @assert n_blocks * n_columns_per_block ≥ ntotal_columns + # Indices + tv = threadIdx().x # thread index + vblock_idx = div(tv-1,ᶠNv)+1 # "column index" per block + if tv > ᶠNv * n_columns_per_block # oob + return (-1, -1, -1, -1) + end + v = mod(tv - 1, ᶠNv) + 1 # block-local column vertical index + bidx = blockIdx().x # block index + block_inds = LinearIndices((1:n_blocks, 1:n_columns_per_block)) # linearized block indices + # if bidx > n_blocks || vblock_idx > n_columns_per_block + if !(1 ≤ bidx*vblock_idx ≤ length(block_inds)) + return (-1, -1, -1, -1) + else + gbi = block_inds[bidx,vblock_idx] # global block index + CI = CartesianIndices((1:Ni,1:Nj,1:Nh)) + if 1 ≤ gbi ≤ length(CI) + (i, j, h) = CI[gbi].I + return (v, i, j, h) + else + return (-1, -1, -1, -1) + end + end +end + +function get_indices_gemini(Ni, Nj, fNv, Nh) + # Get global thread index + global_thread_idx = (blockIdx().x - 1) * blockDim().x + threadIdx().x + + # Calculate v (1-based) + v = mod(threadIdx().x - 1, fNv) + 1 + + # Calculate the linear index of the column (1-based) + # Each column requires fNv threads. + column_linear_idx = fld(global_thread_idx - 1, fNv) + 1 + + # Calculate h, j, i from the linear column index (0-based for calculation, then convert to 1-based) + # The total number of unique (i, j, h) combinations is Ni * Nj * Nh + # h varies slowest, then j, then i. + # column_linear_idx is 1-based, so subtract 1 for 0-based calculations + zero_based_column_idx = column_linear_idx - 1 + + # Calculate h (1-based) + h = fld(zero_based_column_idx, (Ni * Nj)) + 1 + + # Calculate remaining index for j and i + remaining_idx_for_ji = mod(zero_based_column_idx, (Ni * Nj)) + + # Calculate j (1-based) + j = fld(remaining_idx_for_ji, Ni) + 1 + + # Calculate i (1-based) + i = mod(remaining_idx_for_ji, Ni) + 1 + + return (v, i, j, h) end diff --git a/src/Operators/columnwise.jl b/src/Operators/columnwise.jl index 42e4609f94..7855eed9f3 100644 --- a/src/Operators/columnwise.jl +++ b/src/Operators/columnwise.jl @@ -137,8 +137,8 @@ function columnwise_kernel!( SLG = partial_lg_type(eltype(ᶜlg)) ᶜTS_lg = DataLayouts.typesize(FT, SLG) - ᶜui = universal_index_columnwise(device, UI, ᶜus) - ᶠui = universal_index_columnwise(device, UI, ᶠus) + ᶜui = universal_index_columnwise(device, UI, ᶜus, Val(ᶠNv)) + ᶠui = universal_index_columnwise(device, UI, ᶠus, Val(ᶠNv)) colidx = Grids.ColumnIndex((ᶠui.I[1], ᶠui.I[2]), ᶠui.I[5]) if localmem_state @@ -367,4 +367,5 @@ end device::ClimaComms.AbstractCPUDevice, UI, us, + ::Val ) = UI diff --git a/test/Operators/finitedifference/unit_columnwise.jl b/test/Operators/finitedifference/unit_columnwise.jl index 4251dd6fc4..2d40755e42 100644 --- a/test/Operators/finitedifference/unit_columnwise.jl +++ b/test/Operators/finitedifference/unit_columnwise.jl @@ -3,10 +3,13 @@ julia --project=.buildkite using Revise; include("test/Operators/finitedifference/unit_columnwise.jl") ENV["CLIMACOMMS_DEVICE"] = "CPU"; ENV["CLIMACOMMS_DEVICE"] = "CUDA"; + +ncu -o columnwise_report.ncu-rep --section=WarpStateStats --set=full -f julia --project=.buildkite test/Operators/finitedifference/unit_columnwise.jl +scp -r clima:/home/charliek/CliMA/ClimaCore.jl/columnwise_report.ncu-rep ./ =# high_res = true; @info "high_res: $high_res" - +ENV["CLIMACOMMS_DEVICE"] = "CUDA"; # using CUDA import ClimaComms using ClimaParams # needed in environment to load convenience parameter struct wrappers @@ -228,7 +231,7 @@ FT = Float32; if high_res ᶜspace = ExtrudedCubedSphereSpace( FT; - z_elem = 63, + z_elem = 8, z_min = 0, z_max = 30000.0, radius = 6.371e6, @@ -239,7 +242,7 @@ if high_res else ᶜspace = ExtrudedCubedSphereSpace( FT; - z_elem = 8, + z_elem = 5, z_min = 0, z_max = 30000.0, radius = 6.371e6, @@ -249,14 +252,14 @@ else ) end # ᶜspace = SliceXZSpace(FT; -# z_elem = 10, +# z_elem = 8, # x_min = 0, # x_max = 1, # z_min = 0, # z_max = 30000.0, # periodic_x = false, -# n_quad_points = 4, -# x_elem = 4, +# n_quad_points = 2, +# x_elem = 2, # staggering = CellCenter() # ) # ᶜspace = Box3DSpace(FT; @@ -327,6 +330,7 @@ fill!(parent(Yc.ρ), 1); parent(Yf.u₃) .+= 0.001 .* sin.(parent(zf)); fill!(parent(Yc.uₕ), 0.01); fill!(parent(Yc.ρe_tot), 100000.0); +# parent(Yc.ρe_tot) .+= 0.001 .* parent(Yc.ρe_tot) .* sin.(parent(zc)); t₀ = zero(Spaces.undertype(axes(Yc))) @@ -345,14 +349,20 @@ Operators.columnwise!( t₀, ) implicit_tendency_bc!(Yₜ_bc, Y, p, t₀) -abs_err_c = maximum(Array(abs.(parent(Yₜ.c) .- parent(Yₜ_bc.c)))) -abs_err_f = maximum(Array(abs.(parent(Yₜ.f) .- parent(Yₜ_bc.f)))) -results_match = abs_err_c < 6e-9 && abs_err_c < 6e-9 +abs_err_c = Array(abs.(parent(Yₜ.c) .- parent(Yₜ_bc.c))) +abs_err_f = Array(abs.(parent(Yₜ.f) .- parent(Yₜ_bc.f))) +maxabs_err_c = maximum(Array(abs.(parent(Yₜ.c) .- parent(Yₜ_bc.c)))) +maxabs_err_f = maximum(Array(abs.(parent(Yₜ.f) .- parent(Yₜ_bc.f)))) +results_match = maxabs_err_c < 6e-9 && maxabs_err_c < 6e-9 if !results_match @show norm(Array(parent(Yₜ_bc.c))), norm(Array(parent(Yₜ.c))) @show norm(Array(parent(Yₜ_bc.f))), norm(Array(parent(Yₜ.f))) - @show abs_err_c - @show abs_err_f + # @show count(x->x!=0, abs_err_c) + # @show count(x->x!=0, abs_err_f) + @show maximum(Array(abs.(parent(Yₜ.c.ρ) .- parent(Yₜ_bc.c.ρ)))) + @show maximum(Array(abs.(parent(Yₜ.c.ρe_tot) .- parent(Yₜ_bc.c.ρe_tot)))) + @show maxabs_err_c + @show maxabs_err_f end @test results_match #! format: off