Skip to content

Commit 42d41d2

Browse files
committed
add optional mask
1 parent 8d6245e commit 42d41d2

File tree

2 files changed

+40
-12
lines changed

2 files changed

+40
-12
lines changed

src/shared_utilities/utils.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -494,28 +494,34 @@ function isdivisible(
494494
end
495495

496496
"""
497-
count_nans_state(sol)
497+
count_nans_state(state, mask::ClimaCore.Fields.Field = nothing)
498498
499-
Count the number of NaNs in the state variables. This function is useful for
499+
Count the number of NaNs in the state, e.g. the FieldVector given by
500+
`sol.u[end]` after calling `solve`. This function is useful for
500501
debugging simulations to determine quantitatively if a simulation is stable.
501502
502503
If this function is called on a FieldVector, it will recursively call itself
503504
on each Field in the FieldVector. If it is called on a Field, it will count
504505
the number of NaNs in the Field and produce a warning if any are found.
505506
506-
Input: `state` - e.g. the FieldVector given by `sol.u[end]` after calling `solve`
507+
If a ClimaCore Field is provided as `mask`, the function will only count NaNs
508+
in the state variables where the mask is 1. This is intended to be used with
509+
the land/sea mask, to avoid counting NaNs over the ocean. Note this assumes
510+
the mask is 1 over land and 0 over ocean.
507511
"""
508-
function count_nans_state(state::ClimaCore.Fields.FieldVector)
512+
function count_nans_state(state::ClimaCore.Fields.FieldVector, mask = nothing)
509513
for pn in propertynames(state)
510514
state_new = getproperty(state, pn)
511515
@info "Checking NaNs in $pn"
512-
count_nans_state(state_new)
516+
count_nans_state(state_new, mask)
513517
end
514518
return nothing
515519
end
516520

517-
function count_nans_state(state::ClimaCore.Fields.Field)
518-
num_nans = count(isnan.(Array(parent(state))))
521+
function count_nans_state(state::ClimaCore.Fields.Field, mask = nothing)
522+
num_nans =
523+
isnothing(mask) ? count(==(1), isnan.(parent(state))) :
524+
count(==(1), isnan.(parent(state)) .* parent(mask))
519525
if num_nans > 0
520526
@warn "$num_nans NaNs found"
521527
else

test/shared_utilities/utilities.jl

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ using Test
22
import ClimaComms
33
ClimaComms.@import_required_backends
44
using ClimaCore: Spaces, Geometry, Fields
5-
import ClimaComms
65
using ClimaLand
76
using ClimaLand: Domains, condition, SavingAffect, saving_initialize
87

@@ -285,10 +284,15 @@ end
285284
"Checking NaNs in var3",
286285
) (:info, "No NaNs found") ClimaLand.count_nans_state(Y)
287286

288-
# Add some NaNs to the fields
289-
Array(parent(var1))[1] = NaN
290-
Array(parent(var2))[1] = NaN
291-
Array(parent(var2))[2] = NaN
287+
# Add some NaNs to the fields without scalar indexing
288+
ArrayType = ClimaComms.array_type(ClimaComms.device())
289+
var1_copy = Array(parent(var1))
290+
var1_copy[1] = NaN
291+
parent(var1) .= ArrayType(var1_copy)
292+
var2_copy = Array(parent(var1))
293+
var2_copy[1] = NaN
294+
var2_copy[2] = NaN
295+
parent(var2) .= ArrayType(var2_copy)
292296

293297
# Count and log the number of NaNs in the state
294298
@test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") (
@@ -299,4 +303,22 @@ end
299303
"Checking NaNs in var3",
300304
) (:info, "No NaNs found") ClimaLand.count_nans_state(Y)
301305

306+
# Test with a mask
307+
mask_zeros = Fields.zeros(space)
308+
@test_logs (:info, "Checking NaNs in var1") (:info, "No NaNs found") (
309+
:info,
310+
"Checking NaNs in fieldvec",
311+
) (:info, "Checking NaNs in var2") (:info, "No NaNs found") (
312+
:info,
313+
"Checking NaNs in var3",
314+
) (:info, "No NaNs found") ClimaLand.count_nans_state(Y, mask_zeros)
315+
316+
mask_ones = Fields.ones(space)
317+
@test_logs (:info, "Checking NaNs in var1") (:warn, "1 NaNs found") (
318+
:info,
319+
"Checking NaNs in fieldvec",
320+
) (:info, "Checking NaNs in var2") (:warn, "2 NaNs found") (
321+
:info,
322+
"Checking NaNs in var3",
323+
) (:info, "No NaNs found") ClimaLand.count_nans_state(Y, mask_ones)
302324
end

0 commit comments

Comments
 (0)