Skip to content

Commit 74950d0

Browse files
Merge #1264
1264: Speed up inertial gravity wave examples (part of #1263) r=charleskawczynski a=charleskawczynski This PR only includes a few commits in #1263, to see where things are going wrong. Co-authored-by: Charles Kawczynski <kawczynski.charles@gmail.com>
2 parents ac616b7 + bb97c08 commit 74950d0

File tree

2 files changed

+106
-60
lines changed

2 files changed

+106
-60
lines changed

.buildkite/pipeline.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -773,21 +773,25 @@ steps:
773773
- label: ":computer: 2D plane inertial gravity wave"
774774
key: "cpu_inertial_gravity_wave"
775775
command:
776-
- "julia --color=yes --project=examples examples/hybrid/driver.jl"
776+
- "julia --threads 8 --color=yes --project=examples examples/hybrid/driver.jl"
777777
artifact_paths:
778778
- "examples/hybrid/plane/output/inertial_gravity_wave/Float32/*"
779779
env:
780780
TEST_NAME: "plane/inertial_gravity_wave"
781+
agents:
782+
slurm_cpus_per_task: 8
781783

782784
- label: ":computer: stretched 2D plane inertial gravity wave"
783785
key: "cpu_stretch_inertial_gravity_wave"
784786
command:
785-
- "julia --color=yes --project=examples examples/hybrid/driver.jl"
787+
- "julia --threads 8 --color=yes --project=examples examples/hybrid/driver.jl"
786788
artifact_paths:
787789
- "examples/hybrid/plane/output/stretched_inertial_gravity_wave/Float32/*"
788790
env:
789791
TEST_NAME: "plane/inertial_gravity_wave"
790792
Z_STRETCH: "true"
793+
agents:
794+
slurm_cpus_per_task: 8
791795

792796
- group: "Performance"
793797
steps:

examples/hybrid/plane/inertial_gravity_wave.jl

Lines changed: 100 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,46 @@
1+
#=
2+
julia --threads=8 --project=examples
3+
ENV["TEST_NAME"] = "plane/inertial_gravity_wave"
4+
include(joinpath("examples", "hybrid", "driver.jl"))
5+
=#
16
using Printf
27
using ProgressLogging
38
using ClimaCorePlots, Plots
49

510
# Reference paper: https://rmets.onlinelibrary.wiley.com/doi/pdf/10.1002/qj.2105
611

12+
# min_λx = 2 * (x_max / x_elem) / upsampling_factor # this should include npoly
13+
# min_λz = 2 * (FT( / z_)elem) / upsampling_factor
14+
# min_λx = 2 * π / max_kx = x_max / max_ikx
15+
# min_λz = 2 * π / max_kz = 2 * z_max / max_ikz
16+
# max_ikx = x_max / min_λx = upsampling_factor * x_elem / 2
17+
# max_ikz = 2 * z_max / min_λz = upsampling_factor * z_elem
18+
function ρfb_init_coefs!(::Type{FT}, params) where {FT}
19+
(; max_ikz, max_ikx, x_max, z_max, unit_integral) = params
20+
(; ρfb_init_array, ᶜρb_init_xz) = params
21+
# Since the coefficients are for a modified domain of height 2 * z_max, the
22+
# unit integral over the domain must be multiplied by 2 to ensure correct
23+
# normalization. On the other hand, ᶜρb_init is assumed to be 0 outside of
24+
# the "true" domain, so the integral of
25+
# ᶜintegrand (`ᶜintegrand = ᶜρb_init / ᶜfourier_factor`) should not be modified.
26+
# where `ᶜfourier_factor = exp(im * (kx * x + kz * z))`.
27+
@inbounds begin
28+
Threads.@threads for ikx in (-max_ikx):max_ikx
29+
for ikz in (-max_ikz):max_ikz
30+
kx::FT = 2 * π / x_max * ikx
31+
kz::FT = 2 * π / (2 * z_max) * ikz
32+
ρfb_init_array[ikx + max_ikx + 1, ikz + max_ikz + 1] =
33+
sum(ᶜρb_init_xz) do nt
34+
(; ρ, x, z) = nt
35+
ρ / exp(im * (kx * x + kz * z))
36+
end / unit_integral
37+
38+
end
39+
end
40+
end
41+
return nothing
42+
end
43+
744
# Constants for switching between different experiment setups
845
const is_small_scale = true
946
const ᶜ𝔼_name = :ρe
@@ -76,7 +113,7 @@ function discrete_hydrostatic_balance!(ᶠΔz, ᶜΔz, grav)
76113
ᶜp1 = Fields.level(ᶜp, 1)
77114
ᶜΔz1 = Fields.level(ᶜΔz, 1)
78115
@. ᶜp1 = p_0 * (1 - δ * ᶜΔz1 / 4) / (1 + δ * ᶜΔz1 / 4)
79-
for i in 1:(Spaces.nlevels(axes(ᶜp)) - 1)
116+
@inbounds for i in 1:(Spaces.nlevels(axes(ᶜp)) - 1)
80117
ᶜpi = parent(Fields.level(ᶜp, i))
81118
ᶜpi1 = parent(Fields.level(ᶜp, i + 1))
82119
ᶠΔzi1 = parent(Fields.level(ᶠΔz, Spaces.PlusHalf(i)))
@@ -150,21 +187,24 @@ function postprocessing(sol, output_dir)
150187
v′ = Y -> @. Geometry.UVVector(Y.c.uₕ).components.data.:2 - v₀
151188
w′ = Y -> @. Geometry.WVector(Y.f.w).components.data.:1
152189

153-
for iframe in (1, length(sol.t))
154-
t = sol.t[iframe]
155-
Y = sol.u[iframe]
156-
linear_solution!(Y_lin, lin_cache, t)
157-
println("Error norms at time t = $t:")
158-
for (name, f) in ((:ρ′, ρ′), (:T′, T′), (:u′, u′), (:v′, v′), (:w′, w′))
159-
var = f(Y)
160-
var_lin = f(Y_lin)
161-
strings = (
162-
norm_strings(var, var_lin, 2)...,
163-
norm_strings(var, var_lin, Inf)...,
164-
)
165-
println("ϕ = $name: ", join(strings, ", "))
190+
@time "print norms" @inbounds begin
191+
for iframe in (1, length(sol.t))
192+
t = sol.t[iframe]
193+
Y = sol.u[iframe]
194+
linear_solution!(Y_lin, lin_cache, t)
195+
println("Error norms at time t = $t:")
196+
for (name, f) in
197+
((:ρ′, ρ′), (:T′, T′), (:u′, u′), (:v′, v′), (:w′, w′))
198+
var = f(Y)
199+
var_lin = f(Y_lin)
200+
strings = (
201+
norm_strings(var, var_lin, 2)...,
202+
norm_strings(var, var_lin, Inf)...,
203+
)
204+
println("ϕ = $name: ", join(strings, ", "))
205+
end
206+
println()
166207
end
167-
println()
168208
end
169209

170210
anim_vars = (
@@ -173,24 +213,34 @@ function postprocessing(sol, output_dir)
173213
(:wprime, w′, is_small_scale ? 0.0042 : 0.0014),
174214
)
175215
anims = [Animation() for _ in 1:(3 * length(anim_vars))]
176-
@progress "Animations" for iframe in 1:length(sol.t)
177-
t = sol.t[iframe]
178-
Y = sol.u[iframe]
179-
linear_solution!(Y_lin, lin_cache, t)
180-
for (ivar, (_, f, lim)) in enumerate(anim_vars)
181-
var = f(Y)
182-
var_lin = f(Y_lin)
183-
var_rel_err = @. (var - var_lin) / (abs(var_lin) + eps(FT))
184-
# adding eps(FT) to the denominator prevents divisions by 0
185-
frame(anims[3 * ivar - 2], plot(var_lin, clim = (-lim, lim)))
186-
frame(anims[3 * ivar - 1], plot(var, clim = (-lim, lim)))
187-
frame(anims[3 * ivar], plot(var_rel_err, clim = (-10, 10)))
216+
@inbounds begin
217+
@progress "Animations" for iframe in 1:length(sol.t)
218+
t = sol.t[iframe]
219+
Y = sol.u[iframe]
220+
linear_solution!(Y_lin, lin_cache, t)
221+
for (ivar, (_, f, lim)) in enumerate(anim_vars)
222+
var = f(Y)
223+
var_lin = f(Y_lin)
224+
var_rel_err = @. (var - var_lin) / (abs(var_lin) + eps(FT))
225+
# adding eps(FT) to the denominator prevents divisions by 0
226+
frame(anims[3 * ivar - 2], plot(var_lin, clim = (-lim, lim)))
227+
frame(anims[3 * ivar - 1], plot(var, clim = (-lim, lim)))
228+
frame(anims[3 * ivar], plot(var_rel_err, clim = (-10, 10)))
229+
end
230+
end
231+
for (ivar, (name, _, _)) in enumerate(anim_vars)
232+
mp4(
233+
anims[3 * ivar - 2],
234+
joinpath(output_dir, "$(name)_lin.mp4");
235+
fps,
236+
)
237+
mp4(anims[3 * ivar - 1], joinpath(output_dir, "$name.mp4"); fps)
238+
mp4(
239+
anims[3 * ivar],
240+
joinpath(output_dir, "$(name)_rel_err.mp4");
241+
fps,
242+
)
188243
end
189-
end
190-
for (ivar, (name, _, _)) in enumerate(anim_vars)
191-
mp4(anims[3 * ivar - 2], joinpath(output_dir, "$(name)_lin.mp4"); fps)
192-
mp4(anims[3 * ivar - 1], joinpath(output_dir, "$name.mp4"); fps)
193-
mp4(anims[3 * ivar], joinpath(output_dir, "$(name)_rel_err.mp4"); fps)
194244
end
195245
end
196246

@@ -204,13 +254,7 @@ function norm_strings(var, var_lin, p)
204254
)
205255
end
206256

