@@ -529,3 +529,66 @@ function count_nans_state(state::ClimaCore.Fields.Field, mask = nothing)
529
529
end
530
530
return nothing
531
531
end
532
+
533
+ """
534
+ NaNCheckCallback(nancheck_frequency::Union{AbstractFloat, Dates.Period},
535
+ output_dir, start_date, t_start; model, dt)
536
+
537
+ Constructs a DiscreteCallback which counts the number of NaNs in the state
538
+ and produces a warning if any are found.
539
+
540
+ # Arguments
541
+ - `nancheck_frequency`: The frequency at which the state is checked for NaNs.
542
+ Can be specified as a float (in seconds) or a `Dates.Period`.
543
+ - `start_date`: The start date of the simulation.
544
+ - `t_start`: The starting time of the simulation (in seconds).
545
+ - `dt`: The timestep of the model (optional), used to check for consistency.
546
+
547
+ The callback uses `ClimaDiagnostics.EveryCalendarDtSchedule` to determine when
548
+ to save checkpoints based on the `nancheck_frequency`. The schedule is
549
+ initialized with the `start_date` and `t_start` to ensure that the first
550
+ checkpoint is saved at the correct time.
551
+
552
+ The `save_checkpoint` function is called with the current state vector `u`, the
553
+ current time `t`, and the `output_dir` to save the checkpoint to disk.
554
+ """
555
+ function NaNCheckCallback (
556
+ nancheck_frequency:: Union{AbstractFloat, Dates.Period} ,
557
+ start_date,
558
+ t_start,
559
+ dt,
560
+ )
561
+ # TODO : Move to a more general callback system. For the time being, we use
562
+ # the ClimaDiagnostics one because it is flexible and it supports calendar
563
+ # dates.
564
+
565
+ if nancheck_frequency isa AbstractFloat
566
+ # Assume it is in seconds, but go through Millisecond to support
567
+ # fractional seconds
568
+ nancheck_frequency_period = Dates. Millisecond (1000 nancheck_frequency)
569
+ else
570
+ nancheck_frequency_period = nancheck_frequency
571
+ end
572
+
573
+ schedule = EveryCalendarDtSchedule (
574
+ nancheck_frequency_period;
575
+ start_date,
576
+ date_last = start_date + Dates. Millisecond (1000 t_start),
577
+ )
578
+
579
+ if ! isnothing (dt)
580
+ dt_period = Dates. Millisecond (1000 dt)
581
+ if ! isdivisible (nancheck_frequency_period / dt_period)
582
+ @warn " Callback frequency ($(nancheck_frequency_period) ) is not an integer multiple of dt $(dt_period) "
583
+ end
584
+ end
585
+
586
+ cond = let schedule = schedule
587
+ (u, t, integrator) -> schedule (integrator)
588
+ end
589
+ affect! = let output_dir = output_dir, model = model
590
+ (integrator) -> count_nans_state (integrator. u)
591
+ end
592
+
593
+ SciMLBase. DiscreteCallback (cond, affect!)
594
+ end
0 commit comments