Skip to content

Commit e04ca27

Browse files
Try to fix gpu inference
1 parent 6c232af commit e04ca27

File tree

1 file changed

+59
-42
lines changed

1 file changed

+59
-42
lines changed

src/prognostic_equations/remaining_tendency.jl

Lines changed: 59 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -24,32 +24,49 @@ end
2424
prognostic_nt(::Val{names}, ::Val{K}, vals...) where {names, K} =
2525
sorted_nt(Val(names), Val(K), vals...)
2626

27+
@generated function construct_tends(::Val{names}, f, ᶜY, ᶠY, p, t) where {names}
28+
calls = []
29+
for name in names
30+
push!(calls, :(f(Val($name), ᶜY, ᶠY, p, t)))
31+
end
32+
return quote
33+
$(calls...)
34+
end
35+
end
36+
2737
function ᶜremaining_tendency(ᶜY, ᶠY, p, t)
2838
names = propertynames(ᶜY)
29-
tends = (;
30-
ᶜremaining_tendency_ρ(ᶜY, ᶠY, p, t)...,
31-
ᶜremaining_tendency_uₕ(ᶜY, ᶠY, p, t)...,
32-
ᶜremaining_tendency_ρe_tot(ᶜY, ᶠY, p, t)...,
33-
ᶜremaining_tendency_ρq_tot(ᶜY, ᶠY, p, t)...,
34-
ᶜremaining_tendency_ρq_liq(ᶜY, ᶠY, p, t)...,
35-
ᶜremaining_tendency_ρq_ice(ᶜY, ᶠY, p, t)...,
36-
ᶜremaining_tendency_ρq_rai(ᶜY, ᶠY, p, t)...,
37-
ᶜremaining_tendency_ρq_sno(ᶜY, ᶠY, p, t)...,
38-
ᶜremaining_tendency_sgs⁰(ᶜY, ᶠY, p, t)...,
39-
ᶜremaining_tendency_sgsʲs(ᶜY, ᶠY, p, t)...,
40-
)
41-
return lazy.(prognostic_nt.(Val(names), Val(keys(tends)), values(tends)...))
39+
tends = construct_tends(Val(names), ᶜremaining_tendency, ᶜY, ᶠY, p, t)
40+
# tends = (;
41+
# ᶜremaining_tendency(Val(:ρ), ᶜY, ᶠY, p, t)...,
42+
# ᶜremaining_tendency(Val(:uₕ), ᶜY, ᶠY, p, t)...,
43+
# ᶜremaining_tendency(Val(:ρe_tot), ᶜY, ᶠY, p, t)...,
44+
# ᶜremaining_tendency(Val(:ρq_tot), ᶜY, ᶠY, p, t)...,
45+
# ᶜremaining_tendency(Val(:ρq_liq), ᶜY, ᶠY, p, t)...,
46+
# ᶜremaining_tendency(Val(:ρq_ice), ᶜY, ᶠY, p, t)...,
47+
# ᶜremaining_tendency(Val(:ρq_rai), ᶜY, ᶠY, p, t)...,
48+
# ᶜremaining_tendency(Val(:ρq_sno), ᶜY, ᶠY, p, t)...,
49+
# ᶜremaining_tendency(Val(:sgs⁰), ᶜY, ᶠY, p, t)...,
50+
# ᶜremaining_tendency(Val(:sgsʲs), ᶜY, ᶠY, p, t)...,
51+
# )
52+
53+
# return lazy.(prognostic_nt.(Val(names), Val(keys(tends)), values(tends)...))
54+
prog_nt = (NamedTuple{names} tuple)
55+
return lazy.(prog_nt.(tends))
4256
end
4357
function ᶠremaining_tendency(ᶜY, ᶠY, p, t)
4458
names = propertynames(ᶠY)
45-
tends = (;
46-
ᶠremaining_tendency_u₃(ᶜY, ᶠY, p, t)...,
47-
ᶠremaining_tendency_sgsʲs(ᶜY, ᶠY, p, t)...,
48-
)
49-
return lazy.(prognostic_nt.(Val(names), Val(keys(tends)), values(tends)...))
59+
tends = construct_tends(Val(names), ᶠremaining_tendency, ᶜY, ᶠY, p, t)
60+
# tends = (;
61+
# ᶠremaining_tendency(Val(:u₃), ᶜY, ᶠY, p, t)...,
62+
# ᶠremaining_tendency(Val(:sgsʲs), ᶜY, ᶠY, p, t)...,
63+
# )
64+
# return lazy.(prognostic_nt.(Val(names), Val(keys(tends)), values(tends)...))
65+
prog_nt = (NamedTuple{names} tuple)
66+
return lazy.(prog_nt.(tends))
5067
end
5168
using ClimaCore.RecursiveApply: rzero
52-
function ᶜremaining_tendency_ρ(ᶜY, ᶠY, p, t)
69+
function ᶜremaining_tendency(::Val{:ρ}, ᶜY, ᶠY, p, t)
5370
in propertynames(ᶜY) || return ()
5471
∑tendencies = zero(eltype(ᶜY.ρ))
5572
ᶜJ = Fields.local_geometry_field(ᶜY).J
@@ -60,25 +77,25 @@ function ᶜremaining_tendency_ρ(ᶜY, ᶠY, p, t)
6077
ᶜwₜqₜ = compute_ᶜwₜqₜ(ᶜY, ᶠY, p, t)
6178
∑tendencies = sub_tend(∑tendencies, water_adv(ᶜρ, ᶜJ, ᶠJ, ᶜwₜqₜ))
6279
end
63-
return (;ρ=∑tendencies)
80+
return ∑tendencies
6481
end
6582

