Skip to content

Commit cbef085

Browse files
authored
[DRAFT][590] Rename metrics file (ecmwf#601)
* Implemented backward-compatible function to read and write `{RUN-ID}_train_metrics.json` (new) or `metrics.json` (old)
1 parent 3816abd commit cbef085

File tree

4 files changed

+230
-9
lines changed

4 files changed

+230
-9
lines changed

integration_tests/small1_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import weathergen.common.io as io
1919
import weathergen.utils.config as config
2020
from weathergen.run_train import inference_from_args, train_with_args
21+
from weathergen.utils.metrics import get_train_metrics_path
2122

2223
logger = logging.getLogger(__name__)
2324

@@ -100,7 +101,7 @@ def evaluate_results(run_id):
100101

101102
def load_metrics(run_id):
102103
"""Helper function to load metrics"""
103-
file_path = f"{weathergen_home}/results/{run_id}/metrics.json"
104+
file_path = get_train_metrics_path(base_path=weathergen_home / "results", run_id=run_id)
104105
if not os.path.exists(file_path):
105106
raise FileNotFoundError(f"Metrics file not found for run_id: {run_id}")
106107
with open(file_path) as f:
@@ -110,7 +111,7 @@ def load_metrics(run_id):
110111

111112
def assert_missing_metrics_file(run_id):
112113
"""Test that a missing metrics file raises FileNotFoundError."""
113-
file_path = f"{weathergen_home}/results/{run_id}/metrics.json"
114+
file_path = get_train_metrics_path(base_path=weathergen_home / "results", run_id=run_id)
114115
assert os.path.exists(file_path), f"Metrics file does not exist for run_id: {run_id}"
115116
metrics = load_metrics(run_id)
116117
logger.info(f"Loaded metrics for run_id: {run_id}: {metrics}")

src/weathergen/utils/metrics.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,16 @@ def read_metrics_file(f: str | Path) -> pl.DataFrame:
4949
pl.when(pl.col(n).is_not_nan()).then(df1[n]).otherwise(df2[n]).alias(n)
5050
)
5151
return df1
52+
53+
54+
def get_train_metrics_path(base_path: Path, run_id: str) -> Path:
55+
"""
56+
Return the path to the training metrics.json for a particular run_id. This is required for
57+
backwards compatibility after changing the name of the `results/{RUN-ID}/metrics.json` file to
58+
`results/{RUN-ID}/{RUN-ID}_train_metrics.json` to disambiguate `metrics.json`.
59+
See https://github.com/ecmwf/WeatherGenerator/issues/590 for details.
60+
"""
61+
if (base_path / run_id / "metrics.json").exists():
62+
return base_path / run_id / "metrics.json"
63+
else:
64+
return base_path / run_id / f"{run_id}_train_metrics.json"

src/weathergen/utils/train_logger.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import polars as pl
2222

2323
import weathergen.utils.config as config
24-
from weathergen.utils.metrics import read_metrics_file
24+
from weathergen.utils.metrics import get_train_metrics_path, read_metrics_file
2525

2626
_weathergen_timestamp = "weathergen.timestamp"
2727
_weathergen_reltime = "weathergen.reltime"
@@ -66,7 +66,8 @@ def log_metrics(self, stage: Stage, metrics: dict[str, float]) -> None:
6666
# TODO: performance: we repeatedly open the file for each call. Better for multiprocessing
6767
# but we can probably do better and rely for example on the logging module.
6868

69-
with open(self.path_run / "metrics.json", "ab") as f:
69+
metrics_path = get_train_metrics_path(base_path=Path("results"), run_id=self.cf.run_id)
70+
with open(metrics_path, "ab") as f:
7071
s = json.dumps(clean_metrics) + "\n"
7172
f.write(s.encode("utf-8"))
7273

@@ -157,7 +158,12 @@ def read(run_id, model_path: str, epoch=-1):
157158

158159
# define cols for training
159160
cols_train = ["dtime", "samples", "mse", "lr"]
160-
cols1 = [_weathergen_timestamp, "num_samples", "loss_avg_0_mean", "learning_rate"]
161+
cols1 = [
162+
_weathergen_timestamp,
163+
"num_samples",
164+
"loss_avg_0_mean",
165+
"learning_rate",
166+
]
161167
for si in cf.streams:
162168
for _j, lf in enumerate(cf.loss_fcts):
163169
cols1 += [_key_loss(si["name"], lf[0])]
@@ -178,7 +184,13 @@ def read(run_id, model_path: str, epoch=-1):
178184
with open(fname_log_train, "rb") as f:
179185
log_train = np.loadtxt(f, delimiter=",")
180186
log_train = log_train.reshape((log_train.shape[0] // len(cols_train), len(cols_train)))
181-
except (TypeError, AttributeError, IndexError, ZeroDivisionError, ValueError) as e:
187+
except (
188+
TypeError,
189+
AttributeError,
190+
IndexError,
191+
ZeroDivisionError,
192+
ValueError,
193+
) as e:
182194
_logger.warning(
183195
(
184196
f"Warning: no training data loaded for run_id={run_id}",
@@ -230,7 +242,13 @@ def read(run_id, model_path: str, epoch=-1):
230242
with open(fname_log_val, "rb") as f:
231243
log_val = np.loadtxt(f, delimiter=",")
232244
log_val = log_val.reshape((log_val.shape[0] // len(cols_val), len(cols_val)))
233-
except (TypeError, AttributeError, IndexError, ZeroDivisionError, ValueError) as e:
245+
except (
246+
TypeError,
247+
AttributeError,
248+
IndexError,
249+
ZeroDivisionError,
250+
ValueError,
251+
) as e:
234252
_logger.warning(
235253
(
236254
f"Warning: no validation data loaded for run_id={run_id}",
@@ -265,7 +283,13 @@ def read(run_id, model_path: str, epoch=-1):
265283
with open(fname_perf_val, "rb") as f:
266284
log_perf = np.loadtxt(f, delimiter=",")
267285
log_perf = log_perf.reshape((log_perf.shape[0] // len(cols_perf), len(cols_perf)))
268-
except (TypeError, AttributeError, IndexError, ZeroDivisionError, ValueError) as e:
286+
except (
287+
TypeError,
288+
AttributeError,
289+
IndexError,
290+
ZeroDivisionError,
291+
ValueError,
292+
) as e:
269293
_logger.warning(
270294
(
271295
f"Warning: no validation data loaded for run_id={run_id}",
@@ -341,8 +365,9 @@ def read_metrics(
341365
run_id = cf.run_id
342366
assert run_id, "run_id must be provided"
343367

368+
metrics_path = get_train_metrics_path(base_path=results_path, run_id=run_id)
344369
# TODO: this should be a config option
345-
df = read_metrics_file(results_path / run_id / "metrics.json")
370+
df = read_metrics_file(metrics_path)
346371
if stage is not None:
347372
df = df.filter(pl.col("stage") == stage)
348373
df = df.drop("stage")

0 commit comments

Comments
 (0)