Skip to content

Commit 5fc53d6

Browse files
Add callback that estimates walltime
Increase allocation limits Bump allocation limits Finish comment in flame graph Fix doc string Improve names, add docs Improve names Improve docs and names Improve docs, qualify Period Improve names Add comment for eval function Maintain log after 50% Add warning for restarted simulations
1 parent 06a3da7 commit 5fc53d6

File tree

6 files changed

+175
-7
lines changed

6 files changed

+175
-7
lines changed

perf/flame.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ simulation = CA.get_simulation(config)
1717

1818
import SciMLBase
1919
SciMLBase.step!(integrator) # compile first
20+
SciMLBase.step!(integrator) # compile print_walltime_estimate, which skips the first step to avoid timing compilation
2021
CA.call_all_callbacks!(integrator) # compile callbacks
2122
import Profile, ProfileCanvas
2223
output_dir = job_id
@@ -36,17 +37,17 @@ ProfileCanvas.html_file(joinpath(output_dir, "flame.html"), results)
3637
#####
3738

3839
allocs_limit = Dict()
39-
allocs_limit["flame_perf_target"] = 147_520
40-
allocs_limit["flame_perf_target_tracers"] = 179_776
40+
allocs_limit["flame_perf_target"] = 148_256
41+
allocs_limit["flame_perf_target_tracers"] = 180_512
4142
allocs_limit["flame_perf_target_edmfx"] = 7_005_552
4243
allocs_limit["flame_perf_diagnostics"] = 25_356_928
43-
allocs_limit["flame_perf_target_diagnostic_edmfx"] = 1_309_968
44+
allocs_limit["flame_perf_target_diagnostic_edmfx"] = 1_311_040
4445
allocs_limit["flame_sphere_baroclinic_wave_rhoe_equilmoist_expvdiff"] =
4546
4_018_252_656
4647
allocs_limit["flame_perf_target_threaded"] = 1_276_864
4748
allocs_limit["flame_perf_target_callbacks"] = 37_277_112
48-
allocs_limit["flame_perf_gw"] = 3_226_428_736
49-
allocs_limit["flame_perf_target_prognostic_edmfx_aquaplanet"] = 1_257_712
49+
allocs_limit["flame_perf_gw"] = 3_226_429_472
50+
allocs_limit["flame_perf_target_prognostic_edmfx_aquaplanet"] = 1_258_848
5051

5152
# Ideally, we would like to track all the allocations, but this becomes too
5253
# expensive there is too many of them. Here, we set the default sample rate to

