From b3f3351d5aaa752aac1cbf9fdf42c3568e40b7a9 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Thu, 12 Jun 2025 10:13:02 -0400 Subject: [PATCH] Use columnwise for setting tendencies to zero Update src/utils/utilities.jl Co-authored-by: Gregory L. Wagner Update src/utils/utilities.jl Co-authored-by: Gregory L. Wagner --- .../remaining_tendency.jl | 131 +++++++++++++++++- src/utils/utilities.jl | 66 +++++++++ 2 files changed, 196 insertions(+), 1 deletion(-) diff --git a/src/prognostic_equations/remaining_tendency.jl b/src/prognostic_equations/remaining_tendency.jl index 7ebd4b7fc5..0a42e77f76 100644 --- a/src/prognostic_equations/remaining_tendency.jl +++ b/src/prognostic_equations/remaining_tendency.jl @@ -37,6 +37,115 @@ NVTX.@annotate function hyperdiffusion_tendency!(Yₜ, Yₜ_lim, Y, p, t) apply_hyperdiffusion_tendency!(Yₜ, Y, p, t) end +using ClimaCore.RecursiveApply: rzero + +##### +##### Cell center tendencies +##### + +""" + ᶜremaining_tendency(ᶜY, ᶠY, p, t) + +Returns a Broadcasted object, for evaluating the cell center remaining +tendency. This method calls `ᶜremaining_tendency(Val(name), ᶜY, ᶠY, p, t)` for +all `propertynames` of `ᶜY`. +""" +function ᶜremaining_tendency(ᶜY, ᶠY, p, t) + names = propertynames(ᶜY) + tends = construct_tendencies(Val(names), ᶜremaining_tendency, ᶜY, ᶠY, p, t) + # We cannot broadcast over a NamedTuple, so we need to check that edge case + # first. + if all(t -> !(t isa Base.Broadcast.Broadcasted), tends) + return make_named_tuple(Val(names), tends...) + else + return lazy.(make_named_tuple.(Val(names), tends...)) + end +end + +##### +##### Cell face tendencies +##### + +""" + ᶠremaining_tendency(ᶜY, ᶠY, p, t) + +Returns a Broadcasted object, for evaluating the cell center remaining +tendency. This method calls `ᶠremaining_tendency(Val(name), ᶜY, ᶠY, p, t)` for +all `propertynames` of `ᶠY`. +""" +function ᶠremaining_tendency(ᶜY, ᶠY, p, t) + names = propertynames(ᶠY) + tends = construct_tendencies(Val(names), ᶠremaining_tendency, ᶜY, ᶠY, p, t) + # We cannot broadcast over a NamedTuple, so we need to check that edge case + # first. + if all(t -> !(t isa Base.Broadcast.Broadcasted), tends) + return make_named_tuple(Val(names), tends...) + else + return lazy.(make_named_tuple.(Val(names), tends...)) + end +end + +##### +##### Individual tendencies +##### + +function ᶜremaining_tendency(::Val{:ρ}, ᶜY, ᶠY, p, t) + ∑tendencies = zero(eltype(ᶜY.ρ)) + return ∑tendencies +end +function ᶜremaining_tendency(::Val{:uₕ}, ᶜY, ᶠY, p, t) + ∑tendencies = zero(eltype(ᶜY.uₕ)) + return ∑tendencies +end +function ᶜremaining_tendency(::Val{:ρe_tot}, ᶜY, ᶠY, p, t) + ∑tendencies = zero(eltype(ᶜY.ρe_tot)) + return ∑tendencies +end +function ᶜremaining_tendency(::Val{:ρq_tot}, ᶜY, ᶠY, p, t) + ∑tendencies = zero(eltype(ᶜY.ρq_tot)) + return ∑tendencies +end +function ᶜremaining_tendency(::Val{:ρq_liq}, ᶜY, ᶠY, p, t) + ∑tendencies = zero(eltype(ᶜY.ρq_liq)) + return ∑tendencies +end +function ᶜremaining_tendency(::Val{:ρq_ice}, ᶜY, ᶠY, p, t) + ∑tendencies = zero(eltype(ᶜY.ρq_ice)) + return ∑tendencies +end +function ᶜremaining_tendency(::Val{:ρn_liq}, ᶜY, ᶠY, p, t) + ∑tendencies = zero(eltype(ᶜY.ρn_liq)) + return ∑tendencies +end +function ᶜremaining_tendency(::Val{:ρn_rai}, ᶜY, ᶠY, p, t) + ∑tendencies = zero(eltype(ᶜY.ρn_rai)) + return ∑tendencies +end +function ᶜremaining_tendency(::Val{:ρq_rai}, ᶜY, ᶠY, p, t) + ∑tendencies = zero(eltype(ᶜY.ρq_rai)) + return ∑tendencies +end +function ᶜremaining_tendency(::Val{:ρq_sno}, ᶜY, ᶠY, p, t) + ∑tendencies = zero(eltype(ᶜY.ρq_sno)) + return ∑tendencies +end +function ᶜremaining_tendency(::Val{:sgsʲs}, ᶜY, ᶠY, p, t) + ∑tendencies = rzero(eltype(ᶜY.sgsʲs)) + return ∑tendencies +end +function ᶜremaining_tendency(::Val{:sgs⁰}, ᶜY, ᶠY, p, t) + ∑tendencies = rzero(eltype(ᶜY.sgs⁰)) + return ∑tendencies +end +function ᶠremaining_tendency(::Val{:u₃}, ᶜY, ᶠY, p, t) + ∑tendencies = zero(eltype(ᶠY.u₃)) + return ∑tendencies +end +function ᶠremaining_tendency(::Val{:sgsʲs}, ᶜY, ᶠY, p, t) + ∑tendencies = rzero(eltype(ᶠY.sgsʲs)) + return ∑tendencies +end + """ remaining_tendency!(Yₜ, Yₜ_lim, Y, p, t) @@ -64,7 +173,27 @@ Returns: """ NVTX.@annotate function remaining_tendency!(Yₜ, Yₜ_lim, Y, p, t) Yₜ_lim .= zero(eltype(Yₜ_lim)) - Yₜ .= zero(eltype(Yₜ)) + device = ClimaComms.device(axes(Y.c)) + p_kernel = (; + zmax = Spaces.z_max(axes(Y.f)), + atmos = p.atmos, + params = p.params, + dt = p.dt, + ) + if :sfc in propertynames(Y) # columnwise! does not yet handle .sfc + parent(Yₜ.sfc) .= zero(Spaces.undertype(axes(Y.c))) + end + Operators.columnwise!( + device, + ᶜremaining_tendency, + ᶠremaining_tendency, + Yₜ.c, + Yₜ.f, + Y.c, + Y.f, + p_kernel, + t, + ) horizontal_tracer_advection_tendency!(Yₜ_lim, Y, p, t) fill_with_nans!(p) # TODO: would be better to limit this to debug mode (e.g., if p.debug_mode...) horizontal_dynamics_tendency!(Yₜ, Y, p, t) diff --git a/src/utils/utilities.jl b/src/utils/utilities.jl index be779d6418..8769a0047e 100644 --- a/src/utils/utilities.jl +++ b/src/utils/utilities.jl @@ -496,3 +496,69 @@ function issphere(space) return Meshes.domain(Spaces.topology(Spaces.horizontal_space(space))) isa Domains.SphereDomain end + + +""" + construct_tendencies(::Val{names}, f, ᶜY, ᶠY, p, t) + +Return a tuple of calls to + +``` +f(Val(name), ᶜY, ᶠY, p, t) +``` +for all names in `names`. + +For example, `f(Val((:a, :b)), ᶜY, ᶠY, p, t)` will return: + +``` +( + f(Val(:a), ᶜY, ᶠY, p, t), + f(Val(:b), ᶜY, ᶠY, p, t), +) +``` +""" +@generated function construct_tendencies( + ::Val{names}, + f, + ᶜY, + ᶠY, + p, + t, +) where {names} + calls = [] + for name in names + push!(calls, :(f(Val($(QuoteNode(name))), ᶜY, ᶠY, p, t))) + end + return quote + ($(calls...),) + end +end + +""" + make_named_tuple(::Val{names}, vals...) where {names} + +Construct a NamedTuple given the names `names` and values `vals`. +""" +make_named_tuple(::Val{names}, vals...) where {names} = NamedTuple{names}(vals) + +""" + add_tendency(∑tendencies, tendency) + +A helper function which returns `∑tendencies` when `tendency` is a +`NullBroadcasted` and `lazy.(∑tendencies + tendency)` when `tendency` is not a +`NullBroadcasted`. +""" +function add_tendency end +add_tendency(∑tends, t) = lazy.(∑tends .+ t) +add_tendency(∑tends, ::NullBroadcasted) = ∑tends + +""" + subtract_tendency(∑tendencies, tendency) + +A helper function which returns `∑tendencies` when `tendency` is a +`NullBroadcasted` and `lazy.(∑tendencies - tendency)` when `tendency` is not a +`NullBroadcasted`. +""" +function subtract_tendency end +subtract_tendency(∑tends, ::NullBroadcasted) = ∑tends +subtract_tendency(∑tends, t) = lazy.(∑tends .- t)