Skip to content

Commit 2da4504

Browse files
committed
add NaN check callbacks to longruns
1 parent 4b4b25f commit 2da4504

File tree

6 files changed

+86
-6
lines changed

6 files changed

+86
-6
lines changed

experiments/long_runs/bucket.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,11 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 7))
144144
diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler)
145145

146146
driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc)
147-
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb)
147+
148+
nancheck_freq = Dates.Month(1)
149+
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
150+
151+
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb)
148152
end
149153

150154
function setup_and_solve_problem(; greet = false)

experiments/long_runs/land.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,10 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 15))
374374
diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler)
375375

376376
driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc)
377-
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb)
377+
378+
nancheck_freq = Dates.Month(1)
379+
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
380+
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb)
378381
end
379382

380383
function setup_and_solve_problem(; greet = false)

experiments/long_runs/land_region.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,10 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (10, 10, 15))
389389
diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler)
390390

391391
driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc)
392-
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb)
392+
393+
nancheck_freq = Dates.Month(1)
394+
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
395+
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb)
393396
end
394397

395398
function setup_and_solve_problem(; greet = false)

experiments/long_runs/snowy_land.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,11 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 15))
396396
end
397397
report_cb = SciMLBase.DiscreteCallback(every1000steps, report)
398398

399-
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, report_cb)
399+
nancheck_freq = Dates.Month(1)
400+
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
401+
402+
return prob,
403+
SciMLBase.CallbackSet(driver_cb, diag_cb, report_cb, nancheck_cb)
400404
end
401405

402406
function setup_and_solve_problem(; greet = false)

experiments/long_runs/soil.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,10 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 15))
215215
diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler)
216216

217217
driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc)
218-
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb)
218+
219+
nancheck_freq = Dates.Month(1)
220+
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
221+
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb)
219222
end
220223

221224
function setup_and_solve_problem(; greet = false)

src/shared_utilities/utils.jl

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ function CheckpointCallback(
320320

321321
if !isnothing(dt)
322322
dt_period = Dates.Millisecond(1000dt)
323-
if !isdivisible(checkpoint_frequency_period / dt_period)
323+
if !isdivisible(checkpoint_frequency_period, dt_period)
324324
@warn "Checkpoint frequency ($(checkpoint_frequency_period)) is not an integer multiple of dt $(dt_period)"
325325
end
326326
end
@@ -529,3 +529,66 @@ function count_nans_state(state::ClimaCore.Fields.Field, mask = nothing)
529529
end
530530
return nothing
531531
end
532+
533+
"""
534+
NaNCheckCallback(nancheck_frequency::Union{AbstractFloat, Dates.Period},
535+
output_dir, start_date, t_start; model, dt)
536+
537+
Constructs a DiscreteCallback which counts the number of NaNs in the state
538+
and produces a warning if any are found.
539+
540+
# Arguments
541+
- `nancheck_frequency`: The frequency at which the state is checked for NaNs.
542+
Can be specified as a float (in seconds) or a `Dates.Period`.
543+
- `start_date`: The start date of the simulation.
544+
- `t_start`: The starting time of the simulation (in seconds).
545+
- `dt`: The timestep of the model (optional), used to check for consistency.
546+
547+
The callback uses `ClimaDiagnostics.EveryCalendarDtSchedule` to determine when
548+
to save checkpoints based on the `nancheck_frequency`. The schedule is
549+
initialized with the `start_date` and `t_start` to ensure that the first
550+
checkpoint is saved at the correct time.
551+
552+
The `save_checkpoint` function is called with the current state vector `u`, the
553+
current time `t`, and the `output_dir` to save the checkpoint to disk.
554+
"""
555+
function NaNCheckCallback(
556+
nancheck_frequency::Union{AbstractFloat, Dates.Period},
557+
start_date,
558+
t_start,
559+
dt,
560+
)
561+
# TODO: Move to a more general callback system. For the time being, we use
562+
# the ClimaDiagnostics one because it is flexible and it supports calendar
563+
# dates.
564+
565+
if nancheck_frequency isa AbstractFloat
566+
# Assume it is in seconds, but go through Millisecond to support
567+
# fractional seconds
568+
nancheck_frequency_period = Dates.Millisecond(1000nancheck_frequency)
569+
else
570+
nancheck_frequency_period = nancheck_frequency
571+
end
572+
573+
schedule = EveryCalendarDtSchedule(
574+
nancheck_frequency_period;
575+
start_date,
576+
date_last = start_date + Dates.Millisecond(1000t_start),
577+
)
578+
579+
if !isnothing(dt)
580+
dt_period = Dates.Millisecond(1000dt)
581+
if !isdivisible(nancheck_frequency_period, dt_period)
582+
@warn "Callback frequency ($(nancheck_frequency_period)) is not an integer multiple of dt $(dt_period)"
583+
end
584+
end
585+
586+
cond = let schedule = schedule
587+
(u, t, integrator) -> schedule(integrator)
588+
end
589+
affect! = let output_dir = output_dir, model = model
590+
(integrator) -> count_nans_state(integrator.u)
591+
end
592+
593+
SciMLBase.DiscreteCallback(cond, affect!)
594+
end

0 commit comments

Comments
 (0)