207-
# min_λx = 2 * (x_max / x_elem) / upsampling_factor # this should include npoly
208-
# min_λz = 2 * (FT( / z_)elem) / upsampling_factor
209-
# min_λx = 2 * π / max_kx = x_max / max_ikx
210-
# min_λz = 2 * π / max_kz = 2 * z_max / max_ikz
211-
# max_ikx = x_max / min_λx = upsampling_factor * x_elem / 2
212-
# max_ikz = 2 * z_max / min_λz = upsampling_factor * z_elem
213-
function ρfb_init_coefs(
257+
function ρfb_init_coefs_params(
214258
upsampling_factor = 3,
215259
max_ikx = upsampling_factor * x_elem ÷ 2,
216260
max_ikz = upsampling_factor * z_elem,
@@ -242,32 +286,29 @@ function ρfb_init_coefs(
242286
ᶜbretherton_factor_pρ = @. exp(-δ * ᶜz / 2)
243287
ᶜρb_init = @. ᶜρ′_init / ᶜbretherton_factor_pρ
244288
end
289+
combine(ρ, lg) = (; ρ, x = lg.coordinates.x, z = lg.coordinates.z)
290+
ᶜρb_init_xz = combine.(ᶜρb_init, ᶜlocal_geometry)
245291

246292
# Fourier coefficients of Bretherton transform of initial perturbation
247293
ρfb_init_array = Array{Complex{FT}}(undef, 2 * max_ikx + 1, 2 * max_ikz + 1)
248-
ᶜfourier_factor = Fields.Field(Complex{FT}, axes(ᶜlocal_geometry))
249-
ᶜintegrand = Fields.Field(Complex{FT}, axes(ᶜlocal_geometry))
250294
unit_integral = 2 * sum(one.(ᶜρb_init))
251-
# Since the coefficients are for a modified domain of height 2 * z_max, the
252-
# unit integral over the domain must be multiplied by 2 to ensure correct
253-
# normalization. On the other hand, ᶜρb_init is assumed to be 0 outside of
254-
# the "true" domain, so the integral of ᶜintegrand should not be modified.
255-
@progress "ρfb_init" for ikx in (-max_ikx):max_ikx,
256-
ikz in (-max_ikz):max_ikz
257-
258-
kx = 2 * π / x_max * ikx
259-
kz = 2 * π / (2 * z_max) * ikz
260-
@. ᶜfourier_factor = exp(im * (kx * ᶜx + kz * ᶜz))
261-
@. ᶜintegrand = ᶜρb_init / ᶜfourier_factor
262-
ρfb_init_array[ikx + max_ikx + 1, ikz + max_ikz + 1] =
263-
sum(ᶜintegrand) / unit_integral
264-
end
265-
return ρfb_init_array
295+
return (;
296+
ρfb_init_array,
297+
ᶜρb_init_xz,
298+
max_ikz,
299+
max_ikx,
300+
x_max,
301+
z_max,
302+
unit_integral,
303+
)
266304
end
267305

268306
function linear_solution_cache(ᶜlocal_geometry, ᶠlocal_geometry)
269307
ᶜz = ᶜlocal_geometry.coordinates.z
270308
ᶠz = ᶠlocal_geometry.coordinates.z
309+
ρfb_init_array_params = ρfb_init_coefs_params()
310+
@time "ρfb_init_coefs!" ρfb_init_coefs!(FT, ρfb_init_array_params)
311+
(; ρfb_init_array) = ρfb_init_array_params
271312
ᶜp₀ = @. p₀(ᶜz)
272313
return (;
273314
# coordinates
@@ -289,7 +330,7 @@ function linear_solution_cache(ᶜlocal_geometry, ᶠlocal_geometry)
289330
ᶠbretherton_factor_uvwT = (@. exp* ᶠz / 2)),
290331

291332
# Fourier coefficients of Bretherton transform of initial perturbation
292-
ρfb_init_array = ρfb_init_coefs(),
333+
ρfb_init_array,
293334

294335
# Fourier transform factors
295336
ᶜfourier_factor = Fields.Field(Complex{FT}, axes(ᶜlocal_geometry)),
@@ -327,7 +368,7 @@ function linear_solution!(Y, lin_cache, t)
327368
ᶜvb .= FT(0)
328369
ᶠwb .= FT(0)
329370
max_ikx, max_ikz = (size(ρfb_init_array) .- 1) 2
330-
for ikx in (-max_ikx):max_ikx, ikz in (-max_ikz):max_ikz
371+
@inbounds for ikx in (-max_ikx):max_ikx, ikz in (-max_ikz):max_ikz
331372
kx = 2 * π / x_max * ikx
332373
kz = 2 * π / (2 * z_max) * ikz
333374

@@ -405,4 +446,5 @@ function linear_solution!(Y, lin_cache, t)
405446
@. Y.c.uₕ = Geometry.Covariant12Vector(Geometry.UVVector(ᶜu, FT(0.0)))
406447
@. Y.c.uₕ.components.data.:2 .= ᶜv
407448
@. Y.f.w = Geometry.Covariant3Vector(Geometry.WVector(ᶠw))
449+
return nothing
408450
end

0 commit comments

Comments
 (0)