Skip to content

Commit 6dfa7c6

Browse files
Add igw utils module to ensure no closures
1 parent 74950d0 commit 6dfa7c6

File tree

2 files changed

+176
-128
lines changed

2 files changed

+176
-128
lines changed

examples/hybrid/plane/inertial_gravity_wave.jl

Lines changed: 22 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -9,37 +9,8 @@ using ClimaCorePlots, Plots
99

1010
# Reference paper: https://rmets.onlinelibrary.wiley.com/doi/pdf/10.1002/qj.2105
1111

12-
# min_λx = 2 * (x_max / x_elem) / upsampling_factor # this should include npoly
13-
# min_λz = 2 * (FT( / z_)elem) / upsampling_factor
14-
# min_λx = 2 * π / max_kx = x_max / max_ikx
15-
# min_λz = 2 * π / max_kz = 2 * z_max / max_ikz
16-
# max_ikx = x_max / min_λx = upsampling_factor * x_elem / 2
17-
# max_ikz = 2 * z_max / min_λz = upsampling_factor * z_elem
18-
function ρfb_init_coefs!(::Type{FT}, params) where {FT}
19-
(; max_ikz, max_ikx, x_max, z_max, unit_integral) = params
20-
(; ρfb_init_array, ᶜρb_init_xz) = params
21-
# Since the coefficients are for a modified domain of height 2 * z_max, the
22-
# unit integral over the domain must be multiplied by 2 to ensure correct
23-
# normalization. On the other hand, ᶜρb_init is assumed to be 0 outside of
24-
# the "true" domain, so the integral of
25-
# ᶜintegrand (`ᶜintegrand = ᶜρb_init / ᶜfourier_factor`) should not be modified.
26-
# where `ᶜfourier_factor = exp(im * (kx * x + kz * z))`.
27-
@inbounds begin
28-
Threads.@threads for ikx in (-max_ikx):max_ikx
29-
for ikz in (-max_ikz):max_ikz
30-
kx::FT = 2 * π / x_max * ikx
31-
kz::FT = 2 * π / (2 * z_max) * ikz
32-
ρfb_init_array[ikx + max_ikx + 1, ikz + max_ikz + 1] =
33-
sum(ᶜρb_init_xz) do nt
34-
(; ρ, x, z) = nt
35-
ρ / exp(im * (kx * x + kz * z))
36-
end / unit_integral
37-
38-
end
39-
end
40-
end
41-
return nothing
42-
end
12+
include("intertial_gravity_wave_utils.jl")
13+
import .InertialGravityWaveUtils as IGWU
4314

