|
1 | 1 | from openai import OpenAI
|
2 | 2 | import pdb
|
| 3 | +import os |
| 4 | +from dotenv import load_dotenv |
3 | 5 | from langchain_openai import ChatOpenAI
|
4 | 6 | from langchain_core.globals import get_llm_cache
|
5 | 7 | from langchain_core.language_models.base import (
|
|
29 | 31 | from langchain_core.output_parsers.base import OutputParserLike
|
30 | 32 | from langchain_core.runnables import Runnable, RunnableConfig
|
31 | 33 | from langchain_core.tools import BaseTool
|
| 34 | +from pydantic import Field, PrivateAttr |
| 35 | +import requests |
| 36 | +import urllib3 |
32 | 37 |
|
33 | 38 | from typing import (
|
34 | 39 | TYPE_CHECKING,
|
@@ -103,6 +108,72 @@ def invoke(
|
103 | 108 | return AIMessage(content=content, reasoning_content=reasoning_content)
|
104 | 109 |
|
105 | 110 |
|
| 111 | +# Load environment variables |
| 112 | +load_dotenv() |
| 113 | + |
| 114 | +class UnboundChatOpenAI(ChatOpenAI): |
| 115 | + """Chat model that uses Unbound's API.""" |
| 116 | + |
| 117 | + _session: requests.Session = PrivateAttr() |
| 118 | + |
| 119 | + def __init__(self, *args: Any, **kwargs: Any) -> None: |
| 120 | + kwargs["base_url"] = kwargs.get("base_url", os.getenv("UNBOUND_ENDPOINT", "https://api.getunbound.ai")) |
| 121 | + kwargs["api_key"] = kwargs.get("api_key", os.getenv("UNBOUND_API_KEY")) |
| 122 | + if not kwargs["api_key"]: |
| 123 | + raise ValueError("UNBOUND_API_KEY environment variable is not set") |
| 124 | + super().__init__(*args, **kwargs) |
| 125 | + |
| 126 | + self.client = OpenAI( |
| 127 | + base_url=kwargs["base_url"], |
| 128 | + api_key=kwargs["api_key"] |
| 129 | + ) |
| 130 | + |
| 131 | + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) |
| 132 | + |
| 133 | + self._session = requests.Session() |
| 134 | + self._session.verify = False |
| 135 | + |
| 136 | + def invoke( |
| 137 | + self, |
| 138 | + input: LanguageModelInput, |
| 139 | + config: Optional[RunnableConfig] = None, |
| 140 | + *, |
| 141 | + stop: Optional[list[str]] = None, |
| 142 | + **kwargs: Any, |
| 143 | + ) -> AIMessage: |
| 144 | + message_history = [] |
| 145 | + for input_ in input: |
| 146 | + if isinstance(input_, SystemMessage): |
| 147 | + message_history.append({"role": "system", "content": input_.content}) |
| 148 | + elif isinstance(input_, AIMessage): |
| 149 | + message_history.append({"role": "assistant", "content": input_.content}) |
| 150 | + else: |
| 151 | + message_history.append({"role": "user", "content": input_.content}) |
| 152 | + |
| 153 | + response = self._session.post( |
| 154 | + f"{self.client.base_url}/v1/chat/completions", |
| 155 | + headers={"Authorization": f"Bearer {self.client.api_key}", "Content-Type": "application/json"}, |
| 156 | + json={ |
| 157 | + "model": self.model_name or "gpt-4o-mini", |
| 158 | + "messages": message_history |
| 159 | + } |
| 160 | + ) |
| 161 | + response.raise_for_status() |
| 162 | + data = response.json() |
| 163 | + content = data["choices"][0]["message"]["content"] |
| 164 | + return AIMessage(content=content) |
| 165 | + |
| 166 | + async def ainvoke( |
| 167 | + self, |
| 168 | + input: LanguageModelInput, |
| 169 | + config: Optional[RunnableConfig] = None, |
| 170 | + *, |
| 171 | + stop: Optional[list[str]] = None, |
| 172 | + **kwargs: Any, |
| 173 | + ) -> AIMessage: |
| 174 | + return self.invoke(input, config, stop=stop, **kwargs) |
| 175 | + |
| 176 | + |
106 | 177 | class DeepSeekR1ChatOllama(ChatOllama):
|
107 | 178 |
|
108 | 179 | async def ainvoke(
|
|
0 commit comments