2121import polars as pl
2222
2323import weathergen .utils .config as config
24- from weathergen .utils .metrics import read_metrics_file
24+ from weathergen .utils .metrics import get_train_metrics_path , read_metrics_file
2525
2626_weathergen_timestamp = "weathergen.timestamp"
2727_weathergen_reltime = "weathergen.reltime"
@@ -66,7 +66,8 @@ def log_metrics(self, stage: Stage, metrics: dict[str, float]) -> None:
6666 # TODO: performance: we repeatedly open the file for each call. Better for multiprocessing
6767 # but we can probably do better and rely for example on the logging module.
6868
69- with open (self .path_run / "metrics.json" , "ab" ) as f :
69+ metrics_path = get_train_metrics_path (base_path = Path ("results" ), run_id = self .cf .run_id )
70+ with open (metrics_path , "ab" ) as f :
7071 s = json .dumps (clean_metrics ) + "\n "
7172 f .write (s .encode ("utf-8" ))
7273
@@ -157,7 +158,12 @@ def read(run_id, model_path: str, epoch=-1):
157158
158159 # define cols for training
159160 cols_train = ["dtime" , "samples" , "mse" , "lr" ]
160- cols1 = [_weathergen_timestamp , "num_samples" , "loss_avg_0_mean" , "learning_rate" ]
161+ cols1 = [
162+ _weathergen_timestamp ,
163+ "num_samples" ,
164+ "loss_avg_0_mean" ,
165+ "learning_rate" ,
166+ ]
161167 for si in cf .streams :
162168 for _j , lf in enumerate (cf .loss_fcts ):
163169 cols1 += [_key_loss (si ["name" ], lf [0 ])]
@@ -178,7 +184,13 @@ def read(run_id, model_path: str, epoch=-1):
178184 with open (fname_log_train , "rb" ) as f :
179185 log_train = np .loadtxt (f , delimiter = "," )
180186 log_train = log_train .reshape ((log_train .shape [0 ] // len (cols_train ), len (cols_train )))
181- except (TypeError , AttributeError , IndexError , ZeroDivisionError , ValueError ) as e :
187+ except (
188+ TypeError ,
189+ AttributeError ,
190+ IndexError ,
191+ ZeroDivisionError ,
192+ ValueError ,
193+ ) as e :
182194 _logger .warning (
183195 (
184196 f"Warning: no training data loaded for run_id={ run_id } " ,
@@ -230,7 +242,13 @@ def read(run_id, model_path: str, epoch=-1):
230242 with open (fname_log_val , "rb" ) as f :
231243 log_val = np .loadtxt (f , delimiter = "," )
232244 log_val = log_val .reshape ((log_val .shape [0 ] // len (cols_val ), len (cols_val )))
233- except (TypeError , AttributeError , IndexError , ZeroDivisionError , ValueError ) as e :
245+ except (
246+ TypeError ,
247+ AttributeError ,
248+ IndexError ,
249+ ZeroDivisionError ,
250+ ValueError ,
251+ ) as e :
234252 _logger .warning (
235253 (
236254 f"Warning: no validation data loaded for run_id={ run_id } " ,
@@ -265,7 +283,13 @@ def read(run_id, model_path: str, epoch=-1):
265283 with open (fname_perf_val , "rb" ) as f :
266284 log_perf = np .loadtxt (f , delimiter = "," )
267285 log_perf = log_perf .reshape ((log_perf .shape [0 ] // len (cols_perf ), len (cols_perf )))
268- except (TypeError , AttributeError , IndexError , ZeroDivisionError , ValueError ) as e :
286+ except (
287+ TypeError ,
288+ AttributeError ,
289+ IndexError ,
290+ ZeroDivisionError ,
291+ ValueError ,
292+ ) as e :
269293 _logger .warning (
270294 (
271295 f"Warning: no validation data loaded for run_id={ run_id } " ,
@@ -341,8 +365,9 @@ def read_metrics(
341365 run_id = cf .run_id
342366 assert run_id , "run_id must be provided"
343367
368+ metrics_path = get_train_metrics_path (base_path = results_path , run_id = run_id )
344369 # TODO: this should be a config option
345- df = read_metrics_file (results_path / run_id / "metrics.json" )
370+ df = read_metrics_file (metrics_path )
346371 if stage is not None :
347372 df = df .filter (pl .col ("stage" ) == stage )
348373 df = df .drop ("stage" )
0 commit comments