33
44from langchain .chains import LLMChain
55from langchain .chat_models import ChatOpenAI , ChatOllama
6+ from langchain_anthropic import ChatAnthropic
67from collections .abc import Iterable
78from typing import Callable , Optional
89import re
@@ -22,9 +23,9 @@ def extract_codes(
2223 language : Language | str
2324) -> str :
2425 """Extract code out of a message and (if Python) format it with black"""
26+
2527 try :
2628 code_blocks = list (extract_from_buffer (StringIO (message_content )))
27- code_blocks = [code for code in code_blocks if not bool (code )]
2829 except RuntimeError as e :
2930 code_blocks = []
3031
@@ -46,6 +47,20 @@ def run_black(code: str) -> str:
4647 logging .info (e )
4748 return code
4849
50+ def which_api (model_name ):
51+ model_name = model_name .lower ()
52+ if "gpt" in model_name or "deepseek" in model_name :
53+ return ChatOpenAI
54+ elif "claude" in model_name :
55+ return ChatAnthropic
56+ else :
57+ return ChatOllama
58+
59+ def default_batch_size (model_name ):
60+ if which_api (model_name ) == ChatOllama :
61+ return 1
62+ else :
63+ return 10
4964
5065def create_chain (
5166 temperature : float = 0. ,
@@ -55,14 +70,23 @@ def create_chain(
5570) -> LLMChain :
5671 """Set up a LangChain LLMChain"""
5772 chat_prompt_template = create_chat_prompt_template (mode )
58- if "gpt" in model_name .lower ():
73+ api = which_api (model_name )
74+
75+ if api == ChatOpenAI :
5976 chat_model = ChatOpenAI (
6077 model = model_name ,
6178 temperature = temperature ,
79+ openai_api_base = os .getenv ("OPENAI_API_BASE" ),
6280 openai_api_key = os .getenv ("OPENAI_API_KEY" ),
6381 openai_organization = os .getenv ("OPENAI_ORG" )
6482 )
65- elif "llama" in model_name .lower ():
83+ elif api == ChatAnthropic :
84+ chat_model = ChatAnthropic (
85+ model_name = model_name ,
86+ temperature = temperature ,
87+ anthropic_api_key = os .getenv ('ANTHROPIC_API_KEY' )
88+ )
89+ elif api == ChatOllama :
6690 chat_model = ChatOllama (
6791 base_url = base_url ,
6892 model = model_name ,
@@ -90,7 +114,7 @@ def query_llm(
90114 # Assistants are trained to respond with one message.
91115 # it is theoretically possible to get more than one message, but it is very unlikely.
92116 assert all (len (r ) == 1 for r in result .generations ), "The models are expected to respond with one message"
93- result = [r [0 ].message .content for r in result .generations if r [ 0 ]. message . content ]
117+ result = [r [0 ].message .content for r in result .generations ]
94118
95119 if mode == "repair" :
96120 logging .info (f"Generating repair candidates for bug summary: \n { kwargs ['bug_summary' ]} \n " )
0 commit comments