Skip to content

Try improving low vertical resolution cases #2341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 104 additions & 13 deletions ext/cuda/operators_columnwise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ function columnwise!(
us = DataLayouts.UniversalSize(Fields.field_values(ᶜcf))
(Ni, Nj, _, _, Nh) = DataLayouts.universal_size(us)
nitems = Ni * Nj * 1 * ᶠNv * Nh

threads_per_column = ᶠNv
threads_per_block = 256 # figure out a better way to estimate this
ntotal_columns = Nh * Ni * Nj
n_columns_per_block = fld(threads_per_block, threads_per_column)
blocks = cld(ntotal_columns, n_columns_per_block)
threads = threads_per_block

kernel = CUDA.@cuda(
always_inline = true,
launch = false,
Expand All @@ -51,8 +59,8 @@ function columnwise!(
Val(localmem_state),
)
)
threads = (ᶠNv,)
blocks = (Nh, Ni * Nj)
# threads = (ᶠNv,)
# blocks = (Nh, Ni * Nj)
kernel(
device,
ᶜf,
Expand All @@ -71,15 +79,98 @@ function columnwise!(
)
end

@inline function universal_index_columnwise(
device::ClimaComms.CUDADevice,
UI,
us,
)
(v,) = CUDA.threadIdx()
(h, ij) = CUDA.blockIdx()
(Ni, Nj, _, _, _) = DataLayouts.universal_size(us)
Ni * Nj < ij && return CartesianIndex((-1, -1, 1, -1, -1))
@inbounds (i, j) = CartesianIndices((Ni, Nj))[ij].I
return CartesianIndex((i, j, 1, v, h))
# @inline function universal_index_columnwise(
# device::ClimaComms.CUDADevice,
# UI,
# (v,) = CUDA.threadIdx()
# (h, ij) = CUDA.blockIdx()
# (Ni, Nj, _, _, _) = DataLayouts.universal_size(us)
# Ni * Nj < ij && return CartesianIndex((-1, -1, 1, -1, -1))
# @inbounds (i, j) = CartesianIndices((Ni, Nj))[ij].I
# return CartesianIndex((i, j, 1, v, h))
# end

@inline function universal_index_columnwise(device::ClimaComms.CUDADevice, UI, us, ::Val{ᶠNv}) where {ᶠNv}
(Ni, Nj, _, _, Nh) = DataLayouts.universal_size(us)
(v, i, j, h) = get_indices(Ni, Nj, ᶠNv, Nh)

v_safe = 1 ≤ v ≤ ᶠNv ? v : 1
i_safe = 1 ≤ i ≤ Ni ? i : 1
j_safe = 1 ≤ j ≤ Nj ? j : 1
h_safe = 1 ≤ h ≤ Nh ? h : 1

# (1 ≤ v ≤ ᶠNv && oob(v, i, j, h)) || CUDA.@cuprint("bad v=$v, i=$i, j=$j, h=$h, ᶠNv=$ᶠNv\n")
# (1 ≤ i ≤ Ni && oob(v, i, j, h)) || CUDA.@cuprint("bad i = $i\n")
# (1 ≤ j ≤ Nj && oob(v, i, j, h)) || CUDA.@cuprint("bad j = $j\n")
# (1 ≤ h ≤ Nh && oob(v, i, j, h)) || CUDA.@cuprint("bad h = $h, Nh = $Nh\n")

ui = CartesianIndex((i_safe, j_safe, 1, v_safe, h_safe))
return ui
end

oob(v, i, j, h) = v == -1 && i == -1 && j == -1 && h == -1

function get_indices(Ni, Nj, ᶠNv, Nh)
# static parameters
threads_per_column = ᶠNv # number of threads per column
threads_per_block = 256 # number of threads per block
ntotal_columns = Nh * Ni * Nj # total number of columns
n_columns_per_block = fld(threads_per_block, threads_per_column) # number of columns per block
n_blocks = cld(ntotal_columns, n_columns_per_block) # number of blocks (which have multiple columns)
@assert gridDim().x == n_blocks
@assert n_blocks * n_columns_per_block ≥ ntotal_columns
# Indices
tv = threadIdx().x # thread index
vblock_idx = div(tv-1,ᶠNv)+1 # "column index" per block
if tv > ᶠNv * n_columns_per_block # oob
return (-1, -1, -1, -1)
end
v = mod(tv - 1, ᶠNv) + 1 # block-local column vertical index
bidx = blockIdx().x # block index
block_inds = LinearIndices((1:n_blocks, 1:n_columns_per_block)) # linearized block indices
# if bidx > n_blocks || vblock_idx > n_columns_per_block
if !(1 ≤ bidx*vblock_idx ≤ length(block_inds))
return (-1, -1, -1, -1)
else
gbi = block_inds[bidx,vblock_idx] # global block index
CI = CartesianIndices((1:Ni,1:Nj,1:Nh))
if 1 ≤ gbi ≤ length(CI)
(i, j, h) = CI[gbi].I
return (v, i, j, h)
else
return (-1, -1, -1, -1)
end
end
end

function get_indices_gemini(Ni, Nj, fNv, Nh)
# Get global thread index
global_thread_idx = (blockIdx().x - 1) * blockDim().x + threadIdx().x

# Calculate v (1-based)
v = mod(threadIdx().x - 1, fNv) + 1

# Calculate the linear index of the column (1-based)
# Each column requires fNv threads.
column_linear_idx = fld(global_thread_idx - 1, fNv) + 1

# Calculate h, j, i from the linear column index (0-based for calculation, then convert to 1-based)
# The total number of unique (i, j, h) combinations is Ni * Nj * Nh
# h varies slowest, then j, then i.
# column_linear_idx is 1-based, so subtract 1 for 0-based calculations
zero_based_column_idx = column_linear_idx - 1

# Calculate h (1-based)
h = fld(zero_based_column_idx, (Ni * Nj)) + 1

# Calculate remaining index for j and i
remaining_idx_for_ji = mod(zero_based_column_idx, (Ni * Nj))

# Calculate j (1-based)
j = fld(remaining_idx_for_ji, Ni) + 1

# Calculate i (1-based)
i = mod(remaining_idx_for_ji, Ni) + 1

return (v, i, j, h)
end
5 changes: 3 additions & 2 deletions src/Operators/columnwise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ function columnwise_kernel!(
SLG = partial_lg_type(eltype(ᶜlg))
ᶜTS_lg = DataLayouts.typesize(FT, SLG)

ᶜui = universal_index_columnwise(device, UI, ᶜus)
ᶠui = universal_index_columnwise(device, UI, ᶠus)
ᶜui = universal_index_columnwise(device, UI, ᶜus, Val(ᶠNv))
ᶠui = universal_index_columnwise(device, UI, ᶠus, Val(ᶠNv))
colidx = Grids.ColumnIndex((ᶠui.I[1], ᶠui.I[2]), ᶠui.I[5])

if localmem_state
Expand Down Expand Up @@ -367,4 +367,5 @@ end
device::ClimaComms.AbstractCPUDevice,
UI,
us,
::Val
) = UI
32 changes: 21 additions & 11 deletions test/Operators/finitedifference/unit_columnwise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ julia --project=.buildkite
using Revise; include("test/Operators/finitedifference/unit_columnwise.jl")
ENV["CLIMACOMMS_DEVICE"] = "CPU";
ENV["CLIMACOMMS_DEVICE"] = "CUDA";

ncu -o columnwise_report.ncu-rep --section=WarpStateStats --set=full -f julia --project=.buildkite test/Operators/finitedifference/unit_columnwise.jl
scp -r clima:/home/charliek/CliMA/ClimaCore.jl/columnwise_report.ncu-rep ./
=#
high_res = true;
@info "high_res: $high_res"

ENV["CLIMACOMMS_DEVICE"] = "CUDA";
# using CUDA
import ClimaComms
using ClimaParams # needed in environment to load convenience parameter struct wrappers
Expand Down Expand Up @@ -228,7 +231,7 @@ FT = Float32;
if high_res
ᶜspace = ExtrudedCubedSphereSpace(
FT;
z_elem = 63,
z_elem = 8,
z_min = 0,
z_max = 30000.0,
radius = 6.371e6,
Expand All @@ -239,7 +242,7 @@ if high_res
else
ᶜspace = ExtrudedCubedSphereSpace(
FT;
z_elem = 8,
z_elem = 5,
z_min = 0,
z_max = 30000.0,
radius = 6.371e6,
Expand All @@ -249,14 +252,14 @@ else
)
end
# ᶜspace = SliceXZSpace(FT;
# z_elem = 10,
# z_elem = 8,
# x_min = 0,
# x_max = 1,
# z_min = 0,
# z_max = 30000.0,
# periodic_x = false,
# n_quad_points = 4,
# x_elem = 4,
# n_quad_points = 2,
# x_elem = 2,
# staggering = CellCenter()
# )
# ᶜspace = Box3DSpace(FT;
Expand Down Expand Up @@ -327,6 +330,7 @@ fill!(parent(Yc.ρ), 1);
parent(Yf.u₃) .+= 0.001 .* sin.(parent(zf));
fill!(parent(Yc.uₕ), 0.01);
fill!(parent(Yc.ρe_tot), 100000.0);
# parent(Yc.ρe_tot) .+= 0.001 .* parent(Yc.ρe_tot) .* sin.(parent(zc));

t₀ = zero(Spaces.undertype(axes(Yc)))

Expand All @@ -345,14 +349,20 @@ Operators.columnwise!(
t₀,
)
implicit_tendency_bc!(Yₜ_bc, Y, p, t₀)
abs_err_c = maximum(Array(abs.(parent(Yₜ.c) .- parent(Yₜ_bc.c))))
abs_err_f = maximum(Array(abs.(parent(Yₜ.f) .- parent(Yₜ_bc.f))))
results_match = abs_err_c < 6e-9 && abs_err_c < 6e-9
abs_err_c = Array(abs.(parent(Yₜ.c) .- parent(Yₜ_bc.c)))
abs_err_f = Array(abs.(parent(Yₜ.f) .- parent(Yₜ_bc.f)))
maxabs_err_c = maximum(Array(abs.(parent(Yₜ.c) .- parent(Yₜ_bc.c))))
maxabs_err_f = maximum(Array(abs.(parent(Yₜ.f) .- parent(Yₜ_bc.f))))
results_match = maxabs_err_c < 6e-9 && maxabs_err_c < 6e-9
if !results_match
@show norm(Array(parent(Yₜ_bc.c))), norm(Array(parent(Yₜ.c)))
@show norm(Array(parent(Yₜ_bc.f))), norm(Array(parent(Yₜ.f)))
@show abs_err_c
@show abs_err_f
# @show count(x->x!=0, abs_err_c)
# @show count(x->x!=0, abs_err_f)
@show maximum(Array(abs.(parent(Yₜ.c.ρ) .- parent(Yₜ_bc.c.ρ))))
@show maximum(Array(abs.(parent(Yₜ.c.ρe_tot) .- parent(Yₜ_bc.c.ρe_tot))))
@show maxabs_err_c
@show maxabs_err_f
end
@test results_match
#! format: off
Expand Down
Loading