1
+ from typing import (
2
+ Any ,
3
+ Callable ,
4
+ Optional
5
+ )
6
+
7
+ import yaml
8
+ import json
9
+ from rich .panel import Panel
10
+ from rich .text import Text
11
+
12
+ from src .tools import AsyncTool
13
+ from src .exception import (
14
+ AgentGenerationError ,
15
+ AgentParsingError ,
16
+ AgentToolExecutionError ,
17
+ AgentToolCallError
18
+ )
19
+ from src .base .async_multistep_agent import (PromptTemplates ,
20
+ populate_template ,
21
+ AsyncMultiStepAgent )
22
+ from src .memory import (ActionStep ,
23
+ ToolCall ,
24
+ AgentMemory )
25
+ from src .logger import (LogLevel ,
26
+ YELLOW_HEX ,
27
+ logger )
28
+ from src .models import Model , parse_json_if_needed , ChatMessage
29
+ from src .utils .agent_types import (
30
+ AgentAudio ,
31
+ AgentImage ,
32
+ )
33
+ from src .utils import assemble_project_path
34
+
35
+ class BaseAgent (AsyncMultiStepAgent ):
36
+ """Base class for agents with common logic."""
37
+ AGENT_NAME = "base_agent" # Must be overridden by subclasses
38
+
39
+ def __init__ (
40
+ self ,
41
+ config , # Specific configuration object for the agent
42
+ tools : list [AsyncTool ],
43
+ model : Model ,
44
+ prompt_templates_path : str , # Path to the prompt templates file
45
+ prompt_templates : PromptTemplates | None = None , # For preloaded templates
46
+ max_steps : int = 20 ,
47
+ add_base_tools : bool = False ,
48
+ verbosity_level : LogLevel = LogLevel .INFO ,
49
+ grammar : dict [str , str ] | None = None ,
50
+ managed_agents : list | None = None ,
51
+ step_callbacks : list [Callable ] | None = None ,
52
+ planning_interval : int | None = None ,
53
+ name : str | None = None , # AGENT_NAME will be used if not specified
54
+ description : str | None = None ,
55
+ provide_run_summary : bool = False ,
56
+ final_answer_checks : list [Callable ] | None = None ,
57
+ ** kwargs
58
+ ):
59
+ self .config = config # Save config for possible access by subclasses
60
+
61
+ agent_name_to_use = name if name is not None else self .AGENT_NAME
62
+
63
+ super ().__init__ (
64
+ tools = tools ,
65
+ model = model ,
66
+ prompt_templates = None , # Initialize as None, load later
67
+ max_steps = max_steps ,
68
+ add_base_tools = add_base_tools ,
69
+ verbosity_level = verbosity_level ,
70
+ grammar = grammar ,
71
+ managed_agents = managed_agents ,
72
+ step_callbacks = step_callbacks ,
73
+ planning_interval = planning_interval ,
74
+ name = agent_name_to_use , # Use the defined agent name
75
+ description = description ,
76
+ provide_run_summary = provide_run_summary ,
77
+ final_answer_checks = final_answer_checks ,
78
+ ** kwargs # Pass remaining arguments to the parent class
79
+ )
80
+
81
+ # Loading prompt_templates
82
+ if prompt_templates :
83
+ self .prompt_templates = prompt_templates
84
+ else :
85
+ abs_template_path = assemble_project_path (prompt_templates_path )
86
+ with open (abs_template_path , "r" , encoding = 'utf-8' ) as f :
87
+ self .prompt_templates = yaml .safe_load (f )
88
+
89
+ self .system_prompt = self .initialize_system_prompt ()
90
+ self .user_prompt = self .initialize_user_prompt ()
91
+
92
+ self .memory = AgentMemory (
93
+ system_prompt = self .system_prompt ,
94
+ user_prompt = self .user_prompt ,
95
+ )
96
+
97
+ def initialize_system_prompt (self ) -> str :
98
+ """Initialize the system prompt for the agent."""
99
+ system_prompt = populate_template (
100
+ self .prompt_templates ["system_prompt" ],
101
+ variables = {"tools" : self .tools , "managed_agents" : self .managed_agents },
102
+ )
103
+ return system_prompt
104
+
105
+ def initialize_user_prompt (self ) -> str :
106
+
107
+ user_prompt = populate_template (
108
+ self .prompt_templates ["user_prompt" ],
109
+ variables = {},
110
+ )
111
+
112
+ return user_prompt
113
+
114
+ def initialize_task_instruction (self ) -> str :
115
+ """Initialize the task instruction for the agent."""
116
+ task_instruction = populate_template (
117
+ self .prompt_templates ["task_instruction" ],
118
+ variables = {"task" : self .task },
119
+ )
120
+ return task_instruction
121
+
122
+ def _substitute_state_variables (self , arguments : dict [str , str ] | str ) -> dict [str , Any ] | str :
123
+ """Replace string values in arguments with their corresponding state values if they exist."""
124
+ if isinstance (arguments , dict ):
125
+ return {
126
+ key : self .state .get (value , value ) if isinstance (value , str ) else value
127
+ for key , value in arguments .items ()
128
+ }
129
+ return arguments
130
+
131
+ async def execute_tool_call (self , tool_name : str , arguments : dict [str , str ] | str ) -> Any :
132
+ """
133
+ Execute a tool or managed agent with the provided arguments.
134
+
135
+ The arguments are replaced with the actual values from the state if they refer to state variables.
136
+
137
+ Args:
138
+ tool_name (`str`): Name of the tool or managed agent to execute.
139
+ arguments (dict[str, str] | str): Arguments passed to the tool call.
140
+ """
141
+ # Check if the tool exists
142
+ available_tools = {** self .tools , ** self .managed_agents }
143
+ if tool_name not in available_tools :
144
+ raise AgentToolExecutionError (
145
+ f"Unknown tool { tool_name } , should be one of: { ', ' .join (available_tools )} ." , self .logger
146
+ )
147
+
148
+ # Get the tool and substitute state variables in arguments
149
+ tool = available_tools [tool_name ]
150
+ arguments = self ._substitute_state_variables (arguments )
151
+ is_managed_agent = tool_name in self .managed_agents
152
+
153
+ try :
154
+ # Call tool with appropriate arguments
155
+ if isinstance (arguments , dict ):
156
+ return await tool (** arguments ) if is_managed_agent else await tool (** arguments , sanitize_inputs_outputs = True )
157
+ elif isinstance (arguments , str ):
158
+ return await tool (arguments ) if is_managed_agent else await tool (arguments , sanitize_inputs_outputs = True )
159
+ else :
160
+ raise TypeError (f"Unsupported arguments type: { type (arguments )} " )
161
+
162
+ except TypeError as e :
163
+ # Handle invalid arguments
164
+ description = getattr (tool , "description" , "No description" )
165
+ if is_managed_agent :
166
+ error_msg = (
167
+ f"Invalid request to team member '{ tool_name } ' with arguments { json .dumps (arguments , ensure_ascii = False )} : { e } \n "
168
+ "You should call this team member with a valid request.\n "
169
+ f"Team member description: { description } "
170
+ )
171
+ else :
172
+ error_msg = (
173
+ f"Invalid call to tool '{ tool_name } ' with arguments { json .dumps (arguments , ensure_ascii = False )} : { e } \n "
174
+ "You should call this tool with correct input arguments.\n "
175
+ f"Expected inputs: { json .dumps (tool .parameters )} \n "
176
+ f"Returns output type: { tool .output_type } \n "
177
+ f"Tool description: '{ description } '"
178
+ )
179
+ raise AgentToolCallError (error_msg , self .logger ) from e
180
+
181
+ except Exception as e :
182
+ # Handle execution errors
183
+ if is_managed_agent :
184
+ error_msg = (
185
+ f"Error executing request to team member '{ tool_name } ' with arguments { json .dumps (arguments )} : { e } \n "
186
+ "Please try again or request to another team member"
187
+ )
188
+ else :
189
+ error_msg = (
190
+ f"Error executing tool '{ tool_name } ' with arguments { json .dumps (arguments )} : { type (e ).__name__ } : { e } \n "
191
+ "Please try again or use another tool"
192
+ )
193
+ raise AgentToolExecutionError (error_msg , self .logger ) from e
194
+
195
+ async def step (self , memory_step : ActionStep ) -> None | Any :
196
+ """
197
+ Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
198
+ Returns None if the step is not final.
199
+ """
200
+ memory_messages = await self .write_memory_to_messages ()
201
+
202
+ input_messages = memory_messages .copy ()
203
+
204
+ # Add new step in logs
205
+ memory_step .model_input_messages = input_messages
206
+
207
+ try :
208
+ chat_message : ChatMessage = await self .model (
209
+ input_messages ,
210
+ stop_sequences = ["Observation:" , "Calling tools:" ],
211
+ tools_to_call_from = list (self .tools .values ()),
212
+ )
213
+ memory_step .model_output_message = chat_message
214
+ model_output = chat_message .content
215
+ self .logger .log_markdown (
216
+ content = model_output if model_output else str (chat_message .raw ),
217
+ title = "Output message of the LLM:" ,
218
+ level = LogLevel .DEBUG ,
219
+ )
220
+
221
+ memory_step .model_output_message .content = model_output
222
+ memory_step .model_output = model_output
223
+ except Exception as e :
224
+ raise AgentGenerationError (f"Error while generating output:\n { e } " , self .logger ) from e
225
+
226
+ if chat_message .tool_calls is None or len (chat_message .tool_calls ) == 0 :
227
+ try :
228
+ chat_message = self .model .parse_tool_calls (chat_message )
229
+ except Exception as e :
230
+ raise AgentParsingError (f"Error while parsing tool call from model output: { e } " , self .logger )
231
+ else :
232
+ for tool_call in chat_message .tool_calls :
233
+ tool_call .function .arguments = parse_json_if_needed (tool_call .function .arguments )
234
+
235
+ tool_call = chat_message .tool_calls [0 ]
236
+ tool_name , tool_call_id = tool_call .function .name , tool_call .id
237
+ tool_arguments = tool_call .function .arguments
238
+ memory_step .model_output = str (f"Called Tool: '{ tool_name } ' with arguments: { tool_arguments } " )
239
+ memory_step .tool_calls = [ToolCall (name = tool_name , arguments = tool_arguments , id = tool_call_id )]
240
+
241
+ # Execute
242
+ self .logger .log (
243
+ Panel (Text (f"Calling tool: '{ tool_name } ' with arguments: { tool_arguments } " )),
244
+ level = LogLevel .INFO ,
245
+ )
246
+ if tool_name == "final_answer" :
247
+ if isinstance (tool_arguments , dict ):
248
+ if "result" in tool_arguments :
249
+ result = tool_arguments ["result" ]
250
+ else :
251
+ result = tool_arguments
252
+ else :
253
+ result = tool_arguments
254
+ if (
255
+ isinstance (result , str ) and result in self .state .keys ()
256
+ ): # if the answer is a state variable, return the value
257
+ final_result = self .state [result ]
258
+ self .logger .log (
259
+ f"[bold { YELLOW_HEX } ]Final answer:[/bold { YELLOW_HEX } ] Extracting key '{ result } ' from state to return value '{ final_result } '." ,
260
+ level = LogLevel .INFO ,
261
+ )
262
+ else :
263
+ final_result = result
264
+ self .logger .log (
265
+ Text (f"Final result: { final_result } " , style = f"bold { YELLOW_HEX } " ),
266
+ level = LogLevel .INFO ,
267
+ )
268
+
269
+ memory_step .action_output = final_result
270
+ return final_result
271
+ else :
272
+ if tool_arguments is None :
273
+ tool_arguments = {}
274
+ observation = await self .execute_tool_call (tool_name , tool_arguments )
275
+ observation_type = type (observation )
276
+ if observation_type in [AgentImage , AgentAudio ]:
277
+ if observation_type == AgentImage :
278
+ observation_name = "image.png"
279
+ elif observation_type == AgentAudio :
280
+ observation_name = "audio.mp3"
281
+ # TODO: observation naming could allow for different names of same type
282
+
283
+ self .state [observation_name ] = observation
284
+ updated_information = f"Stored '{ observation_name } ' in memory."
285
+ else :
286
+ updated_information = str (observation ).strip ()
287
+ self .logger .log (
288
+ f"Observations: { updated_information .replace ('[' , '|' )} " , # escape potential rich-tag-like components
289
+ level = LogLevel .INFO ,
290
+ )
291
+ memory_step .observations = updated_information
292
+ return None
0 commit comments