Skip to content

Commit eafce12

Browse files
committed
refactor(base_agent): 优化工具调用处理和输出消息逻辑
简化工具调用处理流程,移除冗余的条件判断 重构输出消息处理逻辑,提高代码可读性 统一工具参数处理方式,避免重复代码
1 parent eda439c commit eafce12

File tree

1 file changed

+20
-40
lines changed

1 file changed

+20
-40
lines changed

src/agent/base_agent.py

Lines changed: 20 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
Callable,
44
Optional
55
)
6-
import json
6+
77
import yaml
8+
import json
89
from rich.panel import Panel
910
from rich.text import Text
1011

@@ -21,11 +22,9 @@
2122
from src.memory import (ActionStep,
2223
ToolCall,
2324
AgentMemory)
24-
from src.logger import (
25-
LogLevel,
26-
YELLOW_HEX,
27-
logger # AsyncMultiStepAgent creates its own self.logger
28-
)
25+
from src.logger import (LogLevel,
26+
YELLOW_HEX,
27+
logger)
2928
from src.models import Model, parse_json_if_needed, ChatMessage
3029
from src.utils.agent_types import (
3130
AgentAudio,
@@ -94,7 +93,6 @@ def __init__(
9493
system_prompt=self.system_prompt,
9594
user_prompt=self.user_prompt,
9695
)
97-
# self.logger is inherited from AsyncMultiStepAgent and uses agent_name_to_use
9896

9997
def initialize_system_prompt(self) -> str:
10098
"""Initialize the system prompt for the agent."""
@@ -115,7 +113,6 @@ def initialize_user_prompt(self) -> str:
115113

116114
def initialize_task_instruction(self) -> str:
117115
"""Initialize the task instruction for the agent."""
118-
# self.task is set by the __call__ method of AsyncMultiStepAgent at runtime
119116
task_instruction = populate_template(
120117
self.prompt_templates["task_instruction"],
121118
variables={"task": self.task},
@@ -220,42 +217,20 @@ async def step(self, memory_step: ActionStep) -> None | Any:
220217
title="Output message of the LLM:",
221218
level=LogLevel.DEBUG,
222219
)
220+
223221
memory_step.model_output_message.content = model_output
224222
memory_step.model_output = model_output
225223
except Exception as e:
226224
raise AgentGenerationError(f"Error while generating output:\n{e}", self.logger) from e
227225

228226
if chat_message.tool_calls is None or len(chat_message.tool_calls) == 0:
229227
try:
230-
# Attempt to parse tool calls if they were not automatically populated
231228
chat_message = self.model.parse_tool_calls(chat_message)
232229
except Exception as e:
233-
# If parsing failed and there is model_output, it can be considered a direct answer
234-
if model_output:
235-
self.logger.log(
236-
Text(f"Tool call not detected. Processing model output as final answer: {model_output}", style=f"bold {YELLOW_HEX}"),
237-
level=LogLevel.INFO,
238-
)
239-
memory_step.action_output = model_output
240-
return model_output
241230
raise AgentParsingError(f"Error while parsing tool call from model output: {e}", self.logger)
242-
243-
# If there are still no tool calls after attempting to parse
244-
if not chat_message.tool_calls:
245-
if model_output:
246-
self.logger.log(
247-
Text(f"Tool call not detected after parsing. Processing model output as final answer: {model_output}", style=f"bold {YELLOW_HEX}"),
248-
level=LogLevel.INFO,
249-
)
250-
memory_step.action_output = model_output
251-
return model_output
252-
else:
253-
# If there are no tool calls and no content, it's an error
254-
raise AgentParsingError("Tool call not found, and there is no content in the model output.", self.logger)
255-
256-
# Continue if there are tool calls
257-
for tool_call in chat_message.tool_calls:
258-
tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments)
231+
else:
232+
for tool_call in chat_message.tool_calls:
233+
tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments)
259234

260235
tool_call = chat_message.tool_calls[0]
261236
tool_name, tool_call_id = tool_call.function.name, tool_call.id
@@ -270,7 +245,10 @@ async def step(self, memory_step: ActionStep) -> None | Any:
270245
)
271246
if tool_name == "final_answer":
272247
if isinstance(tool_arguments, dict):
273-
result = tool_arguments.get("result", tool_arguments)
248+
if "result" in tool_arguments:
249+
result = tool_arguments["result"]
250+
else:
251+
result = tool_arguments
274252
else:
275253
result = tool_arguments
276254
if (
@@ -291,19 +269,21 @@ async def step(self, memory_step: ActionStep) -> None | Any:
291269
memory_step.action_output = final_result
292270
return final_result
293271
else:
294-
tool_args_to_pass = tool_arguments if tool_arguments is not None else {}
295-
observation = await self.execute_tool_call(tool_name, tool_args_to_pass)
272+
if tool_arguments is None:
273+
tool_arguments = {}
274+
observation = await self.execute_tool_call(tool_name, tool_arguments)
296275
observation_type = type(observation)
297-
298276
if observation_type in [AgentImage, AgentAudio]:
299-
observation_name = "image.png" if observation_type == AgentImage else "audio.mp3"
277+
if observation_type == AgentImage:
278+
observation_name = "image.png"
279+
elif observation_type == AgentAudio:
280+
observation_name = "audio.mp3"
300281
# TODO: observation naming could allow for different names of same type
301282

302283
self.state[observation_name] = observation
303284
updated_information = f"Stored '{observation_name}' in memory."
304285
else:
305286
updated_information = str(observation).strip()
306-
307287
self.logger.log(
308288
f"Observations: {updated_information.replace('[', '|')}", # escape potential rich-tag-like components
309289
level=LogLevel.INFO,

0 commit comments

Comments
 (0)