Skip to content

Commit d2947c6

Browse files
authored
[210] Handling nan's and other corner cases in the metrics file (ecmwf#248)
* changes * changes * fixes
1 parent 63c5468 commit d2947c6

File tree

6 files changed

+86
-12
lines changed

6 files changed

+86
-12
lines changed

config/streams/streams_test/era5.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ ERA5 :
2121
net : transformer
2222
num_tokens : 1
2323
num_heads : 4
24-
dim_embed : 128
24+
dim_embed : 16
2525
num_blocks : 2
2626
embed_target_coords :
2727
net : linear
28-
dim_embed : 128
28+
dim_embed : 16
2929
target_readout :
3030
type : 'obs_value' # token or obs_value
3131
num_layers : 2

integration_tests/small1.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,42 +88,42 @@ def assert_missing_metrics_file(run_id):
8888

8989

9090
def assert_train_loss_below_threshold(run_id):
91-
"""Test that the 'stream.era5.loss_mse.loss_avg' metric is below a threshold."""
91+
"""Test that the 'stream.ERA5.loss_mse.loss_avg' metric is below a threshold."""
9292
metrics = load_metrics(run_id)
9393
loss_metric = next(
9494
(
95-
metric.get("stream.era5.loss_mse.loss_avg", None)
95+
metric.get("stream.ERA5.loss_mse.loss_avg", None)
9696
for metric in reversed(metrics)
9797
if metric.get("stage") == "train"
9898
),
9999
None,
100100
)
101101
assert loss_metric is not None, (
102-
"'stream.era5.loss_mse.loss_avg' metric is missing in metrics file"
102+
"'stream.ERA5.loss_mse.loss_avg' metric is missing in metrics file"
103103
)
104104
# Check that the loss does not explode in a single epoch
105105
# This is meant to be a quick test, not a convergence test
106106
assert loss_metric < 1.25, (
107-
f"'stream.era5.loss_mse.loss_avg' is {loss_metric}, expected to be below 0.25"
107+
f"'stream.ERA5.loss_mse.loss_avg' is {loss_metric}, expected to be below 0.25"
108108
)
109109

110110

111111
def assert_val_loss_below_threshold(run_id):
112-
"""Test that the 'stream.era5.loss_mse.loss_avg' metric is below a threshold."""
112+
"""Test that the 'stream.ERA5.loss_mse.loss_avg' metric is below a threshold."""
113113
metrics = load_metrics(run_id)
114114
loss_metric = next(
115115
(
116-
metric.get("stream.era5.loss_mse.loss_avg", None)
116+
metric.get("stream.ERA5.loss_mse.loss_avg", None)
117117
for metric in reversed(metrics)
118118
if metric.get("stage") == "val"
119119
),
120120
None,
121121
)
122122
assert loss_metric is not None, (
123-
"'stream.era5.loss_mse.loss_avg' metric is missing in metrics file"
123+
"'stream.ERA5.loss_mse.loss_avg' metric is missing in metrics file"
124124
)
125125
# Check that the loss does not explode in a single epoch
126126
# This is meant to be a quick test, not a convergence test
127127
assert loss_metric < 1.25, (
128-
f"'stream.era5.loss_mse.loss_avg' is {loss_metric}, expected to be below 0.25"
128+
f"'stream.ERA5.loss_mse.loss_avg' is {loss_metric}, expected to be below 0.25"
129129
)

integration_tests/small1.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,10 @@ lr_steps: 4
1010
lr_steps_warmup: 2
1111
lr_steps_cooldown: 2
1212
loader_num_workers: 1
13+
14+
forecast_offset : 0
15+
# len_hrs: 6
16+
# step_hrs: 6
17+
1318
train_log:
1419
log_interval: 1

