@@ -89,9 +89,13 @@ def _read_yaml_config(yaml_file_path):
8989 Expected structure in the YAML file:
9090 train:
9191 plot:
92- run_ids: [run_id1, run_id2, ...]
93- job_ids: [job_id1, job_id2, ...]
94- experiment_names: [experiment_name1, experiment_name2, ...]
92+ run_id:
93+ slurm_id : SLURM_JOB (specify 0 if not available)
94+ description: job description
95+ run_id:
96+ slurm_id : SLURM_JOB (specify 0 if not available)
97+ description : job description
98+ ...
9599
96100 Parameters
97101 ----------
@@ -106,23 +110,17 @@ def _read_yaml_config(yaml_file_path):
106110 data = yaml .safe_load (f )
107111
108112 # Extract configuration for plotting training diagnostics
109- config_plot = data .get ("train" , {}).get ("plot" , {})
110-
111- # Init lists
112- run_ids = _ensure_list (config_plot .get ("run_ids" , []))
113- job_ids = _ensure_list (config_plot .get ("job_ids" , []))
114- experiment_names = _ensure_list (config_plot .get ("experiment_names" , []))
113+ config_dict_temp = data .get ("train" , {}).get ("plot" , {})
115114
116115 # sanity checks
117- assert len (run_ids ) > 0 , "At least one run_id must be provided."
118- assert len (run_ids ) == len (job_ids ) == len (experiment_names ), (
119- "The lengths of run_ids, job_ids, and experiment_names must be equal."
120- )
116+ assert len (config_dict_temp ) > 0 , "At least one run must be specified."
121117
122- config_dict = {
123- run_id : [job_id , exp_name ]
124- for run_id , job_id , exp_name in zip (run_ids , job_ids , experiment_names , strict = False )
125- }
118+ # convert to legacy format
119+ config_dict = {}
120+ for k , v in config_dict_temp .items ():
121+ assert type (v ["slurm_id" ]) == int , "slurm_id has to be int."
122+ assert type (v ["description" ]) == str , "description has to be str."
123+ config_dict [k ] = [v ["slurm_id" ], v ["description" ]]
126124
127125 # Validate the structure: {run_id: [job_id, experiment_name]}
128126 _check_run_id_dict (config_dict )
@@ -559,16 +557,21 @@ def plot_loss_per_run(
559557 logging .basicConfig (level = logging .INFO , format = "%(asctime)s - %(levelname)s - %(message)s" )
560558 parser = argparse .ArgumentParser (
561559 description = """Plot training diagnostics from logged data during training.
562- An example YAML file looks like this:
563- train:
564- plot:
565- - run_ids: [abcde, fghij]
566- - job_ids: [123456, 654321]
567- - experiment_names: [experiment1, experiment2]
568- A dictionary-string can also be used, e.g.:
569- "{'abcde': ['123456', 'experiment1'],
570- 'fghij': ['654321', 'experiment2']}"
571- """
560+ An example YAML file looks like this:
561+ train:
562+ plot:
563+ run_id:
564+ slurm_id : SLURM_JOB (specify 0 if not available)
565+ description: job description
566+ run_id:
567+ slurm_id : SLURM_JOB (specify 0 if not available)
568+ description : job description
569+ ...
570+
571+ A dictionary-string can also be specified on the command line, e.g.:
572+ "{'abcde': ['123456', 'experiment1'],
573+ 'fghij': ['654321', 'experiment2']}"
574+ """
572575 )
573576
574577 parser .add_argument (
@@ -653,8 +656,8 @@ def plot_loss_per_run(
653656 # plot learning rate
654657 plot_lr (runs_ids , runs_data , runs_active , plot_dir = out_dir )
655658
656- # plot performance
657- plot_utilization (runs_ids , runs_data , runs_active , plot_dir = out_dir )
659+ # # plot performance
660+ # plot_utilization(runs_ids, runs_data, runs_active, plot_dir=out_dir)
658661
659662 # compare different runs
660663 plot_loss_per_stream (
0 commit comments