Skip to content

Commit 4411c53

Browse files
committed
add NaN check for state
Adds a function that quantitatively checks how many NaNs are present in the state, and displays this information to the user. This can be used to inspect sol.u[end] after a simulation has run, to see if any NaNs were produced at the end of the simulation. If no NaNs are found, this information is logged in an info statement when run in verbose mode. If NaNs are found, this information is logged in a warn statement.
1 parent c44b057 commit 4411c53

File tree

7 files changed

+235
-8
lines changed

7 files changed

+235
-8
lines changed

experiments/long_runs/bucket.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,11 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 7))
144144
diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler)
145145

146146
driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc)
147-
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb)
147+
148+
nancheck_freq = Dates.Month(1)
149+
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
150+
151+
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb)
148152
end
149153

150154
function setup_and_solve_problem(; greet = false)

experiments/long_runs/land.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,10 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 15))
374374
diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler)
375375

376376
driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc)
377-
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb)
377+
378+
nancheck_freq = Dates.Month(1)
379+
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
380+
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb)
378381
end
379382

380383
function setup_and_solve_problem(; greet = false)

experiments/long_runs/land_region.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,10 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (10, 10, 15))
389389
diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler)
390390

391391
driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc)
392-
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb)
392+
393+
nancheck_freq = Dates.Month(1)
394+
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
395+
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb)
393396
end
394397

395398
function setup_and_solve_problem(; greet = false)

experiments/long_runs/snowy_land.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,11 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 15))
396396
end
397397
report_cb = SciMLBase.DiscreteCallback(every1000steps, report)
398398

399-
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, report_cb)
399+
nancheck_freq = Dates.Month(1)
400+
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
401+
402+
return prob,
403+
SciMLBase.CallbackSet(driver_cb, diag_cb, report_cb, nancheck_cb)
400404
end
401405

402406
function setup_and_solve_problem(; greet = false)

experiments/long_runs/soil.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,10 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 15))
215215
diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler)
216216

217217
driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc)
218-
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb)
218+
219+
nancheck_freq = Dates.Month(1)
220+
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
221+
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb)
219222
end
220223

221224
function setup_and_solve_problem(; greet = false)

src/shared_utilities/utils.jl

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import SciMLBase
33
import ClimaDiagnostics.Schedules: EveryCalendarDtSchedule
44
import Dates
55

6-
export FTfromY
6+
export FTfromY, count_nans_state
77

