3535from weathergen .utils .train_logger import TrainLogger
3636from weathergen .utils .validation_io import write_validation
3737
38+ _logger = logging .getLogger (__name__ )
39+
3840
3941class Trainer (Trainer_Base ):
4042 ###########################################
@@ -47,7 +49,14 @@ def __init__(self, log_freq=20, checkpoint_freq=250, print_freq=10):
4749 self .print_freq = print_freq
4850
4951 ###########################################
50- def init (self , cf , run_id_contd = None , epoch_contd = None , run_id_new = False , run_mode = "training" ):
52+ def init (
53+ self ,
54+ cf ,
55+ run_id_contd = None ,
56+ epoch_contd = None ,
57+ run_id_new = False ,
58+ run_mode = "training" ,
59+ ):
5160 self .cf = cf
5261
5362 if isinstance (run_id_new , str ):
@@ -284,7 +293,7 @@ def evaluate_jac(self, cf, run_id, epoch, mode="row", date=None, obs_id=0, sampl
284293 )
285294
286295 ###########################################
287- def run (self , cf , run_id_contd = None , epoch_contd = None , run_id_new = False ):
296+ def run (self , cf , private_cf , run_id_contd = None , epoch_contd = None , run_id_new = False ):
288297 # general initalization
289298 self .init (cf , run_id_contd , epoch_contd , run_id_new )
290299
@@ -419,18 +428,23 @@ def run(self, cf, run_id_contd=None, epoch_contd=None, run_id_new=False):
419428 )
420429 self .grad_scaler = torch .amp .GradScaler ("cuda" )
421430
431+ assert len (self .dataset ) > 0 , f"No data found in { self .dataset } "
432+
422433 # lr is updated after each batch so account for this
434+ # TODO: conf should be read-only, do not modify the conf in flight
423435 cf .lr_steps = int ((len (self .dataset ) * cf .num_epochs ) / cf .batch_size )
436+
424437 steps_decay = cf .lr_steps - cf .lr_steps_warmup - cf .lr_steps_cooldown
438+ _logger .debug (f"steps_decay={ steps_decay } lr_steps={ cf .lr_steps } " )
425439 # ensure that steps_decay has a reasonable value
426440 if steps_decay < int (0.2 * cf .lr_steps ):
427441 cf .lr_steps_warmup = int (0.1 * cf .lr_steps )
428442 cf .lr_steps_cooldown = int (0.05 * cf .lr_steps )
429443 steps_decay = cf .lr_steps - cf .lr_steps_warmup - cf .lr_steps_cooldown
430- str = f"cf.lr_steps_warmup and cf.lr_steps_cooldown were larger than cf.lr_steps={ cf .lr_steps } "
431- str += ". The value have been adjusted to cf.lr_steps_warmup={cf.lr_steps_warmup} and "
432- str += " cf.lr_steps_cooldown={cf.lr_steps_cooldown} so that steps_decay={steps_decay}."
433- logging .getLogger ("obslearn" ).warning ("" )
444+ s = f"cf.lr_steps_warmup and cf.lr_steps_cooldown were larger than cf.lr_steps={ cf .lr_steps } "
445+ s += f ". The value have been adjusted to cf.lr_steps_warmup={ cf .lr_steps_warmup } and "
446+ s += f " cf.lr_steps_cooldown={ cf .lr_steps_cooldown } so that steps_decay={ steps_decay } ."
447+ logging .getLogger ("obslearn" ).warning (s )
434448 self .lr_scheduler = LearningRateScheduler (
435449 self .optimizer ,
436450 cf .batch_size ,
@@ -558,7 +572,11 @@ def compute_loss(
558572 )
559573 if tro_type == "token" :
560574 pred = pred .reshape (
561- [* pred .shape [:2 ], target .shape [- 2 ], target .shape [- 1 ] - gs ]
575+ [
576+ * pred .shape [:2 ],
577+ target .shape [- 2 ],
578+ target .shape [- 1 ] - gs ,
579+ ]
562580 )
563581 pred = torch .cat ([pred [:, i , :l ] for i , l in enumerate (sl )], 1 )
564582 else :
@@ -600,7 +618,7 @@ def compute_loss(
600618 target_data [mask , i ],
601619 pred [:, mask , i ],
602620 pred [:, mask , i ].mean (0 ),
603- pred [:, mask , i ].std (0 ) if ens else torch .zeros (1 ),
621+ ( pred [:, mask , i ].std (0 ) if ens else torch .zeros (1 ) ),
604622 )
605623 val_uw += temp .item ()
606624 val = val + channel_loss_weight [i ] * temp # * tw[jj]
@@ -613,9 +631,11 @@ def compute_loss(
613631 target_data [mask_nan [:, i ], i ],
614632 pred [:, mask_nan [:, i ], i ],
615633 pred [:, mask_nan [:, i ], i ].mean (0 ),
616- pred [:, mask_nan [:, i ], i ].std (0 )
617- if ens
618- else torch .zeros (1 ),
634+ (
635+ pred [:, mask_nan [:, i ], i ].std (0 )
636+ if ens
637+ else torch .zeros (1 )
638+ ),
619639 )
620640 val_uw += temp .item ()
621641 val = val + channel_loss_weight [i ] * temp
@@ -1028,7 +1048,10 @@ def log_terminal(self, bidx, epoch):
10281048 )
10291049 print ("\t " , end = "" )
10301050 for i_obs , rt in enumerate (self .cf .streams ):
1031- print ("{}" .format (rt ["name" ]) + f" : { l_avg [0 , i_obs ]:0.4E} \t " , end = "" )
1051+ print (
1052+ "{}" .format (rt ["name" ]) + f" : { l_avg [0 , i_obs ]:0.4E} \t " ,
1053+ end = "" ,
1054+ )
10321055 print ("\n " , flush = True )
10331056
10341057 self .t_start = time .time ()
0 commit comments