Skip to content

Commit 1d2e827

Browse files
committed
add state NaN checker
Adds a function that quantitatively checks how many NaNs are present in the state, and displays this information to the user. This can be used to inspect sol.u[end] after a simulation has run, to see if any NaNs were produced at the end of the simulation. If no NaNs are found, this information is logged in an info statement when run in verbose mode. If NaNs are found, this information is logged in a warn statement.
1 parent c44b057 commit 1d2e827

File tree

9 files changed

+274
-10
lines changed

9 files changed

+274
-10
lines changed

docs/make.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,15 @@ include("list_diagnostics.jl")
5252
pages = Any[
5353
"Home" => "index.md",
5454
"Getting Started" => "getting_started.md",
55+
"Repository structure" => "folderstructure.md",
5556
"Tutorials" => tutorials,
5657
"Standalone models" => standalone_models,
5758
"Diagnostics" => diagnostics,
5859
"Leaderboard" => "leaderboard/leaderboard.md",
5960
"Restarts" => "restarts.md",
60-
"Contribution guide" => "Contributing.md",
61-
"Repository structure" => "folderstructure.md",
61+
"Misc. utilities" => "shared_utilities.md",
6262
"APIs" => apis,
63+
"Contribution guide" => "Contributing.md",
6364
]
6465

6566
mathengine = MathJax(

docs/src/shared_utilities.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# ClimaLand Shared Utilities
2+
3+
## State NaN Counter - `count_nans_state`
4+
We have implemented a function `count_nans_state` which recursively goes
5+
through the entire state and displays the number of NaNs found for each state
6+
variable. This function is intended to be used to debug simulations to
7+
determine quantitatively if a simulation is stable.
8+
9+
If NaNs are found for a particular variable, this will be displayed via
10+
a warning printed to the console. The `verbose` argument toggles whether
11+
the function prints output when no NaNs are found.
12+
13+
If a ClimaCore Field is provided as `mask`, the function will only count NaNs
14+
in the state variables where the mask is 1. This is intended to be used with
15+
the land/sea mask, to avoid counting NaNs over the ocean. Note this assumes
16+
the mask is 1 over land and 0 over ocean.
17+
18+
This function does not distinguish between surface or subsurface
19+
variables, so a variable defined on the subsurface will display more NaNs
20+
than one defined on the surface, even if they are NaN at the same
21+
spatial locations in the horizontal.
22+
23+
### Usage examples
24+
This function can be used to inspect the state after a simulation finishes
25+
running by calling `count_nans_state(sol.u[end])`.
26+
27+
This function can be used throughout the duration of a simulation by
28+
triggering it via a callback. The `NaNCheckCallback` is designed for this
29+
purpose, and can be set up as follows:
30+
```julia
31+
nancheck_freq = Dates.Month(1)
32+
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
33+
```
34+
and then included along with any other callbacks in a `SciMLBase.CallbackSet`.
35+
36+
Please see our longrun experiments to see examples of this callback in action!

experiments/long_runs/bucket.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,11 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 7))
144144
diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler)
145145

146146
driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc)
147-
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb)
147+
148+
nancheck_freq = Dates.Month(1)
149+
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
150+
151+
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb)
148152
end
149153

150154
function setup_and_solve_problem(; greet = false)

experiments/long_runs/land.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,10 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 15))
374374
diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler)
375375

376376
driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc)
377-
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb)
377+
378+
nancheck_freq = Dates.Month(1)
379+
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
380+
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb)
378381
end
379382

380383
function setup_and_solve_problem(; greet = false)

experiments/long_runs/land_region.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,10 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (10, 10, 15))
389389
diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler)
390390

391391
driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc)
392-
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb)
392+
393+
nancheck_freq = Dates.Month(1)
394+
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
395+
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb)
393396
end
394397

395398
function setup_and_solve_problem(; greet = false)

experiments/long_runs/snowy_land.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,11 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 15))
396396
end
397397
report_cb = SciMLBase.DiscreteCallback(every1000steps, report)
398398

399-
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, report_cb)
399+
nancheck_freq = Dates.Month(1)
400+
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
401+
402+
return prob,
403+
SciMLBase.CallbackSet(driver_cb, diag_cb, report_cb, nancheck_cb)
400404
end
401405

402406
function setup_and_solve_problem(; greet = false)

experiments/long_runs/soil.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,10 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 15))
215215
diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler)
216216

217217
driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc)
218-
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb)
218+
219+
nancheck_freq = Dates.Month(1)
220+
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
221+
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb)
219222
end
220223

221224
function setup_and_solve_problem(; greet = false)

src/shared_utilities/utils.jl

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import SciMLBase
33
import ClimaDiagnostics.Schedules: EveryCalendarDtSchedule
44
import Dates
55

6-
export FTfromY
6+
export FTfromY, count_nans_state
77

