Skip to content

Commit 27d8202

Browse files
Merge pull request #22 from vadim0x60/add/ollama-python-package
Add/ollama python package
2 parents ccb264c + e306b2a commit 27d8202

23 files changed

+4087
-687
lines changed

README.md

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@ set to your OpenAI API access token.
4444

4545
#### Set up Ollama
4646

47-
Run [Ollama](https://ollama.ai/) with CodeLlama or [another model](https://ollama.ai/library) locally
47+
Run [Ollama](https://ollama.ai/) with Llama 3-8B or [another model](https://ollama.ai/library) locally
4848
or on a server.
4949
In the latter case, start the Ollama server with the following commands and note the `URL:PORT` pair:
5050
```
5151
OLLAMA_HOST=URL:PORT ollama serve &
52-
OLLAMA_HOST=URL:PORT ollama pull codellama:34b-instruct &
52+
OLLAMA_HOST=URL:PORT ollama pull llama3 &
5353
```
5454

5555
Example `.config` file layout:
@@ -75,41 +75,48 @@ export WANDB_DIR=...
7575

7676
If you're using [Slurm](https://slurm.schedmd.com/), write a `run.sh` file with `python benchmark.py`
7777
and run it with `sbatch run.sh --array=1-500`.
78-
If not, run `TASK_ID=n python benchmark.py` to re-run one of our experiments exactly, or set the parameters yourself:
78+
If not, run `TASK_ID=n python benchmark.py` to re-run one of our experiments exactly,
79+
or set the parameters yourself as below.
7980

8081
For example, for basement problem in PSB2, run SEIDR without lexicase selection as follows:
8182
```
8283
python3 benchmark.py \
83-
--task_id 202 \
84-
--problem basement \
85-
--language C++ \
84+
--task_id 0 \
85+
--problem bowling \
86+
--language Python \
87+
--branching_factor 2 \
8688
--max_programs 100 \
8789
--drafts_per_prompt 2 \
8890
--explanations_per_program 2 \
8991
--repairs_per_explanation 2 \
9092
--beam_width 2 \
9193
--log INFO \
9294
--lexicase_selection False \
93-
--dataset psb2 \
94-
--model_name gpt-3.5-turbo
95+
--dataset humaneval \
96+
--model_name gpt-3.5-turbo \
97+
--valid_examples 50 \
98+
--experiment_id 0
9599
```
96100

97-
To run an example with SEIDR with CodeLlama served by Ollama at `URL:PORT`, run the following:
101+
To run an example with SEIDR with Llama 3 served by Ollama at `URL:PORT` on HumanEval with lexicase, run the following:
98102
```
99-
python3 benchmark.py \
100-
--task_id 2202 \
101-
--problem basement \
102-
--language C++ \
103+
python3 benchmark_humaneval.py \
104+
--task_id 0 \
105+
--problem Python/0 \
106+
--language Python \
107+
--branching_factor 2 \
103108
--max_programs 100 \
104109
--drafts_per_prompt 2 \
105110
--explanations_per_program 2 \
106111
--repairs_per_explanation 2 \
107112
--beam_width 2 \
108113
--log INFO \
109-
--lexicase_selection False \
110-
--dataset psb2 \
111-
--model_name codellama:34b-instruct \
114+
--lexicase_selection True \
115+
--dataset humaneval \
116+
--model_name llama3 \
117+
--experiment_id 0 \
112118
--ollama_url "http://URL:PORT"
119+
113120
```
114121

115-
Example Slurm scripts are stored in `example_scripts/` and tables with hyperparameters in `/config`
122+
Example Slurm scripts are stored in `scripts/` and tables with hyperparameters in `/config`

benchmark.py

Lines changed: 111 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -20,59 +20,68 @@
2020

2121
logger = logging.getLogger(__name__)
2222

23-
DATA_PATH = os.environ.get('DATA_PATH') or 'psb2'
23+
DATA_PATH = os.environ.get("DATA_PATH") or "psb2"
2424

2525
task_descriptions = []
26-
with open('psb2-meta/tasks.txt') as f:
27-
task_descriptions = {name.strip(): description.strip()
28-
for name, description in chunked(f.readlines(), 2)}
26+
with open("psb2-meta/tasks.txt") as f:
27+
task_descriptions = {
28+
name.strip(): description.strip()
29+
for name, description in chunked(f.readlines(), 2)
30+
}
2931

30-
debug_templates = [line.split('\t')
31-
for line in get_template('prompts.txt').splitlines()]
32-
debug_templates = {int(ix.strip()): prompt.strip()
33-
for ix, prompt in debug_templates}
32+
debug_templates = [
33+
line.split("\t") for line in get_template("prompts.txt").splitlines()
34+
]
35+
debug_templates = {int(ix.strip()): prompt.strip() for ix, prompt in debug_templates}
3436

3537

3638
def title2kebabcase(title: str) -> str:
3739
"""Replace spaces with hyphens"""
38-
return '-'.join(word.lower() for word in title.split(' '))
40+
return "-".join(word.lower() for word in title.split(" "))
3941

4042

41-
pushgp_success_rates = pd.read_csv('psb2-meta/results.tsv',
42-
sep='\t', index_col=['Problem'])
43-
pushgp_success_rates = pushgp_success_rates['Succ.'].rename(title2kebabcase)
43+
pushgp_success_rates = pd.read_csv(
44+
"psb2-meta/results.tsv", sep="\t", index_col=["Problem"]
45+
)
46+
pushgp_success_rates = pushgp_success_rates["Succ."].rename(title2kebabcase)
4447

4548

4649
def is_already_solved(
47-
solutions_logger: FileLogger,
48-
test_data: Tuple[List[str] | str, List[str] | str],
49-
language: Language) -> Program | bool:
50+
solutions_logger: FileLogger,
51+
test_data: Tuple[List[str] | str, List[str] | str],
52+
language: Language,
53+
) -> Program | bool:
5054
"""Checks if the currently logged solution passes all tests in `test_data`.
5155
Returns False if a Program class instance cannot be created"""
5256
try:
53-
return Program(workdir=solutions_logger.dir,
54-
name=solutions_logger.filename,
55-
language=language).test(test_data)
57+
return Program(
58+
workdir=solutions_logger.dir,
59+
name=solutions_logger.filename,
60+
language=language,
61+
).test(test_data)
5662
except FileNotFoundError:
5763
return False
5864

5965

60-
def run_benchmark(problem: str = 'fizz-buzz',
61-
language: str = 'C++',
62-
max_programs: int = 1000,
63-
drafts_per_prompt: int = 10,
64-
explanations_per_program: int = 10,
65-
repairs_per_explanation: int = 2,
66-
beam_width: int = 100,
67-
seed: int = 42,
68-
valid_examples: int = 100,
69-
test_examples: int = 2000,
70-
prompt_examples: int = 5,
71-
log: str = 'ERROR',
72-
model_name: str = 'gpt-3.5-turbo',
73-
lexicase_selection: bool = False,
74-
ollama_url: Optional[str] = "http://localhost:11434",
75-
**kwargs):
66+
def run_benchmark(
67+
problem: str = "fizz-buzz",
68+
language: str = "C++",
69+
max_programs: int = 1000,
70+
drafts_per_prompt: int = 10,
71+
explanations_per_program: int = 10,
72+
repairs_per_explanation: int = 2,
73+
beam_width: int = 100,
74+
seed: int = 42,
75+
valid_examples: int = 100,
76+
test_examples: int = 2000,
77+
prompt_examples: int = 5,
78+
log: str = "ERROR",
79+
model_name: str = "gpt-3.5-turbo",
80+
lexicase_selection: bool = False,
81+
ollama_url: Optional[str] = "http://localhost:11434",
82+
experiment_id: int = 0,
83+
**kwargs,
84+
):
7685
"""Generate and repair programs in PSB2
7786
7887
Parameters
@@ -115,74 +124,91 @@ def run_benchmark(problem: str = 'fizz-buzz',
115124
link to the ollama cluster, default is localhost
116125
"""
117126
# Setup logging
118-
Path('logs').mkdir(exist_ok=True)
119-
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
120-
datefmt='%m/%d/%Y %H:%M:%S', level=log.upper())
121-
logging.info('logging info')
127+
Path("logs").mkdir(exist_ok=True)
128+
logging.basicConfig(
129+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
130+
datefmt="%m/%d/%Y %H:%M:%S",
131+
level=log.upper(),
132+
)
133+
logging.info("logging info")
122134
baseline = pushgp_success_rates[problem]
123135

