Skip to content

Commit 4b27565

Browse files
authored
Changed format of yaml file (ecmwf#291)
1 parent 1e2dc0c commit 4b27565

File tree

1 file changed

+32
-29
lines changed

1 file changed

+32
-29
lines changed

src/weathergen/utils/plot_training.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,13 @@ def _read_yaml_config(yaml_file_path):
8989
Expected structure in the YAML file:
9090
train:
9191
plot:
92-
run_ids: [run_id1, run_id2, ...]
93-
job_ids: [job_id1, job_id2, ...]
94-
experiment_names: [experiment_name1, experiment_name2, ...]
92+
run_id:
93+
slurm_id : SLURM_JOB (specify 0 if not available)
94+
description: job description
95+
run_id:
96+
slurm_id : SLURM_JOB (specify 0 if not available)
97+
description : job description
98+
...
9599
96100
Parameters
97101
----------
@@ -106,23 +110,17 @@ def _read_yaml_config(yaml_file_path):
106110
data = yaml.safe_load(f)
107111

108112
# Extract configuration for plotting training diagnostics
109-
config_plot = data.get("train", {}).get("plot", {})
110-
111-
# Init lists
112-
run_ids = _ensure_list(config_plot.get("run_ids", []))
113-
job_ids = _ensure_list(config_plot.get("job_ids", []))
114-
experiment_names = _ensure_list(config_plot.get("experiment_names", []))
113+
config_dict_temp = data.get("train", {}).get("plot", {})
115114

116115
# sanity checks
117-
assert len(run_ids) > 0, "At least one run_id must be provided."
118-
assert len(run_ids) == len(job_ids) == len(experiment_names), (
119-
"The lengths of run_ids, job_ids, and experiment_names must be equal."
120-
)
116+
assert len(config_dict_temp) > 0, "At least one run must be specified."
121117

122-
config_dict = {
123-
run_id: [job_id, exp_name]
124-
for run_id, job_id, exp_name in zip(run_ids, job_ids, experiment_names, strict=False)
125-
}
118+
# convert to legacy format
119+
config_dict = {}
120+
for k, v in config_dict_temp.items():
121+
assert type(v["slurm_id"]) == int, "slurm_id has to be int."
122+
assert type(v["description"]) == str, "description has to be str."
123+
config_dict[k] = [v["slurm_id"], v["description"]]
126124

127125
# Validate the structure: {run_id: [job_id, experiment_name]}
128126
_check_run_id_dict(config_dict)
@@ -559,16 +557,21 @@ def plot_loss_per_run(
559557
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
560558
parser = argparse.ArgumentParser(
561559
description="""Plot training diagnostics from logged data during training.
562-
An example YAML file looks like this:
563-
train:
564-
plot:
565-
- run_ids: [abcde, fghij]
566-
- job_ids: [123456, 654321]
567-
- experiment_names: [experiment1, experiment2]
568-
A dictionary-string can also be used, e.g.:
569-
"{'abcde': ['123456', 'experiment1'],
570-
'fghij': ['654321', 'experiment2']}"
571-
"""
560+
An example YAML file looks like this:
561+
train:
562+
plot:
563+
run_id:
564+
slurm_id : SLURM_JOB (specify 0 if not available)
565+
description: job description
566+
run_id:
567+
slurm_id : SLURM_JOB (specify 0 if not available)
568+
description : job description
569+
...
570+
571+
A dictionary-string can also be specified on the command line, e.g.:
572+
"{'abcde': ['123456', 'experiment1'],
573+
'fghij': ['654321', 'experiment2']}"
574+
"""
572575
)
573576

574577
parser.add_argument(
@@ -653,8 +656,8 @@ def plot_loss_per_run(
653656
# plot learning rate
654657
plot_lr(runs_ids, runs_data, runs_active, plot_dir=out_dir)
655658

656-
# plot performance
657-
plot_utilization(runs_ids, runs_data, runs_active, plot_dir=out_dir)
659+
# # plot performance
660+
# plot_utilization(runs_ids, runs_data, runs_active, plot_dir=out_dir)
658661

659662
# compare different runs
660663
plot_loss_per_stream(

0 commit comments

Comments
 (0)