Skip to content

Commit 74dce08

Browse files
committed
Add support for unevaluated compute! functions
`LazyBroadcast.jl` provides a way to return an unevaluated function. This is useful in two cases: 1. reduce code verbosity to handle the `isnothing(out)` case 2. allow clustering all the broadcasted expressions in a single place In turn, 2. is useful because it is the first step in fusing different broadcasted calls. This commit adds support for such functions.
1 parent 2c8b275 commit 74dce08

File tree

10 files changed

+309
-78
lines changed

10 files changed

+309
-78
lines changed

.buildkite/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
1212
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
1313
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
1414
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
15+
LazyBroadcast = "9dccce8e-a116-406d-9fcc-a88ed4f510c8"
1516
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
1617
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
1718
Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"

NEWS.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# NEWS
2+
23
v0.2.13
34
-------
45

@@ -13,6 +14,37 @@ the interval `[0.0, 10.0]`. If one knows that the data represents a time
1314
average, then the time of `10.0` is the time average over the interval
1415
`[0.0, 10.0]`.
1516

17+
### Support for `lazy`
18+
19+
Starting version `0.2.13`, `ClimaDiagnostics` supports diagnostic variables
20+
specified with un-evaluated expressions (as provided by
21+
[LazyBroadcast.jl](https://github.com/CliMA/LazyBroadcast.jl)).
22+
23+
Instead of
24+
```julia
25+
function compute_ta!(out, state, cache, time)
26+
if isnothing(out)
27+
return state.ta .- 273.15
28+
else
29+
out .= state.ta .- 273.15
30+
end
31+
end
32+
```
33+
You can now write
34+
```julia
35+
import LazyBroadcast: lazy
36+
37+
function compute_ta(state, cache, time)
38+
return lazy.(state.ta .- 273.15)
39+
end
40+
```
41+
Or, for `Field`s
42+
```julia
43+
function compute_ta(state, cache, time)
44+
return state.ta
45+
end
46+
```
47+
1648
v0.2.12
1749
-------
1850

@@ -130,6 +162,7 @@ v0.2.4
130162

131163
- Add `EveryCalendarDtSchedule` for schedules with calendar periods.
132164

165+
133166
v0.2.3
134167
-------
135168

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ ClimaUtilities = "0.1.22"
2424
Dates = "1"
2525
Documenter = "1"
2626
ExplicitImports = "1.6"
27+
LazyBroadcast = "1"
2728
JuliaFormatter = "1"
2829
NCDatasets = "0.14"
2930
OrderedCollections = "1.4"
@@ -41,10 +42,11 @@ ClimaTimeSteppers = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
4142
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
4243
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
4344
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
45+
LazyBroadcast = "9dccce8e-a116-406d-9fcc-a88ed4f510c8"
4446
Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
4547
ProfileCanvas = "efd6af41-a80b-495e-886c-e51b0c7d77a3"
4648
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4749
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4850

4951
[targets]
50-
test = ["Aqua", "BenchmarkTools", "ClimaTimeSteppers", "Documenter", "ExplicitImports", "JuliaFormatter", "Profile", "ProfileCanvas", "SafeTestsets", "Test"]
52+
test = ["Aqua", "BenchmarkTools", "ClimaTimeSteppers", "Documenter", "ExplicitImports", "JuliaFormatter", "LazyBroadcast", "Profile", "ProfileCanvas", "SafeTestsets", "Test"]

docs/src/developer_guide.md

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ Let us see the simplest example to accomplish this
105105
import ClimaDiagnostics: DiagnosticVariable, ScheduledDiagnostic
106106
import ClimaDiagnostics.Writers: DictWriter
107107

108-
myvar = DiagnosticVariable(; compute! = (out, u, p, t) -> u.var1)
108+
myvar = DiagnosticVariable(; compute = (u, p, t) -> u.var1)
109109

110110
myschedule = (integrator) -> maximum(integrator.u.var2) > 10.0
111111

@@ -148,7 +148,7 @@ const ALL_DIAGNOSTICS = Dict{String, DiagnosticVariable}()
148148
standard_name,
149149
units,
150150
description,
151-
compute!)
151+
compute)
152152
153153
154154
Add a new variable to the `ALL_DIAGNOSTICS` dictionary (this function mutates the state of
@@ -173,27 +173,25 @@ Keyword arguments
173173
- `comments`: More verbose explanation of what the variable is, or comments related to how
174174
it is defined or computed.
175175
176-
- `compute!`: Function that compute the diagnostic variable from the state. It has to take
177-
two arguments: the `integrator`, and a pre-allocated area of memory where to
178-
write the result of the computation. It the no pre-allocated area is
179-
available, a new one will be allocated. To avoid extra allocations, this
180-
function should perform the calculation in-place (i.e., using `.=`).
181-
176+
- `compute`: Function that computes the diagnostic variable from the state, cache, and time. The function
177+
should return a `Field` or a `Base.Broadcast.Broadcasted` expression. It should not allocate
178+
new `Field`: if you find yourself using a dot, that is a good indication you should be using
179+
`lazy`.
182180
"""
183181
function add_diagnostic_variable!(;
184182
short_name,
185183
long_name,
186184
standard_name = "",
187185
units,
188186
comments = "",
189-
compute!,
187+
compute,
190188
)
191189
haskey(ALL_DIAGNOSTICS, short_name) && @warn(
192190
"overwriting diagnostic `$short_name` entry containing fields\n" *
193191
"$(map(
194192
field -> "$(getfield(ALL_DIAGNOSTICS[short_name], field))",
195193
# We cannot really compare functions...
196-
filter(field -> field != :compute!, fieldnames(DiagnosticVariable)),
194+
filter(field -> !(field in (:compute!, :compute)), fieldnames(DiagnosticVariable)),
197195
))"
198196
)
199197

@@ -203,7 +201,7 @@ function add_diagnostic_variable!(;
203201
standard_name,
204202
units,
205203
comments,
206-
compute!,
204+
compute,
207205
)
208206

209207
"""
@@ -236,15 +234,30 @@ add_diagnostic_variable!(
236234
long_name = "Air Density",
237235
standard_name = "air_density",
238236
units = "kg m^-3",
239-
compute! = (out, state, cache, time) -> begin
240-
if isnothing(out)
241-
return state.c.ρ
242-
else
243-
out .= state.c.ρ
244-
end
245-
end,
237+
compute = (state, cache, time) -> state.c.ρ,
238+
)
239+
```
240+
241+
When writing compute functions, make them lazy with
242+
[LazyBroadcast.jl](https://github.com/CliMA/LazyBroadcast.jl) to improve
243+
performance and avoid intermediate allocations. To do that, add `LazyBroadcast`
244+
to your dependencies and import `lazy`. A slight variation of the previous
245+
example would look like
246+
247+
```julia
248+
###
249+
# Density (3d)
250+
###
251+
add_diagnostic_variable!(
252+
short_name = "rhoa",
253+
long_name = "Air Density",
254+
standard_name = "air_density",
255+
units = "kg m^-3",
256+
compute = (state, cache, time) -> lazy.(1000 .* state.c.ρ),
246257
)
247258
```
259+
Where we added the `1000` to simulate a more complex expression. If you didn't have
260+
`lazy`, the diagnostic would allocate an intermediate `Field`, severly hurting performance.
248261

249262
It is a good idea to put safeguards in place to ensure that your users will not
250263
be allowed to call diagnostics that do not make sense for the simulation they
@@ -254,26 +267,21 @@ can dispatch over that and return an error. A simple example might be
254267
###
255268
# Specific Humidity
256269
###
257-
compute_hus!(out, state, cache, time) =
258-
compute_hus!(out, state, cache, time, cache.atmos.moisture_model)
270+
compute_hus(state, cache, time) =
271+
compute_hus(state, cache, time, cache.atmos.moisture_model)
259272

260-
compute_hus!(out, state, cache, time) =
261-
compute_hus!(out, state, cache, time, cache.model.moisture_model)
262-
compute_hus!(_, _, _, _, model::T) where {T} =
273+
compute_hus(state, cache, time) =
274+
compute_hus!(state, cache, time, cache.model.moisture_model)
275+
compute_hus(_, _, _, model::T) where {T} =
263276
error("Cannot compute hus with $model")
264277

265-
function compute_hus!(
266-
out,
278+
function compute_hus(
267279
state,
268280
cache,
269281
time,
270282
moisture_model::T,
271283
) where {T <: Union{EquilMoistModel, NonEquilMoistModel}}
272-
if isnothing(out)
273-
return state.c.ρq_tot ./ state.c.ρ
274-
else
275-
out .= state.c.ρq_tot ./ state.c.ρ
276-
end
284+
return lazy.(state.c.ρq_tot ./ state.c.ρ)
277285
end
278286

279287
add_diagnostic_variable!(
@@ -282,7 +290,7 @@ add_diagnostic_variable!(
282290
standard_name = "specific_humidity",
283291
units = "kg kg^-1",
284292
comments = "Mass of all water phases per mass of air",
285-
compute! = compute_hus!,
293+
compute = compute_hus,
286294
)
287295
```
288296
This relies on dispatching over `moisture_model`. If `model` is not in

docs/src/index.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,6 @@ the Developer guide page.
2424
- Allow users to define arbitrary new diagnostics;
2525
- Trigger diagnostics on arbitrary conditions;
2626
- Save output to HDF5 or NetCDF files, or a dictionary in memory;
27-
27+
- Work with lazy expressions (such as the ones produced by
28+
[LazyBroadcast.jl](https://github.com/CliMA/LazyBroadcast.jl)).
2829

docs/src/user_guide.md

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,8 @@ Let us see how we would define a `DiagnosticVariable`
3737
```julia
3838
import ClimaDiagnostics: DiagnosticVariable
3939

40-
function compute_ta!(out, state, cache, time)
41-
if isnothing(out)
42-
return state.ta
43-
else
44-
out .= state.ta
45-
end
40+
function compute_ta(state, cache, time)
41+
return state.ta
4642
end
4743

4844
var = DiagnosticVariable(;
@@ -51,25 +47,66 @@ var = DiagnosticVariable(;
5147
standard_name = "air_temperature",
5248
comments = "Measured assuming that the air is in quantum equilibrium with the metaverse",
5349
units = "K",
54-
compute! = compute_ta!
50+
compute = compute_ta
5551
)
5652
```
5753

58-
`compute_ta!` is the key function here. It determines how the variable should be
54+
`compute_ta` is the key function here. It determines how the variable should be
5955
computed from the `state`, `cache`, and `time` of the simulation. Typically,
6056
these are packaged within an `integrator` object (e.g., `state = integrator.u`
6157
or `integrator.Y`).
6258

63-
`compute_ta!` takes another argument, `out`. `out` is an area of memory managed
64-
by `ClimaDiagnostics` that is used to reduce the number of allocations needed
65-
when working with diagnostics. The first time the diagnostic is called, an area
66-
of memory is allocated and filled with the value (this is when `out` is
67-
`nothing`). All the subsequent times, the same space is overwritten, leading to
68-
much better performance. You should follow this pattern in all your diagnostics.
59+
!!! compat "ClimaDiagnostics 0.2.13"
60+
61+
Support for `compute` was introduced in version `0.2.13`. Prior to this
62+
version, the in-place `compute!` had to be provided. In this case, `compute`
63+
has to take an extra argument, `out`. `out` is an area of memory managed by
64+
`ClimaDiagnostics` that is used to reduce the number of allocations needed
65+
when working with diagnostics. The first time the diagnostic is called, an
66+
area of memory is allocated and filled with the value (this is when `out` is
67+
`nothing`). All the subsequent times, the same space is overwritten, leading
68+
to much better performance. You should follow this pattern in all your
69+
diagnostics. This is left to developer to implement, so `compute_ta` would
70+
look like
71+
72+
```julia
73+
function compute_ta!(out, state, cache, time)
74+
if isnothing(out)
75+
return state.ta
76+
else
77+
out .= state.ta
78+
end
79+
end
80+
```
81+
82+
In general, we do not recommend implementing `compute!`, unless required for
83+
backward compatibility.
6984

70-
> Note, in the future, we hope to improve this rather clumsy way to write
71-
> diagnostics. Hopefully, at some point you will just have to write something like
72-
> `state.ta` and not worry about the `out` at all.
85+
When the expression is anything more complicated than just returning a `Field`,
86+
it is best to return an unevaluated expression represented by a
87+
`Base.Broadcast.Broadcasted` object (such as the ones produced with
88+
[LazyBroadcast.jl](https://github.com/CliMA/LazyBroadcast.jl)). Consider the
89+
following example where we want to shift the temperature to Celsius:
90+
```julia
91+
function compute_ta(state, cache, time)
92+
return state.ta .- 273.15
93+
end
94+
```
95+
96+
This `compute` function is inefficient because it allocates an entire `Field`
97+
before returning it. Instead, we can return just a recipe on how the diagnostic
98+
should be computed: Using `LazyBroadcast.jl`, the snippet above can be rewritten
99+
as
100+
```julia
101+
import LazyBroadcast: lazy
102+
103+
function compute_ta(state, cache, time)
104+
return lazy.(state.ta .- 273.15)
105+
end
106+
```
107+
The return value of `compute_ta` is a `Base.Broadcast.Broadcasted` object and
108+
`ClimaDiagnostics` knows how to handle it efficiently, avoiding the intermediate
109+
allocations.
73110

74111
A `DiagnosticVariable` defines what a variable is and how to compute it, but
75112
does not specify when to compute/output it. For that, we need

0 commit comments

Comments
 (0)