23
23
UserPromptMessage ,
24
24
)
25
25
from dify_plugin .entities .tool import LogMetadata , ToolInvokeMessage , ToolProviderType
26
- from dify_plugin .interfaces .agent import AgentModelConfig , AgentStrategy , ToolEntity , ToolInvokeMeta
26
+ from dify_plugin .interfaces .agent import (
27
+ AgentModelConfig ,
28
+ AgentStrategy ,
29
+ ToolEntity ,
30
+ ToolInvokeMeta ,
31
+ )
27
32
28
33
29
34
class FunctionCallingParams (BaseModel ):
@@ -35,8 +40,8 @@ class FunctionCallingParams(BaseModel):
35
40
36
41
37
42
class FunctionCallingAgentStrategy (AgentStrategy ):
38
- def __init__ (self , session ):
39
- super ().__init__ (session )
43
+ def __init__ (self , runtime , session ):
44
+ super ().__init__ (runtime , session )
40
45
self .query = ""
41
46
self .instruction = ""
42
47
@@ -59,6 +64,8 @@ def _invoke(self, parameters: dict[str, Any]) -> Generator[AgentInvokeMessage]:
59
64
self .query = query
60
65
self .instruction = fc_params .instruction
61
66
history_prompt_messages = fc_params .model .history_prompt_messages
67
+ history_prompt_messages .insert (0 , self ._system_prompt_message )
68
+ history_prompt_messages .append (self ._user_prompt_message )
62
69
63
70
# convert tool messages
64
71
tools = fc_params .tools
@@ -96,7 +103,8 @@ def _invoke(self, parameters: dict[str, Any]) -> Generator[AgentInvokeMessage]:
96
103
)
97
104
yield round_log
98
105
99
- if iteration_step == max_iteration_steps :
106
+ # If max_iteration_steps=1, need to execute tool calls
107
+ if iteration_step == max_iteration_steps and max_iteration_steps > 1 :
100
108
# the last iteration, remove all tools
101
109
prompt_messages_tools = []
102
110
@@ -195,7 +203,7 @@ def _invoke(self, parameters: dict[str, Any]) -> Generator[AgentInvokeMessage]:
195
203
data = {
196
204
"output" : response ,
197
205
"tool_name" : tool_call_names ,
198
- "tool_input" : { tool_call [1 ]: tool_call [2 ] for tool_call in tool_calls } ,
206
+ "tool_input" : [{ "name" : tool_call [1 ], "args" : tool_call [2 ]} for tool_call in tool_calls ] ,
199
207
},
200
208
metadata = {
201
209
LogMetadata .STARTED_AT : model_started_at ,
@@ -208,28 +216,30 @@ def _invoke(self, parameters: dict[str, Any]) -> Generator[AgentInvokeMessage]:
208
216
},
209
217
)
210
218
assistant_message = AssistantPromptMessage (content = "" , tool_calls = [])
211
- if tool_calls :
212
- assistant_message .tool_calls = [
213
- AssistantPromptMessage .ToolCall (
214
- id = tool_call [0 ],
215
- type = "function" ,
216
- function = AssistantPromptMessage .ToolCall .ToolCallFunction (
217
- name = tool_call [1 ],
218
- arguments = json .dumps (tool_call [2 ], ensure_ascii = False ),
219
- ),
220
- )
221
- for tool_call in tool_calls
222
- ]
223
- else :
219
+ if not tool_calls :
224
220
assistant_message .content = response
225
-
226
- current_thoughts .append (assistant_message )
221
+ current_thoughts .append (assistant_message )
227
222
228
223
final_answer += response + "\n "
229
224
230
225
# call tools
231
226
tool_responses = []
232
227
for tool_call_id , tool_call_name , tool_call_args in tool_calls :
228
+ current_thoughts .append (
229
+ AssistantPromptMessage (
230
+ content = "" ,
231
+ tool_calls = [
232
+ AssistantPromptMessage .ToolCall (
233
+ id = tool_call_id ,
234
+ type = "function" ,
235
+ function = AssistantPromptMessage .ToolCall .ToolCallFunction (
236
+ name = tool_call_name ,
237
+ arguments = json .dumps (tool_call_args , ensure_ascii = False ),
238
+ ),
239
+ )
240
+ ],
241
+ )
242
+ )
233
243
tool_instance = tool_instances [tool_call_name ]
234
244
tool_call_started_at = time .perf_counter ()
235
245
tool_call_log = self .create_log_message (
@@ -341,8 +351,11 @@ def _invoke(self, parameters: dict[str, Any]) -> Generator[AgentInvokeMessage]:
341
351
LogMetadata .TOTAL_TOKENS : current_llm_usage .total_tokens if current_llm_usage else 0 ,
342
352
},
343
353
)
354
+ # If max_iteration_steps=1, need to return tool responses
355
+ if tool_responses and max_iteration_steps == 1 :
356
+ for resp in tool_responses :
357
+ yield self .create_text_message (resp ["tool_response" ])
344
358
iteration_step += 1
345
-
346
359
yield self .create_json_message (
347
360
{
348
361
"execution_metadata" : {
@@ -452,8 +465,6 @@ def _organize_prompt_messages(
452
465
current_thoughts : list [PromptMessage ],
453
466
history_prompt_messages : list [PromptMessage ],
454
467
) -> list [PromptMessage ]:
455
- history_prompt_messages .insert (0 , self ._system_prompt_message )
456
- history_prompt_messages .append (self ._user_prompt_message )
457
468
prompt_messages = [
458
469
* history_prompt_messages ,
459
470
* current_thoughts ,
0 commit comments