From 203173580524a4d003457f6429256a7a1c6bcfd6 Mon Sep 17 00:00:00 2001 From: Julia Sloan Date: Thu, 2 Jan 2025 13:28:59 -0800 Subject: [PATCH 1/2] add state NaN checker Adds a function that quantitatively checks how many NaNs are present in the state, and displays this information to the user. This can be used to inspect sol.u[end] after a simulation has run, to see if any NaNs were produced at the end of the simulation. If no NaNs are found, this information is logged in an info statement when run in verbose mode. If NaNs are found, this information is logged in a warn statement. --- docs/make.jl | 5 +- docs/src/shared_utilities.md | 36 +++++++++ experiments/long_runs/bucket.jl | 6 +- experiments/long_runs/land.jl | 5 +- experiments/long_runs/land_region.jl | 5 +- experiments/long_runs/snowy_land.jl | 6 +- experiments/long_runs/soil.jl | 5 +- src/shared_utilities/utils.jl | 111 ++++++++++++++++++++++++++- test/shared_utilities/utilities.jl | 105 ++++++++++++++++++++++++- 9 files changed, 274 insertions(+), 10 deletions(-) create mode 100644 docs/src/shared_utilities.md diff --git a/docs/make.jl b/docs/make.jl index 584e9e74d8..06649b24bc 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -52,14 +52,15 @@ include("list_diagnostics.jl") pages = Any[ "Home" => "index.md", "Getting Started" => "getting_started.md", + "Repository structure" => "folderstructure.md", "Tutorials" => tutorials, "Standalone models" => standalone_models, "Diagnostics" => diagnostics, "Leaderboard" => "leaderboard/leaderboard.md", "Restarts" => "restarts.md", - "Contribution guide" => "Contributing.md", - "Repository structure" => "folderstructure.md", + "Misc. utilities" => "shared_utilities.md", "APIs" => apis, + "Contribution guide" => "Contributing.md", ] mathengine = MathJax( diff --git a/docs/src/shared_utilities.md b/docs/src/shared_utilities.md new file mode 100644 index 0000000000..3562515b35 --- /dev/null +++ b/docs/src/shared_utilities.md @@ -0,0 +1,36 @@ +# ClimaLand Shared Utilities + +## State NaN Counter - `count_nans_state` +We have implemented a function `count_nans_state` which recursively goes +through the entire state and displays the number of NaNs found for each state +variable. This function is intended to be used to debug simulations to +determine quantitatively if a simulation is stable. + +If NaNs are found for a particular variable, this will be displayed via +a warning printed to the console. The `verbose` argument toggles whether +the function prints output when no NaNs are found. + +If a ClimaCore Field is provided as `mask`, the function will only count NaNs +in the state variables where the mask is 1. This is intended to be used with +the land/sea mask, to avoid counting NaNs over the ocean. Note this assumes +the mask is 1 over land and 0 over ocean. + +This function does not distinguish between surface or subsurface +variables, so a variable defined on the subsurface will display more NaNs +than one defined on the surface, even if they are NaN at the same +spatial locations in the horizontal. + +### Usage examples +This function can be used to inspect the state after a simulation finishes +running by calling `count_nans_state(sol.u[end])`. + +This function can be used throughout the duration of a simulation by +triggering it via a callback. The `NaNCheckCallback` is designed for this +purpose, and can be set up as follows: +```julia +nancheck_freq = Dates.Month(1) +nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt) +``` +and then included along with any other callbacks in a `SciMLBase.CallbackSet`. + +Please see our longrun experiments to see examples of this callback in action! diff --git a/experiments/long_runs/bucket.jl b/experiments/long_runs/bucket.jl index a43485a607..8c9ba902a2 100644 --- a/experiments/long_runs/bucket.jl +++ b/experiments/long_runs/bucket.jl @@ -144,7 +144,11 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 7)) diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler) driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc) - return prob, SciMLBase.CallbackSet(driver_cb, diag_cb) + + nancheck_freq = Dates.Month(1) + nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt) + + return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb) end function setup_and_solve_problem(; greet = false) diff --git a/experiments/long_runs/land.jl b/experiments/long_runs/land.jl index a1c15fc2fa..f9541a983a 100644 --- a/experiments/long_runs/land.jl +++ b/experiments/long_runs/land.jl @@ -374,7 +374,10 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 15)) diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler) driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc) - return prob, SciMLBase.CallbackSet(driver_cb, diag_cb) + + nancheck_freq = Dates.Month(1) + nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt) + return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb) end function setup_and_solve_problem(; greet = false) diff --git a/experiments/long_runs/land_region.jl b/experiments/long_runs/land_region.jl index 129cb24755..4f638983b4 100644 --- a/experiments/long_runs/land_region.jl +++ b/experiments/long_runs/land_region.jl @@ -389,7 +389,10 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (10, 10, 15)) diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler) driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc) - return prob, SciMLBase.CallbackSet(driver_cb, diag_cb) + + nancheck_freq = Dates.Month(1) + nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt) + return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb) end function setup_and_solve_problem(; greet = false) diff --git a/experiments/long_runs/snowy_land.jl b/experiments/long_runs/snowy_land.jl index 43e526b8d5..35c8b8a296 100644 --- a/experiments/long_runs/snowy_land.jl +++ b/experiments/long_runs/snowy_land.jl @@ -396,7 +396,11 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 15)) end report_cb = SciMLBase.DiscreteCallback(every1000steps, report) - return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, report_cb) + nancheck_freq = Dates.Month(1) + nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt) + + return prob, + SciMLBase.CallbackSet(driver_cb, diag_cb, report_cb, nancheck_cb) end function setup_and_solve_problem(; greet = false) diff --git a/experiments/long_runs/soil.jl b/experiments/long_runs/soil.jl index fc9fa3ebb6..ced7621b37 100644 --- a/experiments/long_runs/soil.jl +++ b/experiments/long_runs/soil.jl @@ -215,7 +215,10 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 15)) diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler) driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc) - return prob, SciMLBase.CallbackSet(driver_cb, diag_cb) + + nancheck_freq = Dates.Month(1) + nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt) + return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb) end function setup_and_solve_problem(; greet = false) diff --git a/src/shared_utilities/utils.jl b/src/shared_utilities/utils.jl index 68740bf4b2..5f27ddad3e 100644 --- a/src/shared_utilities/utils.jl +++ b/src/shared_utilities/utils.jl @@ -3,7 +3,7 @@ import SciMLBase import ClimaDiagnostics.Schedules: EveryCalendarDtSchedule import Dates -export FTfromY +export FTfromY, count_nans_state """ heaviside(x::FT)::FT where {FT} @@ -320,7 +320,7 @@ function CheckpointCallback( if !isnothing(dt) dt_period = Dates.Millisecond(1000dt) - if !isdivisible(checkpoint_frequency_period / dt_period) + if !isdivisible(checkpoint_frequency_period, dt_period) @warn "Checkpoint frequency ($(checkpoint_frequency_period)) is not an integer multiple of dt $(dt_period)" end end @@ -492,3 +492,110 @@ function isdivisible( # have any common divisor) return isinteger(Dates.Day(1) / dt_small) end + +""" + count_nans_state(state, mask::ClimaCore.Fields.Field = nothing, verbose = false) + +Count the number of NaNs in the state, e.g. the FieldVector given by +`sol.u[end]` after calling `solve`. This function is useful for +debugging simulations to determine quantitatively if a simulation is stable. + +If this function is called on a FieldVector, it will recursively call itself +on each Field in the FieldVector. If it is called on a Field, it will count +the number of NaNs in the Field and produce a warning if any are found. + +If a ClimaCore Field is provided as `mask`, the function will only count NaNs +in the state variables where the mask is 1. This is intended to be used with +the land/sea mask, to avoid counting NaNs over the ocean. Note this assumes +the mask is 1 over land and 0 over ocean. + +The `verbose` argument toggles whether the function produces output when no +NaNs are found. +""" +function count_nans_state( + state::ClimaCore.Fields.FieldVector; + mask = nothing, + verbose = false, +) + for pn in propertynames(state) + state_new = getproperty(state, pn) + @info "Checking NaNs in $pn" + count_nans_state(state_new; mask, verbose) + end + return nothing +end + +function count_nans_state( + state::ClimaCore.Fields.Field; + mask = nothing, + verbose = false, +) + # Note: this code uses `parent`; this pattern should not be replicated + num_nans = + isnothing(mask) ? round(sum(isnan, parent(state))) : + round(sum(isnan, parent(state)[Bool.(parent(mask))]; init = 0)) + if isapprox(num_nans, 0) + verbose && @info "No NaNs found" + else + @warn "$num_nans NaNs found" + end + return nothing +end + +""" + NaNCheckCallback(nancheck_frequency::Union{AbstractFloat, Dates.Period}, + start_date, t_start, dt) + +Constructs a DiscreteCallback which counts the number of NaNs in the state +and produces a warning if any are found. + +# Arguments +- `nancheck_frequency`: The frequency at which the state is checked for NaNs. + Can be specified as a float (in seconds) or a `Dates.Period`. +- `start_date`: The start date of the simulation. +- `t_start`: The starting time of the simulation (in seconds). +- `dt`: The timestep of the model (optional), used to check for consistency. + +The callback uses `ClimaDiagnostics.EveryCalendarDtSchedule` to determine when +to check for NaNs based on the `nancheck_frequency`. The schedule is +initialized with the `start_date` and `t_start` to ensure that it is first +called at the correct time. +""" +function NaNCheckCallback( + nancheck_frequency::Union{AbstractFloat, Dates.Period}, + start_date, + t_start, + dt, +) + # TODO: Move to a more general callback system. For the time being, we use + # the ClimaDiagnostics one because it is flexible and it supports calendar + # dates. + + if nancheck_frequency isa AbstractFloat + # Assume it is in seconds, but go through Millisecond to support + # fractional seconds + nancheck_frequency_period = Dates.Millisecond(1000nancheck_frequency) + else + nancheck_frequency_period = nancheck_frequency + end + + schedule = EveryCalendarDtSchedule( + nancheck_frequency_period; + start_date, + date_last = start_date + Dates.Millisecond(1000t_start), + ) + + if !isnothing(dt) + dt_period = Dates.Millisecond(1000dt) + if !isdivisible(nancheck_frequency_period, dt_period) + @warn "Callback frequency ($(nancheck_frequency_period)) is not an integer multiple of dt $(dt_period)" + end + end + + cond = let schedule = schedule + (u, t, integrator) -> schedule(integrator) + end + affect! = (integrator) -> count_nans_state(integrator.u) + + SciMLBase.DiscreteCallback(cond, affect!) +end diff --git a/test/shared_utilities/utilities.jl b/test/shared_utilities/utilities.jl index e44fbe8a16..c935760f31 100644 --- a/test/shared_utilities/utilities.jl +++ b/test/shared_utilities/utilities.jl @@ -2,7 +2,6 @@ using Test import ClimaComms ClimaComms.@import_required_backends using ClimaCore: Spaces, Geometry, Fields -import ClimaComms using ClimaLand using ClimaLand: Domains, condition, SavingAffect, saving_initialize @@ -256,3 +255,107 @@ end @test Y.subfields.subfield2 == Y_copy.subfields.subfield2 end end + +@testset "count_nans_state, FT = $FT" begin + # Test on a 3D spherical domain + domain = ClimaLand.Domains.SphericalShell(; + radius = FT(2), + depth = FT(1.0), + nelements = (10, 5), + npolynomial = 3, + ) + + # Construct some fields + space = domain.space.subsurface + var1 = Fields.zeros(space) + var2 = Fields.zeros(space) + var3 = Fields.zeros(space) + fieldvec = Fields.FieldVector(var2 = var2, var3 = var3) + + # Construct a FieldVector containing the fields and a nested FieldVector + Y = Fields.FieldVector(var1 = var1, fieldvec = fieldvec) + + # Count and log the number of NaNs in the state (test verbose and non-verbose cases) + @test_logs (:info, "Checking NaNs in var1") (:info, "No NaNs found") ( + :info, + "Checking NaNs in fieldvec", + ) (:info, "Checking NaNs in var2") (:info, "No NaNs found") ( + :info, + "Checking NaNs in var3", + ) (:info, "No NaNs found") ClimaLand.count_nans_state(Y, verbose = true) + + @test_logs (:info, "Checking NaNs in var1") ( + :info, + "Checking NaNs in fieldvec", + ) (:info, "Checking NaNs in var2") (:info, "Checking NaNs in var3") ClimaLand.count_nans_state( + Y, + ) + + # Add some NaNs to the fields + # Note: this code uses `parent` and scalar indexing, + # which shouldn't be replicated outside of tests + ClimaComms.allowscalar(ClimaComms.device()) do + parent(var1)[1] = NaN + parent(var2)[1] = NaN + parent(var2)[2] = NaN + end + + # Count and log the number of NaNs in the state (test verbose and non-verbose cases) + @test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") ( + :info, + "Checking NaNs in fieldvec", + ) (:info, "Checking NaNs in var2") (:warn, "2 NaNs found") ( + :info, + "Checking NaNs in var3", + ) (:info, "No NaNs found") ClimaLand.count_nans_state(Y, verbose = true) + + @test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") ( + :info, + "Checking NaNs in fieldvec", + ) (:info, "Checking NaNs in var2") (:warn, "2 NaNs found") ( + :info, + "Checking NaNs in var3", + ) ClimaLand.count_nans_state(Y) + + # Test with a mask + mask_zeros = Fields.zeros(space) + @test_logs (:info, "Checking NaNs in var1") (:info, "No NaNs found") ( + :info, + "Checking NaNs in fieldvec", + ) (:info, "Checking NaNs in var2") (:info, "No NaNs found") ( + :info, + "Checking NaNs in var3", + ) (:info, "No NaNs found") ClimaLand.count_nans_state( + Y, + mask = mask_zeros, + verbose = true, + ) + @test_logs (:info, "Checking NaNs in var1") ( + :info, + "Checking NaNs in fieldvec", + ) (:info, "Checking NaNs in var2") (:info, "Checking NaNs in var3") ClimaLand.count_nans_state( + Y, + mask = mask_zeros, + ) + + mask_ones = Fields.ones(space) + @test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") ( + :info, + "Checking NaNs in fieldvec", + ) (:info, "Checking NaNs in var2") (:warn, "2 NaNs found") ( + :info, + "Checking NaNs in var3", + ) (:info, "No NaNs found") ClimaLand.count_nans_state( + Y, + mask = mask_ones, + verbose = true, + ) + + @test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") ( + :info, + "Checking NaNs in fieldvec", + ) (:info, "Checking NaNs in var2") (:warn, "2 NaNs found") ( + :info, + "Checking NaNs in var3", + ) ClimaLand.count_nans_state(Y, mask = mask_ones) +end From cb6ee6ceb0f0edf04a53106e2fdfcb2656ce8bc1 Mon Sep 17 00:00:00 2001 From: Julia Sloan Date: Wed, 15 Jan 2025 12:11:08 -0800 Subject: [PATCH 2/2] bump ClimaComms compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 3e497eaf6e..27d7bbd653 100644 --- a/Project.toml +++ b/Project.toml @@ -41,7 +41,7 @@ ArtifactWrappers = "0.2" BSON = "0.3.9" CSV = "0.10.14" CUDA = "5.5" -ClimaComms = "0.6" +ClimaComms = "0.6.2" ClimaCore = "0.14.19" ClimaDiagnostics = "0.2.10" ClimaParams = "0.10.2"