2020
2121logger = logging .getLogger (__name__ )
2222
23- DATA_PATH = os .environ .get (' DATA_PATH' ) or ' psb2'
23+ DATA_PATH = os .environ .get (" DATA_PATH" ) or " psb2"
2424
2525task_descriptions = []
26- with open ('psb2-meta/tasks.txt' ) as f :
27- task_descriptions = {name .strip (): description .strip ()
28- for name , description in chunked (f .readlines (), 2 )}
26+ with open ("psb2-meta/tasks.txt" ) as f :
27+ task_descriptions = {
28+ name .strip (): description .strip ()
29+ for name , description in chunked (f .readlines (), 2 )
30+ }
2931
30- debug_templates = [line . split ( ' \t ' )
31- for line in get_template (' prompts.txt' ).splitlines ()]
32- debug_templates = { int ( ix . strip ()): prompt . strip ()
33- for ix , prompt in debug_templates }
32+ debug_templates = [
33+ line . split ( " \t " ) for line in get_template (" prompts.txt" ).splitlines ()
34+ ]
35+ debug_templates = { int ( ix . strip ()): prompt . strip () for ix , prompt in debug_templates }
3436
3537
3638def title2kebabcase (title : str ) -> str :
3739 """Replace spaces with hyphens"""
38- return '-' .join (word .lower () for word in title .split (' ' ))
40+ return "-" .join (word .lower () for word in title .split (" " ))
3941
4042
41- pushgp_success_rates = pd .read_csv ('psb2-meta/results.tsv' ,
42- sep = '\t ' , index_col = ['Problem' ])
43- pushgp_success_rates = pushgp_success_rates ['Succ.' ].rename (title2kebabcase )
43+ pushgp_success_rates = pd .read_csv (
44+ "psb2-meta/results.tsv" , sep = "\t " , index_col = ["Problem" ]
45+ )
46+ pushgp_success_rates = pushgp_success_rates ["Succ." ].rename (title2kebabcase )
4447
4548
4649def is_already_solved (
47- solutions_logger : FileLogger ,
48- test_data : Tuple [List [str ] | str , List [str ] | str ],
49- language : Language ) -> Program | bool :
50+ solutions_logger : FileLogger ,
51+ test_data : Tuple [List [str ] | str , List [str ] | str ],
52+ language : Language ,
53+ ) -> Program | bool :
5054 """Checks if the currently logged solution passes all tests in `test_data`.
5155 Returns False if a Program class instance cannot be created"""
5256 try :
53- return Program (workdir = solutions_logger .dir ,
54- name = solutions_logger .filename ,
55- language = language ).test (test_data )
57+ return Program (
58+ workdir = solutions_logger .dir ,
59+ name = solutions_logger .filename ,
60+ language = language ,
61+ ).test (test_data )
5662 except FileNotFoundError :
5763 return False
5864
5965
60- def run_benchmark (problem : str = 'fizz-buzz' ,
61- language : str = 'C++' ,
62- max_programs : int = 1000 ,
63- drafts_per_prompt : int = 10 ,
64- explanations_per_program : int = 10 ,
65- repairs_per_explanation : int = 2 ,
66- beam_width : int = 100 ,
67- seed : int = 42 ,
68- valid_examples : int = 100 ,
69- test_examples : int = 2000 ,
70- prompt_examples : int = 5 ,
71- log : str = 'ERROR' ,
72- model_name : str = 'gpt-3.5-turbo' ,
73- lexicase_selection : bool = False ,
74- ollama_url : Optional [str ] = "http://localhost:11434" ,
75- ** kwargs ):
66+ def run_benchmark (
67+ problem : str = "fizz-buzz" ,
68+ language : str = "C++" ,
69+ max_programs : int = 1000 ,
70+ drafts_per_prompt : int = 10 ,
71+ explanations_per_program : int = 10 ,
72+ repairs_per_explanation : int = 2 ,
73+ beam_width : int = 100 ,
74+ seed : int = 42 ,
75+ valid_examples : int = 100 ,
76+ test_examples : int = 2000 ,
77+ prompt_examples : int = 5 ,
78+ log : str = "ERROR" ,
79+ model_name : str = "gpt-3.5-turbo" ,
80+ lexicase_selection : bool = False ,
81+ ollama_url : Optional [str ] = "http://localhost:11434" ,
82+ experiment_id : int = 0 ,
83+ ** kwargs ,
84+ ):
7685 """Generate and repair programs in PSB2
7786
7887 Parameters
@@ -115,74 +124,91 @@ def run_benchmark(problem: str = 'fizz-buzz',
115124 link to the ollama cluster, default is localhost
116125 """
117126 # Setup logging
118- Path ('logs' ).mkdir (exist_ok = True )
119- logging .basicConfig (format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' ,
120- datefmt = '%m/%d/%Y %H:%M:%S' , level = log .upper ())
121- logging .info ('logging info' )
127+ Path ("logs" ).mkdir (exist_ok = True )
128+ logging .basicConfig (
129+ format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" ,
130+ datefmt = "%m/%d/%Y %H:%M:%S" ,
131+ level = log .upper (),
132+ )
133+ logging .info ("logging info" )
122134 baseline = pushgp_success_rates [problem ]
123135
124136 config = {
125- ' slurm_job_id' : os .environ .get (' SLURM_JOB_ID' ),
126- ' slurm_task_pid' : os .environ .get (' SLURM_TASK_PID' ),
127- ' slurm_array_task_id' : os .environ .get (' SLURM_ARRAY_TASK_ID' ),
128- ' slurm_array_job_id' : os .environ .get (' SLURM_ARRAY_JOB_ID' ),
129- ' task_id' : os .environ .get (' TASK_ID' ),
137+ " slurm_job_id" : os .environ .get (" SLURM_JOB_ID" ),
138+ " slurm_task_pid" : os .environ .get (" SLURM_TASK_PID" ),
139+ " slurm_array_task_id" : os .environ .get (" SLURM_ARRAY_TASK_ID" ),
140+ " slurm_array_job_id" : os .environ .get (" SLURM_ARRAY_JOB_ID" ),
141+ " task_id" : os .environ .get (" TASK_ID" ),
130142 ** kwargs ,
131- ** locals ()
143+ ** locals (),
132144 }
133145
134- del config [' kwargs' ]
135- model_name_tag = model_name .replace (':' , '_' )
146+ del config [" kwargs" ]
147+ model_name_tag = model_name .replace (":" , "_" )
136148 run = wandb .init (
137- entity = os .environ .get ('WANDB_ENTITY' ),
138- project = f'seidr-telo-psb2-{ model_name_tag } ' ,
139- dir = os .environ .get ('WANDB_DIR' ),
140- config = config )
141- logger .info (f'Run config { run .config } , W&B: { run .url } ' )
149+ entity = os .environ .get ("WANDB_ENTITY" ),
150+ project = f"seidr-telo-psb2-{ model_name_tag } -run{ experiment_id } " ,
151+ dir = os .environ .get ("WANDB_DIR" ),
152+ config = config ,
153+ )
154+ logger .info (f"Run config { run .config } , W&B: { run .url } " )
142155
143156 language = language_ (language )
144157
145- commit_msg_template = get_template (' commit.txt' ).format (
146- problem = problem ,
147- wandb_url = run . url )
158+ commit_msg_template = get_template (" commit.txt" ).format (
159+ problem = problem , wandb_url = run . url
160+ )
148161
149- lexicase_tag = ' _lexicase' if lexicase_selection else ""
150- attempts_branch = f' psb_{ model_name_tag } _{ drafts_per_prompt } x{ explanations_per_program } x{ repairs_per_explanation } { lexicase_tag } _dev'
151- solutions_branch = f' psb_{ model_name_tag } _{ drafts_per_prompt } x{ explanations_per_program } x{ repairs_per_explanation } { lexicase_tag } '
162+ lexicase_tag = " _lexicase" if lexicase_selection else ""
163+ attempts_branch = f" psb_{ model_name_tag } _{ drafts_per_prompt } x{ explanations_per_program } x{ repairs_per_explanation } { lexicase_tag } _run { experiment_id } _dev"
164+ solutions_branch = f" psb_{ model_name_tag } _{ drafts_per_prompt } x{ explanations_per_program } x{ repairs_per_explanation } { lexicase_tag } _run { experiment_id } "
152165
153- attempts_logger = FileLogger (branch = attempts_branch ,
154- filename = language .source .format (name = problem ),
155- commit_msg_template = commit_msg_template )
156- solutions_logger = FileLogger (branch = solutions_branch ,
157- filename = language .source .format (name = problem ),
158- commit_msg_template = commit_msg_template )
166+ attempts_logger = FileLogger (
167+ branch = attempts_branch ,
168+ filename = language .source .format (name = problem ),
169+ commit_msg_template = commit_msg_template ,
170+ )
171+ solutions_logger = FileLogger (
172+ branch = solutions_branch ,
173+ filename = language .source .format (name = problem ),
174+ commit_msg_template = commit_msg_template ,
175+ )
159176
160177 description = task_descriptions [problem ]
161178
162179 # ensure that the same I/O pairs are fetched for every experiment
163180 random .seed (seed )
164181
165182 train_data , test_data = psb2 .fetch_examples (
166- DATA_PATH , problem , max (valid_examples , prompt_examples ),
167- test_examples , format = 'competitive' )
183+ DATA_PATH ,
184+ problem ,
185+ max (valid_examples , prompt_examples ),
186+ test_examples ,
187+ format = "competitive" ,
188+ )
168189 prompt_data = train_data [:prompt_examples ]
169190 valid_data = train_data [:valid_examples ]
170191
171192 if is_already_solved (solutions_logger , test_data , language ):
172- logging .info (f' { problem } is already solved, shutting down' )
193+ logging .info (f" { problem } is already solved, shutting down" )
173194 return
174195
175196 call_count = 0
176197
177198 def log_llm_call (** kwargs ):
178199 """Update and log the number of LLM calls"""
179200 nonlocal call_count
180- wandb .log ({' llm_calls' : call_count })
201+ wandb .log ({" llm_calls" : call_count })
181202 call_count += 1
182203
183204 critics = [
184- lambda code : IOMatch (code = code , language = language , input = inp , output = out ,
185- task_description = description )
205+ lambda code : IOMatch (
206+ code = code ,
207+ language = language ,
208+ input = inp ,
209+ output = out ,
210+ task_description = description ,
211+ )
186212 for inp , out in valid_data
187213 ]
188214 prompt = initial_prompt (description , prompt_data )
@@ -204,31 +230,36 @@ def log_llm_call(**kwargs):
204230 log_solution = solutions_logger ,
205231 log_llm_call = log_llm_call ,
206232 max_programs = max_programs ,
207- ollama_url = ollama_url
233+ ollama_url = ollama_url ,
208234 )
209235
210236 solution = seidr .develop (start_code = start_code )
211237
212- logging .info (' Development done. Testing...' )
238+ logging .info (" Development done. Testing..." )
213239
214240 test_evals = [
215- IOMatch (solution ,
216- language = language ,
217- input = inp , output = out ,
218- task_description = description )
219- for inp , out in test_data ]
241+ IOMatch (
242+ solution ,
243+ language = language ,
244+ input = inp ,
245+ output = out ,
246+ task_description = description ,
247+ )
248+ for inp , out in test_data
249+ ]
220250 avg_score = sum (e .score () for e in test_evals ) / len (test_evals )
221251 test_pass_rate = sum (e .check () for e in test_evals ) / len (test_evals )
222252
223- logging .info (f'\n Test pass rate on test: { test_pass_rate } \n Test avg score on test: { avg_score } ' )
253+ logging .info (
254+ f"\n Test pass rate on test: { test_pass_rate } \n Test avg score on test: { avg_score } "
255+ )
224256
225- run .log ({'test_avg_score' : avg_score ,
226- 'test_pass_rate' : test_pass_rate })
257+ run .log ({"test_avg_score" : avg_score , "test_pass_rate" : test_pass_rate })
227258 # run.finish()
228259 wandb .finish ()
229260
230261
231- if __name__ == ' __main__' :
262+ if __name__ == " __main__" :
232263 try :
233264 Fire (run_benchmark )
234265 except :
0 commit comments