6683
add_tend(∑tends, t) = lazy.(∑tends .+ t)
6784
add_tend(∑tends, ::NullBroadcasted) = ∑tends
6885
sub_tend(∑tends, ::NullBroadcasted) = ∑tends
6986
sub_tend(∑tends, t) = lazy.(∑tends .- t)
7087

71-
function ᶜremaining_tendency_uₕ(ᶜY, ᶠY, p, t)
88+
function ᶜremaining_tendency(::Val{:uₕ}, ᶜY, ᶠY, p, t)
7289
:uₕ in propertynames(ᶜY) || return ()
7390
∑tendencies = zero(eltype(ᶜY.uₕ))
7491
(; viscous_sponge, rayleigh_sponge) = p.atmos
7592
ᶜuₕ = ᶜY.uₕ
7693
∑tendencies = add_tend(∑tendencies, viscous_sponge_tendency_uₕ(ᶜuₕ, viscous_sponge))
7794
∑tendencies = add_tend(∑tendencies, rayleigh_sponge_tendency_uₕ(ᶜuₕ, rayleigh_sponge))
7895

79-
return (;uₕ=∑tendencies)
96+
return ∑tendencies
8097
end
81-
function ᶜremaining_tendency_ρe_tot(ᶜY, ᶠY, p, t)
98+
function ᶜremaining_tendency(::Val{:ρe_tot}, ᶜY, ᶠY, p, t)
8299
:ρe_tot in propertynames(ᶜY) || return ()
83100
∑tendencies = zero(eltype(ᶜY.ρe_tot))
84101

@@ -112,9 +129,9 @@ function ᶜremaining_tendency_ρe_tot(ᶜY, ᶠY, p, t)
112129
end
113130

114131
∑tendencies = add_tend(∑tendencies, viscous_sponge_tendency_ρe_tot(ᶜρ, ᶜh_tot, viscous_sponge))
115-
return (;ρe_tot=∑tendencies)
132+
return ∑tendencies
116133
end
117-
function ᶜremaining_tendency_ρq_tot(ᶜY, ᶠY, p, t)
134+
function ᶜremaining_tendency(::Val{:ρq_tot}, ᶜY, ᶠY, p, t)
118135
:ρq_tot in propertynames(ᶜY) || return ()
119136
∑tendencies = zero(eltype(ᶜY.ρq_tot))
120137
ᶜJ = Fields.local_geometry_field(ᶜY).J
@@ -138,39 +155,39 @@ function ᶜremaining_tendency_ρq_tot(ᶜY, ᶠY, p, t)
138155
∑tendencies = sub_tend(∑tendencies, vtt_central)
139156
end
140157

