Skip to content

Commit 1ca7ad5

Browse files
Improve DiagnosticsHandler field concreteness
Update src/clima_diagnostics.jl Co-authored-by: Gabriele Bozzola <sbozzolator@gmail.com> Bump patch version
1 parent 7b08031 commit 1ca7ad5

File tree

3 files changed

+64
-35
lines changed

3 files changed

+64
-35
lines changed

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
main
44
-------
5+
6+
v0.2.7
7+
-------
8+
59
## Bug fixes
610

711
- `scheduled_diagnostics` are now internally saved as vectors instead of tuples.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ClimaDiagnostics"
22
uuid = "1ecacbb8-0713-4841-9a07-eb5aa8a2d53f"
33
authors = ["Gabriele Bozzola <gbozzola@caltech.edu>"]
4-
version = "0.2.6"
4+
version = "0.2.7"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

src/clima_diagnostics.jl

Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,43 @@ include("reduction_identities.jl")
1313
A struct that contains the scheduled diagnostics, ancillary data and areas of memory needed
1414
to store and accumulate results.
1515
"""
16-
struct DiagnosticsHandler{SD, STORAGE <: Dict, ACC <: Dict, COUNT <: Dict}
16+
struct DiagnosticsHandler{SD, V <: Vector{Int}, STORAGE, ACC <: Dict, COUNT}
1717
"""An iterable with the `ScheduledDiagnostic`s that are scheduled."""
1818
scheduled_diagnostics::SD
1919

20-
"""Dictionary that maps a given `ScheduledDiagnostic` to a potentially pre-allocated
20+
"""A Vector containing keys to index into `scheduled_diagnostics`."""
21+
scheduled_diagnostics_keys::V
22+
23+
"""Container holding a potentially pre-allocated
2124
area of memory where to save the newly computed results."""
2225
storage::STORAGE
2326

24-
"""Dictionary that maps a given `ScheduledDiagnostic` to a potentially pre-allocated
27+
"""Container holding a potentially pre-allocated
2528
area of memory where to accumulate results."""
2629
accumulators::ACC
2730

28-
"""Dictionary that maps a given `ScheduledDiagnostic` to the counter that tracks how
29-
many times the given diagnostics was computed from the last time it was output to
30-
disk."""
31+
"""Container holding a counter that tracks how many times the given
32+
diagnostics was computed from the last time it was output to disk."""
3133
counters::COUNT
3234
end
3335

36+
"""
37+
value_types(
38+
data;
39+
value_map = unionall_type,
40+
)
41+
42+
Given `data`, return a type `Union{V...}` where `V` are the `Union` of all found types in
43+
the values of `data`.
44+
"""
45+
function value_types(data)
46+
ret_types = Set([])
47+
for k in eachindex(data)
48+
push!(ret_types, typeof(data[k]))
49+
end
50+
return Union{ret_types...}
51+
end
52+
3453
"""
3554
DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
3655
@@ -52,16 +71,18 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
5271

5372
# For diagnostics that perform reductions, the storage is used for the values computed
5473
# at each call. Reductions also save the accumulated value in accumulators.
55-
storage = Dict()
56-
accumulators = Dict()
57-
counters = Dict()
74+
storage = []
75+
# Not all diagnostics need an accumulator, so we put them in a dictionary key-ed over the diagnostic index
76+
accumulators = Dict{Int, Any}()
77+
counters = Int[]
78+
scheduled_diagnostics_keys = Int[]
5879

5980
unique_scheduled_diagnostics = unique(scheduled_diagnostics)
6081
if length(unique_scheduled_diagnostics) != length(scheduled_diagnostics)
6182
@warn "Given list of diagnostics contains duplicates, removing them"
6283
end
6384

64-
for diag in unique_scheduled_diagnostics
85+
for (i, diag) in enumerate(unique_scheduled_diagnostics)
6586
if isnothing(dt)
6687
@warn "dt was not passed to DiagnosticsHandler. No checks will be performed on the frequency of the diagnostics"
6788
else
@@ -80,33 +101,37 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
80101
)
81102
end
82103
end
104+
push!(scheduled_diagnostics_keys, i)
83105

84106
variable = diag.variable
85107
isa_time_reduction = !isnothing(diag.reduction_time_func)
86108

87109
# The first time we call compute! we use its return value. All the subsequent times
88110
# (in the callbacks), we will write the result in place
89-
storage[diag] = variable.compute!(nothing, Y, p, t)
90-
counters[diag] = 1
111+
push!(storage, variable.compute!(nothing, Y, p, t))
112+
push!(counters, 1)
91113

92114
# If it is not a reduction, call the output writer as well
93115
if !isa_time_reduction
94-
interpolate_field!(diag.output_writer, storage[diag], diag, Y, p, t)
95-
write_field!(diag.output_writer, storage[diag], diag, Y, p, t)
116+
interpolate_field!(diag.output_writer, storage[i], diag, Y, p, t)
117+
write_field!(diag.output_writer, storage[i], diag, Y, p, t)
96118
else
97119
# Add to the accumulator
98120

