Skip to content

add state NaN checker #970

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ ArtifactWrappers = "0.2"
BSON = "0.3.9"
CSV = "0.10.14"
CUDA = "5.5"
ClimaComms = "0.6"
ClimaComms = "0.6.2"
ClimaCore = "0.14.19"
ClimaDiagnostics = "0.2.10"
ClimaParams = "0.10.2"
Expand Down
5 changes: 3 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,15 @@ include("list_diagnostics.jl")
pages = Any[
"Home" => "index.md",
"Getting Started" => "getting_started.md",
"Repository structure" => "folderstructure.md",
"Tutorials" => tutorials,
"Standalone models" => standalone_models,
"Diagnostics" => diagnostics,
"Leaderboard" => "leaderboard/leaderboard.md",
"Restarts" => "restarts.md",
"Contribution guide" => "Contributing.md",
"Repository structure" => "folderstructure.md",
"Misc. utilities" => "shared_utilities.md",
"APIs" => apis,
"Contribution guide" => "Contributing.md",
]

mathengine = MathJax(
Expand Down
36 changes: 36 additions & 0 deletions docs/src/shared_utilities.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# ClimaLand Shared Utilities

## State NaN Counter - `count_nans_state`
We have implemented a function `count_nans_state` which recursively goes
through the entire state and displays the number of NaNs found for each state
variable. This function is intended to be used to debug simulations to
determine quantitatively if a simulation is stable.

If NaNs are found for a particular variable, this will be displayed via
a warning printed to the console. The `verbose` argument toggles whether
the function prints output when no NaNs are found.

If a ClimaCore Field is provided as `mask`, the function will only count NaNs
in the state variables where the mask is 1. This is intended to be used with
the land/sea mask, to avoid counting NaNs over the ocean. Note this assumes
the mask is 1 over land and 0 over ocean.

This function does not distinguish between surface or subsurface
variables, so a variable defined on the subsurface will display more NaNs
than one defined on the surface, even if they are NaN at the same
spatial locations in the horizontal.

### Usage examples
This function can be used to inspect the state after a simulation finishes
running by calling `count_nans_state(sol.u[end])`.

This function can be used throughout the duration of a simulation by
triggering it via a callback. The `NaNCheckCallback` is designed for this
purpose, and can be set up as follows:
```julia
nancheck_freq = Dates.Month(1)
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
```
and then included along with any other callbacks in a `SciMLBase.CallbackSet`.

Please see our longrun experiments to see examples of this callback in action!
6 changes: 5 additions & 1 deletion experiments/long_runs/bucket.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,11 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 7))
diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler)

driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc)
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb)

nancheck_freq = Dates.Month(1)
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)

return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb)
end

function setup_and_solve_problem(; greet = false)
Expand Down
5 changes: 4 additions & 1 deletion experiments/long_runs/land.jl
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,10 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 15))
diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler)

driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc)
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb)

nancheck_freq = Dates.Month(1)
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb)
end

function setup_and_solve_problem(; greet = false)
Expand Down
5 changes: 4 additions & 1 deletion experiments/long_runs/land_region.jl
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,10 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (10, 10, 15))
diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler)

driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc)
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb)

nancheck_freq = Dates.Month(1)
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb)
end

function setup_and_solve_problem(; greet = false)
Expand Down
6 changes: 5 additions & 1 deletion experiments/long_runs/snowy_land.jl
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,11 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 15))
end
report_cb = SciMLBase.DiscreteCallback(every1000steps, report)

return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, report_cb)
nancheck_freq = Dates.Month(1)
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)

return prob,
SciMLBase.CallbackSet(driver_cb, diag_cb, report_cb, nancheck_cb)
end

function setup_and_solve_problem(; greet = false)
Expand Down
5 changes: 4 additions & 1 deletion experiments/long_runs/soil.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,10 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 15))
diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler)

driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc)
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb)

nancheck_freq = Dates.Month(1)
nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt)
return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb)
end

function setup_and_solve_problem(; greet = false)
Expand Down
111 changes: 109 additions & 2 deletions src/shared_utilities/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import SciMLBase
import ClimaDiagnostics.Schedules: EveryCalendarDtSchedule
import Dates