141-
return (;ρq_tot=∑tendencies)
158+
return ∑tendencies
142159
end
143-
function ᶜremaining_tendency_ρq_liq(ᶜY, ᶠY, p, t)
160+
function ᶜremaining_tendency(::Val{:ρq_liq}, ᶜY, ᶠY, p, t)
144161
:ρq_liq in propertynames(ᶜY) || return ()
145162
∑tendencies = zero(eltype(ᶜY.ρq_liq))
146-
return (; ρq_liq = ∑tendencies)
163+
return ∑tendencies
147164
end
148-
function ᶜremaining_tendency_ρq_ice(ᶜY, ᶠY, p, t)
165+
function ᶜremaining_tendency(::Val{:ρq_ice}, ᶜY, ᶠY, p, t)
149166
:ρq_ice in propertynames(ᶜY) || return ()
150167
∑tendencies = zero(eltype(ᶜY.ρq_ice))
151-
return (; ρq_ice = ∑tendencies)
168+
return ∑tendencies
152169
end
153-
function ᶜremaining_tendency_ρq_rai(ᶜY, ᶠY, p, t)
170+
function ᶜremaining_tendency(::Val{:ρq_rai}, ᶜY, ᶠY, p, t)
154171
:ρq_rai in propertynames(ᶜY) || return ()
155172
∑tendencies = zero(eltype(ᶜY.ρq_rai))
156-
return (; ρq_rai = ∑tendencies)
173+
return ∑tendencies
157174
end
158-
function ᶜremaining_tendency_ρq_sno(ᶜY, ᶠY, p, t)
175+
function ᶜremaining_tendency(::Val{:ρq_sno}, ᶜY, ᶠY, p, t)
159176
:ρq_sno in propertynames(ᶜY) || return ()
160177
∑tendencies = zero(eltype(ᶜY.ρq_sno))
161-
return (; ρq_sno = ∑tendencies)
178+
return ∑tendencies
162179
end
163-
function ᶜremaining_tendency_sgsʲs(ᶜY, ᶠY, p, t)
180+
function ᶜremaining_tendency(::Val{:sgsʲs}, ᶜY, ᶠY, p, t)
164181
:sgsʲs in propertynames(ᶜY) || return ()
165182
∑tendencies = rzero(eltype(ᶜY.sgsʲs))
166-
return (; sgsʲs = ∑tendencies)
183+
return ∑tendencies
167184
end
168-
function ᶜremaining_tendency_sgs⁰(ᶜY, ᶠY, p, t)
185+
function ᶜremaining_tendency(::Val{:sgs⁰}, ᶜY, ᶠY, p, t)
169186
:sgs⁰ in propertynames(ᶜY) || return ()
170187
∑tendencies = rzero(eltype(ᶜY.sgs⁰))
171-
return (; sgs⁰ = ∑tendencies)
188+
return ∑tendencies
172189
end
173-
function ᶠremaining_tendency_u₃(ᶜY, ᶠY, p, t)
190+
function ᶠremaining_tendency(::Val{:u₃}, ᶜY, ᶠY, p, t)
174191
:u₃ in propertynames(ᶠY) || return ()
175192
∑tendencies = zero(eltype(ᶠY.u₃))
176193
(; viscous_sponge) = p.atmos
@@ -179,12 +196,12 @@ function ᶠremaining_tendency_u₃(ᶜY, ᶠY, p, t)
179196
ᶠuₕ³ = compute_ᶠuₕ³(ᶜuₕ, ᶜρ)
180197
ᶠu₃ = compute_ᶠu₃_with_bcs(ᶠY.u₃, ᶠuₕ³)
181198
∑tendencies = add_tend(∑tendencies, viscous_sponge_tendency_u₃(ᶠu₃, viscous_sponge))
182-
return (;u₃=∑tendencies)
199+
return ∑tendencies
183200
end
184-
function ᶠremaining_tendency_sgsʲs(ᶜY, ᶠY, p, t)
201+
function ᶠremaining_tendency(::Val{:sgsʲs}, ᶜY, ᶠY, p, t)
185202
:sgsʲs in propertynames(ᶠY) || return ()
186203
∑tendencies = rzero(eltype(ᶠY.sgsʲs))
187-
return (; sgsʲs = ∑tendencies)
204+
return ∑tendencies
188205
end
189206

190207

0 commit comments

Comments
 (0)