1010import argparse
1111import logging
1212import subprocess
13+ import sys
1314from pathlib import Path
1415
1516import matplotlib .pyplot as plt
2122
2223_logger = logging .getLogger (__name__ )
2324
25+ DEFAULT_RUN_FILE = Path ("./config/runs_plot_train.yml" )
26+
2427
2528####################################################################################################
2629def _ensure_list (value ):
@@ -556,7 +559,7 @@ def plot_loss_per_run(
556559 plt .close ()
557560
558561
559- def plot_train ():
562+ def plot_train (args = None ):
560563 # Example usage:
561564 # When providing a YAML for configuring the run IDs:
562565 # python plot_training.py -rf eval_run.yml -m ./trained_models -o ./training_plots
@@ -622,27 +625,24 @@ def plot_train():
622625
623626 run_id_group = parser .add_mutually_exclusive_group ()
624627 run_id_group .add_argument (
625- "-rs " ,
626- "--run_ids_dict " ,
628+ "-fd " ,
629+ "--from_dict " ,
627630 type = _read_str_config ,
628- dest = "rs" ,
629- help = (
630- "Dictionary-string of form '{run_id: [job_id, experiment_name]}'" ,
631- " for training runs to plot" ,
632- ),
631+ dest = "fd" ,
632+ help = "Dictionary-string of form '{run_id: [job_id, experiment_name]}'"
633+ + "for training runs to plot" ,
633634 )
634635
635636 run_id_group .add_argument (
636- "-rf" ,
637- "--run_ids_file" ,
638- dest = "rf" ,
639- default = "./config/runs_plot_train.yml" ,
637+ "-fy" ,
638+ "--from_yaml" ,
639+ dest = "fy" ,
640640 type = _read_yaml_config ,
641641 help = "YAML file configuring the training run ids to plot" ,
642642 )
643643
644644 # parse the command line arguments
645- args = parser .parse_args ()
645+ args = parser .parse_args (args )
646646
647647 model_base_dir = Path (args .model_base_dir )
648648 out_dir = Path (args .output_dir )
@@ -651,7 +651,17 @@ def plot_train():
651651 if args .x_type not in x_types_valid :
652652 raise ValueError (f"x_type must be one of { x_types_valid } , but got { args .x_type } " )
653653
654- runs_ids = args .rs if args .rs is not None else args .rf
654+ # Post-processing default logic for config from YAML-file
655+ if args .fd is None and args .fy is None :
656+ if DEFAULT_RUN_FILE .exists ():
657+ args .fy = _read_yaml_config (DEFAULT_RUN_FILE )
658+ else :
659+ raise ValueError (
660+ f"Please provide a run_id dictionary or a YAML file with run_ids, "
661+ f"or create a default file at { DEFAULT_RUN_FILE } ."
662+ )
663+
664+ runs_ids = args .fd if args .fd is not None else args .fy
655665
656666 if args .delete == "True" :
657667 clean_plot_folder (out_dir )
@@ -725,3 +735,9 @@ def plot_train():
725735 get_stream_names (run_id , model_path = model_base_dir ), # limit to available streams
726736 plot_dir = out_dir ,
727737 )
738+
739+
740+ if __name__ == "__main__" :
741+ args = sys .argv [1 :] # get CLI args
742+
743+ plot_train (args )
0 commit comments