4415
# Constants for switching between different experiment setups
4516
const is_small_scale = true
@@ -191,7 +162,7 @@ function postprocessing(sol, output_dir)
191162
for iframe in (1, length(sol.t))
192163
t = sol.t[iframe]
193164
Y = sol.u[iframe]
194-
linear_solution!(Y_lin, lin_cache, t)
165+
IGWU.linear_solution!(Y_lin, lin_cache, t, FT)
195166
println("Error norms at time t = $t:")
196167
for (name, f) in
197168
((:ρ′, ρ′), (:T′, T′), (:u′, u′), (:v′, v′), (:w′, w′))
@@ -213,11 +184,12 @@ function postprocessing(sol, output_dir)
213184
(:wprime, w′, is_small_scale ? 0.0042 : 0.0014),
214185
)
215186
anims = [Animation() for _ in 1:(3 * length(anim_vars))]
187+
@info "Creating animation with $(length(sol.t)) frames."
216188
@inbounds begin
217189
@progress "Animations" for iframe in 1:length(sol.t)
218190
t = sol.t[iframe]
219191
Y = sol.u[iframe]
220-
linear_solution!(Y_lin, lin_cache, t)
192+
IGWU.linear_solution!(Y_lin, lin_cache, t, FT)
221193
for (ivar, (_, f, lim)) in enumerate(anim_vars)
222194
var = f(Y)
223195
var_lin = f(Y_lin)
@@ -307,10 +279,26 @@ function linear_solution_cache(ᶜlocal_geometry, ᶠlocal_geometry)
307279
ᶜz = ᶜlocal_geometry.coordinates.z
308280
ᶠz = ᶠlocal_geometry.coordinates.z
309281
ρfb_init_array_params = ρfb_init_coefs_params()
310-
@time "ρfb_init_coefs!" ρfb_init_coefs!(FT, ρfb_init_array_params)
282+
@time "ρfb_init_coefs!" IGWU.ρfb_init_coefs!(FT, ρfb_init_array_params)
311283
(; ρfb_init_array) = ρfb_init_array_params
312284
ᶜp₀ = @. p₀(ᶜz)
313285
return (;
286+
# globals
287+
R_d,
288+
ᶜ𝔼_name,
289+
x_max,
290+
z_max,
291+
p_0,
292+
cp_d,
293+
cv_d,
294+
grav,
295+
T_tri,
296+
u₀,
297+
δ,
298+
cₛ²,
299+
f,
300+
ρₛ,
301+
ᶜinterp,
314302
# coordinates
315303
ᶜx = ᶜlocal_geometry.coordinates.x,
316304
ᶠx = ᶠlocal_geometry.coordinates.x,
@@ -354,97 +342,3 @@ function linear_solution_cache(ᶜlocal_geometry, ᶠlocal_geometry)
354342
ᶜT = Fields.Field(FT, axes(ᶜlocal_geometry)),
355343
)
356344
end
357-
358-
function linear_solution!(Y, lin_cache, t)
359-
(; ᶜx, ᶠx, ᶜz, ᶠz, ᶜp₀, ᶜρ₀, ᶜu₀, ᶜv₀, ᶠw₀) = lin_cache
360-
(; ᶜbretherton_factor_pρ) = lin_cache
361-
(; ᶜbretherton_factor_uvwT, ᶠbretherton_factor_uvwT) = lin_cache
362-
(; ρfb_init_array, ᶜfourier_factor, ᶠfourier_factor) = lin_cache
363-
(; ᶜpb, ᶜρb, ᶜub, ᶜvb, ᶠwb, ᶜp, ᶜρ, ᶜu, ᶜv, ᶠw, ᶜT) = lin_cache
364-
365-
ᶜpb .= FT(0)
366-
ᶜρb .= FT(0)
367-
ᶜub .= FT(0)
368-
ᶜvb .= FT(0)
369-
ᶠwb .= FT(0)
370-
max_ikx, max_ikz = (size(ρfb_init_array) .- 1) 2
371-
@inbounds for ikx in (-max_ikx):max_ikx, ikz in (-max_ikz):max_ikz
372-
kx = 2 * π / x_max * ikx
373-
kz = 2 * π / (2 * z_max) * ikz
374-
375-
# Fourier coefficient of ᶜρb_init (for current kx and kz)
376-
ρfb_init = ρfb_init_array[ikx + max_ikx + 1, ikz + max_ikz + 1]
377-
378-
# Fourier factors, shifted by u₀ * t along the x-axis
379-
@. ᶜfourier_factor = exp(im * (kx * (ᶜx - u₀ * t) + kz * ᶜz))
380-
@. ᶠfourier_factor = exp(im * (kx * (ᶠx - u₀ * t) + kz * ᶠz))
381-
382-
# roots of a₁(s)
383-
p₁ = cₛ² * (kx^2 + kz^2 + δ^2 / 4) + f^2
384-
q₁ = grav * kx^2 * (cₛ² * δ - grav) + cₛ² * f^2 * (kz^2 + δ^2 / 4)
385-
α² = p₁ / 2 - sqrt(p₁^2 / 4 - q₁)
386-
β² = p₁ / 2 + sqrt(p₁^2 / 4 - q₁)
387-
α = sqrt(α²)
388-
β = sqrt(β²)
389-
390-
# inverse Laplace transform of s^p/((s^2 + α^2)(s^2 + β^2)) for p ∈ -1:3
391-
if α == 0
392-
L₋₁ = (β² * t^2 / 2 - 1 + cos* t)) / β^4
393-
L₀ =* t - sin* t)) / β^3
394-
else
395-
L₋₁ =
396-
(-cos* t) / α² + cos* t) / β²) / (β² - α²) + 1 / (α² * β²)
397-
L₀ = (sin* t) / α - sin* t) / β) / (β² - α²)
398-
end
399-
L₁ = (cos* t) - cos* t)) / (β² - α²)
400-
L₂ = (-sin* t) * α + sin* t) * β) / (β² - α²)
401-
L₃ = (-cos* t) * α² + cos* t) * β²) / (β² - α²)
402-
403-
# Fourier coefficients of Bretherton transforms of final perturbations
404-
C₁ = grav * (grav - cₛ² * (im * kz + δ / 2))
405-
C₂ = grav * (im * kz - δ / 2)
406-
pfb = -ρfb_init * (L₁ + L₋₁ * f^2) * C₁
407-
ρfb =
408-
ρfb_init *
409-
(L₃ + L₁ * (p₁ + C₂) + L₋₁ * f^2 * (cₛ² * (kz^2 + δ^2 / 4) + C₂))
410-
ufb = ρfb_init * L₀ * im * kx * C₁ / ρₛ
411-
vfb = -ρfb_init * L₋₁ * im * kx * f * C₁ / ρₛ
412-
wfb = -ρfb_init * (L₂ + L₀ * (f^2 + cₛ² * kx^2)) * grav / ρₛ
413-
414-
# Bretherton transforms of final perturbations
415-
@. ᶜpb += real(pfb * ᶜfourier_factor)
416-
@. ᶜρb += real(ρfb * ᶜfourier_factor)
417-
@. ᶜub += real(ufb * ᶜfourier_factor)
418-
@. ᶜvb += real(vfb * ᶜfourier_factor)
419-
@. ᶠwb += real(wfb * ᶠfourier_factor)
420-
# The imaginary components should be 0 (or at least very close to 0).
421-
end
422-
423-
# final state
424-
@. ᶜp = ᶜp₀ + ᶜpb * ᶜbretherton_factor_pρ
425-
@. ᶜρ = ᶜρ₀ + ᶜρb * ᶜbretherton_factor_pρ
426-
@. ᶜu = ᶜu₀ + ᶜub * ᶜbretherton_factor_uvwT
427-
@. ᶜv = ᶜv₀ + ᶜvb * ᶜbretherton_factor_uvwT
428-
@. ᶠw = ᶠw₀ + ᶠwb * ᶠbretherton_factor_uvwT
429-
@. ᶜT = ᶜp / (R_d * ᶜρ)
430-
431-
@. Y.c.ρ = ᶜρ
432-
if ᶜ𝔼_name == :ρθ
433-
@. Y.c.ρθ = ᶜρ * ᶜT * (p_0 / ᶜp)^(R_d / cp_d)
434-
elseif ᶜ𝔼_name == :ρe
435-
@. Y.c.ρe =
436-
ᶜρ * (
437-
cv_d * (ᶜT - T_tri) +
438-
(ᶜu^2 + ᶜv^2 + ᶜinterp(ᶠw)^2) / 2 +
439-
grav * ᶜz
440-
)
441-
elseif ᶜ𝔼_name == :ρe_int
442-
@. Y.c.ρe_int = ᶜρ * cv_d * (ᶜT - T_tri)
443-
end
444-
# NOTE: The following two lines are a temporary workaround b/c Covariant12Vector won't accept a non-zero second component in an XZ-space.
445-
# So we temporarily set it to zero and then reassign its intended non-zero value (since in case of large-scale config ᶜv is non-zero)
446-
@. Y.c.uₕ = Geometry.Covariant12Vector(Geometry.UVVector(ᶜu, FT(0.0)))
447-
@. Y.c.uₕ.components.data.:2 .= ᶜv
448-
@. Y.f.w = Geometry.Covariant3Vector(Geometry.WVector(ᶠw))
449-
return nothing
450-
end
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
module InertialGravityWaveUtils
2+
3+
import ClimaCore.Geometry as Geometry
4+
5+
# min_λx = 2 * (x_max / x_elem) / upsampling_factor # this should include npoly
6+
# min_λz = 2 * (FT( / z_)elem) / upsampling_factor
7+
# min_λx = 2 * π / max_kx = x_max / max_ikx
8+
# min_λz = 2 * π / max_kz = 2 * z_max / max_ikz
9+
# max_ikx = x_max / min_λx = upsampling_factor * x_elem / 2
10+
# max_ikz = 2 * z_max / min_λz = upsampling_factor * z_elem
11+
function ρfb_init_coefs!(::Type{FT}, params) where {FT}
12+
(; max_ikz, max_ikx, x_max, z_max, unit_integral) = params
13+
(; ρfb_init_array, ᶜρb_init_xz) = params
14+
# Since the coefficients are for a modified domain of height 2 * z_max, the
15+
# unit integral over the domain must be multiplied by 2 to ensure correct
16+
# normalization. On the other hand, ᶜρb_init is assumed to be 0 outside of
17+
# the "true" domain, so the integral of
18+
# ᶜintegrand (`ᶜintegrand = ᶜρb_init / ᶜfourier_factor`) should not be modified.
19+
# where `ᶜfourier_factor = exp(im * (kx * x + kz * z))`.
20+
@inbounds begin
21+
Threads.@threads for ikx in (-max_ikx):max_ikx
22+
for ikz in (-max_ikz):max_ikz
23+
kx::FT = 2 * π / x_max * ikx
24+
kz::FT = 2 * π / (2 * z_max) * ikz
25+
ρfb_init_array[ikx + max_ikx + 1, ikz + max_ikz + 1] =
26+
sum(ᶜρb_init_xz) do nt
27+
(; ρ, x, z) = nt
28+
ρ / exp(im * (kx * x + kz * z))
29+
end / unit_integral
30+
31+
end
32+
end
33+
end
34+
return nothing
35+
end
36+
37+
function Bretherton_transforms!(lin_cache, t, ::Type{FT}) where {FT}
38+
# Bretherton_transforms_partial_sums! is fastest because
39+
# we can multithread across
40+
# `Iterators.product((-max_ikx):max_ikx, (-max_ikz):max_ikz)`
41+
# and apply sums for center and face fields. Using mapreduce requires
42+
# two calls and, as a result in ~20 slower.
43+
44+
Bretherton_transforms_original!(lin_cache, t, FT)
45+
# Bretherton_transforms_partial_sums!(lin_cache, t, FT)
46+
# Bretherton_transforms_threaded_mapreduce!(lin_cache, t, FT)
47+
end
48+
49+
function Bretherton_transforms_original!(lin_cache, t, ::Type{FT}) where {FT}
50+
(; ᶜx, ᶠx, ᶜz, ᶠz) = lin_cache
51+
(; x_max, z_max, u₀, δ, cₛ², grav, f, ρₛ) = lin_cache
52+
(; ρfb_init_array, ᶜfourier_factor, ᶠfourier_factor) = lin_cache
53+
(; ᶜpb, ᶜρb, ᶜub, ᶜvb, ᶠwb) = lin_cache
54+
55+
ᶜpb .= FT(0)
56+
ᶜρb .= FT(0)
57+
ᶜub .= FT(0)
58+
ᶜvb .= FT(0)
59+
ᶠwb .= FT(0)
60+
max_ikx, max_ikz = (size(ρfb_init_array) .- 1) 2
61+
@inbounds for ikx in (-max_ikx):max_ikx, ikz in (-max_ikz):max_ikz
62+
kx = 2 * π / x_max * ikx
63+
kz = 2 * π / (2 * z_max) * ikz
64+
65+
# Fourier coefficient of ᶜρb_init (for current kx and kz)
66+
ρfb_init = ρfb_init_array[ikx + max_ikx + 1, ikz + max_ikz + 1]
67+
68+
# Fourier factors, shifted by u₀ * t along the x-axis
69+
@. ᶜfourier_factor = exp(im * (kx * (ᶜx - u₀ * t) + kz * ᶜz))
70+
@. ᶠfourier_factor = exp(im * (kx * (ᶠx - u₀ * t) + kz * ᶠz))
71+
72+
# roots of a₁(s)
73+
p₁ = cₛ² * (kx^2 + kz^2 + δ^2 / 4) + f^2
74+
q₁ = grav * kx^2 * (cₛ² * δ - grav) + cₛ² * f^2 * (kz^2 + δ^2 / 4)
75+
α² = p₁ / 2 - sqrt(p₁^2 / 4 - q₁)
76+
β² = p₁ / 2 + sqrt(p₁^2 / 4 - q₁)
77+
α = sqrt(α²)
78+
β = sqrt(β²)
79+
80+
# inverse Laplace transform of s^p/((s^2 + α^2)(s^2 + β^2)) for p ∈ -1:3
81+
if α == 0
82+
L₋₁ = (β² * t^2 / 2 - 1 + cos* t)) / β^4
83+
L₀ =* t - sin* t)) / β^3
84+
else
85+
L₋₁ =
86+
(-cos* t) / α² + cos* t) / β²) / (β² - α²) + 1 / (α² * β²)
87+
L₀ = (sin* t) / α - sin* t) / β) / (β² - α²)
88+
end
89+
L₁ = (cos* t) - cos* t)) / (β² - α²)
90+
L₂ = (-sin* t) * α + sin* t) * β) / (β² - α²)
91+
L₃ = (-cos* t) * α² + cos* t) * β²) / (β² - α²)
92+
93+
# Fourier coefficients of Bretherton transforms of final perturbations
94+
C₁ = grav * (grav - cₛ² * (im * kz + δ / 2))
95+
C₂ = grav * (im * kz - δ / 2)
96+
pfb = -ρfb_init * (L₁ + L₋₁ * f^2) * C₁
97+
ρfb =
98+
ρfb_init *
99+
(L₃ + L₁ * (p₁ + C₂) + L₋₁ * f^2 * (cₛ² * (kz^2 + δ^2 / 4) + C₂))
100+
ufb = ρfb_init * L₀ * im * kx * C₁ / ρₛ
101+
vfb = -ρfb_init * L₋₁ * im * kx * f * C₁ / ρₛ
102+
wfb = -ρfb_init * (L₂ + L₀ * (f^2 + cₛ² * kx^2)) * grav / ρₛ
103+
104+
# Bretherton transforms of final perturbations
105+
@. ᶜpb += real(pfb * ᶜfourier_factor)
106+
@. ᶜρb += real(ρfb * ᶜfourier_factor)
107+
@. ᶜub += real(ufb * ᶜfourier_factor)
108+
@. ᶜvb += real(vfb * ᶜfourier_factor)
109+
@. ᶠwb += real(wfb * ᶠfourier_factor)
110+
# The imaginary components should be 0 (or at least very close to 0).
111+
end
112+
return nothing
113+
end
114+
115+
function linear_solution!(Y, lin_cache, t, ::Type{FT}) where {FT}
116+
(; ᶜz, ᶜp₀, ᶜρ₀, ᶜu₀, ᶜv₀, ᶠw₀) = lin_cache
117+
(; ᶜinterp) = lin_cache
118+
(; R_d, ᶜ𝔼_name, x_max, z_max, p_0, cp_d, cv_d, grav, T_tri) = lin_cache
119+
(; ᶜbretherton_factor_pρ) = lin_cache
120+
(; ᶜbretherton_factor_uvwT, ᶠbretherton_factor_uvwT) = lin_cache
121+
(; ᶜpb, ᶜρb, ᶜub, ᶜvb, ᶠwb, ᶜp, ᶜρ, ᶜu, ᶜv, ᶠw, ᶜT) = lin_cache
122+
123+
Bretherton_transforms!(lin_cache, t, FT)
124+
125+
# final state
126+
@. ᶜp = ᶜp₀ + ᶜpb * ᶜbretherton_factor_pρ
127+
@. ᶜρ = ᶜρ₀ + ᶜρb * ᶜbretherton_factor_pρ
128+
@. ᶜu = ᶜu₀ + ᶜub * ᶜbretherton_factor_uvwT
129+
@. ᶜv = ᶜv₀ + ᶜvb * ᶜbretherton_factor_uvwT
130+
@. ᶠw = ᶠw₀ + ᶠwb * ᶠbretherton_factor_uvwT
131+
@. ᶜT = ᶜp / (R_d * ᶜρ)
132+
133+
@. Y.c.ρ = ᶜρ
134+
if ᶜ𝔼_name == :ρθ
135+
@. Y.c.ρθ = ᶜρ * ᶜT * (p_0 / ᶜp)^(R_d / cp_d)
136+
elseif ᶜ𝔼_name == :ρe
137+
@. Y.c.ρe =
138+
ᶜρ * (
139+
cv_d * (ᶜT - T_tri) +
140+
(ᶜu^2 + ᶜv^2 + ᶜinterp(ᶠw)^2) / 2 +
141+
grav * ᶜz
142+
)
143+
elseif ᶜ𝔼_name == :ρe_int
144+
@. Y.c.ρe_int = ᶜρ * cv_d * (ᶜT - T_tri)
145+
end
146+
# NOTE: The following two lines are a temporary workaround b/c Covariant12Vector won't accept a non-zero second component in an XZ-space.
147+
# So we temporarily set it to zero and then reassign its intended non-zero value (since in case of large-scale config ᶜv is non-zero)
148+
@. Y.c.uₕ = Geometry.Covariant12Vector(Geometry.UVVector(ᶜu, FT(0.0)))
149+
@. Y.c.uₕ.components.data.:2 .= ᶜv
150+
@. Y.f.w = Geometry.Covariant3Vector(Geometry.WVector(ᶠw))
151+
return nothing
152+
end
153+
154+
end # module

0 commit comments

Comments
 (0)