124136
config = {
125-
'slurm_job_id': os.environ.get('SLURM_JOB_ID'),
126-
'slurm_task_pid': os.environ.get('SLURM_TASK_PID'),
127-
'slurm_array_task_id': os.environ.get('SLURM_ARRAY_TASK_ID'),
128-
'slurm_array_job_id': os.environ.get('SLURM_ARRAY_JOB_ID'),
129-
'task_id': os.environ.get('TASK_ID'),
137+
"slurm_job_id": os.environ.get("SLURM_JOB_ID"),
138+
"slurm_task_pid": os.environ.get("SLURM_TASK_PID"),
139+
"slurm_array_task_id": os.environ.get("SLURM_ARRAY_TASK_ID"),
140+
"slurm_array_job_id": os.environ.get("SLURM_ARRAY_JOB_ID"),
141+
"task_id": os.environ.get("TASK_ID"),
130142
**kwargs,
131-
**locals()
143+
**locals(),
132144
}
133145

134-
del config['kwargs']
135-
model_name_tag = model_name.replace(':', '_')
146+
del config["kwargs"]
147+
model_name_tag = model_name.replace(":", "_")
136148
run = wandb.init(
137-
entity=os.environ.get('WANDB_ENTITY'),
138-
project=f'seidr-telo-psb2-{model_name_tag}',
139-
dir=os.environ.get('WANDB_DIR'),
140-
config=config)
141-
logger.info(f'Run config {run.config}, W&B: {run.url}')
149+
entity=os.environ.get("WANDB_ENTITY"),
150+
project=f"seidr-telo-psb2-{model_name_tag}-run{experiment_id}",
151+
dir=os.environ.get("WANDB_DIR"),
152+
config=config,
153+
)
154+
logger.info(f"Run config {run.config}, W&B: {run.url}")
142155

