Skip to content

Commit 357e32d

Browse files
authored
Merge pull request #970 from CliMA/js/check-nans
add state NaN checker
2 parents ca2ad88 + cb6ee6c commit 357e32d

File tree

10 files changed

+275
-11
lines changed

10 files changed

+275
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ ArtifactWrappers = "0.2"
4141
BSON = "0.3.9"
4242
CSV = "0.10.14"
4343
CUDA = "5.5"
44-
ClimaComms = "0.6"
44+
ClimaComms = "0.6.2"
4545
ClimaCore = "0.14.19"
4646
ClimaDiagnostics = "0.2.10"
4747
ClimaParams = "0.10.2"

docs/make.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,15 @@ include("list_diagnostics.jl")
5252
pages = Any[
5353
"Home" => "index.md",
5454
"Getting Started" => "getting_started.md",
55+
"Repository structure" => "folderstructure.md",
5556
"Tutorials" => tutorials,
5657
"Standalone models" => standalone_models,
5758
"Diagnostics" => diagnostics,
5859
"Leaderboard" => "leaderboard/leaderboard.md",
5960
"Restarts" => "restarts.md",
60-
"Contribution guide" => "Contributing.md",
61-
"Repository structure" => "folderstructure.md",
61+
"Misc. utilities" => "shared_utilities.md",
6262
"APIs" => apis,
63+
"Contribution guide" => "Contributing.md",
6364
]
6465

6566
mathengine = MathJax(

docs/src/shared_utilities.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# ClimaLand Shared Utilities
2+
3+
## State NaN Counter - `count_nans_state`
4+
We have implemented a function `count_nans_state` which recursively goes
5+
through the entire state and displays the number of NaNs found for each state
6+
variable. This function is intended to be used to debug simulations to
7+
determine quantitatively if a simulation is stable.
8+
9+
If NaNs are found for a particular variable, this will be displayed via
10+
a warning printed to the console. The `verbose` argument toggles whether
11+
the function prints output when no NaNs are found.
12+
13+
If a ClimaCore Field is provided as `mask`, the function will only count NaNs
14+
in the state variables where the mask is 1. This is intended to be used with
15+
the land/sea mask, to avoid counting NaNs over the ocean. Note this assumes
16+
the mask is 1 over land and 0 over ocean.
17+
18+
This function does not distinguish between surface or subsurface
19+
variables, so a variable defined on the subsurface will display more NaNs
20+
than one defined on the surface, even if they are NaN at the same
21+
spatial locations in the horizontal.
22+
23+
### Usage examples
24+
This function can be used to inspect the state after a simulation finishes
25+
running by calling `count_nans_state(sol.u[end])`.
26+
27+
This function can be used throughout the duration of a simulation by
28+
triggering it via a callback. The `NaNCheckCallback` is designed for this
29+
purpose, and can be set up as follows:
30+
```julia
31+
nancheck_freq = Dates.Month(1)
32+
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
33+
```
34+
and then included along with any other callbacks in a `SciMLBase.CallbackSet`.
35+
36+
Please see our longrun experiments to see examples of this callback in action!

experiments/long_runs/bucket.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,11 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 7))
144144
diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler)
145145

146146
driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc)
147-
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb)
147+
148+
nancheck_freq = Dates.Month(1)
149+
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
150+
151+
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb)
148152
end
149153

150154
function setup_and_solve_problem(; greet = false)

experiments/long_runs/land.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,10 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 15))
374374
diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler)
375375

376376
driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc)
377-
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb)
377+
378+
nancheck_freq = Dates.Month(1)
379+
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
380+
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb)
378381
end
379382

380383
function setup_and_solve_problem(; greet = false)

experiments/long_runs/land_region.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,10 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (10, 10, 15))
389389
diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler)
390390

391391
driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc)
392-
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb)
392+
393+
nancheck_freq = Dates.Month(1)
394+
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
395+
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb)
393396
end
394397

395398
function setup_and_solve_problem(; greet = false)

experiments/long_runs/snowy_land.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,11 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 15))
396396
end
397397
report_cb = SciMLBase.DiscreteCallback(every1000steps, report)
398398

399-
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, report_cb)
399+
nancheck_freq = Dates.Month(1)
400+
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
401+
402+
return prob,
403+
SciMLBase.CallbackSet(driver_cb, diag_cb, report_cb, nancheck_cb)
400404
end
401405

402406
function setup_and_solve_problem(; greet = false)

experiments/long_runs/soil.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,10 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 15))
215215
diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler)
216216

217217
driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc)
218-
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb)
218+
219+
nancheck_freq = Dates.Month(1)
220+
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
221+
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb)
219222
end
220223

221224
function setup_and_solve_problem(; greet = false)

src/shared_utilities/utils.jl

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import SciMLBase
33
import ClimaDiagnostics.Schedules: EveryCalendarDtSchedule
44
import Dates
55

6-
export FTfromY
6+
export FTfromY, count_nans_state
77

