Skip to content

Commit d82b603

Browse files
committed
refactor(agent): 重构多个agent类继承自BaseAgent
将browser_use_agent、deep_analyzer_agent、deep_researcher_agent和planning_agent重构为继承自新创建的BaseAgent类,减少代码重复。BaseAgent类包含了这些agent共有的核心功能实现。 主要变更: 1. 创建BaseAgent基类包含共享逻辑 2. 简化各子类实现,仅保留特定配置 3. 统一初始化流程 4. 添加.gitignore忽略.idea目录
1 parent bb47165 commit d82b603

File tree

6 files changed

+565
-1129
lines changed

6 files changed

+565
-1129
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,5 @@ data/
180180

181181
# workdir
182182
workdir/
183+
184+
.idea

src/agent/base_agent.py

Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
from typing import (
2+
Any,
3+
Callable,
4+
Optional
5+
)
6+
import json
7+
import yaml
8+
from rich.panel import Panel
9+
from rich.text import Text
10+
11+
from src.tools import AsyncTool
12+
from src.exception import (
13+
AgentGenerationError,
14+
AgentParsingError,
15+
AgentToolExecutionError,
16+
AgentToolCallError,
17+
logger as exception_logger # Use exception logger if it differs
18+
)
19+
from src.base.async_multistep_agent import (PromptTemplates,
20+
populate_template,
21+
AsyncMultiStepAgent)
22+
from src.memory import (ActionStep,
23+
ToolCall,
24+
AgentMemory)
25+
from src.logger import (
26+
LogLevel,
27+
YELLOW_HEX
28+
# logger # AsyncMultiStepAgent creates its own self.logger
29+
)
30+
from src.models import Model, parse_json_if_needed, ChatMessage
31+
from src.utils.agent_types import (
32+
AgentAudio,
33+
AgentImage,
34+
)
35+
from src.utils import assemble_project_path
36+
37+
class BaseAgent(AsyncMultiStepAgent):
38+
"""Base class for agents with common logic."""
39+
AGENT_NAME = "base_agent" # Must be overridden by subclasses
40+
41+
def __init__(
42+
self,
43+
config, # Specific configuration object for the agent
44+
tools: list[AsyncTool],
45+
model: Model,
46+
prompt_templates_path: str, # Path to the prompt templates file
47+
prompt_templates: PromptTemplates | None = None, # For preloaded templates
48+
max_steps: int = 20,
49+
add_base_tools: bool = False,
50+
verbosity_level: LogLevel = LogLevel.INFO,
51+
grammar: dict[str, str] | None = None,
52+
managed_agents: list | None = None,
53+
step_callbacks: list[Callable] | None = None,
54+
planning_interval: int | None = None,
55+
name: str | None = None, # AGENT_NAME will be used if not specified
56+
description: str | None = None,
57+
provide_run_summary: bool = False,
58+
final_answer_checks: list[Callable] | None = None,
59+
**kwargs
60+
):
61+
self.config = config # Save config for possible access by subclasses
62+
63+
agent_name_to_use = name if name is not None else self.AGENT_NAME
64+
65+
super().__init__(
66+
tools=tools,
67+
model=model,
68+
prompt_templates=None, # Initialize as None, load later
69+
max_steps=max_steps,
70+
add_base_tools=add_base_tools,
71+
verbosity_level=verbosity_level,
72+
grammar=grammar,
73+
managed_agents=managed_agents,
74+
step_callbacks=step_callbacks,
75+
planning_interval=planning_interval,
76+
name=agent_name_to_use, # Use the defined agent name
77+
description=description,
78+
provide_run_summary=provide_run_summary,
79+
final_answer_checks=final_answer_checks,
80+
**kwargs # Pass remaining arguments to the parent class
81+
)
82+
83+
# Loading prompt_templates
84+
if prompt_templates:
85+
self.prompt_templates = prompt_templates
86+
else:
87+
abs_template_path = assemble_project_path(prompt_templates_path)
88+
with open(abs_template_path, "r", encoding='utf-8') as f:
89+
self.prompt_templates = yaml.safe_load(f)
90+
91+
self.system_prompt = self.initialize_system_prompt()
92+
self.user_prompt = self.initialize_user_prompt()
93+
94+
self.memory = AgentMemory(
95+
system_prompt=self.system_prompt,
96+
user_prompt=self.user_prompt,
97+
)
98+
# self.logger is inherited from AsyncMultiStepAgent and uses agent_name_to_use
99+
100+
def initialize_system_prompt(self) -> str:
101+
"""Initialize the system prompt for the agent."""
102+
system_prompt = populate_template(
103+
self.prompt_templates["system_prompt"],
104+
variables={"tools": self.tools, "managed_agents": self.managed_agents},
105+
)
106+
return system_prompt
107+
108+
def initialize_user_prompt(self) -> str:
109+
110+
user_prompt = populate_template(
111+
self.prompt_templates["user_prompt"],
112+
variables={},
113+
)
114+
115+
return user_prompt
116+
117+
def initialize_task_instruction(self) -> str:
118+
"""Initialize the task instruction for the agent."""
119+
# self.task is set by the __call__ method of AsyncMultiStepAgent at runtime
120+
task_instruction = populate_template(
121+
self.prompt_templates["task_instruction"],
122+
variables={"task": self.task},
123+
)
124+
return task_instruction
125+
126+
def _substitute_state_variables(self, arguments: dict[str, str] | str) -> dict[str, Any] | str:
127+
"""Replace string values in arguments with their corresponding state values if they exist."""
128+
if isinstance(arguments, dict):
129+
return {
130+
key: self.state.get(value, value) if isinstance(value, str) else value
131+
for key, value in arguments.items()
132+
}
133+
return arguments
134+
135+
async def execute_tool_call(self, tool_name: str, arguments: dict[str, str] | str) -> Any:
136+
"""
137+
Execute a tool or managed agent with the provided arguments.
138+
139+
The arguments are replaced with the actual values from the state if they refer to state variables.
140+
141+
Args:
142+
tool_name (`str`): Name of the tool or managed agent to execute.
143+
arguments (dict[str, str] | str): Arguments passed to the tool call.
144+
"""
145+
# Check if the tool exists
146+
available_tools = {**self.tools, **self.managed_agents}
147+
if tool_name not in available_tools:
148+
raise AgentToolExecutionError(
149+
f"Unknown tool {tool_name}, should be one of: {', '.join(available_tools)}.", self.logger
150+
)
151+
152+
# Get the tool and substitute state variables in arguments
153+
tool = available_tools[tool_name]
154+
arguments = self._substitute_state_variables(arguments)
155+
is_managed_agent = tool_name in self.managed_agents
156+
157+
try:
158+
# Call tool with appropriate arguments
159+
if isinstance(arguments, dict):
160+
return await tool(**arguments) if is_managed_agent else await tool(**arguments, sanitize_inputs_outputs=True)
161+
elif isinstance(arguments, str):
162+
return await tool(arguments) if is_managed_agent else await tool(arguments, sanitize_inputs_outputs=True)
163+
else:
164+
raise TypeError(f"Unsupported arguments type: {type(arguments)}")
165+
166+
except TypeError as e:
167+
# Handle invalid arguments
168+
description = getattr(tool, "description", "No description")
169+
if is_managed_agent:
170+
error_msg = (
171+
f"Invalid request to team member '{tool_name}' with arguments {json.dumps(arguments, ensure_ascii=False)}: {e}\n"
172+
"You should call this team member with a valid request.\n"
173+
f"Team member description: {description}"
174+
)
175+
else:
176+
error_msg = (
177+
f"Invalid call to tool '{tool_name}' with arguments {json.dumps(arguments, ensure_ascii=False)}: {e}\n"
178+
"You should call this tool with correct input arguments.\n"
179+
f"Expected inputs: {json.dumps(tool.parameters)}\n"
180+
f"Returns output type: {tool.output_type}\n"
181+
f"Tool description: '{description}'"
182+
)
183+
raise AgentToolCallError(error_msg, self.logger) from e
184+
185+
except Exception as e:
186+
# Handle execution errors
187+
if is_managed_agent:
188+
error_msg = (
189+
f"Error executing request to team member '{tool_name}' with arguments {json.dumps(arguments)}: {e}\n"
190+
"Please try again or request to another team member"
191+
)
192+
else:
193+
error_msg = (
194+
f"Error executing tool '{tool_name}' with arguments {json.dumps(arguments)}: {type(e).__name__}: {e}\n"
195+
"Please try again or use another tool"
196+
)
197+
raise AgentToolExecutionError(error_msg, self.logger) from e
198+
199+
async def step(self, memory_step: ActionStep) -> None | Any:
200+
"""
201+
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
202+
Returns None if the step is not final.
203+
"""
204+
memory_messages = await self.write_memory_to_messages()
205+
206+
input_messages = memory_messages.copy()
207+
208+
# Add new step in logs
209+
memory_step.model_input_messages = input_messages
210+
211+
try:
212+
chat_message: ChatMessage = await self.model(
213+
input_messages,
214+
stop_sequences=["Observation:", "Calling tools:"],
215+
tools_to_call_from=list(self.tools.values()),
216+
)
217+
memory_step.model_output_message = chat_message
218+
model_output = chat_message.content
219+
self.logger.log_markdown(
220+
content=model_output if model_output else str(chat_message.raw),
221+
title="Output message of the LLM:",
222+
level=LogLevel.DEBUG,
223+
)
224+
memory_step.model_output_message.content = model_output
225+
memory_step.model_output = model_output
226+
except Exception as e:
227+
raise AgentGenerationError(f"Error while generating output:\n{e}", self.logger) from e
228+
229+
if chat_message.tool_calls is None or len(chat_message.tool_calls) == 0:
230+
try:
231+
# Attempt to parse tool calls if they were not automatically populated
232+
chat_message = self.model.parse_tool_calls(chat_message)
233+
except Exception as e:
234+
# If parsing failed and there is model_output, it can be considered a direct answer
235+
if model_output:
236+
self.logger.log(
237+
Text(f"Tool call not detected. Processing model output as final answer: {model_output}", style=f"bold {YELLOW_HEX}"),
238+
level=LogLevel.INFO,
239+
)
240+
memory_step.action_output = model_output
241+
return model_output
242+
raise AgentParsingError(f"Error while parsing tool call from model output: {e}", self.logger)
243+
244+
# If there are still no tool calls after attempting to parse
245+
if not chat_message.tool_calls:
246+
if model_output:
247+
self.logger.log(
248+
Text(f"Tool call not detected after parsing. Processing model output as final answer: {model_output}", style=f"bold {YELLOW_HEX}"),
249+
level=LogLevel.INFO,
250+
)
251+
memory_step.action_output = model_output
252+
return model_output
253+
else:
254+
# If there are no tool calls and no content, it's an error
255+
raise AgentParsingError("Tool call not found, and there is no content in the model output.", self.logger)
256+
257+
# Continue if there are tool calls
258+
for tool_call in chat_message.tool_calls:
259+
tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments)
260+
261+
tool_call = chat_message.tool_calls[0]
262+
tool_name, tool_call_id = tool_call.function.name, tool_call.id
263+
tool_arguments = tool_call.function.arguments
264+
memory_step.model_output = str(f"Called Tool: '{tool_name}' with arguments: {tool_arguments}")
265+
memory_step.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)]
266+
267+
# Execute
268+
self.logger.log(
269+
Panel(Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}")),
270+
level=LogLevel.INFO,
271+
)
272+
if tool_name == "final_answer":
273+
if isinstance(tool_arguments, dict):
274+
result = tool_arguments.get("result", tool_arguments)
275+
else:
276+
result = tool_arguments
277+
if (
278+
isinstance(result, str) and result in self.state.keys()
279+
): # if the answer is a state variable, return the value
280+
final_result = self.state[result]
281+
self.logger.log(
282+
f"[bold {YELLOW_HEX}]Final answer:[/bold {YELLOW_HEX}] Extracting key '{result}' from state to return value '{final_result}'.",
283+
level=LogLevel.INFO,
284+
)
285+
else:
286+
final_result = result
287+
self.logger.log(
288+
Text(f"Final result: {final_result}", style=f"bold {YELLOW_HEX}"),
289+
level=LogLevel.INFO,
290+
)
291+
292+
memory_step.action_output = final_result
293+
return final_result
294+
else:
295+
tool_args_to_pass = tool_arguments if tool_arguments is not None else {}
296+
observation = await self.execute_tool_call(tool_name, tool_args_to_pass)
297+
observation_type = type(observation)
298+
299+
if observation_type in [AgentImage, AgentAudio]:
300+
observation_name = "image.png" if observation_type == AgentImage else "audio.mp3"
301+
# TODO: observation naming could allow for different names of same type
302+
303+
self.state[observation_name] = observation
304+
updated_information = f"Stored '{observation_name}' in memory."
305+
else:
306+
updated_information = str(observation).strip()
307+
308+
self.logger.log(
309+
f"Observations: {updated_information.replace('[', '|')}", # escape potential rich-tag-like components
310+
level=LogLevel.INFO,
311+
)
312+
memory_step.observations = updated_information
313+
return None

0 commit comments

Comments
 (0)