@@ -13,24 +13,43 @@ include("reduction_identities.jl")
13
13
A struct that contains the scheduled diagnostics, ancillary data and areas of memory needed
14
14
to store and accumulate results.
15
15
"""
16
- struct DiagnosticsHandler{SD, STORAGE <: Dict , ACC <: Dict , COUNT <: Dict }
16
+ struct DiagnosticsHandler{SD, V <: Vector{Int} , STORAGE, ACC <: Dict , COUNT}
17
17
""" An iterable with the `ScheduledDiagnostic`s that are scheduled."""
18
18
scheduled_diagnostics:: SD
19
19
20
- """ Dictionary that maps a given `ScheduledDiagnostic` to a potentially pre-allocated
20
+ """ A Vector containing keys to index into `scheduled_diagnostics`."""
21
+ scheduled_diagnostics_keys:: V
22
+
23
+ """ Container holding a potentially pre-allocated
21
24
area of memory where to save the newly computed results."""
22
25
storage:: STORAGE
23
26
24
- """ Dictionary that maps a given `ScheduledDiagnostic` to a potentially pre-allocated
27
+ """ Container holding a potentially pre-allocated
25
28
area of memory where to accumulate results."""
26
29
accumulators:: ACC
27
30
28
- """ Dictionary that maps a given `ScheduledDiagnostic` to the counter that tracks how
29
- many times the given diagnostics was computed from the last time it was output to
30
- disk."""
31
+ """ Container holding a counter that tracks how many times the given
32
+ diagnostics was computed from the last time it was output to disk."""
31
33
counters:: COUNT
32
34
end
33
35
36
+ """
37
+ value_types(
38
+ data;
39
+ value_map = unionall_type,
40
+ )
41
+
42
+ Given `data`, return a type `Union{V...}` where `V` are the `Union` of all found types in
43
+ the values of `data`.
44
+ """
45
+ function value_types (data)
46
+ ret_types = Set ([])
47
+ for k in eachindex (data)
48
+ push! (ret_types, typeof (data[k]))
49
+ end
50
+ return Union{ret_types... }
51
+ end
52
+
34
53
"""
35
54
DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
36
55
@@ -52,16 +71,18 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
52
71
53
72
# For diagnostics that perform reductions, the storage is used for the values computed
54
73
# at each call. Reductions also save the accumulated value in accumulators.
55
- storage = Dict ()
56
- accumulators = Dict ()
57
- counters = Dict ()
74
+ storage = []
75
+ # Not all diagnostics need an accumulator, so we put them in a dictionary key-ed over the diagnostic index
76
+ accumulators = Dict {Int, Any} ()
77
+ counters = Int[]
78
+ scheduled_diagnostics_keys = Int[]
58
79
59
80
unique_scheduled_diagnostics = unique (scheduled_diagnostics)
60
81
if length (unique_scheduled_diagnostics) != length (scheduled_diagnostics)
61
82
@warn " Given list of diagnostics contains duplicates, removing them"
62
83
end
63
84
64
- for diag in unique_scheduled_diagnostics
85
+ for (i, diag) in enumerate ( unique_scheduled_diagnostics)
65
86
if isnothing (dt)
66
87
@warn " dt was not passed to DiagnosticsHandler. No checks will be performed on the frequency of the diagnostics"
67
88
else
@@ -80,33 +101,37 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
80
101
)
81
102
end
82
103
end
104
+ push! (scheduled_diagnostics_keys, i)
83
105
84
106
variable = diag. variable
85
107
isa_time_reduction = ! isnothing (diag. reduction_time_func)
86
108
87
109
# The first time we call compute! we use its return value. All the subsequent times
88
110
# (in the callbacks), we will write the result in place
89
- storage[diag] = variable. compute! (nothing , Y, p, t)
90
- counters[diag] = 1
111
+ push! ( storage, variable. compute! (nothing , Y, p, t) )
112
+ push! ( counters, 1 )
91
113
92
114
# If it is not a reduction, call the output writer as well
93
115
if ! isa_time_reduction
94
- interpolate_field! (diag. output_writer, storage[diag ], diag, Y, p, t)
95
- write_field! (diag. output_writer, storage[diag ], diag, Y, p, t)
116
+ interpolate_field! (diag. output_writer, storage[i ], diag, Y, p, t)
117
+ write_field! (diag. output_writer, storage[i ], diag, Y, p, t)
96
118
else
97
119
# Add to the accumulator
98
120
99
121
# We use similar + .= instead of copy because CUDA 5.2 does not supported nested
100
122
# wrappers with view(reshape(view)) objects. See discussion in
101
123
# https://github.com/CliMA/ClimaAtmos.jl/pull/2579 and
102
124
# https://github.com/JuliaGPU/Adapt.jl/issues/21
103
- accumulators[diag ] = similar (storage[diag ])
104
- accumulators[diag ] .= storage[diag ]
125
+ accumulators[i ] = similar (storage[i ])
126
+ accumulators[i ] .= storage[i ]
105
127
end
106
128
end
129
+ storage = value_types (storage)[storage... ]
130
+ accumulators = Dict {Int, value_types(accumulators)} (accumulators... )
107
131
108
132
return DiagnosticsHandler (
109
133
unique_scheduled_diagnostics,
134
+ scheduled_diagnostics_keys,
110
135
storage,
111
136
accumulators,
112
137
counters,
@@ -132,7 +157,7 @@ function orchestrate_diagnostics(
132
157
integrator,
133
158
diagnostic_handler:: DiagnosticsHandler ,
134
159
)
135
- scheduled_diagnostics = diagnostic_handler. scheduled_diagnostics
160
+ (; scheduled_diagnostics, scheduled_diagnostics_keys) = diagnostic_handler
136
161
active_compute = Bool[]
137
162
active_output = Bool[]
138
163
active_sync = Bool[]
@@ -144,30 +169,30 @@ function orchestrate_diagnostics(
144
169
end
145
170
146
171
# Compute
147
- for diag_index in 1 : length (scheduled_diagnostics)
172
+ for diag_index in scheduled_diagnostics_keys
148
173
active_compute[diag_index] || continue
149
174
diag = scheduled_diagnostics[diag_index]
150
175
151
176
diag. variable. compute! (
152
- diagnostic_handler. storage[diag ],
177
+ diagnostic_handler. storage[diag_index ],
153
178
integrator. u,
154
179
integrator. p,
155
180
integrator. t,
156
181
)
157
- diagnostic_handler. counters[diag ] += 1
182
+ diagnostic_handler. counters[diag_index ] += 1
158
183
159
184
isa_time_reduction = ! isnothing (diag. reduction_time_func)
160
185
if isa_time_reduction
161
- diagnostic_handler. accumulators[diag ] .=
186
+ diagnostic_handler. accumulators[diag_index ] .=
162
187
diag. reduction_time_func .(
163
- diagnostic_handler. accumulators[diag ],
164
- diagnostic_handler. storage[diag ],
188
+ diagnostic_handler. accumulators[diag_index ],
189
+ diagnostic_handler. storage[diag_index ],
165
190
)
166
191
end
167
192
end
168
193
169
194
# Pre-output (averages/interpolation)
170
- for diag_index in 1 : length (scheduled_diagnostics)
195
+ for diag_index in scheduled_diagnostics_keys
171
196
active_output[diag_index] || continue
172
197
diag = scheduled_diagnostics[diag_index]
173
198
@@ -176,20 +201,20 @@ function orchestrate_diagnostics(
176
201
# additional copy. If this copy turns out to be too expensive, we can move the if
177
202
# statement below.
178
203
isnothing (diag. reduction_time_func) || (
179
- diagnostic_handler. storage[diag ] .=
180
- diagnostic_handler. accumulators[diag ]
204
+ diagnostic_handler. storage[diag_index ] .=
205
+ diagnostic_handler. accumulators[diag_index ]
181
206
)
182
207
183
208
# Any operations we have to perform before writing to output? Here is where we would
184
209
# divide by N to obtain an arithmetic average
185
210
diag. pre_output_hook! (
186
- diagnostic_handler. storage[diag ],
187
- diagnostic_handler. counters[diag ],
211
+ diagnostic_handler. storage[diag_index ],
212
+ diagnostic_handler. counters[diag_index ],
188
213
)
189
214
190
215
interpolate_field! (
191
216
diag. output_writer,
192
- diagnostic_handler. storage[diag ],
217
+ diagnostic_handler. storage[diag_index ],
193
218
diag,
194
219
integrator. u,
195
220
integrator. p,
@@ -198,13 +223,13 @@ function orchestrate_diagnostics(
198
223
end
199
224
200
225
# Save to disk
201
- for diag_index in 1 : length (scheduled_diagnostics)
226
+ for diag_index in scheduled_diagnostics_keys
202
227
active_output[diag_index] || continue
203
228
diag = scheduled_diagnostics[diag_index]
204
229
205
230
write_field! (
206
231
diag. output_writer,
207
- diagnostic_handler. storage[diag ],
232
+ diagnostic_handler. storage[diag_index ],
208
233
diag,
209
234
integrator. u,
210
235
integrator. p,
@@ -213,7 +238,7 @@ function orchestrate_diagnostics(
213
238
end
214
239
215
240
# Post-output clean-up
216
- for diag_index in 1 : length (scheduled_diagnostics)
241
+ for diag_index in scheduled_diagnostics_keys
217
242
diag = scheduled_diagnostics[diag_index]
218
243
219
244
# First, maybe call sync for the writer. This might happen regardless of
@@ -229,10 +254,10 @@ function orchestrate_diagnostics(
229
254
# identity_of_reduction works by dispatching over operation.
230
255
# The function is defined in reduction_identities.jl
231
256
identity = identity_of_reduction (diag. reduction_time_func)
232
- fill! (parent (diagnostic_handler. accumulators[diag ]), identity)
257
+ fill! (parent (diagnostic_handler. accumulators[diag_index ]), identity)
233
258
end
234
259
# Reset counter
235
- diagnostic_handler. counters[diag ] = 0
260
+ diagnostic_handler. counters[diag_index ] = 0
236
261
end
237
262
238
263
return nothing
0 commit comments