88
"""
99
heaviside(x::FT)::FT where {FT}
@@ -320,7 +320,7 @@ function CheckpointCallback(
320320

321321
if !isnothing(dt)
322322
dt_period = Dates.Millisecond(1000dt)
323-
if !isdivisible(checkpoint_frequency_period / dt_period)
323+
if !isdivisible(checkpoint_frequency_period, dt_period)
324324
@warn "Checkpoint frequency ($(checkpoint_frequency_period)) is not an integer multiple of dt $(dt_period)"
325325
end
326326
end
@@ -492,3 +492,110 @@ function isdivisible(
492492
# have any common divisor)
493493
return isinteger(Dates.Day(1) / dt_small)
494494
end
495+
496+
"""
497+
count_nans_state(state, mask::ClimaCore.Fields.Field = nothing, verbose = false)
498+
499+
Count the number of NaNs in the state, e.g. the FieldVector given by
500+
`sol.u[end]` after calling `solve`. This function is useful for
501+
debugging simulations to determine quantitatively if a simulation is stable.
502+
503+
If this function is called on a FieldVector, it will recursively call itself
504+
on each Field in the FieldVector. If it is called on a Field, it will count
505+
the number of NaNs in the Field and produce a warning if any are found.
506+
507+
If a ClimaCore Field is provided as `mask`, the function will only count NaNs
508+
in the state variables where the mask is 1. This is intended to be used with
509+
the land/sea mask, to avoid counting NaNs over the ocean. Note this assumes
510+
the mask is 1 over land and 0 over ocean.
511+
512+
The `verbose` argument toggles whether the function produces output when no
513+
NaNs are found.
514+
"""
515+
function count_nans_state(
516+
state::ClimaCore.Fields.FieldVector;
517+
mask = nothing,
518+
verbose = false,
519+
)
520+
for pn in propertynames(state)
521+
state_new = getproperty(state, pn)
522+
@info "Checking NaNs in $pn"
523+
count_nans_state(state_new; mask, verbose)
524+
end
525+
return nothing
526+
end
527+
528+
function count_nans_state(
529+
state::ClimaCore.Fields.Field;
530+
mask = nothing,
531+
verbose = false,
532+
)
533+
# Note: this code uses `parent`; this pattern should not be replicated
534+
num_nans =
535+
isnothing(mask) ? round(sum(isnan, parent(state))) :
536+
round(sum(isnan, parent(state)[Bool.(parent(mask))]; init = 0))
537+
if isapprox(num_nans, 0)
538+
verbose && @info "No NaNs found"
539+
else
540+
@warn "$num_nans NaNs found"
541+
end
542+
return nothing
543+
end
544+
545+
"""
546+
NaNCheckCallback(nancheck_frequency::Union{AbstractFloat, Dates.Period},
547+
start_date, t_start, dt)
548+
549+
Constructs a DiscreteCallback which counts the number of NaNs in the state
550+
and produces a warning if any are found.
551+
552+
# Arguments
553+
- `nancheck_frequency`: The frequency at which the state is checked for NaNs.
554+
Can be specified as a float (in seconds) or a `Dates.Period`.
555+
- `start_date`: The start date of the simulation.
556+
- `t_start`: The starting time of the simulation (in seconds).
557+
- `dt`: The timestep of the model (optional), used to check for consistency.
558+
559+
The callback uses `ClimaDiagnostics.EveryCalendarDtSchedule` to determine when
560+
to check for NaNs based on the `nancheck_frequency`. The schedule is
561+
initialized with the `start_date` and `t_start` to ensure that it is first
562+
called at the correct time.
563+
"""
564+
function NaNCheckCallback(
565+
nancheck_frequency::Union{AbstractFloat, Dates.Period},
566+
start_date,
567+
t_start,
568+
dt,
569+
)
570+
# TODO: Move to a more general callback system. For the time being, we use
571+
# the ClimaDiagnostics one because it is flexible and it supports calendar
572+
# dates.
573+
574+
if nancheck_frequency isa AbstractFloat
575+
# Assume it is in seconds, but go through Millisecond to support
576+
# fractional seconds
577+
nancheck_frequency_period = Dates.Millisecond(1000nancheck_frequency)
578+
else
579+
nancheck_frequency_period = nancheck_frequency
580+
end
581+
582+
schedule = EveryCalendarDtSchedule(
583+
nancheck_frequency_period;
584+
start_date,
585+
date_last = start_date + Dates.Millisecond(1000t_start),
586+
)
587+
588+
if !isnothing(dt)
589+
dt_period = Dates.Millisecond(1000dt)
590+
if !isdivisible(nancheck_frequency_period, dt_period)
591+
@warn "Callback frequency ($(nancheck_frequency_period)) is not an integer multiple of dt $(dt_period)"
592+
end
593+
end
594+
595+
cond = let schedule = schedule
596+
(u, t, integrator) -> schedule(integrator)
597+
end
598+
affect! = (integrator) -> count_nans_state(integrator.u)
599+
600+
SciMLBase.DiscreteCallback(cond, affect!)
601+
end

test/shared_utilities/utilities.jl

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ using Test
22
import ClimaComms
33
ClimaComms.@import_required_backends
44
using ClimaCore: Spaces, Geometry, Fields
5-
import ClimaComms
65
using ClimaLand
76
using ClimaLand: Domains, condition, SavingAffect, saving_initialize
87

@@ -256,3 +255,107 @@ end
256255
@test Y.subfields.subfield2 == Y_copy.subfields.subfield2
257256
end
258257
end
258+
259+
@testset "count_nans_state, FT = $FT" begin
260+
# Test on a 3D spherical domain
261+
domain = ClimaLand.Domains.SphericalShell(;
262+
radius = FT(2),
263+
depth = FT(1.0),
264+
nelements = (10, 5),
265+
npolynomial = 3,
266+
)
267+
268+
# Construct some fields
269+
space = domain.space.subsurface
270+
var1 = Fields.zeros(space)
271+
var2 = Fields.zeros(space)
272+
var3 = Fields.zeros(space)
273+
fieldvec = Fields.FieldVector(var2 = var2, var3 = var3)
274+
275+
# Construct a FieldVector containing the fields and a nested FieldVector
276+
Y = Fields.FieldVector(var1 = var1, fieldvec = fieldvec)
277+
278+
# Count and log the number of NaNs in the state (test verbose and non-verbose cases)
279+
@test_logs (:info, "Checking NaNs in var1") (:info, "No NaNs found") (
280+
:info,
281+
"Checking NaNs in fieldvec",
282+
) (:info, "Checking NaNs in var2") (:info, "No NaNs found") (
283+
:info,
284+
"Checking NaNs in var3",
285+
) (:info, "No NaNs found") ClimaLand.count_nans_state(Y, verbose = true)
286+
287+
@test_logs (:info, "Checking NaNs in var1") (
288+
:info,
289+
"Checking NaNs in fieldvec",
290+
) (:info, "Checking NaNs in var2") (:info, "Checking NaNs in var3") ClimaLand.count_nans_state(
291+
Y,
292+
)
293+
294+
# Add some NaNs to the fields
295+
# Note: this code uses `parent` and scalar indexing,
296+
# which shouldn't be replicated outside of tests
297+
ClimaComms.allowscalar(ClimaComms.device()) do
298+
parent(var1)[1] = NaN
299+
parent(var2)[1] = NaN
300+
parent(var2)[2] = NaN
301+
end
302+
303+
# Count and log the number of NaNs in the state (test verbose and non-verbose cases)
304+
@test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") (
305+
:info,
306+
"Checking NaNs in fieldvec",
307+
) (:info, "Checking NaNs in var2") (:warn, "2 NaNs found") (
308+
:info,
309+
"Checking NaNs in var3",
310+
) (:info, "No NaNs found") ClimaLand.count_nans_state(Y, verbose = true)
311+
312+
@test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") (
313+
:info,
314+
"Checking NaNs in fieldvec",
315+
) (:info, "Checking NaNs in var2") (:warn, "2 NaNs found") (
316+
:info,
317+
"Checking NaNs in var3",
318+
) ClimaLand.count_nans_state(Y)
319+
320+
# Test with a mask
321+
mask_zeros = Fields.zeros(space)
322+
@test_logs (:info, "Checking NaNs in var1") (:info, "No NaNs found") (
323+
:info,
324+
"Checking NaNs in fieldvec",
325+
) (:info, "Checking NaNs in var2") (:info, "No NaNs found") (
326+
:info,
327+
"Checking NaNs in var3",
328+
) (:info, "No NaNs found") ClimaLand.count_nans_state(
329+
Y,
330+
mask = mask_zeros,
331+
verbose = true,
332+
)
333+
@test_logs (:info, "Checking NaNs in var1") (
334+
:info,
335+
"Checking NaNs in fieldvec",
336+
) (:info, "Checking NaNs in var2") (:info, "Checking NaNs in var3") ClimaLand.count_nans_state(
337+
Y,
338+
mask = mask_zeros,
339+
)
340+
341+
mask_ones = Fields.ones(space)
342+
@test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") (
343+
:info,
344+
"Checking NaNs in fieldvec",
345+
) (:info, "Checking NaNs in var2") (:warn, "2 NaNs found") (
346+
:info,
347+
"Checking NaNs in var3",
348+
) (:info, "No NaNs found") ClimaLand.count_nans_state(
349+
Y,
350+
mask = mask_ones,
351+
verbose = true,
352+
)
353+
354+
@test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") (
355+
:info,
356+
"Checking NaNs in fieldvec",
357+
) (:info, "Checking NaNs in var2") (:warn, "2 NaNs found") (
358+
:info,
359+
"Checking NaNs in var3",
360+
) ClimaLand.count_nans_state(Y, mask = mask_ones)
361+
end

0 commit comments

Comments
 (0)