@@ -78,23 +78,48 @@ def parse_args():
7878 default = True ,
7979 help = "Whether to download dataset for the first time" ,
8080 )
81+ parser .add_argument (
82+ "--opt_model_name" ,
83+ type = str ,
84+ default = "claude-3-5-sonnet-20240620" ,
85+ help = "Specifies the name of the model used for optimization tasks." ,
86+ )
87+ parser .add_argument (
88+ "--exec_model_name" ,
89+ type = str ,
90+ default = "gpt-4o-mini" ,
91+ help = "Specifies the name of the model used for execution tasks." ,
92+ )
8193 return parser .parse_args ()
8294
8395
8496if __name__ == "__main__" :
8597 args = parse_args ()
8698
87- download (["datasets" , "initial_rounds" ], if_first_download = args .if_first_optimize )
8899 config = EXPERIMENT_CONFIGS [args .dataset ]
89100
90- mini_llm_config = ModelsConfig .default ().get ("gpt-4o-mini" )
91- claude_llm_config = ModelsConfig .default ().get ("claude-3-5-sonnet-20240620" )
101+ models_config = ModelsConfig .default ()
102+ opt_llm_config = models_config .get (args .opt_model_name )
103+ if opt_llm_config is None :
104+ raise ValueError (
105+ f"The optimization model '{ args .opt_model_name } ' was not found in the 'models' section of the configuration file. "
106+ "Please add it to the configuration file or specify a valid model using the --opt_model_name flag. "
107+ )
108+
109+ exec_llm_config = models_config .get (args .exec_model_name )
110+ if exec_llm_config is None :
111+ raise ValueError (
112+ f"The execution model '{ args .exec_model_name } ' was not found in the 'models' section of the configuration file. "
113+ "Please add it to the configuration file or specify a valid model using the --exec_model_name flag. "
114+ )
115+
116+ download (["datasets" , "initial_rounds" ], if_first_download = args .if_first_optimize )
92117
93118 optimizer = Optimizer (
94119 dataset = config .dataset ,
95120 question_type = config .question_type ,
96- opt_llm_config = claude_llm_config ,
97- exec_llm_config = mini_llm_config ,
121+ opt_llm_config = opt_llm_config ,
122+ exec_llm_config = exec_llm_config ,
98123 check_convergence = args .check_convergence ,
99124 operators = config .operators ,
100125 optimized_path = args .optimized_path ,
0 commit comments