Skip to content

Commit a31b40c

Browse files
authored
Re-enabled option to run plot_training as script and fixed -rf argument (ecmwf#444)
* Re-enabled option to runplot_training as script and removed relative path as default from mutually-exclusive argument -rf. * Ruffed code. * Ruff check fix. * Rename flags for parsing configuration and fixed default handling for standard config YAML-file.
1 parent 419f7dc commit a31b40c

File tree

1 file changed

+30
-14
lines changed

1 file changed

+30
-14
lines changed

src/weathergen/utils/plot_training.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import argparse
1111
import logging
1212
import subprocess
13+
import sys
1314
from pathlib import Path
1415

1516
import matplotlib.pyplot as plt
@@ -21,6 +22,8 @@
2122

2223
_logger = logging.getLogger(__name__)
2324

25+
DEFAULT_RUN_FILE = Path("./config/runs_plot_train.yml")
26+
2427

2528
####################################################################################################
2629
def _ensure_list(value):
@@ -556,7 +559,7 @@ def plot_loss_per_run(
556559
plt.close()
557560

558561

559-
def plot_train():
562+
def plot_train(args=None):
560563
# Example usage:
561564
# When providing a YAML for configuring the run IDs:
562565
# python plot_training.py -rf eval_run.yml -m ./trained_models -o ./training_plots
@@ -622,27 +625,24 @@ def plot_train():
622625

623626
run_id_group = parser.add_mutually_exclusive_group()
624627
run_id_group.add_argument(
625-
"-rs",
626-
"--run_ids_dict",
628+
"-fd",
629+
"--from_dict",
627630
type=_read_str_config,
628-
dest="rs",
629-
help=(
630-
"Dictionary-string of form '{run_id: [job_id, experiment_name]}'",
631-
" for training runs to plot",
632-
),
631+
dest="fd",
632+
help="Dictionary-string of form '{run_id: [job_id, experiment_name]}'"
633+
+ "for training runs to plot",
633634
)
634635

635636
run_id_group.add_argument(
636-
"-rf",
637-
"--run_ids_file",
638-
dest="rf",
639-
default="./config/runs_plot_train.yml",
637+
"-fy",
638+
"--from_yaml",
639+
dest="fy",
640640
type=_read_yaml_config,
641641
help="YAML file configuring the training run ids to plot",
642642
)
643643

644644
# parse the command line arguments
645-
args = parser.parse_args()
645+
args = parser.parse_args(args)
646646

647647
model_base_dir = Path(args.model_base_dir)
648648
out_dir = Path(args.output_dir)
@@ -651,7 +651,17 @@ def plot_train():
651651
if args.x_type not in x_types_valid:
652652
raise ValueError(f"x_type must be one of {x_types_valid}, but got {args.x_type}")
653653

654-
runs_ids = args.rs if args.rs is not None else args.rf
654+
# Post-processing default logic for config from YAML-file
655+
if args.fd is None and args.fy is None:
656+
if DEFAULT_RUN_FILE.exists():
657+
args.fy = _read_yaml_config(DEFAULT_RUN_FILE)
658+
else:
659+
raise ValueError(
660+
f"Please provide a run_id dictionary or a YAML file with run_ids, "
661+
f"or create a default file at {DEFAULT_RUN_FILE}."
662+
)
663+
664+
runs_ids = args.fd if args.fd is not None else args.fy
655665

656666
if args.delete == "True":
657667
clean_plot_folder(out_dir)
@@ -725,3 +735,9 @@ def plot_train():
725735
get_stream_names(run_id, model_path=model_base_dir), # limit to available streams
726736
plot_dir=out_dir,
727737
)
738+
739+
740+
if __name__ == "__main__":
741+
args = sys.argv[1:] # get CLI args
742+
743+
plot_train(args)

0 commit comments

Comments
 (0)