143156
language = language_(language)
144157

145-
commit_msg_template = get_template('commit.txt').format(
146-
problem=problem,
147-
wandb_url=run.url)
158+
commit_msg_template = get_template("commit.txt").format(
159+
problem=problem, wandb_url=run.url
160+
)
148161

149-
lexicase_tag = '_lexicase' if lexicase_selection else ""
150-
attempts_branch = f'psb_{model_name_tag}_{drafts_per_prompt}x{explanations_per_program}x{repairs_per_explanation}{lexicase_tag}_dev'
151-
solutions_branch = f'psb_{model_name_tag}_{drafts_per_prompt}x{explanations_per_program}x{repairs_per_explanation}{lexicase_tag}'
162+
lexicase_tag = "_lexicase" if lexicase_selection else ""
163+
attempts_branch = f"psb_{model_name_tag}_{drafts_per_prompt}x{explanations_per_program}x{repairs_per_explanation}{lexicase_tag}_run{experiment_id}_dev"
164+
solutions_branch = f"psb_{model_name_tag}_{drafts_per_prompt}x{explanations_per_program}x{repairs_per_explanation}{lexicase_tag}_run{experiment_id}"
152165

153-
attempts_logger = FileLogger(branch=attempts_branch,
154-
filename=language.source.format(name=problem),
155-
commit_msg_template=commit_msg_template)
156-
solutions_logger = FileLogger(branch=solutions_branch,
157-
filename=language.source.format(name=problem),
158-
commit_msg_template=commit_msg_template)
166+
attempts_logger = FileLogger(
167+
branch=attempts_branch,
168+
filename=language.source.format(name=problem),
169+
commit_msg_template=commit_msg_template,
170+
)
171+
solutions_logger = FileLogger(
172+
branch=solutions_branch,
173+
filename=language.source.format(name=problem),
174+
commit_msg_template=commit_msg_template,
175+
)
159176

