Skip to content

Commit f48beed

Browse files
committed
use existing ChatOpenAI instead of Unbound class and remove loadDotEnv
1 parent ebf9a06 commit f48beed

File tree

2 files changed

+3
-79
lines changed

2 files changed

+3
-79
lines changed

src/utils/llm.py

Lines changed: 0 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from openai import OpenAI
22
import pdb
3-
import os
4-
from dotenv import load_dotenv
53
from langchain_openai import ChatOpenAI
64
from langchain_core.globals import get_llm_cache
75
from langchain_core.language_models.base import (
@@ -31,9 +29,6 @@
3129
from langchain_core.output_parsers.base import OutputParserLike
3230
from langchain_core.runnables import Runnable, RunnableConfig
3331
from langchain_core.tools import BaseTool
34-
from pydantic import Field, PrivateAttr
35-
import requests
36-
import urllib3
3732

3833
from typing import (
3934
TYPE_CHECKING,
@@ -108,72 +103,6 @@ def invoke(
108103
return AIMessage(content=content, reasoning_content=reasoning_content)
109104

110105

111-
# Load environment variables
112-
load_dotenv()
113-
114-
class UnboundChatOpenAI(ChatOpenAI):
115-
"""Chat model that uses Unbound's API."""
116-
117-
_session: requests.Session = PrivateAttr()
118-
119-
def __init__(self, *args: Any, **kwargs: Any) -> None:
120-
kwargs["base_url"] = kwargs.get("base_url", os.getenv("UNBOUND_ENDPOINT", "https://api.getunbound.ai"))
121-
kwargs["api_key"] = kwargs.get("api_key", os.getenv("UNBOUND_API_KEY"))
122-
if not kwargs["api_key"]:
123-
raise ValueError("UNBOUND_API_KEY environment variable is not set")
124-
super().__init__(*args, **kwargs)
125-
126-
self.client = OpenAI(
127-
base_url=kwargs["base_url"],
128-
api_key=kwargs["api_key"]
129-
)
130-
131-
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
132-
133-
self._session = requests.Session()
134-
self._session.verify = False
135-
136-
def invoke(
137-
self,
138-
input: LanguageModelInput,
139-
config: Optional[RunnableConfig] = None,
140-
*,
141-
stop: Optional[list[str]] = None,
142-
**kwargs: Any,
143-
) -> AIMessage:
144-
message_history = []
145-
for input_ in input:
146-
if isinstance(input_, SystemMessage):
147-
message_history.append({"role": "system", "content": input_.content})
148-
elif isinstance(input_, AIMessage):
149-
message_history.append({"role": "assistant", "content": input_.content})
150-
else:
151-
message_history.append({"role": "user", "content": input_.content})
152-
153-
response = self._session.post(
154-
f"{self.client.base_url}/v1/chat/completions",
155-
headers={"Authorization": f"Bearer {self.client.api_key}", "Content-Type": "application/json"},
156-
json={
157-
"model": self.model_name or "gpt-4o-mini",
158-
"messages": message_history
159-
}
160-
)
161-
response.raise_for_status()
162-
data = response.json()
163-
content = data["choices"][0]["message"]["content"]
164-
return AIMessage(content=content)
165-
166-
async def ainvoke(
167-
self,
168-
input: LanguageModelInput,
169-
config: Optional[RunnableConfig] = None,
170-
*,
171-
stop: Optional[list[str]] = None,
172-
**kwargs: Any,
173-
) -> AIMessage:
174-
return self.invoke(input, config, stop=stop, **kwargs)
175-
176-
177106
class DeepSeekR1ChatOllama(ChatOllama):
178107

179108
async def ainvoke(

src/utils/utils.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from langchain_ollama import ChatOllama
1515
from langchain_openai import AzureChatOpenAI, ChatOpenAI
1616

17-
from .llm import DeepSeekR1ChatOpenAI, DeepSeekR1ChatOllama, UnboundChatOpenAI
17+
from .llm import DeepSeekR1ChatOpenAI, DeepSeekR1ChatOllama
1818

1919
PROVIDER_DISPLAY_NAMES = {
2020
"openai": "OpenAI",
@@ -162,15 +162,10 @@ def get_llm_model(provider: str, **kwargs):
162162
api_key=os.getenv("MOONSHOT_API_KEY"),
163163
)
164164
elif provider == "unbound":
165-
if not kwargs.get("base_url", ""):
166-
base_url = os.getenv("UNBOUND_ENDPOINT", "https://api.getunbound.ai")
167-
else:
168-
base_url = kwargs.get("base_url")
169-
170-
return UnboundChatOpenAI(
165+
return ChatOpenAI(
171166
model=kwargs.get("model_name", "gpt-4o-mini"),
172167
temperature=kwargs.get("temperature", 0.0),
173-
base_url=base_url,
168+
base_url = os.getenv("UNBOUND_ENDPOINT", "https://api.getunbound.ai"),
174169
api_key=api_key,
175170
)
176171
else:

0 commit comments

Comments
 (0)