88
"""
99
heaviside(x::FT)::FT where {FT}
@@ -320,7 +320,7 @@ function CheckpointCallback(
320320

321321
if !isnothing(dt)
322322
dt_period = Dates.Millisecond(1000dt)
323-
if !isdivisible(checkpoint_frequency_period / dt_period)
323+
if !isdivisible(checkpoint_frequency_period, dt_period)
324324
@warn "Checkpoint frequency ($(checkpoint_frequency_period)) is not an integer multiple of dt $(dt_period)"
325325
end
326326
end
@@ -492,3 +492,109 @@ function isdivisible(
492492
# have any common divisor)
493493
return isinteger(Dates.Day(1) / dt_small)
494494
end
495+
496+
"""
497+
count_nans_state(state, mask::ClimaCore.Fields.Field = nothing, verbose = false)
498+
499+
Count the number of NaNs in the state, e.g. the FieldVector given by
500+
`sol.u[end]` after calling `solve`. This function is useful for
501+
debugging simulations to determine quantitatively if a simulation is stable.
502+
503+
If this function is called on a FieldVector, it will recursively call itself
504+
on each Field in the FieldVector. If it is called on a Field, it will count
505+
the number of NaNs in the Field and produce a warning if any are found.
506+
507+
If a ClimaCore Field is provided as `mask`, the function will only count NaNs
508+
in the state variables where the mask is 1. This is intended to be used with
509+
the land/sea mask, to avoid counting NaNs over the ocean. Note this assumes
510+
the mask is 1 over land and 0 over ocean.
511+
512+
The `verbose` argument toggles whether the function produces output when no
513+
NaNs are found.
514+
"""
515+
function count_nans_state(
516+
state::ClimaCore.Fields.FieldVector;
517+
mask = nothing,
518+
verbose = false,
519+
)
520+
for pn in propertynames(state)
521+
state_new = getproperty(state, pn)
522+
@info "Checking NaNs in $pn"
523+
count_nans_state(state_new; mask, verbose)
524+
end
525+
return nothing
526+
end
527+
528+
function count_nans_state(
529+
state::ClimaCore.Fields.Field;
530+
mask = nothing,
531+
verbose = false,
532+
)
533+
num_nans =
534+
isnothing(mask) ? count(==(1), isnan.(parent(state))) :
535+
count(==(1), isnan.(parent(state)) .* parent(mask))
536+
if num_nans > 0
537+
@warn "$num_nans NaNs found"
538+
else
539+
verbose && @info "No NaNs found"
540+
end
541+
return nothing
542+
end
543+
544+
"""
545+
NaNCheckCallback(nancheck_frequency::Union{AbstractFloat, Dates.Period},
546+
start_date, t_start, dt)
547+
548+
Constructs a DiscreteCallback which counts the number of NaNs in the state
549+
and produces a warning if any are found.
550+
551+
# Arguments
552+
- `nancheck_frequency`: The frequency at which the state is checked for NaNs.
553+
Can be specified as a float (in seconds) or a `Dates.Period`.
554+
- `start_date`: The start date of the simulation.
555+
- `t_start`: The starting time of the simulation (in seconds).
556+
- `dt`: The timestep of the model (optional), used to check for consistency.
557+
558+
The callback uses `ClimaDiagnostics.EveryCalendarDtSchedule` to determine when
559+
to check for NaNs based on the `nancheck_frequency`. The schedule is
560+
initialized with the `start_date` and `t_start` to ensure that it is first
561+
called at the correct time.
562+
"""
563+
function NaNCheckCallback(
564+
nancheck_frequency::Union{AbstractFloat, Dates.Period},
565+
start_date,
566+
t_start,
567+
dt,
568+
)
569+
# TODO: Move to a more general callback system. For the time being, we use
570+
# the ClimaDiagnostics one because it is flexible and it supports calendar
571+
# dates.
572+
573+
if nancheck_frequency isa AbstractFloat
574+
# Assume it is in seconds, but go through Millisecond to support
575+
# fractional seconds
576+
nancheck_frequency_period = Dates.Millisecond(1000nancheck_frequency)
577+
else
578+
nancheck_frequency_period = nancheck_frequency
579+
end
580+
581+
schedule = EveryCalendarDtSchedule(
582+
nancheck_frequency_period;
583+
start_date,
584+
date_last = start_date + Dates.Millisecond(1000t_start),
585+
)
586+
587+
if !isnothing(dt)
588+
dt_period = Dates.Millisecond(1000dt)
589+
if !isdivisible(nancheck_frequency_period, dt_period)
590+
@warn "Callback frequency ($(nancheck_frequency_period)) is not an integer multiple of dt $(dt_period)"
591+
end
592+
end
593+
594+
cond = let schedule = schedule
595+
(u, t, integrator) -> schedule(integrator)
596+
end
597+
affect! = (integrator) -> count_nans_state(integrator.u)
598+
599+
SciMLBase.DiscreteCallback(cond, affect!)
600+
end

test/shared_utilities/utilities.jl

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ using Test
22
import ClimaComms
33
ClimaComms.@import_required_backends
44
using ClimaCore: Spaces, Geometry, Fields
5-
import ClimaComms
65
using ClimaLand
76
using ClimaLand: Domains, condition, SavingAffect, saving_initialize
87

@@ -256,3 +255,108 @@ end
256255
@test Y.subfields.subfield2 == Y_copy.subfields.subfield2
257256
end
258257
end
258+
259+
@testset "count_nans_state, FT = $FT" begin
260+
# Test on a 3D spherical domain
261+
domain = ClimaLand.Domains.SphericalShell(;
262+
radius = FT(2),
263+
depth = FT(1.0),
264+
nelements = (10, 5),
265+
npolynomial = 3,
266+
)
267+
268+
# Construct some fields
269+
space = domain.space.subsurface
270+
var1 = Fields.zeros(space)
271+
var2 = Fields.zeros(space)
272+
var3 = Fields.zeros(space)
273+
fieldvec = Fields.FieldVector(var2 = var2, var3 = var3)
274+
275+
# Construct a FieldVector containing the fields and a nested FieldVector
276+
Y = Fields.FieldVector(var1 = var1, fieldvec = fieldvec)
277+
278+
# Count and log the number of NaNs in the state (test verbose and non-verbose cases)
279+
@test_logs (:info, "Checking NaNs in var1") (:info, "No NaNs found") (
280+
:info,
281+
"Checking NaNs in fieldvec",
282+
) (:info, "Checking NaNs in var2") (:info, "No NaNs found") (
283+
:info,
284+
"Checking NaNs in var3",
285+
) (:info, "No NaNs found") ClimaLand.count_nans_state(Y, verbose = true)
286+
287+
@test_logs (:info, "Checking NaNs in var1") (
288+
:info,
289+
"Checking NaNs in fieldvec",
290+
) (:info, "Checking NaNs in var2") (:info, "Checking NaNs in var3") ClimaLand.count_nans_state(
291+
Y,
292+
)
293+
294+
# Add some NaNs to the fields without scalar indexing
295+
ArrayType = ClimaComms.array_type(ClimaComms.device())
296+
var1_copy = Array(parent(var1))
297+
var1_copy[1] = NaN
298+
parent(var1) .= ArrayType(var1_copy)
299+
var2_copy = Array(parent(var1))
300+
var2_copy[1] = NaN
301+
var2_copy[2] = NaN
302+
parent(var2) .= ArrayType(var2_copy)
303+
304+
# Count and log the number of NaNs in the state (test verbose and non-verbose cases)
305+
@test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") (
306+
:info,
307+
"Checking NaNs in fieldvec",
308+
) (:info, "Checking NaNs in var2") (:warn, "2 NaNs found") (
309+
:info,
310+
"Checking NaNs in var3",
311+
) (:info, "No NaNs found") ClimaLand.count_nans_state(Y, verbose = true)
312+
313+
@test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") (
314+
:info,
315+
"Checking NaNs in fieldvec",
316+
) (:info, "Checking NaNs in var2") (:warn, "2 NaNs found") (
317+
:info,
318+
"Checking NaNs in var3",
319+
) ClimaLand.count_nans_state(Y)
320+
321+
# Test with a mask
322+
mask_zeros = Fields.zeros(space)
323+
@test_logs (:info, "Checking NaNs in var1") (:info, "No NaNs found") (
324+
:info,
325+
"Checking NaNs in fieldvec",
326+
) (:info, "Checking NaNs in var2") (:info, "No NaNs found") (
327+
:info,
328+
"Checking NaNs in var3",
329+
) (:info, "No NaNs found") ClimaLand.count_nans_state(
330+
Y,
331+
mask = mask_zeros,
332+
verbose = true,
333+
)
334+
@test_logs (:info, "Checking NaNs in var1") (
335+
:info,
336+
"Checking NaNs in fieldvec",
337+
) (:info, "Checking NaNs in var2") (:info, "Checking NaNs in var3") ClimaLand.count_nans_state(
338+
Y,
339+
mask = mask_zeros,
340+
)
341+
342+
mask_ones = Fields.ones(space)
343+
@test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") (
344+
:info,
345+
"Checking NaNs in fieldvec",
346+
) (:info, "Checking NaNs in var2") (:warn, "2 NaNs found") (
347+
:info,
348+
"Checking NaNs in var3",
349+
) (:info, "No NaNs found") ClimaLand.count_nans_state(
350+
Y,
351+
mask = mask_ones,
352+
verbose = true,
353+
)
354+
355+
@test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") (
356+
:info,
357+
"Checking NaNs in fieldvec",
358+
) (:info, "Checking NaNs in var2") (:warn, "2 NaNs found") (
359+
:info,
360+
"Checking NaNs in var3",
361+
) ClimaLand.count_nans_state(Y, mask = mask_ones)
362+
end

0 commit comments

Comments
 (0)