160177
description = task_descriptions[problem]
161178

162179
# ensure that the same I/O pairs are fetched for every experiment
163180
random.seed(seed)
164181

165182
train_data, test_data = psb2.fetch_examples(
166-
DATA_PATH, problem, max(valid_examples, prompt_examples),
167-
test_examples, format='competitive')
183+
DATA_PATH,
184+
problem,
185+
max(valid_examples, prompt_examples),
186+
test_examples,
187+
format="competitive",
188+
)
168189
prompt_data = train_data[:prompt_examples]
169190
valid_data = train_data[:valid_examples]
170191

171192
if is_already_solved(solutions_logger, test_data, language):
172-
logging.info(f'{problem} is already solved, shutting down')
193+
logging.info(f"{problem} is already solved, shutting down")
173194
return
174195

175196
call_count = 0
176197

177198
def log_llm_call(**kwargs):
178199
"""Update and log the number of LLM calls"""
179200
nonlocal call_count
180-
wandb.log({'llm_calls': call_count})
201+
wandb.log({"llm_calls": call_count})
181202
call_count += 1
182203

183204
critics = [
184-
lambda code: IOMatch(code=code, language=language, input=inp, output=out,
185-
task_description=description)
205+
lambda code: IOMatch(
206+
code=code,
207+
language=language,
208+
input=inp,
209+
output=out,
210+
task_description=description,
211+
)
186212
for inp, out in valid_data
187213
]
188214
prompt = initial_prompt(description, prompt_data)
@@ -204,31 +230,36 @@ def log_llm_call(**kwargs):
204230
log_solution=solutions_logger,
205231
log_llm_call=log_llm_call,
206232
max_programs=max_programs,
207-
ollama_url=ollama_url
233+
ollama_url=ollama_url,
208234
)
209235

210236
solution = seidr.develop(start_code=start_code)
211237

212-
logging.info('Development done. Testing...')
238+
logging.info("Development done. Testing...")
213239

214240
test_evals = [
215-
IOMatch(solution,
216-
language=language,
217-
input=inp, output=out,
218-
task_description=description)
219-
for inp, out in test_data]
241+
IOMatch(
242+
solution,
243+
language=language,
244+
input=inp,
245+
output=out,
246+
task_description=description,
247+
)
248+
for inp, out in test_data
249+
]
220250
avg_score = sum(e.score() for e in test_evals) / len(test_evals)
221251
test_pass_rate = sum(e.check() for e in test_evals) / len(test_evals)
222252

223-
logging.info(f'\nTest pass rate on test: {test_pass_rate}\nTest avg score on test: {avg_score}')
253+
logging.info(
254+
f"\nTest pass rate on test: {test_pass_rate}\nTest avg score on test: {avg_score}"
255+
)
224256

225-
run.log({'test_avg_score': avg_score,
226-
'test_pass_rate': test_pass_rate})
257+
run.log({"test_avg_score": avg_score, "test_pass_rate": test_pass_rate})
227258
# run.finish()
228259
wandb.finish()
229260

230261

231-
if __name__ == '__main__':
262+
if __name__ == "__main__":
232263
try:
233264
Fire(run_benchmark)
234265
except:

0 commit comments

Comments
 (0)