Skip to content

Commit ccb264c

Browse files
Merge pull request #20 from vadim0x60/vadim
Multiple quality of life improvements
2 parents 037c0f2 + 30792da commit ccb264c

File tree

5 files changed

+99
-18
lines changed

5 files changed

+99
-18
lines changed

pyproject.toml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "seidr"
3-
version = "3.1.1"
3+
version = "3.4.0"
44
description = "Synthesize Execute Instruct Debug Rank"
55
authors = ["Vadim Liventsev <v.liventsev@tue.nl>", "Anastasia Grishina <anastasiia@simula.no>"]
66
license = "MIT"
@@ -16,17 +16,19 @@ python = "^3.9"
1616
psb2 = ">=1.1.1"
1717
openai = "<1.0.0"
1818
more-itertools = ">=8.0.0,<9.0.0"
19-
programlib = ">=9.0.2,<10.0.0"
19+
programlib = ">=12.0.4"
2020
wandb = "<1.0.0"
2121
gitpython = ">=3.0.0,<4.0.0"
2222
tenacity = ">=8.0.0,<9.0.0"
2323
pandas = ">=1.0.0,<2.0.0"
2424
fire = "<1.0.0"
2525
jsonlines = "^4.0.0"
26-
jupyterlab = "^4.0.7"
2726
black = "^23.10.1"
28-
langchain = "^0.0.326"
29-
pytest-codeblocks = "^0.17.0"
27+
langchain = "~=0.1"
28+
langchain-community = "~=0.2"
29+
langchain-anthropic = "~=0.1"
30+
pytest-codeblocks = "~=0.17"
31+
anthropic = "~=0.29"
3032

3133
[build-system]
3234
requires = ["poetry-core"]

seidr/dev.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
from programlib import Program, Language
44
from typing import Callable, Optional, Iterable, Tuple, List, Generator
55
import random
6+
import time
67

7-
from seidr.llm import explore_llm
8+
from seidr.llm import explore_llm, default_batch_size
89
from seidr.eval import Evaluation
910

1011

