Skip to content

Commit 7b8c68d

Browse files
charleskawczynskidennisYatunin
authored andcommitted
Make ᶜspecific lazy and remove matching_subfields
1 parent 2ff6481 commit 7b8c68d

File tree

3 files changed

+122
-163
lines changed

3 files changed

+122
-163
lines changed

src/cache/precomputed_quantities.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ function implicit_precomputed_quantities(Y, atmos)
4242
TST = thermo_state_type(moisture_model, FT)
4343
n = n_mass_flux_subdomains(turbconv_model)
4444
gs_quantities = (;
45-
ᶜspecific = specific_gs.(Y.c),
45+
ᶜspecific = Base.materialize(ᶜspecific_gs_tracers(Y)),
4646
ᶜu = similar(Y.c, C123{FT}),
4747
ᶠu³ = similar(Y.f, CT3{FT}),
4848
ᶠu = similar(Y.f, CT123{FT}),
@@ -461,7 +461,7 @@ NVTX.@annotate function set_implicit_precomputed_quantities!(Y, p, t)
461461
thermo_params = CAP.thermodynamics_params(p.params)
462462
thermo_args = (thermo_params, moisture_model, precip_model)
463463

464-
@. ᶜspecific = specific_gs(Y.c)
464+
ᶜspecific .= ᶜspecific_gs_tracers(Y)
465465
@. ᶠuₕ³ = $compute_ᶠuₕ³(Y.c.uₕ, Y.c.ρ)
466466

467467
# TODO: We might want to move this to dss! (and rename dss! to something

src/prognostic_equations/hyperdiffusion.jl

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ function hyperdiffusion_cache(
3434
gs_quantities = (;
3535
ᶜ∇²u = similar(Y.c, C123{FT}),
3636
ᶜ∇²specific_energy = similar(Y.c, FT),
37-
ᶜ∇²specific_tracers = remove_energy_var.(specific_gs.(Y.c)),
37+
ᶜ∇²specific_tracers = Base.materialize(ᶜspecific_gs_tracers(Y)),
3838
)
3939

4040
# Sub-grid scale quantities
@@ -85,7 +85,7 @@ NVTX.@annotate function prep_hyperdiffusion_tendency!(Yₜ, Y, p, t)
8585

8686
n = n_mass_flux_subdomains(turbconv_model)
8787
diffuse_tke = use_prognostic_tke(turbconv_model)
88-
(; ᶜp, ᶜspecific) = p.precomputed
88+
(; ᶜp) = p.precomputed
8989
(; ᶜ∇²u, ᶜ∇²specific_energy) = p.hyperdiff
9090
if turbconv_model isa PrognosticEDMFX
9191
(; ᶜ∇²uₕʲs, ᶜ∇²uᵥʲs, ᶜ∇²uʲs, ᶜ∇²mseʲs) = p.hyperdiff
@@ -137,7 +137,7 @@ NVTX.@annotate function apply_hyperdiffusion_tendency!(Yₜ, Y, p, t)
137137
diffuse_tke = use_prognostic_tke(turbconv_model)
138138
ᶜJ = Fields.local_geometry_field(Y.c).J
139139
point_type = eltype(Fields.coordinate_field(Y.c))
140-
(; ᶜp, ᶜspecific) = p.precomputed
140+
(; ᶜp) = p.precomputed
141141
(; ᶜ∇²u, ᶜ∇²specific_energy) = p.hyperdiff
142142
if turbconv_model isa PrognosticEDMFX
143143
(; ᶜρa⁰) = p.precomputed
@@ -240,11 +240,12 @@ NVTX.@annotate function prep_tracer_hyperdiffusion_tendency!(Yₜ, Y, p, t)
240240
(; hyperdiff, turbconv_model) = p.atmos
241241
isnothing(hyperdiff) && return nothing
242242

243-
(; ᶜspecific) = p.precomputed
244243
(; ᶜ∇²specific_tracers) = p.hyperdiff
245244

246-
for χ_name in propertynames(ᶜ∇²specific_tracers)
247-
@. ᶜ∇²specific_tracers.:($$χ_name) = wdivₕ(gradₕ(ᶜspecific.:($$χ_name)))
245+
# TODO: Fix RecursiveApply bug in gradₕ to fuse this operation.
246+
# ᶜ∇²specific_tracers .= wdivₕ.(gradₕ.(ᶜspecific_gs_tracers(Y)))
247+
foreach_gs_tracer(Y, ᶜ∇²specific_tracers) do ᶜρχ, ᶜ∇²χ, _
248+
@. ᶜ∇²χ = wdivₕ(gradₕ(specific(ᶜρχ, Y.c.ρ)))
248249
end
249250

250251
if turbconv_model isa PrognosticEDMFX
@@ -275,31 +276,30 @@ NVTX.@annotate function apply_tracer_hyperdiffusion_tendency!(Yₜ, Y, p, t)
275276
(; hyperdiff, turbconv_model) = p.atmos
276277
isnothing(hyperdiff) && return nothing
277278

279+
# When α_hyperdiff_tracer is 0, precipitating species are not hyperdiffused.
278280
α_hyperdiff_tracer = CAP.α_hyperdiff_tracer(p.params)
279281
(; ν₄_scalar_coeff) = hyperdiff
280282
h_space = Spaces.horizontal_space(axes(Y.c))
281283
h_length_scale = Spaces.node_horizontal_length_scale(h_space) # mean nodal distance
282284
ν₄_scalar = ν₄_scalar_coeff * h_length_scale^3
285+
ν₄_scalar_for_precip = α_hyperdiff_tracer * ν₄_scalar
283286
n = n_mass_flux_subdomains(turbconv_model)
284287

285288
(; ᶜ∇²specific_tracers) = p.hyperdiff
286289

287290
# TODO: Since we are not applying the limiter to density (or area-weighted
288291
# density), the mass redistributed by hyperdiffusion will not be conserved
289292
# by the limiter. Is this a significant problem?
290-
# TODO: Figure out why caching the duplicated tendencies in ᶜtemp_scalar
291-
# triggers allocations.
292-
for (ᶜρχₜ, ᶜ∇²χ, χ_name) in matching_subfields(Yₜ.c, ᶜ∇²specific_tracers)
293-
ν₄_scalar = ifelse(
294-
χ_name in (:q_rai, :q_sno, :n_rai),
295-
α_hyperdiff_tracer * ν₄_scalar,
296-
ν₄_scalar,
297-
)
298-
@. ᶜρχₜ -= ν₄_scalar * wdivₕ(Y.c.ρ * gradₕ(ᶜ∇²χ))
299-
300-
# Exclude contributions from hyperdiffusion of condensate,
301-
# precipitating species from mass tendency.
302-
if χ_name == :q_tot
293+
foreach_gs_tracer(Yₜ, ᶜ∇²specific_tracers) do ᶜρχₜ, ᶜ∇²χ, ρχ_name
294+
ν₄_scalar_for_χ =
295+
ρχ_name in (@name(ρq_rai), @name(ρq_sno), @name(ρn_rai)) ?
296+
ν₄_scalar_for_precip : ν₄_scalar
297+
@. ᶜρχₜ -= ν₄_scalar_for_χ * wdivₕ(Y.c.ρ * gradₕ(ᶜ∇²χ))
298+
299+
# Precipitating species are part of q_tot, so they are included in the
300+
# contribution of tracer hyperdiffusion to the mass tendency, even if
301+
# they are not hyperdiffused themselves.
302+
if ρχ_name == @name(ρq_tot)
303303
@. Yₜ.c.ρ -= ν₄_scalar * wdivₕ(Y.c.ρ * gradₕ(ᶜ∇²χ))
304304
end
305305
end
@@ -322,13 +322,9 @@ NVTX.@annotate function apply_tracer_hyperdiffusion_tendency!(Yₜ, Y, p, t)
322322
@. Yₜ.c.sgsʲs.:($$j).q_ice -=
323323
ν₄_scalar * wdivₕ(gradₕ(ᶜ∇²q_iceʲs.:($$j)))
324324
@. Yₜ.c.sgsʲs.:($$j).q_rai -=
325-
α_hyperdiff_tracer *
326-
ν₄_scalar *
327-
wdivₕ(gradₕ(ᶜ∇²q_raiʲs.:($$j)))
325+
ν₄_scalar_for_precip * wdivₕ(gradₕ(ᶜ∇²q_raiʲs.:($$j)))
328326
@. Yₜ.c.sgsʲs.:($$j).q_sno -=
329-
α_hyperdiff_tracer *
330-
ν₄_scalar *
331-
wdivₕ(gradₕ(ᶜ∇²q_snoʲs.:($$j)))
327+
ν₄_scalar_for_precip * wdivₕ(gradₕ(ᶜ∇²q_snoʲs.:($$j)))
332328
end
333329
end
334330
end

src/utils/variable_manipulations.jl

Lines changed: 99 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -53,64 +53,128 @@ Arguments:
5353
return ρa == 0 ? ρχ / ρ : weight * ρaχ / ρa + (1 - weight) * ρχ / ρ
5454
end
5555

56+
# Internal method that checks if its input is @name(ρχ) for some variable χ.
57+
@generated is_ρ_weighted_name(
58+
::MatrixFields.FieldName{name_chain},
59+
) where {name_chain} =
60+
length(name_chain) == 1 && startswith(string(name_chain[1]), "ρ")
61+
62+
# Internal method that converts @name(ρχ) to @name(χ) for some variable χ.
63+
@generated function specific_tracer_name(
64+
::MatrixFields.FieldName{ρχ_name_chain},
65+
) where {ρχ_name_chain}
66+
χ_symbol = Symbol(string(ρχ_name_chain[1])[(ncodeunits("ρ") + 1):end])
67+
return :(@name($χ_symbol))
68+
end
69+
5670
"""
57-
tracer_names(field)
71+
gs_tracer_names(Y)
5872
59-
Filters and returns the names of the variables from a given state
60-
vector component, excluding `ρ`, `ρe_tot`, and `uₕ` and SGS fields.
73+
`Tuple` of `@name`s for the grid-scale tracers in the center field `Y.c`
74+
(excluding `ρ`, `ρe_tot`, velocities, and SGS fields).
75+
"""
76+
gs_tracer_names(Y) =
77+
unrolled_filter(MatrixFields.top_level_names(Y.c)) do name
78+
is_ρ_weighted_name(name) && !(name in (@name(ρ), @name(ρe_tot)))
79+
end
6180

62-
Arguments:
81+
"""
82+
specific_gs_tracer_names(Y)
83+
84+
`Tuple` of the specific tracer names `@name(χ)` that correspond to the
85+
density-weighted tracer names `@name(ρχ)` in `gs_tracer_names(Y)`.
86+
"""
87+
specific_gs_tracer_names(Y) =
88+
unrolled_map(specific_tracer_name, gs_tracer_names(Y))
89+
90+
"""
91+
ᶜempty(Y)
92+
93+
Lazy center `Field` of empty `NamedTuple`s.
94+
"""
95+
ᶜempty(Y) = lazy.(Returns((;)).(Y.c))
6396

64-
- `field`: A component of the state vector `Y.c`.
97+
"""
98+
ᶜgs_tracers(Y)
6599
66-
Returns:
100+
Lazy center `Field` of `NamedTuple`s that contain the values of all grid-scale
101+
tracers given by `gs_tracer_names(Y)`.
102+
"""
103+
function ᶜgs_tracers(Y)
104+
isempty(gs_tracer_names(Y)) && return ᶜempty(Y)
105+
ρχ_symbols = unrolled_map(MatrixFields.extract_first, gs_tracer_names(Y))
106+
ρχ_fields = unrolled_map(gs_tracer_names(Y)) do ρχ_name
107+
MatrixFields.get_field(Y.c, ρχ_name)
108+
end
109+
return @. lazy(NamedTuple{ρχ_symbols}(tuple(ρχ_fields...)))
110+
end
67111

68-
- A `Tuple` of `ClimaCore.MatrixFields.FieldName`s corresponding to the tracers.
69112
"""
70-
tracer_names(field) =
71-
unrolled_filter(MatrixFields.top_level_names(field)) do name
72-
!(
73-
name in
74-
(@name(ρ), @name(ρe_tot), @name(uₕ), @name(sgs⁰), @name(sgsʲs))
75-
)
113+
ᶜspecific_gs_tracers(Y)
114+
115+
Lazy center `Field` of `NamedTuple`s that contain the values of all specific
116+
grid-scale tracers given by `specific_gs_tracer_names(Y)`.
117+
"""
118+
function ᶜspecific_gs_tracers(Y)
119+
isempty(gs_tracer_names(Y)) && return ᶜempty(Y)
120+
χ_symbols =
121+
unrolled_map(MatrixFields.extract_first, specific_gs_tracer_names(Y))
122+
χ_fields = unrolled_map(gs_tracer_names(Y)) do ρχ_name
123+
ρχ_field = MatrixFields.get_field(Y.c, ρχ_name)
124+
@. lazy(specific(ρχ_field, Y.c.ρ))
76125
end
126+
return @. lazy(NamedTuple{χ_symbols}(tuple(χ_fields...)))
127+
end
77128

78129
"""
79-
foreach_gs_tracer(f::F, Yₜ, Y) where {F}
130+
foreach_gs_tracer(f, Y_or_similar_values...)
131+
132+
Applies a function `f` to each grid-scale tracer in the state `Y` or any similar
133+
value like the tendency `Yₜ`. This is used to implement performant loops over
134+
all tracers given by `gs_tracer_names(Y)`.
80135
81-
Applies a given function `f` to each grid-scale scalar variable (except `ρ` and `ρe_tot`)
82-
in the state `Y` and its corresponding tendency `Yₜ`.
83-
This utility abstracts the process of iterating over all scalars. It uses
84-
`tracer_names` to identify the relevant variables and `unrolled_foreach` to
85-
ensure a performant loop. For each tracer, it calls the provided function `f`
86-
with the tendency field, the state field, and a boolean flag indicating if
87-
the current tracer is `ρq_tot` (to allow for special handling).
136+
Although the first input value needs to be similar to `Y`, the remaining values
137+
can also be center `Field`s similar to `Y.c`, and they can use specific tracers
138+
given by `specific_gs_tracer_names(Y)` instead of density-weighted tracers.
88139
89140
Arguments:
90141
91-
- `f`: A function to apply to each grid-scale scalar. It must have the signature `f
92-
(ᶜρχₜ, ᶜρχ, ρχ_name)`, where `ᶜρχₜ` is the tendency field, `ᶜρχ`
93-
is the state field, and `ρχ_name` is a `MatrixFields.@name` object.
94-
- `Yₜ`: The tendency state vector.
95-
- `Y`: The current state vector.
142+
- `f`: The function applied to each grid-scale tracer, which must have the
143+
signature `f(ρχ_or_χ_fields..., ρχ_name)`, where `ρχ_or_χ_fields` are
144+
grid-scale tracer subfields (either density-weighted or specific) and
145+
`ρχ_name` is the `MatrixFields.FieldName` of the tracer.
146+
- `Y_or_similar_values`: The state `Y` or similar values like the tendency `Yₜ`.
96147
97-
# Example
148+
# Examples
98149
99150
```julia
100151
foreach_gs_tracer(Yₜ, Y) do ᶜρχₜ, ᶜρχ, ρχ_name
101-
# Apply some operation, e.g., a sponge layer
102-
@. ᶜρχₜ += some_sponge_function(ᶜρχ)
152+
ᶜρχₜ .+= tendency_of_ρχ(ᶜρχ)
103153
if ρχ_name == @name(ρq_tot)
104-
# Perform an additional operation only for ρq_tot
154+
ᶜρχₜ .+= additional_tendency_of_ρq_tot(ᶜρχ)
155+
end
156+
end
157+
```
158+
159+
```julia
160+
foreach_gs_tracer(Yₜ, Base.materialize(ᶜspecific_gs_tracers(Y))) do ᶜρχₜ, ᶜχ, ρχ_name
161+
ᶜρχₜ .+= Y.c.ρ .* tendency_of_χ(ᶜχ)
162+
if ρχ_name == @name(ρq_tot)
163+
ᶜρχₜ .+= Y.c.ρ .* additional_tendency_of_q_tot(ᶜχ)
105164
end
106165
end
107166
```
108167
"""
109-
foreach_gs_tracer(f::F, Yₜ, Y) where {F} =
110-
unrolled_foreach(tracer_names(Y.c)) do scalar_name
111-
ᶜρχₜ = MatrixFields.get_field(Yₜ.c, scalar_name)
112-
ᶜρχ = MatrixFields.get_field(Y.c, scalar_name)
113-
f(ᶜρχₜ, ᶜρχ, scalar_name)
168+
foreach_gs_tracer(f::F, Y_or_similar_values...) where {F} =
169+
unrolled_foreach(gs_tracer_names(Y_or_similar_values[1])) do ρχ_name
170+
ρχ_or_χ_fields = unrolled_map(Y_or_similar_values) do value
171+
field = value isa Fields.Field ? value : value.c
172+
ρχ_or_χ_name =
173+
MatrixFields.has_field(field, ρχ_name) ? ρχ_name :
174+
specific_tracer_name(ρχ_name)
175+
MatrixFields.get_field(field, ρχ_or_χ_name)
176+
end
177+
f(ρχ_or_χ_fields..., ρχ_name)
114178
end
115179

116180
"""
@@ -206,94 +270,6 @@ function divide_by_ρa(ρaχ, ρa, ρχ, ρ, turbconv_model)
206270
return ρa == 0 ? ρχ / ρ : weight * ρaχ / ρa + (1 - weight) * ρχ / ρ
207271
end
208272

209-
# Helper functions for manipulating symbols in the generated functions:
210-
has_prefix(symbol, prefix_symbol) =
211-
startswith(string(symbol), string(prefix_symbol))
212-
remove_prefix(symbol, prefix_symbol) =
213-
Symbol(string(symbol)[(ncodeunits(string(prefix_symbol)) + 1):end])
214-
# Note that we need to use ncodeunits instead of length because prefix_symbol
215-
# can contain non-ASCII characters like 'ρ'.
216-
217-
"""
218-
specific_gs(gs)
219-
220-
Converts every variable of the form `ρχ` in the grid-scale state `gs` into the
221-
specific variable `χ` by dividing it by `ρ`. All other variables in `gs` are
222-
omitted from the result.
223-
"""
224-
@generated function specific_gs(gs)
225-
gs_names = Base._nt_names(gs)
226-
relevant_gs_names =
227-
filter(name -> has_prefix(name, ) && name != , gs_names)
228-
specific_gs_names = map(name -> remove_prefix(name, ), relevant_gs_names)
229-
specific_gs_values = map(name -> :(gs.$name / gs.ρ), relevant_gs_names)
230-
return :(NamedTuple{$specific_gs_names}(($(specific_gs_values...),)))
231-
end
232-
233-
"""
234-
specific_sgs(sgs, gs, turbconv_model)
235-
236-
Converts every variable of the form `ρaχ` in the sub-grid-scale state `sgs` into
237-
the specific variable `χ` by dividing it by `ρa`. All other variables in `sgs`
238-
are omitted from the result. The division is computed as
239-
`divide_by_ρa(ρaχ, ρa, ρχ, ρ, turbconv_model)`, which is preferable to simply
240-
calling `ρaχ / ρa` because it avoids numerical issues that arise when `a` is
241-
small. The values of `ρ` and `ρχ` are taken from `gs`, but, when `ρχ` is not
242-
available in `gs` (e.g., when `χ` is a second moment variable like `tke`), its
243-
value is assumed to be equal to the value of `ρaχ` in `sgs`.
244-
"""
245-
@generated function specific_sgs(sgs, gs, turbconv_model)
246-
sgs_names = Base._nt_names(sgs)
247-
gs_names = Base._nt_names(gs)
248-
relevant_sgs_names =
249-
filter(name -> has_prefix(name, :ρa) && name != :ρa, sgs_names)
250-
specific_sgs_names =
251-
map(name -> remove_prefix(name, :ρa), relevant_sgs_names)
252-
relevant_gs_names = map(name -> Symbol(, name), specific_sgs_names)
253-
specific_sgs_values = map(
254-
(sgs_name, gs_name) -> :(divide_by_ρa(
255-
sgs.$sgs_name,
256-
sgs.ρa,
257-
$(gs_name in gs_names ? :(gs.$gs_name) : :(sgs.$sgs_name)),
258-
gs.ρ,
259-
turbconv_model,
260-
)),
261-
relevant_sgs_names,
262-
relevant_gs_names,
263-
)
264-
return :(NamedTuple{$specific_sgs_names}(($(specific_sgs_values...),)))
265-
end
266-
267-
"""
268-
matching_subfields(tendency_field, specific_field)
269-
270-
Given a field that contains the tendencies of variables of the form `ρχ` or
271-
`ρaχ` and another field that contains the values of specific variables `χ`,
272-
returns all tuples `(tendency_field.<ρχ or ρaχ>, specific_field.<χ>, :<χ>)`.
273-
Variables in `tendency_field` that do not have matching variables in
274-
`specific_field` are omitted, as are variables in `specific_field` that do not
275-
have matching variables in `tendency_field`. This function is needed to avoid
276-
allocations due to failures in type inference, which are triggered when the
277-
`propertynames` of these fields are manipulated during runtime in order to pick
278-
out the matching subfields (as of Julia 1.8).
279-
"""
280-
@generated function matching_subfields(tendency_field, specific_field)
281-
tendency_names = Base._nt_names(eltype(tendency_field))
282-
specific_names = Base._nt_names(eltype(specific_field))
283-
prefix = :ρa in tendency_names ? :ρa :
284-
relevant_specific_names =
285-
filter(name -> Symbol(prefix, name) in tendency_names, specific_names)
286-
subfield_tuples = map(
287-
name -> :((
288-
tendency_field.$(Symbol(prefix, name)),
289-
specific_field.$name,
290-
$(QuoteNode(name)),
291-
)),
292-
relevant_specific_names,
293-
)
294-
return :(($(subfield_tuples...),))
295-
end
296-
297273
"""
298274
ρa⁺(gs)
299275
@@ -401,19 +377,6 @@ u₃⁰(ρaʲs, u₃ʲs, ρ, u₃, turbconv_model) = divide_by_ρa(
401377
turbconv_model,
402378
)
403379

404-
"""
405-
remove_energy_var(specific_state)
406-
407-
Creates a copy of `specific_state` with the energy variable
408-
removed, where `specific_state` is the result of calling, e.g., `specific_gs`,
409-
`specific_sgsʲs`, or `specific_sgs⁰`.
410-
"""
411-
remove_energy_var(specific_state::NamedTuple) =
412-
Base.structdiff(specific_state, NamedTuple{(:e_tot,)})
413-
remove_energy_var(specific_state::Tuple) =
414-
map(remove_energy_var, specific_state)
415-
416-
417380
import ClimaCore.RecursiveApply: , , rzero, rpromote_type
418381
function mapreduce_with_init(f, op, iter...)
419382
r₀ = rzero(rpromote_type(typeof(f(map(first, iter)...))))

0 commit comments

Comments
 (0)