src/weathergen/utils/metrics.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""
2+
Utilities related to reading and writing metrics.
3+
4+
We use our own simple json-based format to abstract away various backends (our own pipeline, mlflow, wandb, etc.).
5+
"""
6+
7+
import polars as pl
8+
9+
# Known columns that are not scalar metrics:
10+
_known_cols = {"weathergen.timestamp": pl.Int64, "weathergen.time": pl.Int64, "stage": pl.String}
11+
12+
13+
def read_metrics_file(f: str) -> pl.DataFrame:
14+
"""
15+
Loads a file of metrics.
16+
17+
The resulting dataframe has the following format:
18+
- all columns in known_cols (if they exist in the file) have the right type
19+
- all other columns are of type float64 (including NaN values)
20+
"""
21+
22+
# All values are scalar, except for known values
23+
# The following point needs to be taken into account:
24+
# 1. The schema is not known in advance
25+
# 2. NaN is encoded as string
26+
# 3. numbers are encoded as numbers
27+
# The file needs to be read 3 times:
28+
# 1. Get the name of all the columns
29+
# 2. Find all the NaN values
30+
# 3. Read the numbers
31+
# 4. Merge the two dataframes
32+
33+
# Find the list of all columns (read everything)
34+
df0 = pl.read_ndjson(f, infer_schema_length=None)
35+
# Read with the final schema:
36+
schema1 = dict([(n, _known_cols.get(n, pl.Float64)) for n in df0.columns])
37+
df1 = pl.read_ndjson(f, schema=schema1)
38+
# Read again as strings to find the NaN values:
39+
schema2 = dict([(n, _known_cols.get(n, pl.String)) for n in df0.columns])
40+
metrics_cols = [n for n in df0.columns if n not in _known_cols]
41+
df2 = pl.read_ndjson(f, schema=schema2).cast(dict([(n, pl.Float64) for n in metrics_cols]))
42+
43+
# Merge the two dataframes:
44+
for n in metrics_cols:
45+
df1 = df1.with_columns(
46+
pl.when(pl.col(n).is_not_nan()).then(df1[n]).otherwise(df2[n]).alias(n)
47+
)
48+
return df1
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from io import StringIO
2+
from math import isnan
3+
4+
from weathergen.utils.metrics import (
5+
read_metrics_file,
6+
)
7+
8+
s = """{"weathergen.timestamp":100, "m": "nan"}
9+
{"weathergen.timestamp":101,"m": 1.3}
10+
{"weathergen.timestamp":102,"a": 4}
11+
"""
12+
13+
14+
def test1():
15+
df = read_metrics_file(StringIO(s))
16+
assert df.shape == (3, 3)
17+
assert df["weathergen.timestamp"].to_list() == [100, 101, 102]
18+
assert isnan(df["m"].to_list()[0])
19+
assert df["m"].to_list()[1:] == [1.3, None]
20+
assert df["a"].to_list() == [None, None, 4]

src/weathergen/utils/train_logger.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import polars as pl
2121

2222
import weathergen.utils.config as config
23+
from weathergen.utils.metrics import read_metrics_file
2324

2425
_weathergen_timestamp = "weathergen.timestamp"
2526
_weathergen_reltime = "weathergen.reltime"
@@ -257,7 +258,7 @@ def read_metrics(
257258
run_id = cf.run_id
258259

259260
# TODO: this should be a config option
260-
df = pl.read_ndjson(f"./results/{run_id}/metrics.json")
261+
df = read_metrics_file(f"./results/{run_id}/metrics.json")
261262
if stage is not None:
262263
df = df.filter(pl.col("stage") == stage)
263264
df = df.drop("stage")
@@ -294,7 +295,7 @@ def clean_df(df, columns: list[str] | None):
294295

295296
def _clean_name(n: str) -> str:
296297
"""Cleans the stream name to only retain alphanumeric characters"""
297-
return "".join([c for c in n if c.isalnum()]).lower()
298+
return "".join([c for c in n if c.isalnum()])
298299

299300

300301
def _key_loss(st_name: str, lf_name: str) -> str:

0 commit comments

Comments
 (0)