88
"""
99
heaviside(x::FT)::FT where {FT}
@@ -320,7 +320,7 @@ function CheckpointCallback(
320320

321321
if !isnothing(dt)
322322
dt_period = Dates.Millisecond(1000dt)
323-
if !isdivisible(checkpoint_frequency_period / dt_period)
323+
if !isdivisible(checkpoint_frequency_period, dt_period)
324324
@warn "Checkpoint frequency ($(checkpoint_frequency_period)) is not an integer multiple of dt $(dt_period)"
325325
end
326326
end
@@ -492,3 +492,110 @@ function isdivisible(
492492
# have any common divisor)
493493
return isinteger(Dates.Day(1) / dt_small)
494494
end
495+
496+
"""
497+
count_nans_state(state, mask::ClimaCore.Fields.Field = nothing, verbose = false)
498+
499+
Count the number of NaNs in the state, e.g. the FieldVector given by
500+
`sol.u[end]` after calling `solve`. This function is useful for
501+
debugging simulations to determine quantitatively if a simulation is stable.
502+
503+
If this function is called on a FieldVector, it will recursively call itself
504+
on each Field in the FieldVector. If it is called on a Field, it will count
505+
the number of NaNs in the Field and produce a warning if any are found.
506+
507+
If a ClimaCore Field is provided as `mask`, the function will only count NaNs
508+
in the state variables where the mask is 1. This is intended to be used with
509+
the land/sea mask, to avoid counting NaNs over the ocean. Note this assumes
510+
the mask is 1 over land and 0 over ocean.
511+
512+
The `verbose` argument toggles whether the function produces output when no
513+
NaNs are found.
514+
"""
515+
function count_nans_state(
516+
state::ClimaCore.Fields.FieldVector;
517+
mask = nothing,
518+
verbose = false,
519+
)
520+
for pn in propertynames(state)
521+
state_new = getproperty(state, pn)
522+
@info "Checking NaNs in $pn"
523+
count_nans_state(state_new; mask, verbose)
524+
end
525+
return nothing
526+
end
527+
528+
function count_nans_state(
529+
state::ClimaCore.Fields.Field;
530+
mask = nothing,
531+
verbose = false,
532+
)
533+
# Note: this code uses `parent`; this pattern should not be replicated
534+
num_nans =
535+
isnothing(mask) ? round(sum(isnan, parent(state))) :
536+
round(sum(isnan, parent(state)[Bool.(parent(mask))]))
537+
if isapprox(num_nans, 0)
538+
verbose && @info "No NaNs found"
539+
else
540+
@warn "$num_nans NaNs found"
541+
end
542+
return nothing
543+
end
544+
545+
"""
546+
NaNCheckCallback(nancheck_frequency::Union{AbstractFloat, Dates.Period},
547+
start_date, t_start, dt)
548+
549+
Constructs a DiscreteCallback which counts the number of NaNs in the state
550+
and produces a warning if any are found.
551+
552+
# Arguments
553+
- `nancheck_frequency`: The frequency at which the state is checked for NaNs.
554+
Can be specified as a float (in seconds) or a `Dates.Period`.
555+
- `start_date`: The start date of the simulation.
556+
- `t_start`: The starting time of the simulation (in seconds).
557+
- `dt`: The timestep of the model (optional), used to check for consistency.
558+
559+
The callback uses `ClimaDiagnostics.EveryCalendarDtSchedule` to determine when
560+
to check for NaNs based on the `nancheck_frequency`. The schedule is
561+
initialized with the `start_date` and `t_start` to ensure that it is first
562+
called at the correct time.
563+
"""
564+
function NaNCheckCallback(
565+
nancheck_frequency::Union{AbstractFloat, Dates.Period},
566+
start_date,
567+
t_start,
568+
dt,
569+
)
570+
# TODO: Move to a more general callback system. For the time being, we use
571+
# the ClimaDiagnostics one because it is flexible and it supports calendar
572+
# dates.
573+
574+
if nancheck_frequency isa AbstractFloat
575+
# Assume it is in seconds, but go through Millisecond to support
576+
# fractional seconds
577+
nancheck_frequency_period = Dates.Millisecond(1000nancheck_frequency)
578+
else
579+
nancheck_frequency_period = nancheck_frequency
580+
end
581+
582+
schedule = EveryCalendarDtSchedule(
583+
nancheck_frequency_period;
584+
start_date,
585+
date_last = start_date + Dates.Millisecond(1000t_start),
586+
)
587+
588+
if !isnothing(dt)
589+
dt_period = Dates.Millisecond(1000dt)
590+
if !isdivisible(nancheck_frequency_period, dt_period)
591+
@warn "Callback frequency ($(nancheck_frequency_period)) is not an integer multiple of dt $(dt_period)"
592+
end
593+
end
594+
595+
cond = let schedule = schedule
596+
(u, t, integrator) -> schedule(integrator)
597+
end
598+
affect! = (integrator) -> count_nans_state(integrator.u)
599+
600+
SciMLBase.DiscreteCallback(cond, affect!)
601+
end

test/shared_utilities/utilities.jl

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ using Test
22
import ClimaComms
33
ClimaComms.@import_required_backends
44
using ClimaCore: Spaces, Geometry, Fields
5-
import ClimaComms
65
using ClimaLand
76
using ClimaLand: Domains, condition, SavingAffect, saving_initialize
87

@@ -256,3 +255,107 @@ end
256255
@test Y.subfields.subfield2 == Y_copy.subfields.subfield2
257256
end
258257
end
258+
259+
@testset "count_nans_state, FT = $FT" begin
260+
# Test on a 3D spherical domain
261+
domain = ClimaLand.Domains.SphericalShell(;
262+
radius = FT(2),
263+
depth = FT(1.0),
264+
nelements = (10, 5),
265+
npolynomial = 3,
266+
)
267+
268+
# Construct some fields
269+
space = domain.space.subsurface
270+
var1 = Fields.zeros(space)
271+
var2 = Fields.zeros(space)
272+
var3 = Fields.zeros(space)
273+
fieldvec = Fields.FieldVector(var2 = var2, var3 = var3)
274+
275+
# Construct a FieldVector containing the fields and a nested FieldVector
276+
Y = Fields.FieldVector(var1 = var1, fieldvec = fieldvec)
277+
278+
# Count and log the number of NaNs in the state (test verbose and non-verbose cases)
279+
@test_logs (:info, "Checking NaNs in var1") (:info, "No NaNs found") (
280+
:info,
281+
"Checking NaNs in fieldvec",
282+
) (:info, "Checking NaNs in var2") (:info, "No NaNs found") (
283+
:info,
284+
"Checking NaNs in var3",
285+
) (:info, "No NaNs found") ClimaLand.count_nans_state(Y, verbose = true)
286+
287+
@test_logs (:info, "Checking NaNs in var1") (
288+
:info,
289+
"Checking NaNs in fieldvec",
290+
) (:info, "Checking NaNs in var2") (:info, "Checking NaNs in var3") ClimaLand.count_nans_state(
291+
Y,
292+
)
293+
294+
# Add some NaNs to the fields
295+
# Note: this code uses `parent` and scalar indexing,
296+
# which shouldn't be replicated outside of tests
297+
ClimaComms.allowscalar(ClimaComms.device()) do
298+
parent(var1)[1] = NaN
299+
parent(var2)[1] = NaN
300+
parent(var2)[2] = NaN
301+
end
302+
303+
# Count and log the number of NaNs in the state (test verbose and non-verbose cases)
304+
@test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") (
305+
:info,
306+
"Checking NaNs in fieldvec",
307+
) (:info, "Checking NaNs in var2") (:warn, "2 NaNs found") (
308+
:info,
309+
"Checking NaNs in var3",
310+
) (:info, "No NaNs found") ClimaLand.count_nans_state(Y, verbose = true)
311+
312+
@test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") (
313+
:info,
314+
"Checking NaNs in fieldvec",
315+
) (:info, "Checking NaNs in var2") (:warn, "2 NaNs found") (
316+
:info,
317+
"Checking NaNs in var3",
318+
) ClimaLand.count_nans_state(Y)
319+
320+
# Test with a mask
321+
mask_zeros = Fields.zeros(space)
322+
@test_logs (:info, "Checking NaNs in var1") (:info, "No NaNs found") (
323+
:info,
324+
"Checking NaNs in fieldvec",
325+
) (:info, "Checking NaNs in var2") (:info, "No NaNs found") (
326+
:info,
327+
"Checking NaNs in var3",
328+
) (:info, "No NaNs found") ClimaLand.count_nans_state(
329+
Y,
330+
mask = mask_zeros,
331+
verbose = true,
332+
)
333+
@test_logs (:info, "Checking NaNs in var1") (
334+
:info,
335+
"Checking NaNs in fieldvec",
336+
) (:info, "Checking NaNs in var2") (:info, "Checking NaNs in var3") ClimaLand.count_nans_state(
337+
Y,
338+
mask = mask_zeros,
339+
)
340+
341+
mask_ones = Fields.ones(space)
342+
@test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") (
343+
:info,
344+
"Checking NaNs in fieldvec",
345+
) (:info, "Checking NaNs in var2") (:warn, "2 NaNs found") (
346+
:info,
347+
"Checking NaNs in var3",
348+
) (:info, "No NaNs found") ClimaLand.count_nans_state(
349+
Y,
350+
mask = mask_ones,
351+
verbose = true,
352+
)
353+
354+
@test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") (
355+
:info,
356+
"Checking NaNs in fieldvec",
357+
) (:info, "Checking NaNs in var2") (:warn, "2 NaNs found") (
358+
:info,
359+
"Checking NaNs in var3",
360+
) ClimaLand.count_nans_state(Y, mask = mask_ones)
361+
end

0 commit comments

Comments
 (0)