Skip to content

Commit a8c308c

Browse files
committed
lazy levels. remove some scratch
1 parent 7365926 commit a8c308c

File tree

7 files changed

+43
-41
lines changed

7 files changed

+43
-41
lines changed

.buildkite/pipeline.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ steps:
2929
- echo "--- Instantiate .buildkite"
3030
- "julia --project=.buildkite -e 'using Pkg; Pkg.instantiate(;verbose=true); Pkg.precompile(;strict=true); using CUDA; CUDA.precompile_runtime(); Pkg.status()'"
3131

32+
- echo "--- dev package"
33+
- "julia --project=.buildkite -e 'using Pkg; Pkg.add(Pkg.PackageSpec(;name=\"ClimaCore\", rev=\"dy/lazy_field_levels\"))'"
34+
3235
agents:
3336
slurm_cpus_per_task: 8
3437
slurm_gpus: 1

src/cache/diagnostic_edmf_precomputed_quantities.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,10 @@ NVTX.@annotate function set_diagnostic_edmf_precomputed_quantities_bottom_bc!(
113113
uₕ_int_level = Fields.field_values(Fields.level(Y.c.uₕ, 1))
114114
u³_int_halflevel = Fields.field_values(Fields.level(ᶠu³, half))
115115
h_tot_int_level =
116-
Fields.field_values(Fields.level(Base.materialize(ᶜh_tot), 1))
116+
Fields.field_values(Fields.level(ᶜh_tot, 1))
117117
K_int_level = Fields.field_values(Fields.level(ᶜK, 1))
118118
q_tot_int_level =
119-
Fields.field_values(Fields.level(Base.materialize(q_tot), 1))
119+
Fields.field_values(Fields.level(q_tot, 1))
120120

121121
p_int_level = Fields.field_values(Fields.level(ᶜp, 1))
122122
Φ_int_level = Fields.field_values(Fields.level(ᶜΦ, 1))
@@ -369,9 +369,9 @@ NVTX.@annotate function set_diagnostic_edmf_precomputed_quantities_do_integral!(
369369
u³_halflevel = Fields.field_values(Fields.level(ᶠu³, i - half))
370370
K_level = Fields.field_values(Fields.level(ᶜK, i))
371371
h_tot_level =
372-
Fields.field_values(Fields.level(Base.materialize(ᶜh_tot), i))
372+
Fields.field_values(Fields.level(ᶜh_tot, i))
373373
q_tot_level =
374-
Fields.field_values(Fields.level(Base.materialize(q_tot), i))
374+
Fields.field_values(Fields.level(q_tot, i))
375375
p_level = Fields.field_values(Fields.level(ᶜp, i))
376376
Φ_level = Fields.field_values(Fields.level(ᶜΦ, i))
377377
local_geometry_level = Fields.field_values(
@@ -396,9 +396,9 @@ NVTX.@annotate function set_diagnostic_edmf_precomputed_quantities_do_integral!(
396396
u³⁰_data_prev_halflevel = u³⁰_prev_halflevel.components.data.:1
397397
K_prev_level = Fields.field_values(Fields.level(ᶜK, i - 1))
398398
h_tot_prev_level =
399-
Fields.field_values(Fields.level(Base.materialize(ᶜh_tot), i - 1))
399+
Fields.field_values(Fields.level(ᶜh_tot, i - 1))
400400
q_tot_prev_level =
401-
Fields.field_values(Fields.level(Base.materialize(q_tot), i - 1))
401+
Fields.field_values(Fields.level(q_tot, i - 1))
402402
ts_prev_level = Fields.field_values(Fields.level(ᶜts, i - 1))
403403
p_prev_level = Fields.field_values(Fields.level(ᶜp, i - 1))
404404
z_prev_level = Fields.field_values(Fields.level(ᶜz, i - 1))
@@ -499,7 +499,7 @@ NVTX.@annotate function set_diagnostic_edmf_precomputed_quantities_do_integral!(
499499
end
500500

501501
tke_prev_level = Fields.field_values(
502-
Fields.level(Base.materialize(ᶜtke⁰), i - 1),
502+
Fields.level(ᶜtke⁰, i - 1),
503503
)
504504

505505
@. entrʲ_prev_level = entrainment(
@@ -1041,7 +1041,7 @@ NVTX.@annotate function set_diagnostic_edmf_precomputed_quantities_env_closures!
10411041
(1 / 2 * norm_sqr(ᶜinterp(ᶠu³⁰) - ᶜinterp(ᶠu³ʲs.:($$j))) - ᶜtke⁰)
10421042
end
10431043

1044-
sfc_tke = Fields.level(Base.materialize(ᶜtke⁰), 1)
1044+
sfc_tke = Fields.level(ᶜtke⁰, 1)
10451045
z_sfc = Fields.level(Fields.coordinate_field(Y.f).z, half)
10461046
@. ᶜmixing_length_tuple = mixing_length(
10471047
params,

src/cache/precipitation_precomputed_quantities.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -474,24 +474,28 @@ function set_precipitation_surface_fluxes!(
474474
sfc_ρ = @. lazy(int_ρ * int_J / sfc_J)
475475

476476
# Constant extrapolation to surface, consistent with simple downwinding
477-
ᶜq_rai = ᶜspecific(Y.c.ρq_rai, Y.c.ρ)
478-
ᶜq_sno = ᶜspecific(Y.c.ρq_sno, Y.c.ρ)
479-
ᶜq_liq = ᶜspecific(Y.c.ρq_liq, Y.c.ρ)
480-
ᶜq_ice = ᶜspecific(Y.c.ρq_ice, Y.c.ρ)
477+
ᶜq_rai = p.scratch.ᶜtemp_scalar
478+
ᶜq_sno = p.scratch.ᶜtemp_scalar_2
479+
ᶜq_liq = p.scratch.ᶜtemp_scalar_3
480+
ᶜq_ice = p.scratch.ᶜtemp_scalar_4
481+
ᶜq_rai .= ᶜspecific(Y.c.ρq_rai, Y.c.ρ)
482+
ᶜq_sno .= ᶜspecific(Y.c.ρq_sno, Y.c.ρ)
483+
ᶜq_liq .= ᶜspecific(Y.c.ρq_liq, Y.c.ρ)
484+
ᶜq_ice .= ᶜspecific(Y.c.ρq_ice, Y.c.ρ)
481485
sfc_qᵣ = Fields.Field(
482-
Fields.field_values(Fields.level(Base.materialize(ᶜq_rai), 1)),
486+
Fields.field_values(Fields.level(ᶜq_rai, 1)),
483487
sfc_space,
484488
)
485489
sfc_qₛ = Fields.Field(
486-
Fields.field_values(Fields.level(Base.materialize(ᶜq_sno), 1)),
490+
Fields.field_values(Fields.level(ᶜq_sno, 1)),
487491
sfc_space,
488492
)
489493
sfc_qₗ = Fields.Field(
490-
Fields.field_values(Fields.level(Base.materialize(ᶜq_liq), 1)),
494+
Fields.field_values(Fields.level(ᶜq_liq, 1)),
491495
sfc_space,
492496
)
493497
sfc_qᵢ = Fields.Field(
494-
Fields.field_values(Fields.level(Base.materialize(ᶜq_ice), 1)),
498+
Fields.field_values(Fields.level(ᶜq_ice, 1)),
495499
sfc_space,
496500
)
497501
sfc_wᵣ = Fields.Field(Fields.field_values(Fields.level(ᶜwᵣ, 1)), sfc_space)

src/cache/prognostic_edmf_precomputed_quantities.jl

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,13 @@ NVTX.@annotate function set_prognostic_edmf_precomputed_quantities_environment!(
2424
(; ᶜp, ᶜK) = p.precomputed
2525
(; ᶠu₃⁰, ᶜu⁰, ᶠu³⁰, ᶜK⁰, ᶜts⁰) = p.precomputed
2626

27-
ᶜρa⁰_vals = ᶜρa⁰(Y, p)
2827
ᶜtke⁰ = ᶜspecific_tke(Y, p)
2928
set_sgs_ᶠu₃!(u₃⁰, ᶠu₃⁰, Y, turbconv_model)
3029
set_velocity_quantities!(ᶜu⁰, ᶠu³⁰, ᶜK⁰, ᶠu₃⁰, Y.c.uₕ, ᶠuₕ³)
3130
# @. ᶜK⁰ += ᶜtke⁰
3231
ᶜq_tot⁰ = ᶜspecific_env_value(Val(:q_tot), Y, p)
3332

34-
ᶜmse⁰ = p.scratch.ᶜtemp_scalar_2
35-
ᶜmse⁰ .= ᶜspecific_env_mse(Y, p)
33+
ᶜmse⁰ = ᶜspecific_env_mse(Y, p)
3634

3735
if p.atmos.moisture_model isa NonEquilMoistModel &&
3836
p.atmos.microphysics_model isa Microphysics1Moment
@@ -184,7 +182,7 @@ NVTX.@annotate function set_prognostic_edmf_precomputed_quantities_bottom_bc!(
184182
),
185183
)
186184
ᶜh_tot_int_val =
187-
Fields.field_values(Fields.level(Base.materialize(ᶜh_tot), 1))
185+
Fields.field_values(Fields.level(ᶜh_tot, 1))
188186
ᶜK_int_val = Fields.field_values(Fields.level(ᶜK, 1))
189187
ᶜmseʲ_int_val = Fields.field_values(Fields.level(ᶜmseʲ, 1))
190188
@. ᶜmseʲ_int_val = sgs_scalar_first_interior_bc(
@@ -203,7 +201,7 @@ NVTX.@annotate function set_prognostic_edmf_precomputed_quantities_bottom_bc!(
203201

204202
ᶜq_tot = ᶜspecific(Y.c.ρq_tot, Y.c.ρ)
205203
ᶜq_tot_int_val =
206-
Fields.field_values(Fields.level(Base.materialize(ᶜq_tot), 1))
204+
Fields.field_values(Fields.level(ᶜq_tot, 1))
207205
ᶜq_totʲ_int_val = Fields.field_values(Fields.level(ᶜq_totʲ, 1))
208206
@. ᶜq_totʲ_int_val = sgs_scalar_first_interior_bc(
209207
ᶜz_int_val - z_sfc_val,
@@ -225,22 +223,22 @@ NVTX.@annotate function set_prognostic_edmf_precomputed_quantities_bottom_bc!(
225223
ᶜq_rai = ᶜspecific(Y.c.ρq_rai, Y.c.ρ)
226224
ᶜq_sno = ᶜspecific(Y.c.ρq_sno, Y.c.ρ)
227225
ᶜq_liq_int_val =
228-
Fields.field_values(Fields.level(Base.materialize(ᶜq_liq), 1))
226+
Fields.field_values(Fields.level(ᶜq_liq, 1))
229227
ᶜq_liqʲ_int_val = Fields.field_values(Fields.level(ᶜq_liqʲ, 1))
230228
@. ᶜq_liqʲ_int_val = ᶜq_liq_int_val
231229

232230
ᶜq_ice_int_val =
233-
Fields.field_values(Fields.level(Base.materialize(ᶜq_ice), 1))
231+
Fields.field_values(Fields.level(ᶜq_ice, 1))
234232
ᶜq_iceʲ_int_val = Fields.field_values(Fields.level(ᶜq_iceʲ, 1))
235233
@. ᶜq_iceʲ_int_val = ᶜq_ice_int_val
236234

237235
ᶜq_rai_int_val =
238-
Fields.field_values(Fields.level(Base.materialize(ᶜq_rai), 1))
236+
Fields.field_values(Fields.level(ᶜq_rai, 1))
239237
ᶜq_raiʲ_int_val = Fields.field_values(Fields.level(ᶜq_raiʲ, 1))
240238
@. ᶜq_raiʲ_int_val = ᶜq_rai_int_val
241239

242240
ᶜq_sno_int_val =
243-
Fields.field_values(Fields.level(Base.materialize(ᶜq_sno), 1))
241+
Fields.field_values(Fields.level(ᶜq_sno, 1))
244242
ᶜq_snoʲ_int_val = Fields.field_values(Fields.level(ᶜq_snoʲ, 1))
245243
@. ᶜq_snoʲ_int_val = ᶜq_sno_int_val
246244
end
@@ -492,7 +490,7 @@ NVTX.@annotate function set_prognostic_edmf_precomputed_quantities_explicit_clos
492490
(1 / 2 * norm_sqr(ᶜinterp(ᶠu³⁰) - ᶜinterp(ᶠu³ʲs.:($$j))) - ᶜtke⁰)
493491
end
494492

495-
sfc_tke = Fields.level(Base.materialize(ᶜtke⁰), 1)
493+
sfc_tke = Fields.level(ᶜtke⁰, 1)
496494
@. ᶜmixing_length_tuple = mixing_length(
497495
p.params,
498496
ustar,
@@ -516,7 +514,7 @@ NVTX.@annotate function set_prognostic_edmf_precomputed_quantities_explicit_clos
516514

517515
ρatke_flux_values = Fields.field_values(ρatke_flux)
518516
ρa_sfc_values =
519-
Fields.field_values(Fields.level(Base.materialize(ᶜρa⁰_vals), 1)) # TODO: replace by surface value
517+
Fields.field_values(Fields.level(ᶜρa⁰_vals, 1)) # TODO: replace by surface value
520518
ustar_values = Fields.field_values(ustar)
521519
sfc_local_geometry_values = Fields.field_values(
522520
Fields.level(Fields.local_geometry_field(Y.f), half),

src/prognostic_equations/edmfx_entr_detr.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,8 +530,7 @@ function edmfx_entr_detr_tendency!(Yₜ, Y, p, t, turbconv_model::PrognosticEDMF
530530
(; ᶜturb_entrʲs, ᶜentrʲs, ᶜdetrʲs) = p.precomputed
531531
(; ᶠu₃⁰) = p.precomputed
532532

533-
ᶜmse⁰ = p.scratch.ᶜtemp_scalar
534-
ᶜmse⁰ .= ᶜspecific_env_mse(Y, p)
533+
ᶜmse⁰ = ᶜspecific_env_mse(Y, p)
535534
if p.atmos.moisture_model isa NonEquilMoistModel &&
536535
p.atmos.microphysics_model isa Microphysics1Moment
537536
ᶜq_liq⁰ = ᶜspecific_env_value(Val(:q_liq), Y, p)

src/prognostic_equations/edmfx_sgs_flux.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ function edmfx_sgs_mass_flux_tendency!(
4444
(; ᶠu³ʲs, ᶜKʲs, ᶜρʲs) = p.precomputed
4545
(; ᶠu³⁰, ᶜK⁰, ᶜts⁰, ᶜts) = p.precomputed
4646
thermo_params = CAP.thermodynamics_params(p.params)
47-
ᶜρ⁰ = @. TD.air_density(thermo_params, ᶜts⁰)
47+
ᶜρ⁰ = @. lazy(TD.air_density(thermo_params, ᶜts⁰))
4848
ᶜρa⁰_vals = ᶜρa⁰(Y, p)
4949
(; dt) = p
5050
ᶜJ = Fields.local_geometry_field(Y.c).J
@@ -80,8 +80,7 @@ function edmfx_sgs_mass_flux_tendency!(
8080
# Add the environment fluxes
8181
@. ᶠu³_diff = ᶠu³⁰ - ᶠu³
8282

83-
ᶜmse⁰ = p.scratch.ᶜtemp_scalar_2
84-
ᶜmse⁰ .= ᶜspecific_env_mse(Y, p)
83+
ᶜmse⁰ = ᶜspecific_env_mse(Y, p)
8584
@. ᶜa_scalar = (ᶜmse⁰ + ᶜK⁰ - ᶜh_tot) * draft_area(ᶜρa⁰_vals, ᶜρ⁰)
8685
vtt = vertical_transport(
8786
ᶜρ⁰,
@@ -420,8 +419,7 @@ function edmfx_sgs_diffusive_flux_tendency!(
420419
bottom = Operators.SetValue(C3(FT(0))),
421420
)
422421

423-
ᶜmse⁰ = p.scratch.ᶜtemp_scalar_2
424-
ᶜmse⁰ .= ᶜspecific_env_mse(Y, p)
422+
ᶜmse⁰ = ᶜspecific_env_mse(Y, p)
425423
@. Yₜ.c.ρe_tot -= ᶜdivᵥ_ρe_tot(-(ᶠρaK_h * ᶠgradᵥ(ᶜmse⁰ + ᶜK⁰)))
426424
if use_prognostic_tke(turbconv_model)
427425
# Turbulent TKE transport (diffusion)

src/utils/variable_manipulations.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ function ᶜenv_value(
269269
grid_scale_value,
270270
f_draft,
271271
gs,
272-
turbconv_model::PrognosticEDMFX,
272+
turbconv_model::PrognosticEDMFX
273273
)
274274
return @. lazy(grid_scale_value - draft_sum(f_draft, gs.sgsʲs))
275275
end
@@ -313,10 +313,10 @@ function ᶜspecific_env_value(::Val{χ_name}, Y, p) where {χ_name}
313313
# environment density-area-weighted mse (`ρa⁰χ⁰`).
314314
# Numerator: ρa⁰χ⁰ = ρχ - (Σ ρaʲ * χʲ)
315315
if turbconv_model isa PrognosticEDMFX
316-
# Numerator: ρa⁰χ⁰ = ρχ - (Σ sgsʲ.ρa * sgsʲ.χ)
316+
#Numerator: ρa⁰χ⁰ = ρχ - (Σ sgsʲ.ρa * sgsʲ.χ)
317317
ᶜρaχ⁰ = ᶜenv_value(
318318
ᶜρχ,
319-
sgsʲ -> getproperty(sgsʲ, :ρa) * getproperty(sgsʲ, χ_name),
319+
sgsʲ -> getfield(sgsʲ, :ρa) * getfield(sgsʲ, χ_name),
320320
Y.c,
321321
turbconv_model,
322322
)
@@ -326,7 +326,7 @@ function ᶜspecific_env_value(::Val{χ_name}, Y, p) where {χ_name}
326326
n = n_mass_flux_subdomains(turbconv_model)
327327

328328
# Σ ρaʲ * χʲ
329-
ᶜρaχʲs_sum = p.scratch.ᶜtemp_scalar
329+
ᶜρaχʲs_sum = p.scratch.ᶜtemp_scalar_3
330330
@. ᶜρaχʲs_sum = 0
331331
for j in 1:n
332332
ᶜρaʲ = p.precomputed.ᶜρaʲs.:($j)
@@ -370,9 +370,10 @@ Returns:
370370

371371
function ᶜρa⁰(Y, p)
372372
turbconv_model = p.atmos.turbconv_model
373-
373+
# ρ - Σ ρaʲ
374374
if turbconv_model isa PrognosticEDMFX
375375
return ᶜenv_value(Y.c.ρ, sgsʲ -> sgsʲ.ρa, Y.c, turbconv_model)
376+
376377
elseif turbconv_model isa DiagnosticEDMFX
377378
(; ᶜρaʲs) = p.precomputed
378379
return ᶜenv_value(Y.c.ρ, ᶜρaʲ -> ᶜρaʲ, ᶜρaʲs, turbconv_model)
@@ -464,9 +465,8 @@ function ᶜspecific_env_mse(Y, p)
464465
elseif turbconv_model isa DiagnosticEDMFX || turbconv_model isa EDOnlyEDMFX
465466

466467
n = n_mass_flux_subdomains(turbconv_model)
467-
ᶜρamseʲ_sum = p.scratch.ᶜtemp_scalar
468+
ᶜρamseʲ_sum = p.scratch.ᶜtemp_scalar_2
468469
@. ᶜρamseʲ_sum = 0
469-
# Numerator: ρa⁰mse⁰ = ρmse - (Σ ρaʲ * mseʲ)
470470
for j in 1:n
471471
ᶜρaʲ = p.precomputed.ᶜρaʲs.:($j)
472472
ᶜmseʲ = p.precomputed.ᶜmseʲs.:($j)

0 commit comments

Comments
 (0)