diff --git a/Project.toml b/Project.toml index 3e497eaf6e..27d7bbd653 100644 --- a/Project.toml +++ b/Project.toml @@ -41,7 +41,7 @@ ArtifactWrappers = "0.2" BSON = "0.3.9" CSV = "0.10.14" CUDA = "5.5" -ClimaComms = "0.6" +ClimaComms = "0.6.2" ClimaCore = "0.14.19" ClimaDiagnostics = "0.2.10" ClimaParams = "0.10.2" diff --git a/docs/make.jl b/docs/make.jl index 584e9e74d8..06649b24bc 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -52,14 +52,15 @@ include("list_diagnostics.jl") pages = Any[ "Home" => "index.md", "Getting Started" => "getting_started.md", + "Repository structure" => "folderstructure.md", "Tutorials" => tutorials, "Standalone models" => standalone_models, "Diagnostics" => diagnostics, "Leaderboard" => "leaderboard/leaderboard.md", "Restarts" => "restarts.md", - "Contribution guide" => "Contributing.md", - "Repository structure" => "folderstructure.md", + "Misc. utilities" => "shared_utilities.md", "APIs" => apis, + "Contribution guide" => "Contributing.md", ] mathengine = MathJax( diff --git a/docs/src/shared_utilities.md b/docs/src/shared_utilities.md new file mode 100644 index 0000000000..3562515b35 --- /dev/null +++ b/docs/src/shared_utilities.md @@ -0,0 +1,36 @@ +# ClimaLand Shared Utilities + +## State NaN Counter - `count_nans_state` +We have implemented a function `count_nans_state` which recursively goes +through the entire state and displays the number of NaNs found for each state +variable. This function is intended to be used to debug simulations to +determine quantitatively if a simulation is stable. + +If NaNs are found for a particular variable, this will be displayed via +a warning printed to the console. The `verbose` argument toggles whether +the function prints output when no NaNs are found. + +If a ClimaCore Field is provided as `mask`, the function will only count NaNs +in the state variables where the mask is 1. This is intended to be used with +the land/sea mask, to avoid counting NaNs over the ocean. Note this assumes +the mask is 1 over land and 0 over ocean. + +This function does not distinguish between surface or subsurface +variables, so a variable defined on the subsurface will display more NaNs +than one defined on the surface, even if they are NaN at the same +spatial locations in the horizontal. + +### Usage examples +This function can be used to inspect the state after a simulation finishes +running by calling `count_nans_state(sol.u[end])`. + +This function can be used throughout the duration of a simulation by +triggering it via a callback. The `NaNCheckCallback` is designed for this +purpose, and can be set up as follows: +```julia +nancheck_freq = Dates.Month(1) +nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt) +``` +and then included along with any other callbacks in a `SciMLBase.CallbackSet`. + +Please see our longrun experiments to see examples of this callback in action! diff --git a/experiments/long_runs/bucket.jl b/experiments/long_runs/bucket.jl index a43485a607..8c9ba902a2 100644 --- a/experiments/long_runs/bucket.jl +++ b/experiments/long_runs/bucket.jl @@ -144,7 +144,11 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 7)) diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler) driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc) - return prob, SciMLBase.CallbackSet(driver_cb, diag_cb) + + nancheck_freq = Dates.Month(1) + nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt) + + return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb) end function setup_and_solve_problem(; greet = false) diff --git a/experiments/long_runs/land.jl b/experiments/long_runs/land.jl index a1c15fc2fa..f9541a983a 100644 --- a/experiments/long_runs/land.jl +++ b/experiments/long_runs/land.jl @@ -374,7 +374,10 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 15)) diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler) driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc) - return prob, SciMLBase.CallbackSet(driver_cb, diag_cb) + + nancheck_freq = Dates.Month(1) + nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt) + return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb) end function setup_and_solve_problem(; greet = false) diff --git a/experiments/long_runs/land_region.jl b/experiments/long_runs/land_region.jl index 129cb24755..4f638983b4 100644 --- a/experiments/long_runs/land_region.jl +++ b/experiments/long_runs/land_region.jl @@ -389,7 +389,10 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (10, 10, 15)) diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler) driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc) - return prob, SciMLBase.CallbackSet(driver_cb, diag_cb) + + nancheck_freq = Dates.Month(1) + nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt) + return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb) end function setup_and_solve_problem(; greet = false) diff --git a/experiments/long_runs/snowy_land.jl b/experiments/long_runs/snowy_land.jl index 43e526b8d5..35c8b8a296 100644 --- a/experiments/long_runs/snowy_land.jl +++ b/experiments/long_runs/snowy_land.jl @@ -396,7 +396,11 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 15)) end report_cb = SciMLBase.DiscreteCallback(every1000steps, report) - return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, report_cb) + nancheck_freq = Dates.Month(1) + nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt) + + return prob, + SciMLBase.CallbackSet(driver_cb, diag_cb, report_cb, nancheck_cb) end function setup_and_solve_problem(; greet = false) diff --git a/experiments/long_runs/soil.jl b/experiments/long_runs/soil.jl index fc9fa3ebb6..ced7621b37 100644 --- a/experiments/long_runs/soil.jl +++ b/experiments/long_runs/soil.jl @@ -215,7 +215,10 @@ function setup_prob(t0, tf, Δt; outdir = outdir, nelements = (101, 15)) diag_cb = ClimaDiagnostics.DiagnosticsCallback(diagnostic_handler) driver_cb = ClimaLand.DriverUpdateCallback(updateat, updatefunc) - return prob, SciMLBase.CallbackSet(driver_cb, diag_cb) + + nancheck_freq = Dates.Month(1) + nancheck_cb = ClimaLand.NaNCheckCallback(nancheck_freq, start_date, t0, Δt) + return prob, SciMLBase.CallbackSet(driver_cb, diag_cb, nancheck_cb) end function setup_and_solve_problem(; greet = false) diff --git a/src/shared_utilities/utils.jl b/src/shared_utilities/utils.jl index 68740bf4b2..5f27ddad3e 100644 --- a/src/shared_utilities/utils.jl +++ b/src/shared_utilities/utils.jl @@ -3,7 +3,7 @@ import SciMLBase import ClimaDiagnostics.Schedules: EveryCalendarDtSchedule import Dates -export FTfromY +export FTfromY, count_nans_state """ heaviside(x::FT)::FT where {FT} @@ -320,7 +320,7 @@ function CheckpointCallback( if !isnothing(dt) dt_period = Dates.Millisecond(1000dt) - if !isdivisible(checkpoint_frequency_period / dt_period) + if !isdivisible(checkpoint_frequency_period, dt_period) @warn "Checkpoint frequency ($(checkpoint_frequency_period)) is not an integer multiple of dt $(dt_period)" end end @@ -492,3 +492,110 @@ function isdivisible( # have any common divisor) return isinteger(Dates.Day(1) / dt_small) end + +""" + count_nans_state(state, mask::ClimaCore.Fields.Field = nothing, verbose = false) + +Count the number of NaNs in the state, e.g. the FieldVector given by +`sol.u[end]` after calling `solve`. This function is useful for +debugging simulations to determine quantitatively if a simulation is stable. + +If this function is called on a FieldVector, it will recursively call itself +on each Field in the FieldVector. If it is called on a Field, it will count +the number of NaNs in the Field and produce a warning if any are found. + +If a ClimaCore Field is provided as `mask`, the function will only count NaNs +in the state variables where the mask is 1. This is intended to be used with +the land/sea mask, to avoid counting NaNs over the ocean. Note this assumes +the mask is 1 over land and 0 over ocean. + +The `verbose` argument toggles whether the function produces output when no +NaNs are found. +""" +function count_nans_state( + state::ClimaCore.Fields.FieldVector; + mask = nothing, + verbose = false, +) + for pn in propertynames(state) + state_new = getproperty(state, pn) + @info "Checking NaNs in $pn" + count_nans_state(state_new; mask, verbose) + end + return nothing +end + +function count_nans_state( + state::ClimaCore.Fields.Field; + mask = nothing, + verbose = false, +) + # Note: this code uses `parent`; this pattern should not be replicated + num_nans = + isnothing(mask) ? round(sum(isnan, parent(state))) : + round(sum(isnan, parent(state)[Bool.(parent(mask))]; init = 0)) + if isapprox(num_nans, 0) + verbose && @info "No NaNs found" + else + @warn "$num_nans NaNs found" + end + return nothing +end + +""" + NaNCheckCallback(nancheck_frequency::Union{AbstractFloat, Dates.Period}, + start_date, t_start, dt) + +Constructs a DiscreteCallback which counts the number of NaNs in the state +and produces a warning if any are found. + +# Arguments +- `nancheck_frequency`: The frequency at which the state is checked for NaNs. + Can be specified as a float (in seconds) or a `Dates.Period`. +- `start_date`: The start date of the simulation. +- `t_start`: The starting time of the simulation (in seconds). +- `dt`: The timestep of the model (optional), used to check for consistency. + +The callback uses `ClimaDiagnostics.EveryCalendarDtSchedule` to determine when +to check for NaNs based on the `nancheck_frequency`. The schedule is +initialized with the `start_date` and `t_start` to ensure that it is first +called at the correct time. +""" +function NaNCheckCallback( + nancheck_frequency::Union{AbstractFloat, Dates.Period}, + start_date, + t_start, + dt, +) + # TODO: Move to a more general callback system. For the time being, we use + # the ClimaDiagnostics one because it is flexible and it supports calendar + # dates. + + if nancheck_frequency isa AbstractFloat + # Assume it is in seconds, but go through Millisecond to support + # fractional seconds + nancheck_frequency_period = Dates.Millisecond(1000nancheck_frequency) + else + nancheck_frequency_period = nancheck_frequency + end + + schedule = EveryCalendarDtSchedule( + nancheck_frequency_period; + start_date, + date_last = start_date + Dates.Millisecond(1000t_start), + ) + + if !isnothing(dt) + dt_period = Dates.Millisecond(1000dt) + if !isdivisible(nancheck_frequency_period, dt_period) + @warn "Callback frequency ($(nancheck_frequency_period)) is not an integer multiple of dt $(dt_period)" + end + end + + cond = let schedule = schedule + (u, t, integrator) -> schedule(integrator) + end + affect! = (integrator) -> count_nans_state(integrator.u) + + SciMLBase.DiscreteCallback(cond, affect!) +end diff --git a/test/shared_utilities/utilities.jl b/test/shared_utilities/utilities.jl index e44fbe8a16..c935760f31 100644 --- a/test/shared_utilities/utilities.jl +++ b/test/shared_utilities/utilities.jl @@ -2,7 +2,6 @@ using Test import ClimaComms ClimaComms.@import_required_backends using ClimaCore: Spaces, Geometry, Fields -import ClimaComms using ClimaLand using ClimaLand: Domains, condition, SavingAffect, saving_initialize @@ -256,3 +255,107 @@ end @test Y.subfields.subfield2 == Y_copy.subfields.subfield2 end end + +@testset "count_nans_state, FT = $FT" begin + # Test on a 3D spherical domain + domain = ClimaLand.Domains.SphericalShell(; + radius = FT(2), + depth = FT(1.0), + nelements = (10, 5), + npolynomial = 3, + ) + + # Construct some fields + space = domain.space.subsurface + var1 = Fields.zeros(space) + var2 = Fields.zeros(space) + var3 = Fields.zeros(space) + fieldvec = Fields.FieldVector(var2 = var2, var3 = var3) + + # Construct a FieldVector containing the fields and a nested FieldVector + Y = Fields.FieldVector(var1 = var1, fieldvec = fieldvec) + + # Count and log the number of NaNs in the state (test verbose and non-verbose cases) + @test_logs (:info, "Checking NaNs in var1") (:info, "No NaNs found") ( + :info, + "Checking NaNs in fieldvec", + ) (:info, "Checking NaNs in var2") (:info, "No NaNs found") ( + :info, + "Checking NaNs in var3", + ) (:info, "No NaNs found") ClimaLand.count_nans_state(Y, verbose = true) + + @test_logs (:info, "Checking NaNs in var1") ( + :info, + "Checking NaNs in fieldvec", + ) (:info, "Checking NaNs in var2") (:info, "Checking NaNs in var3") ClimaLand.count_nans_state( + Y, + ) + + # Add some NaNs to the fields + # Note: this code uses `parent` and scalar indexing, + # which shouldn't be replicated outside of tests + ClimaComms.allowscalar(ClimaComms.device()) do + parent(var1)[1] = NaN + parent(var2)[1] = NaN + parent(var2)[2] = NaN + end + + # Count and log the number of NaNs in the state (test verbose and non-verbose cases) + @test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") ( + :info, + "Checking NaNs in fieldvec", + ) (:info, "Checking NaNs in var2") (:warn, "2 NaNs found") ( + :info, + "Checking NaNs in var3", + ) (:info, "No NaNs found") ClimaLand.count_nans_state(Y, verbose = true) + + @test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") ( + :info, + "Checking NaNs in fieldvec", + ) (:info, "Checking NaNs in var2") (:warn, "2 NaNs found") ( + :info, + "Checking NaNs in var3", + ) ClimaLand.count_nans_state(Y) + + # Test with a mask + mask_zeros = Fields.zeros(space) + @test_logs (:info, "Checking NaNs in var1") (:info, "No NaNs found") ( + :info, + "Checking NaNs in fieldvec", + ) (:info, "Checking NaNs in var2") (:info, "No NaNs found") ( + :info, + "Checking NaNs in var3", + ) (:info, "No NaNs found") ClimaLand.count_nans_state( + Y, + mask = mask_zeros, + verbose = true, + ) + @test_logs (:info, "Checking NaNs in var1") ( + :info, + "Checking NaNs in fieldvec", + ) (:info, "Checking NaNs in var2") (:info, "Checking NaNs in var3") ClimaLand.count_nans_state( + Y, + mask = mask_zeros, + ) + + mask_ones = Fields.ones(space) + @test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") ( + :info, + "Checking NaNs in fieldvec", + ) (:info, "Checking NaNs in var2") (:warn, "2 NaNs found") ( + :info, + "Checking NaNs in var3", + ) (:info, "No NaNs found") ClimaLand.count_nans_state( + Y, + mask = mask_ones, + verbose = true, + ) + + @test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") ( + :info, + "Checking NaNs in fieldvec", + ) (:info, "Checking NaNs in var2") (:warn, "2 NaNs found") ( + :info, + "Checking NaNs in var3", + ) ClimaLand.count_nans_state(Y, mask = mask_ones) +end