Skip to content

Commit 713824a

Browse files
committed
refactor(agent): 重构多个agent类继承自BaseAgent
将browser_use_agent、deep_analyzer_agent、deep_researcher_agent和planning_agent重构为继承自新创建的BaseAgent类,减少代码重复。BaseAgent类包含了这些agent共有的核心功能实现。 主要变更: 1. 创建BaseAgent基类包含共享逻辑 2. 简化各子类实现,仅保留特定配置 3. 统一初始化流程 4. 添加.gitignore忽略.idea目录
1 parent 07b7eea commit 713824a

File tree

6 files changed

+550
-1129
lines changed

6 files changed

+550
-1129
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,5 @@ data/
180180

181181
# workdir
182182
workdir/
183+
184+
.idea

src/agent/base_agent.py

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
from typing import (
2+
Any,
3+
Callable,
4+
Optional
5+
)
6+
7+
import yaml
8+
import json
9+
from rich.panel import Panel
10+
from rich.text import Text
11+
12+
from src.tools import AsyncTool
13+
from src.exception import (
14+
AgentGenerationError,
15+
AgentParsingError,
16+
AgentToolExecutionError,
17+
AgentToolCallError
18+
)
19+
from src.base.async_multistep_agent import (PromptTemplates,
20+
populate_template,
21+
AsyncMultiStepAgent)
22+
from src.memory import (ActionStep,
23+
ToolCall,
24+
AgentMemory)
25+
from src.logger import (LogLevel,
26+
YELLOW_HEX,
27+
logger)
28+
from src.models import Model, parse_json_if_needed, ChatMessage
29+
from src.utils.agent_types import (
30+
AgentAudio,
31+
AgentImage,
32+
)
33+
from src.utils import assemble_project_path
34+
35+
class BaseAgent(AsyncMultiStepAgent):
36+
"""Base class for agents with common logic."""
37+
AGENT_NAME = "base_agent" # Must be overridden by subclasses
38+
39+
def __init__(
40+
self,
41+
config, # Specific configuration object for the agent
42+
tools: list[AsyncTool],
43+
model: Model,
44+
prompt_templates_path: str, # Path to the prompt templates file
45+
prompt_templates: PromptTemplates | None = None, # For preloaded templates
46+
max_steps: int = 20,
47+
add_base_tools: bool = False,
48+
verbosity_level: LogLevel = LogLevel.INFO,
49+
grammar: dict[str, str] | None = None,
50+
managed_agents: list | None = None,
51+
step_callbacks: list[Callable] | None = None,
52+
planning_interval: int | None = None,
53+
name: str | None = None, # AGENT_NAME will be used if not specified
54+
description: str | None = None,
55+
provide_run_summary: bool = False,
56+
final_answer_checks: list[Callable] | None = None,
57+
**kwargs
58+
):
59+
self.config = config # Save config for possible access by subclasses
60+
61+
agent_name_to_use = name if name is not None else self.AGENT_NAME
62+
63+
super().__init__(
64+
tools=tools,
65+
model=model,
66+
prompt_templates=None, # Initialize as None, load later
67+
max_steps=max_steps,
68+
add_base_tools=add_base_tools,
69+
verbosity_level=verbosity_level,
70+
grammar=grammar,
71+
managed_agents=managed_agents,
72+
step_callbacks=step_callbacks,
73+
planning_interval=planning_interval,
74+
name=agent_name_to_use, # Use the defined agent name
75+
description=description,
76+
provide_run_summary=provide_run_summary,
77+
final_answer_checks=final_answer_checks,
78+
**kwargs # Pass remaining arguments to the parent class
79+
)
80+
81+
# Loading prompt_templates
82+
if prompt_templates:
83+
self.prompt_templates = prompt_templates
84+
else:
85+
abs_template_path = assemble_project_path(prompt_templates_path)
86+
with open(abs_template_path, "r", encoding='utf-8') as f:
87+
self.prompt_templates = yaml.safe_load(f)
88+
89+
self.system_prompt = self.initialize_system_prompt()
90+
self.user_prompt = self.initialize_user_prompt()
91+
92+
self.memory = AgentMemory(
93+
system_prompt=self.system_prompt,
94+
user_prompt=self.user_prompt,
95+
)
96+
97+
def initialize_system_prompt(self) -> str:
98+
"""Initialize the system prompt for the agent."""
99+
system_prompt = populate_template(
100+
self.prompt_templates["system_prompt"],
101+
variables={"tools": self.tools, "managed_agents": self.managed_agents},
102+
)
103+
return system_prompt
104+
105+
def initialize_user_prompt(self) -> str:
106+
107+
user_prompt = populate_template(
108+
self.prompt_templates["user_prompt"],
109+
variables={},
110+
)
111+
112+
return user_prompt
113+
114+
def initialize_task_instruction(self) -> str:
115+
"""Initialize the task instruction for the agent."""
116+
task_instruction = populate_template(
117+
self.prompt_templates["task_instruction"],
118+
variables={"task": self.task},
119+
)
120+
return task_instruction
121+
122+
def _substitute_state_variables(self, arguments: dict[str, str] | str) -> dict[str, Any] | str:
123+
"""Replace string values in arguments with their corresponding state values if they exist."""
124+
if isinstance(arguments, dict):
125+
return {
126+
key: self.state.get(value, value) if isinstance(value, str) else value
127+
for key, value in arguments.items()
128+
}
129+
return arguments
130+
131+
async def execute_tool_call(self, tool_name: str, arguments: dict[str, str] | str) -> Any:
132+
"""
133+
Execute a tool or managed agent with the provided arguments.
134+
135+
The arguments are replaced with the actual values from the state if they refer to state variables.
136+
137+
Args:
138+
tool_name (`str`): Name of the tool or managed agent to execute.
139+
arguments (dict[str, str] | str): Arguments passed to the tool call.
140+
"""
141+
# Check if the tool exists
142+
available_tools = {**self.tools, **self.managed_agents}
143+
if tool_name not in available_tools:
144+
raise AgentToolExecutionError(
145+
f"Unknown tool {tool_name}, should be one of: {', '.join(available_tools)}.", self.logger
146+
)
147+
148+
# Get the tool and substitute state variables in arguments
149+
tool = available_tools[tool_name]
150+
arguments = self._substitute_state_variables(arguments)
151+
is_managed_agent = tool_name in self.managed_agents
152+
153+
try:
154+
# Call tool with appropriate arguments
155+
if isinstance(arguments, dict):
156+
return await tool(**arguments) if is_managed_agent else await tool(**arguments, sanitize_inputs_outputs=True)
157+
elif isinstance(arguments, str):
158+
return await tool(arguments) if is_managed_agent else await tool(arguments, sanitize_inputs_outputs=True)
159+
else:
160+
raise TypeError(f"Unsupported arguments type: {type(arguments)}")
161+
162+
except TypeError as e:
163+
# Handle invalid arguments
164+
description = getattr(tool, "description", "No description")
165+
if is_managed_agent:
166+
error_msg = (
167+
f"Invalid request to team member '{tool_name}' with arguments {json.dumps(arguments, ensure_ascii=False)}: {e}\n"
168+
"You should call this team member with a valid request.\n"
169+
f"Team member description: {description}"
170+
)
171+
else:
172+
error_msg = (
173+
f"Invalid call to tool '{tool_name}' with arguments {json.dumps(arguments, ensure_ascii=False)}: {e}\n"
174+
"You should call this tool with correct input arguments.\n"
175+
f"Expected inputs: {json.dumps(tool.parameters)}\n"
176+
f"Returns output type: {tool.output_type}\n"
177+
f"Tool description: '{description}'"
178+
)
179+
raise AgentToolCallError(error_msg, self.logger) from e
180+
181+
except Exception as e:
182+
# Handle execution errors
183+
if is_managed_agent:
184+
error_msg = (
185+
f"Error executing request to team member '{tool_name}' with arguments {json.dumps(arguments)}: {e}\n"
186+
"Please try again or request to another team member"
187+
)
188+
else:
189+
error_msg = (
190+
f"Error executing tool '{tool_name}' with arguments {json.dumps(arguments)}: {type(e).__name__}: {e}\n"
191+
"Please try again or use another tool"
192+
)
193+
raise AgentToolExecutionError(error_msg, self.logger) from e
194+
195+
async def step(self, memory_step: ActionStep) -> None | Any:
196+
"""
197+
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
198+
Returns None if the step is not final.
199+
"""
200+
memory_messages = await self.write_memory_to_messages()
201+
202+
input_messages = memory_messages.copy()
203+
204+
# Add new step in logs
205+
memory_step.model_input_messages = input_messages
206+
207+
try:
208+
chat_message: ChatMessage = await self.model(
209+
input_messages,
210+
stop_sequences=["Observation:", "Calling tools:"],
211+
tools_to_call_from=list(self.tools.values()),
212+
)
213+
memory_step.model_output_message = chat_message
214+
model_output = chat_message.content
215+
self.logger.log_markdown(
216+
content=model_output if model_output else str(chat_message.raw),
217+
title="Output message of the LLM:",
218+
level=LogLevel.DEBUG,
219+
)
220+
221+
memory_step.model_output_message.content = model_output
222+
memory_step.model_output = model_output
223+
except Exception as e:
224+
raise AgentGenerationError(f"Error while generating output:\n{e}", self.logger) from e
225+
226+
if chat_message.tool_calls is None or len(chat_message.tool_calls) == 0:
227+
try:
228+
chat_message = self.model.parse_tool_calls(chat_message)
229+
except Exception as e:
230+
raise AgentParsingError(f"Error while parsing tool call from model output: {e}", self.logger)
231+
else:
232+
for tool_call in chat_message.tool_calls:
233+
tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments)
234+
235+
tool_call = chat_message.tool_calls[0]
236+
tool_name, tool_call_id = tool_call.function.name, tool_call.id
237+
tool_arguments = tool_call.function.arguments
238+
memory_step.model_output = str(f"Called Tool: '{tool_name}' with arguments: {tool_arguments}")
239+
memory_step.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)]
240+
241+
# Execute
242+
self.logger.log(
243+
Panel(Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}")),
244+
level=LogLevel.INFO,
245+
)
246+
if tool_name == "final_answer":
247+
if isinstance(tool_arguments, dict):
248+
if "result" in tool_arguments:
249+
result = tool_arguments["result"]
250+
else:
251+
result = tool_arguments
252+
else:
253+
result = tool_arguments
254+
if (
255+
isinstance(result, str) and result in self.state.keys()
256+
): # if the answer is a state variable, return the value
257+
final_result = self.state[result]
258+
self.logger.log(
259+
f"[bold {YELLOW_HEX}]Final answer:[/bold {YELLOW_HEX}] Extracting key '{result}' from state to return value '{final_result}'.",
260+
level=LogLevel.INFO,
261+
)
262+
else:
263+
final_result = result
264+
self.logger.log(
265+
Text(f"Final result: {final_result}", style=f"bold {YELLOW_HEX}"),
266+
level=LogLevel.INFO,
267+
)
268+
269+
memory_step.action_output = final_result
270+
return final_result
271+
else:
272+
if tool_arguments is None:
273+
tool_arguments = {}
274+
observation = await self.execute_tool_call(tool_name, tool_arguments)
275+
observation_type = type(observation)
276+
if observation_type in [AgentImage, AgentAudio]:
277+
if observation_type == AgentImage:
278+
observation_name = "image.png"
279+
elif observation_type == AgentAudio:
280+
observation_name = "audio.mp3"
281+
# TODO: observation naming could allow for different names of same type
282+
283+
self.state[observation_name] = observation
284+
updated_information = f"Stored '{observation_name}' in memory."
285+
else:
286+
updated_information = str(observation).strip()
287+
self.logger.log(
288+
f"Observations: {updated_information.replace('[', '|')}", # escape potential rich-tag-like components
289+
level=LogLevel.INFO,
290+
)
291+
memory_step.observations = updated_information
292+
return None

0 commit comments

Comments
 (0)