Skip to content

Add simple simulation progress callback #1999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from
Closed
3 changes: 3 additions & 0 deletions config/default_configs/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ dt_save_restart:
dt_save_to_sol:
help: "Time between saving solution. Examples: [`10days`, `1hours`, `Inf` (do not save)]"
value: "1days"
dt_show_progress:
help: "Simulation time between displaying progress update. Examples: [`600secs`, `1mins`, `Inf` (do not display)]"
value: "600secs"
moist:
help: "Moisture model [`dry` (default), `equil`, `non_equil`]"
value: "dry"
Expand Down
22 changes: 16 additions & 6 deletions src/callbacks/callback_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@ function call_every_n_steps(f!, n = 1; skip_first = false, call_at_end = false)
)
end

function call_every_dt(f!, dt; skip_first = false, call_at_end = false)
function call_every_dt(
f!,
dt;
skip_first = false,
call_at_end = false,
initialize = nothing,
)
cb! = AtmosCallback(f!, EveryΔt(dt))
@assert dt ≠ Inf "Adding callback that never gets called!"
next_t = Ref{typeof(dt)}()
Expand All @@ -34,11 +40,15 @@ function call_every_dt(f!, dt; skip_first = false, call_at_end = false)
return ODE.DiscreteCallback(
(u, t, integrator) -> t >= next_t[],
affect!;
initialize = (cb, u, t, integrator) -> begin
skip_first || cb!(integrator)
t_end = integrator.sol.prob.tspan[2]
next_t[] =
(call_at_end && t < t_end) ? min(t_end, t + dt) : t + dt
initialize = if isnothing(initialize)
(cb, u, t, integrator) -> begin
skip_first || cb!(integrator)
t_end = integrator.sol.prob.tspan[2]
next_t[] =
(call_at_end && t < t_end) ? min(t_end, t + dt) : t + dt
end
else
initialize
end,
save_positions = (false, false),
)
Expand Down
48 changes: 48 additions & 0 deletions src/callbacks/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,54 @@ import ClimaCore.Fields: ColumnField

include("callback_helpers.jl")

"""
display_status_callback!()

Returns a callback to display simulation information.
Adapted from ClimaTimeSteppers.jl #89.
"""
function display_status_callback!(::Type{tType}) where {tType}
# start_time = Ref{Float64}()
# prev_time = Ref{Float64}()
# time = Ref{Float64}()

# prev_t = Ref{tType}()
# eta = Ref{tType}()
# speed = Ref{tType}()
# is_not_first_step = Ref{Bool}()
# eta_string = Ref{String}()
# output_string = Ref{String}()

function affect!(integrator)
# t_end = maximum(integrator.tstops.valtree)
# nsteps = ceil(Int64, t_end / integrator.dt)
# t = integrator.t
# step = ceil(Int64, t / integrator.dt)
# time[] = time_ns() / 1e9
# speed[] = (time[] - prev_time[]) / (t - prev_t[])
# eta[] = speed[] * (t_end - t)
# # eta_string[] = eta[] == Inf ? "..." : string(round(eta[])) * " seconds"

# if !is_not_first_step[]
# is_not_first_step[] = true
# start_time[] = time[]
# end
# output_string[] = "dssda"
# println(output_string[])
println("")
# $(Dates.format(Dates.now(), "HH:MM:SS:ss u d")) \n")
# println("Timestep: $(step) / $(nsteps); Simulation Time: $(t) seconds \n ")
# Walltime: $(round(time[] - start_time[], digits=2)) seconds; \
# Time/Step: $(round(speed[] * integrator.dt, digits=2)) seconds \n")
# Time Remaining: $eta_string"

# prev_t[] = t
# prev_time[] = time[]
return nothing
end
return affect!
end

function dss_callback!(integrator)
Y = integrator.u
ghost_buffer = integrator.p.ghost_buffer
Expand Down
12 changes: 12 additions & 0 deletions src/solver/type_getters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,18 @@ function get_callbacks(parsed_args, simulation, atmos, params)
(callbacks..., call_every_dt(save_restart_func, dt_save_restart))
end

dt_show_progress = time_to_seconds(parsed_args["dt_show_progress"])
if !(dt_show_progress == Inf)
callbacks = (
callbacks...,
call_every_dt(
display_status_callback!(typeof(dt)),
dt_show_progress;
skip_first = true,
),
)
end

if is_distributed(simulation.comms_ctx)
callbacks = (
callbacks...,
Expand Down