src/cache/cache.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
struct AtmosCache{
22
FT <: AbstractFloat,
3+
FTE,
4+
WTE,
35
SD,
46
AM,
57
NUM,
@@ -28,6 +30,12 @@ struct AtmosCache{
2830
"""Timestep of the simulation (in seconds). This is also used by callbacks and tendencies"""
2931
dt::FT
3032

33+
"""End time of the simulation (in seconds). This used by callbacks"""
34+
t_end::FTE
35+
36+
"""Walltime estimate"""
37+
walltime_estimate::WTE
38+
3139
"""Start date (used for insolation)."""
3240
start_date::SD
3341

@@ -93,7 +101,7 @@ end
93101

94102
# The model also depends on f_plane_coriolis_frequency(params)
95103
# This is a constant Coriolis frequency that is only used if space is flat
96-
function build_cache(Y, atmos, params, surface_setup, dt, start_date)
104+
function build_cache(Y, atmos, params, surface_setup, dt, t_end, start_date)
97105
FT = eltype(params)
98106

99107
ᶜcoord = Fields.local_geometry_field(Y.c).coordinates
@@ -184,6 +192,8 @@ function build_cache(Y, atmos, params, surface_setup, dt, start_date)
184192

185193
args = (
186194
dt,
195+
t_end,
196+
WallTimeEstimate(),
187197
start_date,
188198
atmos,
189199
numerics,

src/callbacks/callbacks.jl

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,94 @@ function save_restart_func(integrator, output_dir)
434434
return nothing
435435
end
436436

437+
Base.@kwdef mutable struct WallTimeEstimate
438+
"""Number of calls to the callback"""
439+
n_calls::Int = 0
440+
"""Int indicating next time the callback will print to the log"""
441+
n_next::Int = 1
442+
"""Wall time of previous call to update `WallTimeEstimate`"""
443+
t_wall_last::Float64 = -1
444+
"""Sum of elapsed walltime over calls to `step!`"""
445+
∑Δt_wall::Float64 = 0
446+
"""Fixed increment to increase n_next by after 5% completion"""
447+
n_fixed_increment::Float64 = -1
448+
end
449+
import Dates
450+
function print_walltime_estimate(integrator)
451+
(; walltime_estimate, dt, t_end) = integrator.p
452+
wte = walltime_estimate
453+
454+
# Notes on `ready_to_report`
455+
# - The very first call (when `n_calls == 0`), there's no elapsed
456+
# times to report (and this is called during initialization,
457+
# before `step!` has been called).
458+
# - The second call (`n_calls == 1`) is after `step!` is called
459+
# for the first time, but we don't want to report this since it
460+
# includes compilation time.
461+
# - Calls after that (`n_calls > 1`) exclude compilation and provide
462+
# the best wall time estimates
463+
464+
ready_to_report = wte.n_calls > 1
465+
if ready_to_report
466+
# We need to account for skipping cost of `Δt_wall` when `n_calls == 1`:
467+
factor = wte.n_calls == 2 ? 2 : 1
468+
Δt_wall = factor * (time() - wte.t_wall_last)
469+
else
470+
wte.n_calls == 1 && @info "Progress: Completed first step"
471+
Δt_wall = Float64(0)
472+
wte.n_next = wte.n_calls + 1
473+
end
474+
wte.∑Δt_wall += Δt_wall
475+
wte.t_wall_last = time()
476+
477+
if wte.n_calls == wte.n_next && ready_to_report
478+
t = integrator.t
479+
n_steps_total = ceil(Int, t_end / dt)
480+
n_steps = ceil(Int, t / dt)
481+
wall_time_ave_per_step = wte.∑Δt_wall / n_steps
482+
wall_time_ave_per_step_str = time_and_units_str(wall_time_ave_per_step)
483+
percent_complete = round(t / t_end * 100; digits = 1)
484+
n_steps_remaining = n_steps_total - n_steps
485+
wall_time_remaining = wall_time_ave_per_step * n_steps_remaining
486+
wall_time_remaining_str = time_and_units_str(wall_time_remaining)
487+
wall_time_total =
488+
time_and_units_str(wall_time_ave_per_step * n_steps_total)
489+
wall_time_spent = time_and_units_str(wte.∑Δt_wall)
490+
simulation_time = time_and_units_str(Float64(t))
491+
sypd = round(
492+
simulated_years_per_day(
493+
EfficiencyStats((zero(t), t), wte.∑Δt_wall),
494+
);
495+
digits = 3,
496+
)
497+
estimated_finish_date =
498+
Dates.now() + compound_period(wall_time_remaining, Dates.Second)
499+
@info "Progress" simulation_time = simulation_time n_steps_completed =
500+
n_steps wall_time_per_step = wall_time_ave_per_step_str wall_time_total =
501+
wall_time_total wall_time_remaining = wall_time_remaining_str wall_time_spent =
502+
wall_time_spent percent_complete = "$percent_complete%" sypd = sypd date_now =
503+
Dates.now() estimated_finish_date = estimated_finish_date
504+
505+
# the first fixed increment is equivalent to
506+
# doubling (which puts us at 10%), so we check
507+
# if we're below 5%.
508+
if percent_complete < 5
509+
# doubling factor (to reduce log noise)
510+
wte.n_next *= 2
511+
else
512+
if wte.n_fixed_increment == -1
513+
wte.n_fixed_increment = wte.n_next
514+
end
515+
# increase by fixed increment after 10%
516+
# completion to maintain logs after 50%.
517+
wte.n_next += wte.n_fixed_increment
518+
end
519+
end
520+
wte.n_calls += 1
521+
522+
return nothing
523+
end
524+
437525
function gc_func(integrator)
438526
num_pre = Base.gc_num()
439527
alloc_since_last = (num_pre.allocd + num_pre.deferred_alloc) / 2^20

src/solver/type_getters.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,12 @@ function get_callbacks(parsed_args, sim_info, atmos, params, comms_ctx)
455455
FT = eltype(params)
456456
(; dt, output_dir) = sim_info
457457

458-
callbacks = ()
458+
callbacks = (
459+
call_every_n_steps(
460+
(integrator) -> print_walltime_estimate(integrator);
461+
skip_first = true,
462+
),
463+
)
459464
dt_save_to_disk = time_to_seconds(parsed_args["dt_save_to_disk"])
460465
if !(dt_save_to_disk == Inf)
461466
callbacks = (
@@ -759,6 +764,7 @@ function get_simulation(config::AtmosConfig)
759764
if sim_info.restart
760765
(Y, t_start) = get_state_restart(config.comms_ctx)
761766
spaces = get_spaces_restart(Y)
767+
@warn "Progress estimates do not support restarted simulations"
762768
else
763769
spaces = get_spaces(config.parsed_args, params, config.comms_ctx)
764770
Y = ICs.atmos_state(
@@ -779,6 +785,7 @@ function get_simulation(config::AtmosConfig)
779785
params,
780786
surface_setup,
781787
sim_info.dt,
788+
sim_info.t_end,
782789
sim_info.start_date,
783790
)
784791
end

src/utils/utilities.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,66 @@ function prettytime(t)
210210
return @sprintf("%.3f %s", value, units)
211211
end
212212

213+
import Dates
214+
215+
time_per_time(::Type{P}, ::Type{P}) where {P} = 1
216+
#=
217+
define_time_per_times(periods)
218+
219+
Evals `time_per_time(::Type{<:Dates.Period},::Type{<:Dates.Period})`
220+
for `Nanosecond, Microsecond, Millisecond, Second, Minute, Hour, Day, Week`
221+
in a triangular fashion-- `time_per_time` provides the conversion factor
222+
(e.g., `Nanosecond`s per `Second`) and all larger periods (but not smaller ones).
223+
=#
224+
function define_time_per_times(periods)
225+
for i in eachindex(periods)
226+
T, n = periods[i]
227+
N = Int64(1)
228+
for j in (i - 1):-1:firstindex(periods) # less-precise periods
229+
Tc, nc = periods[j]
230+
N *= nc
231+
@eval time_per_time(::Type{Dates.$T}, ::Type{Dates.$Tc}) = $N
232+
end
233+
end
234+
end
235+
236+
# From Dates
237+
define_time_per_times([
238+
(:Week, 7),
239+
(:Day, 24),
240+
(:Hour, 60),
241+
(:Minute, 60),
242+
(:Second, 1000),
243+
(:Millisecond, 1000),
244+
(:Microsecond, 1000),
245+
(:Nanosecond, 1),
246+
])
247+
248+
"""
249+
time_and_units_str(x::Real)
250+
251+
Returns a truncated string of time and units,
252+
given a time `x` in Seconds.
253+
"""
254+
time_and_units_str(x::Real) =
255+
trunc_time(string(compound_period(x, Dates.Second)))
256+
257+
"""
258+
compound_period(x::Real, ::Type{T}) where {T <: Dates.Period}
259+
260+
A canonicalized `Dates.CompoundPeriod` given a real value
261+
`x`, and its units via the period type `T`.
262+
"""
263+
function compound_period(x::Real, ::Type{T}) where {T <: Dates.Period}
264+
nf = time_per_time(Dates.Nanosecond, T)
265+
return Dates.canonicalize(
266+
Dates.CompoundPeriod(Dates.Nanosecond(ceil(x * nf))),
267+
)
268+
end
269+
270+
trunc_time(s::String) = count(',', s) > 1 ? join(split(s, ",")[1:2], ",") : s
271+
272+
213273
function prettymemory(b)
214274
if b < 1024
215275
return string(b, " bytes")

test/coupler_compatibility.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ const T2 = 290
5858
@. sfc_setup = (surface_state,)
5959
p_overwritten = CA.AtmosCache(
6060
p.dt,
61+
simulation.t_end,
62+
CA.WallTimeEstimate(),
6163
p.start_date,
6264
p.atmos,
6365
p.numerics,

0 commit comments

Comments
 (0)