Skip to content

Commit 61de4e8

Browse files
committed
修改为直接使用OpenAIChat
1 parent d711c85 commit 61de4e8

File tree

2 files changed

+2
-92
lines changed

2 files changed

+2
-92
lines changed

src/utils/llm.py

Lines changed: 0 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -136,91 +136,3 @@ def invoke(
136136
if "**JSON Response:**" in content:
137137
content = content.split("**JSON Response:**")[-1]
138138
return AIMessage(content=content, reasoning_content=reasoning_content)
139-
140-
141-
class SiliconFlowChat(ChatOpenAI):
142-
"""Wrapper for SiliconFlow Chat API, fully compatible with OpenAI-spec format."""
143-
144-
def __init__(self, *args: Any, **kwargs: Any) -> None:
145-
super().__init__(*args, **kwargs)
146-
147-
# Ensure the API client is initialized with SiliconFlow's endpoint and key
148-
self.client = OpenAI(
149-
api_key=kwargs.get("api_key"),
150-
base_url=kwargs.get("base_url")
151-
)
152-
153-
async def ainvoke(
154-
self,
155-
input: LanguageModelInput,
156-
config: Optional[RunnableConfig] = None,
157-
*,
158-
stop: Optional[List[str]] = None,
159-
**kwargs: Any,
160-
) -> AIMessage:
161-
"""Async call SiliconFlow API."""
162-
163-
# Convert input messages into OpenAI-compatible format
164-
message_history = []
165-
for input_msg in input:
166-
if isinstance(input_msg, SystemMessage):
167-
message_history.append({"role": "system", "content": input_msg.content})
168-
elif isinstance(input_msg, AIMessage):
169-
message_history.append({"role": "assistant", "content": input_msg.content})
170-
else: # HumanMessage or similar
171-
message_history.append({"role": "user", "content": input_msg.content})
172-
173-
# Send request to SiliconFlow API (OpenAI-spec endpoint)
174-
response = await self.client.chat.completions.create(
175-
model=self.model_name,
176-
messages=message_history,
177-
stop=stop,
178-
**kwargs,
179-
)
180-
181-
# Extract the AI response (SiliconFlow's response must match OpenAI format)
182-
if hasattr(response.choices[0].message, "reasoning_content"):
183-
reasoning_content = response.choices[0].message.reasoning_content
184-
else:
185-
reasoning_content = None
186-
187-
content = response.choices[0].message.content
188-
return AIMessage(content=content, reasoning_content=reasoning_content) # Return reasoning_content if needed
189-
190-
def invoke(
191-
self,
192-
input: LanguageModelInput,
193-
config: Optional[RunnableConfig] = None,
194-
*,
195-
stop: Optional[List[str]] = None,
196-
**kwargs: Any,
197-
) -> AIMessage:
198-
"""Sync call SiliconFlow API."""
199-
200-
# Same conversion as async version
201-
message_history = []
202-
for input_msg in input:
203-
if isinstance(input_msg, SystemMessage):
204-
message_history.append({"role": "system", "content": input_msg.content})
205-
elif isinstance(input_msg, AIMessage):
206-
message_history.append({"role": "assistant", "content": input_msg.content})
207-
else:
208-
message_history.append({"role": "user", "content": input_msg.content})
209-
210-
# Sync call
211-
response = self.client.chat.completions.create(
212-
model=self.model_name,
213-
messages=message_history,
214-
stop=stop,
215-
**kwargs,
216-
)
217-
218-
# Handle reasoning_content (if supported)
219-
reasoning_content = None
220-
if hasattr(response.choices[0].message, "reasoning_content"):
221-
reasoning_content = response.choices[0].message.reasoning_content
222-
223-
return AIMessage(
224-
content=response.choices[0].message.content,
225-
reasoning_content=reasoning_content, # Only if SiliconFlow supports it
226-
)

src/utils/utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from langchain_ollama import ChatOllama
1515
from langchain_openai import AzureChatOpenAI, ChatOpenAI
1616

17-
from .llm import DeepSeekR1ChatOpenAI, DeepSeekR1ChatOllama,SiliconFlowChat
17+
from .llm import DeepSeekR1ChatOpenAI, DeepSeekR1ChatOllama
1818

1919
PROVIDER_DISPLAY_NAMES = {
2020
"openai": "OpenAI",
@@ -177,13 +177,11 @@ def get_llm_model(provider: str, **kwargs):
177177
base_url = os.getenv("SiliconFLOW_ENDPOINT", "")
178178
else:
179179
base_url = kwargs.get("base_url")
180-
return SiliconFlowChat(
180+
return ChatOpenAI(
181181
api_key=api_key,
182182
base_url=base_url,
183183
model_name=kwargs.get("model_name", "Qwen/QwQ-32B"),
184184
temperature=kwargs.get("temperature", 0.0),
185-
max_tokens=kwargs.get("max_tokens", 512),
186-
frequency_penalty=kwargs.get("frequency_penalty", 0.5),
187185
)
188186
else:
189187
raise ValueError(f"Unsupported provider: {provider}")

0 commit comments

Comments
 (0)