99121
# We use similar + .= instead of copy because CUDA 5.2 does not supported nested
100122
# wrappers with view(reshape(view)) objects. See discussion in
101123
# https://github.com/CliMA/ClimaAtmos.jl/pull/2579 and
102124
# https://github.com/JuliaGPU/Adapt.jl/issues/21
103-
accumulators[diag] = similar(storage[diag])
104-
accumulators[diag] .= storage[diag]
125+
accumulators[i] = similar(storage[i])
126+
accumulators[i] .= storage[i]
105127
end
106128
end
129+
storage = value_types(storage)[storage...]
130+
accumulators = Dict{Int, value_types(accumulators)}(accumulators...)
107131

108132
return DiagnosticsHandler(
109133
unique_scheduled_diagnostics,
134+
scheduled_diagnostics_keys,
110135
storage,
111136
accumulators,
112137
counters,
@@ -132,7 +157,7 @@ function orchestrate_diagnostics(
132157
integrator,
133158
diagnostic_handler::DiagnosticsHandler,
134159
)
135-
scheduled_diagnostics = diagnostic_handler.scheduled_diagnostics
160+
(; scheduled_diagnostics, scheduled_diagnostics_keys) = diagnostic_handler
136161
active_compute = Bool[]
137162
active_output = Bool[]
138163
active_sync = Bool[]
@@ -144,30 +169,30 @@ function orchestrate_diagnostics(
144169
end
145170

146171
# Compute
147-
for diag_index in 1:length(scheduled_diagnostics)
172+
for diag_index in scheduled_diagnostics_keys
148173
active_compute[diag_index] || continue
149174
diag = scheduled_diagnostics[diag_index]
150175

151176
diag.variable.compute!(
152-
diagnostic_handler.storage[diag],
177+
diagnostic_handler.storage[diag_index],
153178
integrator.u,
154179
integrator.p,
155180
integrator.t,
156181
)
157-
diagnostic_handler.counters[diag] += 1
182+
diagnostic_handler.counters[diag_index] += 1
158183

159184
isa_time_reduction = !isnothing(diag.reduction_time_func)
160185
if isa_time_reduction
161-
diagnostic_handler.accumulators[diag] .=
186+
diagnostic_handler.accumulators[diag_index] .=
162187
diag.reduction_time_func.(
163-
diagnostic_handler.accumulators[diag],
164-
diagnostic_handler.storage[diag],
188+
diagnostic_handler.accumulators[diag_index],
189+
diagnostic_handler.storage[diag_index],
165190
)
166191
end
167192
end
168193

169194
# Pre-output (averages/interpolation)
170-
for diag_index in 1:length(scheduled_diagnostics)
195+
for diag_index in scheduled_diagnostics_keys
171196
active_output[diag_index] || continue
172197
diag = scheduled_diagnostics[diag_index]
173198

@@ -176,20 +201,20 @@ function orchestrate_diagnostics(
176201
# additional copy. If this copy turns out to be too expensive, we can move the if
177202
# statement below.
178203
isnothing(diag.reduction_time_func) || (
179-
diagnostic_handler.storage[diag] .=
180-
diagnostic_handler.accumulators[diag]
204+
diagnostic_handler.storage[diag_index] .=
205+
diagnostic_handler.accumulators[diag_index]
181206
)
182207

183208
# Any operations we have to perform before writing to output? Here is where we would
184209
# divide by N to obtain an arithmetic average
185210
diag.pre_output_hook!(
186-
diagnostic_handler.storage[diag],
187-
diagnostic_handler.counters[diag],
211+
diagnostic_handler.storage[diag_index],
212+
diagnostic_handler.counters[diag_index],
188213
)
189214

190215
interpolate_field!(
191216
diag.output_writer,
192-
diagnostic_handler.storage[diag],
217+
diagnostic_handler.storage[diag_index],
193218
diag,
194219
integrator.u,
195220
integrator.p,
@@ -198,13 +223,13 @@ function orchestrate_diagnostics(
198223
end
199224

200225
# Save to disk
201-
for diag_index in 1:length(scheduled_diagnostics)
226+
for diag_index in scheduled_diagnostics_keys
202227
active_output[diag_index] || continue
203228
diag = scheduled_diagnostics[diag_index]
204229

205230
write_field!(
206231
diag.output_writer,
207-
diagnostic_handler.storage[diag],
232+
diagnostic_handler.storage[diag_index],
208233
diag,
209234
integrator.u,
210235
integrator.p,
@@ -213,7 +238,7 @@ function orchestrate_diagnostics(
213238
end
214239

215240
# Post-output clean-up
216-
for diag_index in 1:length(scheduled_diagnostics)
241+
for diag_index in scheduled_diagnostics_keys
217242
diag = scheduled_diagnostics[diag_index]
218243

219244
# First, maybe call sync for the writer. This might happen regardless of
@@ -229,10 +254,10 @@ function orchestrate_diagnostics(
229254
# identity_of_reduction works by dispatching over operation.
230255
# The function is defined in reduction_identities.jl
231256
identity = identity_of_reduction(diag.reduction_time_func)
232-
fill!(parent(diagnostic_handler.accumulators[diag]), identity)
257+
fill!(parent(diagnostic_handler.accumulators[diag_index]), identity)
233258
end
234259
# Reset counter
235-
diagnostic_handler.counters[diag] = 0
260+
diagnostic_handler.counters[diag_index] = 0
236261
end
237262

238263
return nothing

0 commit comments

Comments
 (0)