2828from weathergen .train .trainer_base import Trainer_Base
2929from weathergen .train .utils import get_run_id
3030from weathergen .utils .config import Config
31+ from weathergen .utils .distributed import is_root
3132from weathergen .utils .train_logger import TrainLogger
3233from weathergen .utils .validation_io import write_validation
3334
@@ -48,12 +49,12 @@ def init(
4849 cf : Config ,
4950 run_id_contd = None ,
5051 epoch_contd = None , # unused
51- run_id_new = False ,
52+ run_id_new : bool | str | None = False ,
5253 run_mode = "training" , # unused
5354 ):
5455 self .cf = cf
5556
56- if isinstance (run_id_new , str ):
57+ if run_id_new is not None and isinstance (run_id_new , str ):
5758 cf .run_id = run_id_new
5859 elif run_id_new or cf .run_id is None :
5960 cf .run_id = get_run_id ()
@@ -64,6 +65,7 @@ def init(
6465 assert cf .samples_per_epoch % cf .batch_size == 0
6566 assert cf .samples_per_validation % cf .batch_size_validation == 0
6667
68+ _logger .info (f"Starting run with id: { cf .run_id } " )
6769 self .devices = self .init_torch ()
6870
6971 self .init_ddp (cf )
@@ -82,7 +84,6 @@ def init(
8284 self .path_run = path_run
8385
8486 self .init_perf_monitoring ()
85-
8687 self .train_logger = TrainLogger (cf , self .path_run )
8788
8889 ###########################################
@@ -134,7 +135,7 @@ def evaluate(self, cf, run_id_trained, epoch, run_id_new=False):
134135 _logger .info (f"Finished evaluation run with id: { cf .run_id } " )
135136
136137 ###########################################
137- def run (self , cf , run_id_contd = None , epoch_contd = None , run_id_new = False ):
138+ def run (self , cf , run_id_contd = None , epoch_contd = None , run_id_new : bool | str = False ):
138139 # general initalization
139140 self .init (cf , run_id_contd , epoch_contd , run_id_new )
140141
@@ -169,6 +170,7 @@ def run(self, cf, run_id_contd=None, epoch_contd=None, run_id_new=False):
169170 self .model = Model (cf , sources_size , targets_num_channels , targets_coords_size ).create ()
170171 # load model if specified
171172 if run_id_contd is not None :
173+ _logger .info (f"Continuing run with id={ run_id_contd } at epoch { epoch_contd } ." )
172174 self .model .load (run_id_contd , epoch_contd )
173175 _logger .info (f"Loaded model id={ run_id_contd } ." )
174176
@@ -278,7 +280,7 @@ def run(self, cf, run_id_contd=None, epoch_contd=None, run_id_new=False):
278280 if cf .forecast_policy is not None :
279281 torch ._dynamo .config .optimize_ddp = False
280282
281- if self . cf . rank == 0 :
283+ if is_root () :
282284 config .save (self .cf , None )
283285 config .print_cf (self .cf )
284286
@@ -674,7 +676,7 @@ def save_model(self, epoch: int, name=None):
674676 else :
675677 state = self .ddp_model .state_dict ()
676678
677- if self . cf . rank == 0 :
679+ if is_root () :
678680 filename = "" .join (
679681 [
680682 self .cf .run_id ,
0 commit comments