export FTfromY
export FTfromY, count_nans_state

"""
heaviside(x::FT)::FT where {FT}
Expand Down Expand Up @@ -320,7 +320,7 @@ function CheckpointCallback(

if !isnothing(dt)
dt_period = Dates.Millisecond(1000dt)
if !isdivisible(checkpoint_frequency_period / dt_period)
if !isdivisible(checkpoint_frequency_period, dt_period)
@warn "Checkpoint frequency ($(checkpoint_frequency_period)) is not an integer multiple of dt $(dt_period)"
end
end
Expand Down Expand Up @@ -492,3 +492,110 @@ function isdivisible(
# have any common divisor)
return isinteger(Dates.Day(1) / dt_small)
end

"""
count_nans_state(state, mask::ClimaCore.Fields.Field = nothing, verbose = false)

Count the number of NaNs in the state, e.g. the FieldVector given by
`sol.u[end]` after calling `solve`. This function is useful for
debugging simulations to determine quantitatively if a simulation is stable.

If this function is called on a FieldVector, it will recursively call itself
on each Field in the FieldVector. If it is called on a Field, it will count
the number of NaNs in the Field and produce a warning if any are found.

If a ClimaCore Field is provided as `mask`, the function will only count NaNs
in the state variables where the mask is 1. This is intended to be used with
the land/sea mask, to avoid counting NaNs over the ocean. Note this assumes
the mask is 1 over land and 0 over ocean.

The `verbose` argument toggles whether the function produces output when no
NaNs are found.
"""
function count_nans_state(
state::ClimaCore.Fields.FieldVector;
mask = nothing,
verbose = false,
)
for pn in propertynames(state)
state_new = getproperty(state, pn)
@info "Checking NaNs in $pn"
count_nans_state(state_new; mask, verbose)
end
return nothing
end

function count_nans_state(
state::ClimaCore.Fields.Field;
mask = nothing,
verbose = false,
)
# Note: this code uses `parent`; this pattern should not be replicated
num_nans =
isnothing(mask) ? round(sum(isnan, parent(state))) :
round(sum(isnan, parent(state)[Bool.(parent(mask))]; init = 0))
if isapprox(num_nans, 0)
verbose && @info "No NaNs found"
else
@warn "$num_nans NaNs found"
end
return nothing
end

"""
NaNCheckCallback(nancheck_frequency::Union{AbstractFloat, Dates.Period},
start_date, t_start, dt)

Constructs a DiscreteCallback which counts the number of NaNs in the state
and produces a warning if any are found.

# Arguments
- `nancheck_frequency`: The frequency at which the state is checked for NaNs.
Can be specified as a float (in seconds) or a `Dates.Period`.
- `start_date`: The start date of the simulation.
- `t_start`: The starting time of the simulation (in seconds).
- `dt`: The timestep of the model (optional), used to check for consistency.

