Skip to content

Commit 3534862

Browse files
jiapingzengdhrubo-osZhangxunmt
authored
Use function calling for existing LLM interfaces (#3888)
* function calling in agent draft Signed-off-by: Jiaping Zeng <jpz@amazon.com> * added function calling in runTool Signed-off-by: Jiaping Zeng <jpz@amazon.com> * UT fix for function calling Signed-off-by: Jiaping Zeng <jpz@amazon.com> * fixed openai llm interface Signed-off-by: Jiaping Zeng <jpz@amazon.com> * fixed test + reuse existing static where possible Signed-off-by: Jiaping Zeng <jpz@amazon.com> --------- Signed-off-by: Jiaping Zeng <jpz@amazon.com> Co-authored-by: Dhrubo Saha <dhrubo@amazon.com> Co-authored-by: Xun Zhang <xunzh@amazon.com>
1 parent fdbe3b4 commit 3534862

12 files changed

+296
-171
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
import org.opensearch.ml.engine.MLEngineClassLoader;
7676
import org.opensearch.ml.engine.algorithms.remote.McpConnectorExecutor;
7777
import org.opensearch.ml.engine.encryptor.Encryptor;
78+
import org.opensearch.ml.engine.function_calling.FunctionCalling;
7879
import org.opensearch.ml.engine.tools.McpSseTool;
7980
import org.opensearch.remote.metadata.client.GetDataObjectRequest;
8081
import org.opensearch.remote.metadata.client.SdkClient;
@@ -129,6 +130,9 @@ public class AgentUtils {
129130
public static final String LLM_FINISH_REASON_TOOL_USE = "llm_finish_reason_tool_use";
130131
public static final String TOOL_FILTERS_FIELD = "tool_filters";
131132

133+
// For function calling, do not escape the below params in connector by default
134+
public static final String DEFAULT_NO_ESCAPE_PARAMS = "_chat_history,_tools,_interactions,tool_configs";
135+
132136
public static String addExamplesToPrompt(Map<String, String> parameters, String prompt) {
133137
Map<String, String> examplesMap = new HashMap<>();
134138
if (parameters.containsKey(EXAMPLES)) {
@@ -299,7 +303,8 @@ public static Map<String, String> parseLLMOutput(
299303
ModelTensorOutput tmpModelTensorOutput,
300304
List<String> llmResponsePatterns,
301305
Set<String> inputTools,
302-
List<String> interactions
306+
List<String> interactions,
307+
FunctionCalling functionCalling
303308
) {
304309
Map<String, String> modelOutput = new HashMap<>();
305310
Map<String, ?> dataAsMap = tmpModelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap();
@@ -339,20 +344,33 @@ public static Map<String, String> parseLLMOutput(
339344
llmFinishReason = JsonPath.read(dataAsMap, llmFinishReasonPath);
340345
}
341346
if (parameters.get(LLM_FINISH_REASON_TOOL_USE).equalsIgnoreCase(llmFinishReason) || isToolUseResponse) {
342-
List toolCalls = null;
347+
List<Map<String, String>> toolCalls = null;
343348
try {
344-
String toolCallsPath = parameters.get(TOOL_CALLS_PATH);
345-
if (toolCallsPath.startsWith("_llm_response.")) {
346-
Map<String, Object> llmResponse = StringUtils.fromJson(response.toString(), RESPONSE_FIELD);
347-
toolCalls = JsonPath.read(llmResponse, toolCallsPath.substring("_llm_response.".length()));
349+
String toolName = "";
350+
String toolInput = "";
351+
String toolCallId = "";
352+
if (functionCalling != null) {
353+
toolCalls = functionCalling.handle(tmpModelTensorOutput, parameters);
354+
// TODO: support multiple tool calls here
355+
toolName = toolCalls.getFirst().get("tool_name");
356+
toolInput = toolCalls.getFirst().get("tool_input");
357+
toolCallId = toolCalls.getFirst().get("tool_call_id");
348358
} else {
349-
toolCalls = JsonPath.read(dataAsMap, toolCallsPath);
359+
String toolCallsPath = parameters.get(TOOL_CALLS_PATH);
360+
if (toolCallsPath.startsWith("_llm_response.")) {
361+
Map<String, Object> llmResponse = StringUtils.fromJson(response.toString(), RESPONSE_FIELD);
362+
toolCalls = JsonPath.read(llmResponse, toolCallsPath.substring("_llm_response.".length()));
363+
} else {
364+
toolCalls = JsonPath.read(dataAsMap, toolCallsPath);
365+
}
366+
toolName = JsonPath.read(toolCalls.get(0), parameters.get(TOOL_CALLS_TOOL_NAME));
367+
toolInput = StringUtils.toJson(JsonPath.read(toolCalls.get(0), parameters.get(TOOL_CALLS_TOOL_INPUT)));
368+
toolCallId = JsonPath.read(toolCalls.get(0), parameters.get(TOOL_CALL_ID_PATH));
350369
}
351370
String toolCallsMsgPath = parameters.get(INTERACTION_TEMPLATE_ASSISTANT_TOOL_CALLS_PATH);
352371
String toolCallsMsgExcludePath = parameters.get(INTERACTION_TEMPLATE_ASSISTANT_TOOL_CALLS_EXCLUDE_PATH);
353372
if (toolCallsMsgPath != null) {
354373
if (toolCallsMsgExcludePath != null) {
355-
356374
Map<String, ?> newDataAsMap = removeJsonPath(dataAsMap, toolCallsMsgExcludePath, false);
357375
Object toolCallsMsg = JsonPath.read(newDataAsMap, toolCallsMsgPath);
358376
interactions.add(StringUtils.toJson(toolCallsMsg));
@@ -371,9 +389,6 @@ public static Map<String, String> parseLLMOutput(
371389
)
372390
);
373391
}
374-
String toolName = JsonPath.read(toolCalls.get(0), parameters.get(TOOL_CALLS_TOOL_NAME));
375-
String toolInput = StringUtils.toJson(JsonPath.read(toolCalls.get(0), parameters.get(TOOL_CALLS_TOOL_INPUT)));
376-
String toolCallId = JsonPath.read(toolCalls.get(0), parameters.get(TOOL_CALL_ID_PATH));
377392
modelOutput.put(THOUGHT, "");
378393
modelOutput.put(ACTION, toolName);
379394
modelOutput.put(ACTION_INPUT, toolInput);

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

Lines changed: 41 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,13 @@
1111
import static org.opensearch.ml.common.utils.StringUtils.processTextDoc;
1212
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.DISABLE_TRACE;
1313
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.INTERACTIONS_PREFIX;
14-
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_FINISH_REASON_PATH;
15-
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_FINISH_REASON_TOOL_USE;
16-
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER;
17-
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.NO_ESCAPE_PARAMS;
1814
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_CHAT_HISTORY_PREFIX;
1915
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX;
2016
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX;
2117
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.RESPONSE_FORMAT_INSTRUCTION;
22-
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALLS_PATH;
23-
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALLS_TOOL_INPUT;
24-
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALLS_TOOL_NAME;
2518
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALL_ID;
26-
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALL_ID_PATH;
2719
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_RESPONSE;
28-
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_TEMPLATE;
20+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_RESULT;
2921
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.VERBOSE;
3022
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.cleanUpResource;
3123
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.constructToolParams;
@@ -79,6 +71,9 @@
7971
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
8072
import org.opensearch.ml.common.utils.StringUtils;
8173
import org.opensearch.ml.engine.encryptor.Encryptor;
74+
import org.opensearch.ml.engine.function_calling.FunctionCalling;
75+
import org.opensearch.ml.engine.function_calling.FunctionCallingFactory;
76+
import org.opensearch.ml.engine.function_calling.LLMMessage;
8277
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
8378
import org.opensearch.ml.engine.memory.ConversationIndexMessage;
8479
import org.opensearch.ml.engine.tools.MLModelTool;
@@ -117,7 +112,6 @@ public class MLChatAgentRunner implements MLAgentRunner {
117112
public static final String FINAL_ANSWER = "final_answer";
118113
public static final String THOUGHT_RESPONSE = "thought_response";
119114
public static final String INTERACTIONS = "_interactions";
120-
public static final String DEFAULT_NO_ESCAPE_PARAMS = "_chat_history,_tools,_interactions,tool_configs";
121115
public static final String INTERACTION_TEMPLATE_TOOL_RESPONSE = "interaction_template.tool_response";
122116
public static final String CHAT_HISTORY_QUESTION_TEMPLATE = "chat_history_template.user_question";
123117
public static final String CHAT_HISTORY_RESPONSE_TEMPLATE = "chat_history_template.ai_response";
@@ -170,116 +164,11 @@ public void run(MLAgent mlAgent, Map<String, String> inputParams, ActionListener
170164
params.putAll(inputParams);
171165

172166
String llmInterface = params.get(LLM_INTERFACE);
173-
// todo: introduce function calling
174-
// handle parameters based on llmInterface
175-
if ("openai/v1/chat/completions".equalsIgnoreCase(llmInterface)) {
176-
if (!params.containsKey(NO_ESCAPE_PARAMS)) {
177-
params.put(NO_ESCAPE_PARAMS, DEFAULT_NO_ESCAPE_PARAMS);
178-
}
179-
params.put(LLM_RESPONSE_FILTER, "$.choices[0].message.content");
180-
181-
params
182-
.put(
183-
TOOL_TEMPLATE,
184-
"{\"type\": \"function\", \"function\": { \"name\": \"${tool.name}\", \"description\": \"${tool.description}\", \"parameters\": ${tool.attributes.input_schema}, \"strict\": ${tool.attributes.strict:-false} } }"
185-
);
186-
params.put(TOOL_CALLS_PATH, "$.choices[0].message.tool_calls");
187-
params.put(TOOL_CALLS_TOOL_NAME, "function.name");
188-
params.put(TOOL_CALLS_TOOL_INPUT, "function.arguments");
189-
params.put(TOOL_CALL_ID_PATH, "id");
190-
params.put("tool_configs", ", \"tools\": [${parameters._tools:-}], \"parallel_tool_calls\": false");
191-
192-
params.put("tool_choice", "auto");
193-
params.put("parallel_tool_calls", "false");
194-
195-
params.put("interaction_template.assistant_tool_calls_path", "$.choices[0].message");
196-
params
197-
.put(
198-
"interaction_template.tool_response",
199-
"{ \"role\": \"tool\", \"tool_call_id\": \"${_interactions.tool_call_id}\", \"content\": \"${_interactions.tool_response}\" }"
200-
);
201-
202-
params.put("chat_history_template.user_question", "{\"role\": \"user\",\"content\": \"${_chat_history.message.question}\"}");
203-
params.put("chat_history_template.ai_response", "{\"role\": \"assistant\",\"content\": \"${_chat_history.message.response}\"}");
204-
205-
params.put(LLM_FINISH_REASON_PATH, "$.choices[0].finish_reason");
206-
params.put(LLM_FINISH_REASON_TOOL_USE, "tool_calls");
207-
} else if ("bedrock/converse/claude".equalsIgnoreCase(llmInterface)) {
208-
if (!params.containsKey(NO_ESCAPE_PARAMS)) {
209-
params.put(NO_ESCAPE_PARAMS, DEFAULT_NO_ESCAPE_PARAMS);
210-
}
211-
params.put(LLM_RESPONSE_FILTER, "$.output.message.content[0].text");
212-
213-
params
214-
.put(
215-
TOOL_TEMPLATE,
216-
"{\"toolSpec\":{\"name\":\"${tool.name}\",\"description\":\"${tool.description}\",\"inputSchema\": {\"json\": ${tool.attributes.input_schema} } }}"
217-
);
218-
params.put(TOOL_CALLS_PATH, "$.output.message.content[*].toolUse");
219-
params.put(TOOL_CALLS_TOOL_NAME, "name");
220-
params.put(TOOL_CALLS_TOOL_INPUT, "input");
221-
params.put(TOOL_CALL_ID_PATH, "toolUseId");
222-
params.put("tool_configs", ", \"toolConfig\": {\"tools\": [${parameters._tools:-}]}");
223-
224-
params.put("interaction_template.assistant_tool_calls_path", "$.output.message");
225-
params
226-
.put(
227-
"interaction_template.tool_response",
228-
"{\"role\":\"user\",\"content\":[{\"toolResult\":{\"toolUseId\":\"${_interactions.tool_call_id}\",\"content\":[{\"text\":\"${_interactions.tool_response}\"}]}}]}"
229-
);
230-
231-
params
232-
.put(
233-
"chat_history_template.user_question",
234-
"{\"role\":\"user\",\"content\":[{\"text\":\"${_chat_history.message.question}\"}]}"
235-
);
236-
params
237-
.put(
238-
"chat_history_template.ai_response",
239-
"{\"role\":\"assistant\",\"content\":[{\"text\":\"${_chat_history.message.response}\"}]}"
240-
);
241-
242-
params.put(LLM_FINISH_REASON_PATH, "$.stopReason");
243-
params.put(LLM_FINISH_REASON_TOOL_USE, "tool_use");
244-
} else if ("bedrock/converse/deepseek_r1".equalsIgnoreCase(llmInterface)) {
245-
if (!params.containsKey(NO_ESCAPE_PARAMS)) {
246-
params.put(NO_ESCAPE_PARAMS, "_chat_history,_interactions");
247-
}
248-
params.put(LLM_RESPONSE_FILTER, "$.output.message.content[0].text");
249-
params.put("llm_final_response_post_filter", "$.message.content[0].text");
250-
251-
params
252-
.put(
253-
TOOL_TEMPLATE,
254-
"{\"toolSpec\":{\"name\":\"${tool.name}\",\"description\":\"${tool.description}\",\"inputSchema\": {\"json\": ${tool.attributes.input_schema} } }}"
255-
);
256-
params.put(TOOL_CALLS_PATH, "_llm_response.tool_calls");
257-
params.put(TOOL_CALLS_TOOL_NAME, "tool_name");
258-
params.put(TOOL_CALLS_TOOL_INPUT, "input");
259-
params.put(TOOL_CALL_ID_PATH, "id");
260-
261-
params.put("interaction_template.assistant_tool_calls_path", "$.output.message");
262-
params.put("interaction_template.assistant_tool_calls_exclude_path", "[ \"$.output.message.content[?(@.reasoningContent)]\" ]");
263-
params
264-
.put(
265-
"interaction_template.tool_response",
266-
"{\"role\":\"user\",\"content\":[ {\"text\":\"{\\\"tool_call_id\\\":\\\"${_interactions.tool_call_id}\\\",\\\"tool_result\\\": \\\"${_interactions.tool_response}\\\"\"} ]}"
267-
);
268-
269-
params
270-
.put(
271-
"chat_history_template.user_question",
272-
"{\"role\":\"user\",\"content\":[{\"text\":\"${_chat_history.message.question}\"}]}"
273-
);
274-
params
275-
.put(
276-
"chat_history_template.ai_response",
277-
"{\"role\":\"assistant\",\"content\":[{\"text\":\"${_chat_history.message.response}\"}]}"
278-
);
279-
280-
params.put(LLM_FINISH_REASON_PATH, "_llm_response.stop_reason");
281-
params.put(LLM_FINISH_REASON_TOOL_USE, "tool_use");
167+
FunctionCalling functionCalling = FunctionCallingFactory.create(llmInterface);
168+
if (functionCalling != null) {
169+
functionCalling.configure(params);
282170
}
171+
283172
String memoryType = mlAgent.getMemory().getType();
284173
String memoryId = params.get(MLAgentExecutor.MEMORY_ID);
285174
String appType = mlAgent.getAppType();
@@ -347,23 +236,30 @@ public void run(MLAgent mlAgent, Map<String, String> inputParams, ActionListener
347236
}
348237
}
349238

350-
runAgent(mlAgent, params, listener, memory, memory.getConversationId());
239+
runAgent(mlAgent, params, listener, memory, memory.getConversationId(), functionCalling);
351240
}, e -> {
352241
log.error("Failed to get chat history", e);
353242
listener.onFailure(e);
354243
}), messageHistoryLimit);
355244
}, listener::onFailure));
356245
}
357246

358-
private void runAgent(MLAgent mlAgent, Map<String, String> params, ActionListener<Object> listener, Memory memory, String sessionId) {
247+
private void runAgent(
248+
MLAgent mlAgent,
249+
Map<String, String> params,
250+
ActionListener<Object> listener,
251+
Memory memory,
252+
String sessionId,
253+
FunctionCalling functionCalling
254+
) {
359255
List<MLToolSpec> toolSpecs = getMlToolSpecs(mlAgent, params);
360256

361257
// Create a common method to handle both success and failure cases
362258
Consumer<List<MLToolSpec>> processTools = (allToolSpecs) -> {
363259
Map<String, Tool> tools = new HashMap<>();
364260
Map<String, MLToolSpec> toolSpecMap = new HashMap<>();
365261
createTools(toolFactories, params, allToolSpecs, tools, toolSpecMap, mlAgent);
366-
runReAct(mlAgent.getLlm(), tools, toolSpecMap, params, memory, sessionId, mlAgent.getTenantId(), listener);
262+
runReAct(mlAgent.getLlm(), tools, toolSpecMap, params, memory, sessionId, mlAgent.getTenantId(), listener, functionCalling);
367263
};
368264

369265
// Fetch MCP tools and handle both success and failure cases
@@ -384,7 +280,8 @@ private void runReAct(
384280
Memory memory,
385281
String sessionId,
386282
String tenantId,
387-
ActionListener<Object> listener
283+
ActionListener<Object> listener,
284+
FunctionCalling functionCalling
388285
) {
389286
Map<String, String> tmpParameters = constructLLMParams(llm, parameters);
390287
String prompt = constructLLMPrompt(tools, tmpParameters);
@@ -437,7 +334,8 @@ private void runReAct(
437334
tmpModelTensorOutput,
438335
llmResponsePatterns,
439336
tools.keySet(),
440-
interactions
337+
interactions,
338+
functionCalling
441339
);
442340

443341
String thought = String.valueOf(modelOutput.get(THOUGHT));
@@ -510,7 +408,8 @@ private void runReAct(
510408
actionInput,
511409
toolParams,
512410
interactions,
513-
toolCallId
411+
toolCallId,
412+
functionCalling
514413
);
515414
} else {
516415
String res = String.format(Locale.ROOT, "Failed to run the tool %s which is unsupported.", action);
@@ -675,20 +574,28 @@ private static void runTool(
675574
String actionInput,
676575
Map<String, String> toolParams,
677576
List<String> interactions,
678-
String toolCallId
577+
String toolCallId,
578+
FunctionCalling functionCalling
679579
) {
680580
if (tools.get(action).validate(toolParams)) {
681581
try {
682582
String finalAction = action;
683583
ActionListener<Object> toolListener = ActionListener.wrap(r -> {
684-
interactions
685-
.add(
686-
substitute(
687-
tmpParameters.get(INTERACTION_TEMPLATE_TOOL_RESPONSE),
688-
Map.of(TOOL_CALL_ID, toolCallId, "tool_response", processTextDoc(StringUtils.toJson(r))),
689-
INTERACTIONS_PREFIX
690-
)
691-
);
584+
if (functionCalling != null) {
585+
List<Map<String, Object>> toolResults = List.of(Map.of(TOOL_CALL_ID, toolCallId, TOOL_RESULT, Map.of("text", r)));
586+
List<LLMMessage> llmMessages = functionCalling.supply(toolResults);
587+
// TODO: support multiple tool calls at the same time so that multiple LLMMessages can be generated here
588+
interactions.add(llmMessages.getFirst().getResponse());
589+
} else {
590+
interactions
591+
.add(
592+
substitute(
593+
tmpParameters.get(INTERACTION_TEMPLATE_TOOL_RESPONSE),
594+
Map.of(TOOL_CALL_ID, toolCallId, "tool_response", processTextDoc(StringUtils.toJson(r))),
595+
INTERACTIONS_PREFIX
596+
)
597+
);
598+
}
692599
nextStepListener.onResponse(r);
693600
}, e -> {
694601
interactions

0 commit comments

Comments
 (0)