11
11
import static org .opensearch .ml .common .utils .StringUtils .processTextDoc ;
12
12
import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .DISABLE_TRACE ;
13
13
import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .INTERACTIONS_PREFIX ;
14
- import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .LLM_FINISH_REASON_PATH ;
15
- import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .LLM_FINISH_REASON_TOOL_USE ;
16
- import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .LLM_RESPONSE_FILTER ;
17
- import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .NO_ESCAPE_PARAMS ;
18
14
import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .PROMPT_CHAT_HISTORY_PREFIX ;
19
15
import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .PROMPT_PREFIX ;
20
16
import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .PROMPT_SUFFIX ;
21
17
import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .RESPONSE_FORMAT_INSTRUCTION ;
22
- import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .TOOL_CALLS_PATH ;
23
- import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .TOOL_CALLS_TOOL_INPUT ;
24
- import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .TOOL_CALLS_TOOL_NAME ;
25
18
import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .TOOL_CALL_ID ;
26
- import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .TOOL_CALL_ID_PATH ;
27
19
import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .TOOL_RESPONSE ;
28
- import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .TOOL_TEMPLATE ;
20
+ import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .TOOL_RESULT ;
29
21
import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .VERBOSE ;
30
22
import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .cleanUpResource ;
31
23
import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .constructToolParams ;
79
71
import org .opensearch .ml .common .transport .prediction .MLPredictionTaskRequest ;
80
72
import org .opensearch .ml .common .utils .StringUtils ;
81
73
import org .opensearch .ml .engine .encryptor .Encryptor ;
74
+ import org .opensearch .ml .engine .function_calling .FunctionCalling ;
75
+ import org .opensearch .ml .engine .function_calling .FunctionCallingFactory ;
76
+ import org .opensearch .ml .engine .function_calling .LLMMessage ;
82
77
import org .opensearch .ml .engine .memory .ConversationIndexMemory ;
83
78
import org .opensearch .ml .engine .memory .ConversationIndexMessage ;
84
79
import org .opensearch .ml .engine .tools .MLModelTool ;
@@ -117,7 +112,6 @@ public class MLChatAgentRunner implements MLAgentRunner {
117
112
public static final String FINAL_ANSWER = "final_answer" ;
118
113
public static final String THOUGHT_RESPONSE = "thought_response" ;
119
114
public static final String INTERACTIONS = "_interactions" ;
120
- public static final String DEFAULT_NO_ESCAPE_PARAMS = "_chat_history,_tools,_interactions,tool_configs" ;
121
115
public static final String INTERACTION_TEMPLATE_TOOL_RESPONSE = "interaction_template.tool_response" ;
122
116
public static final String CHAT_HISTORY_QUESTION_TEMPLATE = "chat_history_template.user_question" ;
123
117
public static final String CHAT_HISTORY_RESPONSE_TEMPLATE = "chat_history_template.ai_response" ;
@@ -170,116 +164,11 @@ public void run(MLAgent mlAgent, Map<String, String> inputParams, ActionListener
170
164
params .putAll (inputParams );
171
165
172
166
String llmInterface = params .get (LLM_INTERFACE );
173
- // todo: introduce function calling
174
- // handle parameters based on llmInterface
175
- if ("openai/v1/chat/completions" .equalsIgnoreCase (llmInterface )) {
176
- if (!params .containsKey (NO_ESCAPE_PARAMS )) {
177
- params .put (NO_ESCAPE_PARAMS , DEFAULT_NO_ESCAPE_PARAMS );
178
- }
179
- params .put (LLM_RESPONSE_FILTER , "$.choices[0].message.content" );
180
-
181
- params
182
- .put (
183
- TOOL_TEMPLATE ,
184
- "{\" type\" : \" function\" , \" function\" : { \" name\" : \" ${tool.name}\" , \" description\" : \" ${tool.description}\" , \" parameters\" : ${tool.attributes.input_schema}, \" strict\" : ${tool.attributes.strict:-false} } }"
185
- );
186
- params .put (TOOL_CALLS_PATH , "$.choices[0].message.tool_calls" );
187
- params .put (TOOL_CALLS_TOOL_NAME , "function.name" );
188
- params .put (TOOL_CALLS_TOOL_INPUT , "function.arguments" );
189
- params .put (TOOL_CALL_ID_PATH , "id" );
190
- params .put ("tool_configs" , ", \" tools\" : [${parameters._tools:-}], \" parallel_tool_calls\" : false" );
191
-
192
- params .put ("tool_choice" , "auto" );
193
- params .put ("parallel_tool_calls" , "false" );
194
-
195
- params .put ("interaction_template.assistant_tool_calls_path" , "$.choices[0].message" );
196
- params
197
- .put (
198
- "interaction_template.tool_response" ,
199
- "{ \" role\" : \" tool\" , \" tool_call_id\" : \" ${_interactions.tool_call_id}\" , \" content\" : \" ${_interactions.tool_response}\" }"
200
- );
201
-
202
- params .put ("chat_history_template.user_question" , "{\" role\" : \" user\" ,\" content\" : \" ${_chat_history.message.question}\" }" );
203
- params .put ("chat_history_template.ai_response" , "{\" role\" : \" assistant\" ,\" content\" : \" ${_chat_history.message.response}\" }" );
204
-
205
- params .put (LLM_FINISH_REASON_PATH , "$.choices[0].finish_reason" );
206
- params .put (LLM_FINISH_REASON_TOOL_USE , "tool_calls" );
207
- } else if ("bedrock/converse/claude" .equalsIgnoreCase (llmInterface )) {
208
- if (!params .containsKey (NO_ESCAPE_PARAMS )) {
209
- params .put (NO_ESCAPE_PARAMS , DEFAULT_NO_ESCAPE_PARAMS );
210
- }
211
- params .put (LLM_RESPONSE_FILTER , "$.output.message.content[0].text" );
212
-
213
- params
214
- .put (
215
- TOOL_TEMPLATE ,
216
- "{\" toolSpec\" :{\" name\" :\" ${tool.name}\" ,\" description\" :\" ${tool.description}\" ,\" inputSchema\" : {\" json\" : ${tool.attributes.input_schema} } }}"
217
- );
218
- params .put (TOOL_CALLS_PATH , "$.output.message.content[*].toolUse" );
219
- params .put (TOOL_CALLS_TOOL_NAME , "name" );
220
- params .put (TOOL_CALLS_TOOL_INPUT , "input" );
221
- params .put (TOOL_CALL_ID_PATH , "toolUseId" );
222
- params .put ("tool_configs" , ", \" toolConfig\" : {\" tools\" : [${parameters._tools:-}]}" );
223
-
224
- params .put ("interaction_template.assistant_tool_calls_path" , "$.output.message" );
225
- params
226
- .put (
227
- "interaction_template.tool_response" ,
228
- "{\" role\" :\" user\" ,\" content\" :[{\" toolResult\" :{\" toolUseId\" :\" ${_interactions.tool_call_id}\" ,\" content\" :[{\" text\" :\" ${_interactions.tool_response}\" }]}}]}"
229
- );
230
-
231
- params
232
- .put (
233
- "chat_history_template.user_question" ,
234
- "{\" role\" :\" user\" ,\" content\" :[{\" text\" :\" ${_chat_history.message.question}\" }]}"
235
- );
236
- params
237
- .put (
238
- "chat_history_template.ai_response" ,
239
- "{\" role\" :\" assistant\" ,\" content\" :[{\" text\" :\" ${_chat_history.message.response}\" }]}"
240
- );
241
-
242
- params .put (LLM_FINISH_REASON_PATH , "$.stopReason" );
243
- params .put (LLM_FINISH_REASON_TOOL_USE , "tool_use" );
244
- } else if ("bedrock/converse/deepseek_r1" .equalsIgnoreCase (llmInterface )) {
245
- if (!params .containsKey (NO_ESCAPE_PARAMS )) {
246
- params .put (NO_ESCAPE_PARAMS , "_chat_history,_interactions" );
247
- }
248
- params .put (LLM_RESPONSE_FILTER , "$.output.message.content[0].text" );
249
- params .put ("llm_final_response_post_filter" , "$.message.content[0].text" );
250
-
251
- params
252
- .put (
253
- TOOL_TEMPLATE ,
254
- "{\" toolSpec\" :{\" name\" :\" ${tool.name}\" ,\" description\" :\" ${tool.description}\" ,\" inputSchema\" : {\" json\" : ${tool.attributes.input_schema} } }}"
255
- );
256
- params .put (TOOL_CALLS_PATH , "_llm_response.tool_calls" );
257
- params .put (TOOL_CALLS_TOOL_NAME , "tool_name" );
258
- params .put (TOOL_CALLS_TOOL_INPUT , "input" );
259
- params .put (TOOL_CALL_ID_PATH , "id" );
260
-
261
- params .put ("interaction_template.assistant_tool_calls_path" , "$.output.message" );
262
- params .put ("interaction_template.assistant_tool_calls_exclude_path" , "[ \" $.output.message.content[?(@.reasoningContent)]\" ]" );
263
- params
264
- .put (
265
- "interaction_template.tool_response" ,
266
- "{\" role\" :\" user\" ,\" content\" :[ {\" text\" :\" {\\ \" tool_call_id\\ \" :\\ \" ${_interactions.tool_call_id}\\ \" ,\\ \" tool_result\\ \" : \\ \" ${_interactions.tool_response}\\ \" \" } ]}"
267
- );
268
-
269
- params
270
- .put (
271
- "chat_history_template.user_question" ,
272
- "{\" role\" :\" user\" ,\" content\" :[{\" text\" :\" ${_chat_history.message.question}\" }]}"
273
- );
274
- params
275
- .put (
276
- "chat_history_template.ai_response" ,
277
- "{\" role\" :\" assistant\" ,\" content\" :[{\" text\" :\" ${_chat_history.message.response}\" }]}"
278
- );
279
-
280
- params .put (LLM_FINISH_REASON_PATH , "_llm_response.stop_reason" );
281
- params .put (LLM_FINISH_REASON_TOOL_USE , "tool_use" );
167
+ FunctionCalling functionCalling = FunctionCallingFactory .create (llmInterface );
168
+ if (functionCalling != null ) {
169
+ functionCalling .configure (params );
282
170
}
171
+
283
172
String memoryType = mlAgent .getMemory ().getType ();
284
173
String memoryId = params .get (MLAgentExecutor .MEMORY_ID );
285
174
String appType = mlAgent .getAppType ();
@@ -347,23 +236,30 @@ public void run(MLAgent mlAgent, Map<String, String> inputParams, ActionListener
347
236
}
348
237
}
349
238
350
- runAgent (mlAgent , params , listener , memory , memory .getConversationId ());
239
+ runAgent (mlAgent , params , listener , memory , memory .getConversationId (), functionCalling );
351
240
}, e -> {
352
241
log .error ("Failed to get chat history" , e );
353
242
listener .onFailure (e );
354
243
}), messageHistoryLimit );
355
244
}, listener ::onFailure ));
356
245
}
357
246
358
- private void runAgent (MLAgent mlAgent , Map <String , String > params , ActionListener <Object > listener , Memory memory , String sessionId ) {
247
+ private void runAgent (
248
+ MLAgent mlAgent ,
249
+ Map <String , String > params ,
250
+ ActionListener <Object > listener ,
251
+ Memory memory ,
252
+ String sessionId ,
253
+ FunctionCalling functionCalling
254
+ ) {
359
255
List <MLToolSpec > toolSpecs = getMlToolSpecs (mlAgent , params );
360
256
361
257
// Create a common method to handle both success and failure cases
362
258
Consumer <List <MLToolSpec >> processTools = (allToolSpecs ) -> {
363
259
Map <String , Tool > tools = new HashMap <>();
364
260
Map <String , MLToolSpec > toolSpecMap = new HashMap <>();
365
261
createTools (toolFactories , params , allToolSpecs , tools , toolSpecMap , mlAgent );
366
- runReAct (mlAgent .getLlm (), tools , toolSpecMap , params , memory , sessionId , mlAgent .getTenantId (), listener );
262
+ runReAct (mlAgent .getLlm (), tools , toolSpecMap , params , memory , sessionId , mlAgent .getTenantId (), listener , functionCalling );
367
263
};
368
264
369
265
// Fetch MCP tools and handle both success and failure cases
@@ -384,7 +280,8 @@ private void runReAct(
384
280
Memory memory ,
385
281
String sessionId ,
386
282
String tenantId ,
387
- ActionListener <Object > listener
283
+ ActionListener <Object > listener ,
284
+ FunctionCalling functionCalling
388
285
) {
389
286
Map <String , String > tmpParameters = constructLLMParams (llm , parameters );
390
287
String prompt = constructLLMPrompt (tools , tmpParameters );
@@ -437,7 +334,8 @@ private void runReAct(
437
334
tmpModelTensorOutput ,
438
335
llmResponsePatterns ,
439
336
tools .keySet (),
440
- interactions
337
+ interactions ,
338
+ functionCalling
441
339
);
442
340
443
341
String thought = String .valueOf (modelOutput .get (THOUGHT ));
@@ -510,7 +408,8 @@ private void runReAct(
510
408
actionInput ,
511
409
toolParams ,
512
410
interactions ,
513
- toolCallId
411
+ toolCallId ,
412
+ functionCalling
514
413
);
515
414
} else {
516
415
String res = String .format (Locale .ROOT , "Failed to run the tool %s which is unsupported." , action );
@@ -675,20 +574,28 @@ private static void runTool(
675
574
String actionInput ,
676
575
Map <String , String > toolParams ,
677
576
List <String > interactions ,
678
- String toolCallId
577
+ String toolCallId ,
578
+ FunctionCalling functionCalling
679
579
) {
680
580
if (tools .get (action ).validate (toolParams )) {
681
581
try {
682
582
String finalAction = action ;
683
583
ActionListener <Object > toolListener = ActionListener .wrap (r -> {
684
- interactions
685
- .add (
686
- substitute (
687
- tmpParameters .get (INTERACTION_TEMPLATE_TOOL_RESPONSE ),
688
- Map .of (TOOL_CALL_ID , toolCallId , "tool_response" , processTextDoc (StringUtils .toJson (r ))),
689
- INTERACTIONS_PREFIX
690
- )
691
- );
584
+ if (functionCalling != null ) {
585
+ List <Map <String , Object >> toolResults = List .of (Map .of (TOOL_CALL_ID , toolCallId , TOOL_RESULT , Map .of ("text" , r )));
586
+ List <LLMMessage > llmMessages = functionCalling .supply (toolResults );
587
+ // TODO: support multiple tool calls at the same time so that multiple LLMMessages can be generated here
588
+ interactions .add (llmMessages .getFirst ().getResponse ());
589
+ } else {
590
+ interactions
591
+ .add (
592
+ substitute (
593
+ tmpParameters .get (INTERACTION_TEMPLATE_TOOL_RESPONSE ),
594
+ Map .of (TOOL_CALL_ID , toolCallId , "tool_response" , processTextDoc (StringUtils .toJson (r ))),
595
+ INTERACTIONS_PREFIX
596
+ )
597
+ );
598
+ }
692
599
nextStepListener .onResponse (r );
693
600
}, e -> {
694
601
interactions
0 commit comments