@@ -124,7 +125,8 @@ def __init__(self,
124125
log_llm_call: Callable = lambda **kwargs: print(kwargs),
125126
max_programs: Optional[int] = None,
126127
batch_size: Optional[int] = None,
127-
ollama_url: Optional[str] = None) -> None:
128+
ollama_url: Optional[str] = None,
129+
delay: int = 0) -> None:
128130
self.task_name = task_name
129131
self.task_description = task_description
130132
self.critics = critics
@@ -141,13 +143,11 @@ def __init__(self,
141143
self.log_llm_call = log_llm_call
142144
self.max_programs = max_programs
143145
self.ollama_url = ollama_url
146+
self.delay = delay
144147

145148
if not batch_size:
146-
if 'gpt' in model_name:
147-
self.batch_size = 10
148-
else:
149-
# Because Ollama doesn't support batch inference
150-
self.batch_size = 1
149+
batch_size = default_batch_size(model_name)
150+
self.batch_size = batch_size
151151

152152
def draft(self, start_code: str = '') -> Iterable[str]:
153153
"""Create a draft solution with the "generate" prompt template
@@ -286,4 +286,6 @@ def have_kids(
286286
if self.max_programs is not None and (idx == self.max_programs - 1):
287287
break
288288

289+
time.sleep(self.delay)
290+
289291
return best_code

seidr/eval.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class Evaluation(ABC):
1010
Produces a binary pass/fail result, a float score, and a text report
1111
"""
1212

13-
def __init__(self, SUT: Program, passing_score: float = 1.):
13+
def __init__(self, SUT, passing_score: float = 1.):
1414
"""
1515
SUT: System Under Test
1616
passing_score: float score required to pass the evaluation
@@ -97,3 +97,56 @@ def pen_report(self) -> str:
9797
else:
9898
self.output = "\n".join(self.output) if type(self.output) == list else self.output
9999
return self.output
100+
101+
class Gymnasium(Evaluation):
102+
def __init__(self, env, code, language, passing_score, error_reward=-1000):
103+
self.action_mode = type(env.action_space).__name__.lower()
104+
program = Program(code, language=language)
105+
super().__init__(program, passing_score)
106+
107+
self.env = env
108+
self.tot_reward = 0
109+
self.tot_txt = ''
110+
self.done = False
111+
self.error_reward = error_reward
112+
113+
def play(self):
114+
if self.done:
115+
return
116+
117+
self.tot_reward = 0
118+
self.tot_txt = ''
119+
agent = self.SUT.spawn(action_mode=self.action_mode)
120+
121+
try:
122+
observation, info = self.env.reset()
123+
self.tot_txt += info.get('memos', '')
124+
terminated = False
125+
truncated = False
126+
127+
while not (terminated or truncated):
128+
if 'ascii' in self.env.metadata.get('render.modes', []):
129+
ascii_render = self.env.render(mode='ascii')
130+
self.tot_txt += ascii_render
131+
132+
action, _ = agent.predict(observation, deterministic=True)
133+
134+
observation, reward, terminated, truncated, info = self.env.step(action)
135+
self.tot_reward += reward
136+
self.tot_txt += info.get('memos', '')
137+
except RuntimeError as e:
138+
self.tot_reward = self.error_reward
139+
self.tot_txt += f'FATAL {e}'
140+
finally:
141+
agent.close()
142+
143+
self.done = True
144+
145+
def score(self):
146+
self.play()
147+
return self.tot_reward
148+
149+
def pen_report(self):
150+
self.play()
151+
self.tot_txt += f'\nFinal reward: {self.tot_reward}'
152+
return self.tot_txt

seidr/github.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def ensure_repo(remote: str, path: pathlib.Path | str, branch: str = None) -> Re
5252
if branch:
5353
repo.git.checkout(branch)
5454
except GitError as e:
55-
logging.info(f'Git error in ensure repo {e}. \n{traceback.print_stack()}')
55+
logging.info(f'Git error in ensure repo {e}.')
5656
shutil.rmtree(path, ignore_errors=True)
5757
repo = Repo.clone_from(remote, path)
5858

seidr/llm.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from langchain.chains import LLMChain
55
from langchain.chat_models import ChatOpenAI, ChatOllama
6+
from langchain_anthropic import ChatAnthropic
67
from collections.abc import Iterable
78
from typing import Callable, Optional
89
import re
@@ -22,9 +23,9 @@ def extract_codes(
2223
language: Language | str
2324
) -> str:
2425
"""Extract code out of a message and (if Python) format it with black"""
26+
2527
try:
2628
code_blocks = list(extract_from_buffer(StringIO(message_content)))
27-
code_blocks = [code for code in code_blocks if not bool(code)]
2829
except RuntimeError as e:
2930
code_blocks = []
3031

@@ -46,6 +47,20 @@ def run_black(code: str) -> str:
4647
logging.info(e)
4748
return code
4849

50+
def which_api(model_name):
51+
model_name = model_name.lower()
52+
if "gpt" in model_name or "deepseek" in model_name:
53+
return ChatOpenAI
54+
elif "claude" in model_name:
55+
return ChatAnthropic
56+
else:
57+
return ChatOllama
58+
59+
def default_batch_size(model_name):
60+
if which_api(model_name) == ChatOllama:
61+
return 1
62+
else:
63+
return 10
4964

5065
def create_chain(
5166
temperature: float = 0.,
@@ -55,14 +70,23 @@ def create_chain(
5570
) -> LLMChain:
5671
"""Set up a LangChain LLMChain"""
5772
chat_prompt_template = create_chat_prompt_template(mode)
58-
if "gpt" in model_name.lower():
73+
api = which_api(model_name)
74+
75+
if api == ChatOpenAI:
5976
chat_model = ChatOpenAI(
6077
model=model_name,
6178
temperature=temperature,
79+
openai_api_base=os.getenv("OPENAI_API_BASE"),
6280
openai_api_key=os.getenv("OPENAI_API_KEY"),
6381
openai_organization=os.getenv("OPENAI_ORG")
6482
)
65-
elif "llama" in model_name.lower():
83+
elif api == ChatAnthropic:
84+
chat_model = ChatAnthropic(
85+
model_name=model_name,
86+
temperature=temperature,
87+
anthropic_api_key=os.getenv('ANTHROPIC_API_KEY')
88+
)
89+
elif api == ChatOllama:
6690
chat_model = ChatOllama(
6791
base_url=base_url,
6892
model=model_name,
@@ -90,7 +114,7 @@ def query_llm(
90114
# Assistants are trained to respond with one message.
91115
# it is theoretically possible to get more than one message, but it is very unlikely.
92116
assert all(len(r) == 1 for r in result.generations), "The models are expected to respond with one message"
93-
result = [r[0].message.content for r in result.generations if r[0].message.content]
117+
result = [r[0].message.content for r in result.generations]
94118

95119
if mode == "repair":
96120
logging.info(f"Generating repair candidates for bug summary: \n{kwargs['bug_summary']}\n")

0 commit comments

Comments
 (0)