1
+ from typing import (
2
+ Any ,
3
+ Callable ,
4
+ Optional
5
+ )
6
+ import json
7
+ import yaml
8
+ from rich .panel import Panel
9
+ from rich .text import Text
10
+
11
+ from src .tools import AsyncTool
12
+ from src .exception import (
13
+ AgentGenerationError ,
14
+ AgentParsingError ,
15
+ AgentToolExecutionError ,
16
+ AgentToolCallError ,
17
+ logger as exception_logger # Use exception logger if it differs
18
+ )
19
+ from src .base .async_multistep_agent import (PromptTemplates ,
20
+ populate_template ,
21
+ AsyncMultiStepAgent )
22
+ from src .memory import (ActionStep ,
23
+ ToolCall ,
24
+ AgentMemory )
25
+ from src .logger import (
26
+ LogLevel ,
27
+ YELLOW_HEX
28
+ # logger # AsyncMultiStepAgent creates its own self.logger
29
+ )
30
+ from src .models import Model , parse_json_if_needed , ChatMessage
31
+ from src .utils .agent_types import (
32
+ AgentAudio ,
33
+ AgentImage ,
34
+ )
35
+ from src .utils import assemble_project_path
36
+
37
+ class BaseAgent (AsyncMultiStepAgent ):
38
+ """Base class for agents with common logic."""
39
+ AGENT_NAME = "base_agent" # Must be overridden by subclasses
40
+
41
+ def __init__ (
42
+ self ,
43
+ config , # Specific configuration object for the agent
44
+ tools : list [AsyncTool ],
45
+ model : Model ,
46
+ prompt_templates_path : str , # Path to the prompt templates file
47
+ prompt_templates : PromptTemplates | None = None , # For preloaded templates
48
+ max_steps : int = 20 ,
49
+ add_base_tools : bool = False ,
50
+ verbosity_level : LogLevel = LogLevel .INFO ,
51
+ grammar : dict [str , str ] | None = None ,
52
+ managed_agents : list | None = None ,
53
+ step_callbacks : list [Callable ] | None = None ,
54
+ planning_interval : int | None = None ,
55
+ name : str | None = None , # AGENT_NAME will be used if not specified
56
+ description : str | None = None ,
57
+ provide_run_summary : bool = False ,
58
+ final_answer_checks : list [Callable ] | None = None ,
59
+ ** kwargs
60
+ ):
61
+ self .config = config # Save config for possible access by subclasses
62
+
63
+ agent_name_to_use = name if name is not None else self .AGENT_NAME
64
+
65
+ super ().__init__ (
66
+ tools = tools ,
67
+ model = model ,
68
+ prompt_templates = None , # Initialize as None, load later
69
+ max_steps = max_steps ,
70
+ add_base_tools = add_base_tools ,
71
+ verbosity_level = verbosity_level ,
72
+ grammar = grammar ,
73
+ managed_agents = managed_agents ,
74
+ step_callbacks = step_callbacks ,
75
+ planning_interval = planning_interval ,
76
+ name = agent_name_to_use , # Use the defined agent name
77
+ description = description ,
78
+ provide_run_summary = provide_run_summary ,
79
+ final_answer_checks = final_answer_checks ,
80
+ ** kwargs # Pass remaining arguments to the parent class
81
+ )
82
+
83
+ # Loading prompt_templates
84
+ if prompt_templates :
85
+ self .prompt_templates = prompt_templates
86
+ else :
87
+ abs_template_path = assemble_project_path (prompt_templates_path )
88
+ with open (abs_template_path , "r" , encoding = 'utf-8' ) as f :
89
+ self .prompt_templates = yaml .safe_load (f )
90
+
91
+ self .system_prompt = self .initialize_system_prompt ()
92
+ self .user_prompt = self .initialize_user_prompt ()
93
+
94
+ self .memory = AgentMemory (
95
+ system_prompt = self .system_prompt ,
96
+ user_prompt = self .user_prompt ,
97
+ )
98
+ # self.logger is inherited from AsyncMultiStepAgent and uses agent_name_to_use
99
+
100
+ def initialize_system_prompt (self ) -> str :
101
+ """Initialize the system prompt for the agent."""
102
+ system_prompt = populate_template (
103
+ self .prompt_templates ["system_prompt" ],
104
+ variables = {"tools" : self .tools , "managed_agents" : self .managed_agents },
105
+ )
106
+ return system_prompt
107
+
108
+ def initialize_user_prompt (self ) -> str :
109
+
110
+ user_prompt = populate_template (
111
+ self .prompt_templates ["user_prompt" ],
112
+ variables = {},
113
+ )
114
+
115
+ return user_prompt
116
+
117
+ def initialize_task_instruction (self ) -> str :
118
+ """Initialize the task instruction for the agent."""
119
+ # self.task is set by the __call__ method of AsyncMultiStepAgent at runtime
120
+ task_instruction = populate_template (
121
+ self .prompt_templates ["task_instruction" ],
122
+ variables = {"task" : self .task },
123
+ )
124
+ return task_instruction
125
+
126
+ def _substitute_state_variables (self , arguments : dict [str , str ] | str ) -> dict [str , Any ] | str :
127
+ """Replace string values in arguments with their corresponding state values if they exist."""
128
+ if isinstance (arguments , dict ):
129
+ return {
130
+ key : self .state .get (value , value ) if isinstance (value , str ) else value
131
+ for key , value in arguments .items ()
132
+ }
133
+ return arguments
134
+
135
+ async def execute_tool_call (self , tool_name : str , arguments : dict [str , str ] | str ) -> Any :
136
+ """
137
+ Execute a tool or managed agent with the provided arguments.
138
+
139
+ The arguments are replaced with the actual values from the state if they refer to state variables.
140
+
141
+ Args:
142
+ tool_name (`str`): Name of the tool or managed agent to execute.
143
+ arguments (dict[str, str] | str): Arguments passed to the tool call.
144
+ """
145
+ # Check if the tool exists
146
+ available_tools = {** self .tools , ** self .managed_agents }
147
+ if tool_name not in available_tools :
148
+ raise AgentToolExecutionError (
149
+ f"Unknown tool { tool_name } , should be one of: { ', ' .join (available_tools )} ." , self .logger
150
+ )
151
+
152
+ # Get the tool and substitute state variables in arguments
153
+ tool = available_tools [tool_name ]
154
+ arguments = self ._substitute_state_variables (arguments )
155
+ is_managed_agent = tool_name in self .managed_agents
156
+
157
+ try :
158
+ # Call tool with appropriate arguments
159
+ if isinstance (arguments , dict ):
160
+ return await tool (** arguments ) if is_managed_agent else await tool (** arguments , sanitize_inputs_outputs = True )
161
+ elif isinstance (arguments , str ):
162
+ return await tool (arguments ) if is_managed_agent else await tool (arguments , sanitize_inputs_outputs = True )
163
+ else :
164
+ raise TypeError (f"Unsupported arguments type: { type (arguments )} " )
165
+
166
+ except TypeError as e :
167
+ # Handle invalid arguments
168
+ description = getattr (tool , "description" , "No description" )
169
+ if is_managed_agent :
170
+ error_msg = (
171
+ f"Invalid request to team member '{ tool_name } ' with arguments { json .dumps (arguments , ensure_ascii = False )} : { e } \n "
172
+ "You should call this team member with a valid request.\n "
173
+ f"Team member description: { description } "
174
+ )
175
+ else :
176
+ error_msg = (
177
+ f"Invalid call to tool '{ tool_name } ' with arguments { json .dumps (arguments , ensure_ascii = False )} : { e } \n "
178
+ "You should call this tool with correct input arguments.\n "
179
+ f"Expected inputs: { json .dumps (tool .parameters )} \n "
180
+ f"Returns output type: { tool .output_type } \n "
181
+ f"Tool description: '{ description } '"
182
+ )
183
+ raise AgentToolCallError (error_msg , self .logger ) from e
184
+
185
+ except Exception as e :
186
+ # Handle execution errors
187
+ if is_managed_agent :
188
+ error_msg = (
189
+ f"Error executing request to team member '{ tool_name } ' with arguments { json .dumps (arguments )} : { e } \n "
190
+ "Please try again or request to another team member"
191
+ )
192
+ else :
193
+ error_msg = (
194
+ f"Error executing tool '{ tool_name } ' with arguments { json .dumps (arguments )} : { type (e ).__name__ } : { e } \n "
195
+ "Please try again or use another tool"
196
+ )
197
+ raise AgentToolExecutionError (error_msg , self .logger ) from e
198
+
199
+ async def step (self , memory_step : ActionStep ) -> None | Any :
200
+ """
201
+ Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
202
+ Returns None if the step is not final.
203
+ """
204
+ memory_messages = await self .write_memory_to_messages ()
205
+
206
+ input_messages = memory_messages .copy ()
207
+
208
+ # Add new step in logs
209
+ memory_step .model_input_messages = input_messages
210
+
211
+ try :
212
+ chat_message : ChatMessage = await self .model (
213
+ input_messages ,
214
+ stop_sequences = ["Observation:" , "Calling tools:" ],
215
+ tools_to_call_from = list (self .tools .values ()),
216
+ )
217
+ memory_step .model_output_message = chat_message
218
+ model_output = chat_message .content
219
+ self .logger .log_markdown (
220
+ content = model_output if model_output else str (chat_message .raw ),
221
+ title = "Output message of the LLM:" ,
222
+ level = LogLevel .DEBUG ,
223
+ )
224
+ memory_step .model_output_message .content = model_output
225
+ memory_step .model_output = model_output
226
+ except Exception as e :
227
+ raise AgentGenerationError (f"Error while generating output:\n { e } " , self .logger ) from e
228
+
229
+ if chat_message .tool_calls is None or len (chat_message .tool_calls ) == 0 :
230
+ try :
231
+ # Attempt to parse tool calls if they were not automatically populated
232
+ chat_message = self .model .parse_tool_calls (chat_message )
233
+ except Exception as e :
234
+ # If parsing failed and there is model_output, it can be considered a direct answer
235
+ if model_output :
236
+ self .logger .log (
237
+ Text (f"Tool call not detected. Processing model output as final answer: { model_output } " , style = f"bold { YELLOW_HEX } " ),
238
+ level = LogLevel .INFO ,
239
+ )
240
+ memory_step .action_output = model_output
241
+ return model_output
242
+ raise AgentParsingError (f"Error while parsing tool call from model output: { e } " , self .logger )
243
+
244
+ # If there are still no tool calls after attempting to parse
245
+ if not chat_message .tool_calls :
246
+ if model_output :
247
+ self .logger .log (
248
+ Text (f"Tool call not detected after parsing. Processing model output as final answer: { model_output } " , style = f"bold { YELLOW_HEX } " ),
249
+ level = LogLevel .INFO ,
250
+ )
251
+ memory_step .action_output = model_output
252
+ return model_output
253
+ else :
254
+ # If there are no tool calls and no content, it's an error
255
+ raise AgentParsingError ("Tool call not found, and there is no content in the model output." , self .logger )
256
+
257
+ # Continue if there are tool calls
258
+ for tool_call in chat_message .tool_calls :
259
+ tool_call .function .arguments = parse_json_if_needed (tool_call .function .arguments )
260
+
261
+ tool_call = chat_message .tool_calls [0 ]
262
+ tool_name , tool_call_id = tool_call .function .name , tool_call .id
263
+ tool_arguments = tool_call .function .arguments
264
+ memory_step .model_output = str (f"Called Tool: '{ tool_name } ' with arguments: { tool_arguments } " )
265
+ memory_step .tool_calls = [ToolCall (name = tool_name , arguments = tool_arguments , id = tool_call_id )]
266
+
267
+ # Execute
268
+ self .logger .log (
269
+ Panel (Text (f"Calling tool: '{ tool_name } ' with arguments: { tool_arguments } " )),
270
+ level = LogLevel .INFO ,
271
+ )
272
+ if tool_name == "final_answer" :
273
+ if isinstance (tool_arguments , dict ):
274
+ result = tool_arguments .get ("result" , tool_arguments )
275
+ else :
276
+ result = tool_arguments
277
+ if (
278
+ isinstance (result , str ) and result in self .state .keys ()
279
+ ): # if the answer is a state variable, return the value
280
+ final_result = self .state [result ]
281
+ self .logger .log (
282
+ f"[bold { YELLOW_HEX } ]Final answer:[/bold { YELLOW_HEX } ] Extracting key '{ result } ' from state to return value '{ final_result } '." ,
283
+ level = LogLevel .INFO ,
284
+ )
285
+ else :
286
+ final_result = result
287
+ self .logger .log (
288
+ Text (f"Final result: { final_result } " , style = f"bold { YELLOW_HEX } " ),
289
+ level = LogLevel .INFO ,
290
+ )
291
+
292
+ memory_step .action_output = final_result
293
+ return final_result
294
+ else :
295
+ tool_args_to_pass = tool_arguments if tool_arguments is not None else {}
296
+ observation = await self .execute_tool_call (tool_name , tool_args_to_pass )
297
+ observation_type = type (observation )
298
+
299
+ if observation_type in [AgentImage , AgentAudio ]:
300
+ observation_name = "image.png" if observation_type == AgentImage else "audio.mp3"
301
+ # TODO: observation naming could allow for different names of same type
302
+
303
+ self .state [observation_name ] = observation
304
+ updated_information = f"Stored '{ observation_name } ' in memory."
305
+ else :
306
+ updated_information = str (observation ).strip ()
307
+
308
+ self .logger .log (
309
+ f"Observations: { updated_information .replace ('[' , '|' )} " , # escape potential rich-tag-like components
310
+ level = LogLevel .INFO ,
311
+ )
312
+ memory_step .observations = updated_information
313
+ return None
0 commit comments