The callback uses `ClimaDiagnostics.EveryCalendarDtSchedule` to determine when
to check for NaNs based on the `nancheck_frequency`. The schedule is
initialized with the `start_date` and `t_start` to ensure that it is first
called at the correct time.
"""
function NaNCheckCallback(
nancheck_frequency::Union{AbstractFloat, Dates.Period},
start_date,
t_start,
dt,
)
# TODO: Move to a more general callback system. For the time being, we use
# the ClimaDiagnostics one because it is flexible and it supports calendar
# dates.

if nancheck_frequency isa AbstractFloat
# Assume it is in seconds, but go through Millisecond to support
# fractional seconds
nancheck_frequency_period = Dates.Millisecond(1000nancheck_frequency)
else
nancheck_frequency_period = nancheck_frequency
end

schedule = EveryCalendarDtSchedule(
nancheck_frequency_period;
start_date,
date_last = start_date + Dates.Millisecond(1000t_start),
)

if !isnothing(dt)
dt_period = Dates.Millisecond(1000dt)
if !isdivisible(nancheck_frequency_period, dt_period)
@warn "Callback frequency ($(nancheck_frequency_period)) is not an integer multiple of dt $(dt_period)"
end
end

cond = let schedule = schedule
(u, t, integrator) -> schedule(integrator)
end
affect! = (integrator) -> count_nans_state(integrator.u)

SciMLBase.DiscreteCallback(cond, affect!)
end
105 changes: 104 additions & 1 deletion test/shared_utilities/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ using Test
import ClimaComms
ClimaComms.@import_required_backends
using ClimaCore: Spaces, Geometry, Fields
import ClimaComms
using ClimaLand
using ClimaLand: Domains, condition, SavingAffect, saving_initialize

Expand Down Expand Up @@ -256,3 +255,107 @@ end
@test Y.subfields.subfield2 == Y_copy.subfields.subfield2
end
end

@testset "count_nans_state, FT = $FT" begin
# Test on a 3D spherical domain
domain = ClimaLand.Domains.SphericalShell(;
radius = FT(2),
depth = FT(1.0),
nelements = (10, 5),
npolynomial = 3,
)

# Construct some fields
space = domain.space.subsurface
var1 = Fields.zeros(space)
var2 = Fields.zeros(space)
var3 = Fields.zeros(space)
fieldvec = Fields.FieldVector(var2 = var2, var3 = var3)

# Construct a FieldVector containing the fields and a nested FieldVector
Y = Fields.FieldVector(var1 = var1, fieldvec = fieldvec)

# Count and log the number of NaNs in the state (test verbose and non-verbose cases)
@test_logs (:info, "Checking NaNs in var1") (:info, "No NaNs found") (
:info,
"Checking NaNs in fieldvec",
) (:info, "Checking NaNs in var2") (:info, "No NaNs found") (
:info,
"Checking NaNs in var3",
) (:info, "No NaNs found") ClimaLand.count_nans_state(Y, verbose = true)

@test_logs (:info, "Checking NaNs in var1") (
:info,
"Checking NaNs in fieldvec",
) (:info, "Checking NaNs in var2") (:info, "Checking NaNs in var3") ClimaLand.count_nans_state(
Y,
)

# Add some NaNs to the fields
# Note: this code uses `parent` and scalar indexing,
# which shouldn't be replicated outside of tests
ClimaComms.allowscalar(ClimaComms.device()) do
parent(var1)[1] = NaN
parent(var2)[1] = NaN
parent(var2)[2] = NaN
end

# Count and log the number of NaNs in the state (test verbose and non-verbose cases)
@test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") (
:info,
"Checking NaNs in fieldvec",
) (:info, "Checking NaNs in var2") (:warn, "2 NaNs found") (
:info,
"Checking NaNs in var3",
) (:info, "No NaNs found") ClimaLand.count_nans_state(Y, verbose = true)

@test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") (
:info,
"Checking NaNs in fieldvec",
) (:info, "Checking NaNs in var2") (:warn, "2 NaNs found") (
:info,
"Checking NaNs in var3",
) ClimaLand.count_nans_state(Y)

# Test with a mask
mask_zeros = Fields.zeros(space)
@test_logs (:info, "Checking NaNs in var1") (:info, "No NaNs found") (
:info,
"Checking NaNs in fieldvec",
) (:info, "Checking NaNs in var2") (:info, "No NaNs found") (
:info,
"Checking NaNs in var3",
) (:info, "No NaNs found") ClimaLand.count_nans_state(
Y,
mask = mask_zeros,
verbose = true,
)
@test_logs (:info, "Checking NaNs in var1") (
:info,
"Checking NaNs in fieldvec",
) (:info, "Checking NaNs in var2") (:info, "Checking NaNs in var3") ClimaLand.count_nans_state(
Y,
mask = mask_zeros,
)

mask_ones = Fields.ones(space)
@test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") (
:info,
"Checking NaNs in fieldvec",
) (:info, "Checking NaNs in var2") (:warn, "2 NaNs found") (
:info,
"Checking NaNs in var3",
) (:info, "No NaNs found") ClimaLand.count_nans_state(
Y,
mask = mask_ones,
verbose = true,
)

@test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") (
:info,
"Checking NaNs in fieldvec",
) (:info, "Checking NaNs in var2") (:warn, "2 NaNs found") (
:info,
"Checking NaNs in var3",
) ClimaLand.count_nans_state(Y, mask = mask_ones)
end
Loading