Skip to content

Commit b9d5201

Browse files
authored
function calling for openai v1, bedrock claude and deepseek (#3712)
Signed-off-by: Jing Zhang <jngz@amazon.com>
1 parent 517890e commit b9d5201

12 files changed

+699
-0
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES;
2525
import static org.opensearch.ml.engine.memory.ConversationIndexMemory.LAST_N_INTERACTIONS;
2626

27+
import java.lang.reflect.Type;
2728
import java.security.AccessController;
2829
import java.security.PrivilegedActionException;
2930
import java.security.PrivilegedExceptionAction;
@@ -48,6 +49,11 @@
4849
import org.opensearch.ml.common.spi.tools.Tool;
4950
import org.opensearch.ml.common.utils.StringUtils;
5051

52+
import com.google.gson.reflect.TypeToken;
53+
import com.jayway.jsonpath.DocumentContext;
54+
import com.jayway.jsonpath.JsonPath;
55+
import com.jayway.jsonpath.PathNotFoundException;
56+
5157
import lombok.extern.log4j.Log4j2;
5258

5359
@Log4j2
@@ -62,6 +68,13 @@ public class AgentUtils {
6268
public static final String DISABLE_TRACE = "disable_trace";
6369
public static final String VERBOSE = "verbose";
6470
public static final String LLM_GEN_INPUT = "llm_generated_input";
71+
public static final String LLM_RESPONSE_EXCLUDE_PATH = "llm_response_exclude_path";
72+
public static final String LLM_RESPONSE_FILTER = "llm_response_filter";
73+
public static final String TOOL_RESULT = "tool_result";
74+
public static final String TOOL_CALL_ID = "tool_call_id";
75+
public static final String LLM_INTERFACE_BEDROCK_CONVERSE_CLAUDE = "bedrock/converse/claude";
76+
public static final String LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS = "openai/v1/chat/completions";
77+
public static final String LLM_INTERFACE_BEDROCK_CONVERSE_DEEPSEEK_R1 = "bedrock/converse/deepseek_r1";
6578

6679
public static String addExamplesToPrompt(Map<String, String> parameters, String prompt) {
6780
Map<String, String> examplesMap = new HashMap<>();
@@ -503,4 +516,40 @@ public static Map<String, String> constructToolParams(
503516
}
504517
return toolParams;
505518
}
519+
520+
public static Map<String, ?> removeJsonPath(Map<String, ?> json, String excludePaths, boolean inPlace) {
521+
Type listType = new TypeToken<List<String>>() {
522+
}.getType();
523+
List<String> excludedPath = gson.fromJson(excludePaths, listType);
524+
return removeJsonPath(json, excludedPath, inPlace);
525+
}
526+
527+
private static Map<String, ?> removeJsonPath(Map<String, ?> json, List<String> excludePaths, boolean inPlace) {
528+
529+
if (json == null || excludePaths == null || excludePaths.isEmpty()) {
530+
return json;
531+
}
532+
if (inPlace) {
533+
DocumentContext context = JsonPath.parse(json);
534+
for (String path : excludePaths) {
535+
try {
536+
context.delete(path);
537+
} catch (PathNotFoundException e) {
538+
log.warn("can't find path: {}", path);
539+
}
540+
}
541+
return json;
542+
} else {
543+
Map<String, Object> copy = StringUtils.fromJson(gson.toJson(json), "response");
544+
DocumentContext context = JsonPath.parse(copy);
545+
for (String path : excludePaths) {
546+
try {
547+
context.delete(path);
548+
} catch (PathNotFoundException e) {
549+
log.warn("can't find path: {}", path);
550+
}
551+
}
552+
return context.json();
553+
}
554+
}
506555
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.engine.function_calling;
7+
8+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_EXCLUDE_PATH;
9+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER;
10+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALL_ID;
11+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_RESULT;
12+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.removeJsonPath;
13+
14+
import java.util.ArrayList;
15+
import java.util.List;
16+
import java.util.Map;
17+
18+
import org.opensearch.core.common.util.CollectionUtils;
19+
import org.opensearch.ml.common.output.model.ModelTensorOutput;
20+
import org.opensearch.ml.common.utils.StringUtils;
21+
22+
import com.jayway.jsonpath.JsonPath;
23+
24+
public class BedrockConverseDeepseekR1FunctionCalling implements FunctionCalling {
25+
public static final String FINISH_REASON_PATH = "stop_reason";
26+
public static final String FINISH_REASON = "tool_use";
27+
public static final String CALL_PATH = "tool_calls";
28+
public static final String NAME = "tool_name";
29+
public static final String INPUT = "input";
30+
public static final String ID_PATH = "id";
31+
public static final String TOOL_ERROR = "tool_error";
32+
public static final String BEDROCK_DEEPSEEK_R1_TOOL_TEMPLATE =
33+
"{\"toolSpec\":{\"name\":\"${tool.name}\",\"description\":\"${tool.description}\",\"inputSchema\": {\"json\": ${tool.attributes.input_schema} } }}";
34+
35+
@Override
36+
public void configure(Map<String, String> params) {
37+
params.put("tool_template", BEDROCK_DEEPSEEK_R1_TOOL_TEMPLATE);
38+
}
39+
40+
@Override
41+
public List<Map<String, String>> handle(ModelTensorOutput tmpModelTensorOutput, Map<String, String> parameters) {
42+
List<Map<String, String>> output = new ArrayList<>();
43+
Map<String, ?> dataAsMap = tmpModelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap();
44+
String llmResponseExcludePath = parameters.get(LLM_RESPONSE_EXCLUDE_PATH);
45+
if (llmResponseExcludePath != null) {
46+
dataAsMap = removeJsonPath(dataAsMap, llmResponseExcludePath, true);
47+
}
48+
Object response = JsonPath.read(dataAsMap, parameters.get(LLM_RESPONSE_FILTER));
49+
Map<String, Object> llmResponse = StringUtils.fromJson(response.toString(), "response");
50+
String llmFinishReason = JsonPath.read(llmResponse, FINISH_REASON_PATH);
51+
if (!llmFinishReason.contentEquals(FINISH_REASON)) {
52+
return output;
53+
}
54+
List toolCalls = JsonPath.read(llmResponse, CALL_PATH);
55+
if (CollectionUtils.isEmpty(toolCalls)) {
56+
return output;
57+
}
58+
for (Object call : toolCalls) {
59+
String toolName = JsonPath.read(call, NAME);
60+
String toolInput = StringUtils.toJson(JsonPath.read(call, INPUT));
61+
String toolCallId = JsonPath.read(call, ID_PATH);
62+
output.add(Map.of("tool_name", toolName, "tool_input", toolInput, "tool_call_id", toolCallId));
63+
}
64+
return output;
65+
}
66+
67+
@Override
68+
public List<LLMMessage> supply(List<Map<String, Object>> toolResults) {
69+
BedrockMessage toolMessage = new BedrockMessage();
70+
for (Map toolResult : toolResults) {
71+
String toolUseId = (String) toolResult.get(TOOL_CALL_ID);
72+
if (toolUseId == null) {
73+
continue;
74+
}
75+
toolMessage.getContent().add(Map.of("text", Map.of(TOOL_CALL_ID, toolUseId, TOOL_RESULT, toolResult.get(TOOL_RESULT))));
76+
}
77+
78+
return List.of(toolMessage);
79+
}
80+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.engine.function_calling;
7+
8+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_EXCLUDE_PATH;
9+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALL_ID;
10+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_RESULT;
11+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.removeJsonPath;
12+
13+
import java.util.ArrayList;
14+
import java.util.List;
15+
import java.util.Map;
16+
17+
import org.opensearch.core.common.util.CollectionUtils;
18+
import org.opensearch.ml.common.output.model.ModelTensorOutput;
19+
import org.opensearch.ml.common.utils.StringUtils;
20+
21+
import com.jayway.jsonpath.JsonPath;
22+
23+
import lombok.Data;
24+
25+
public class BedrockConverseFunctionCalling implements FunctionCalling {
26+
public static final String FINISH_REASON_PATH = "$.stopReason";
27+
public static final String FINISH_REASON = "tool_use";
28+
public static final String CALL_PATH = "$.output.message.content[*].toolUse";
29+
public static final String NAME = "name";
30+
public static final String INPUT = "input";
31+
public static final String ID_PATH = "toolUseId";
32+
public static final String TOOL_ERROR = "tool_error";
33+
public static final String BEDROCK_CONVERSE_TOOL_TEMPLATE =
34+
"{\"toolSpec\":{\"name\":\"${tool.name}\",\"description\":\"${tool.description}\",\"inputSchema\": {\"json\": ${tool.attributes.input_schema} } }}";
35+
36+
@Override
37+
public void configure(Map<String, String> params) {
38+
params.put("tool_template", BEDROCK_CONVERSE_TOOL_TEMPLATE);
39+
params.put("tool_configs", ", \"toolConfig\": {\"tools\": [${parameters._tools:-}]}");
40+
}
41+
42+
@Override
43+
public List<Map<String, String>> handle(ModelTensorOutput tmpModelTensorOutput, Map<String, String> parameters) {
44+
List<Map<String, String>> output = new ArrayList<>();
45+
Map<String, ?> dataAsMap = tmpModelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap();
46+
String llmResponseExcludePath = parameters.get(LLM_RESPONSE_EXCLUDE_PATH);
47+
if (llmResponseExcludePath != null) {
48+
dataAsMap = removeJsonPath(dataAsMap, llmResponseExcludePath, true);
49+
}
50+
String llmFinishReason = JsonPath.read(dataAsMap, FINISH_REASON_PATH);
51+
if (!llmFinishReason.contentEquals(FINISH_REASON)) {
52+
return output;
53+
}
54+
List toolCalls = JsonPath.read(dataAsMap, CALL_PATH);
55+
if (CollectionUtils.isEmpty(toolCalls)) {
56+
return output;
57+
}
58+
for (Object call : toolCalls) {
59+
String toolName = JsonPath.read(call, NAME);
60+
String toolInput = StringUtils.toJson(JsonPath.read(call, INPUT));
61+
String toolCallId = JsonPath.read(call, ID_PATH);
62+
output.add(Map.of("tool_name", toolName, "tool_input", toolInput, "tool_call_id", toolCallId));
63+
}
64+
return output;
65+
}
66+
67+
@Override
68+
public List<LLMMessage> supply(List<Map<String, Object>> toolResults) {
69+
BedrockMessage toolMessage = new BedrockMessage();
70+
for (Map toolResult : toolResults) {
71+
String toolUseId = (String) toolResult.get(TOOL_CALL_ID);
72+
if (toolUseId == null) {
73+
continue;
74+
}
75+
ToolResult result = new ToolResult();
76+
result.setToolUseId(toolUseId);
77+
result.getContent().add(toolResult.get(TOOL_RESULT));
78+
if (toolResult.containsKey(TOOL_ERROR)) {
79+
result.setStatus("error");
80+
}
81+
toolMessage.getContent().add(result);
82+
}
83+
84+
return List.of(toolMessage);
85+
}
86+
87+
@Data
88+
public static class ToolResult {
89+
private String toolUseId;
90+
private List<Object> content = new ArrayList<>();
91+
private String status;
92+
}
93+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.engine.function_calling;
7+
8+
import java.util.ArrayList;
9+
import java.util.List;
10+
11+
import lombok.Data;
12+
13+
@Data
14+
public class BedrockMessage implements LLMMessage {
15+
16+
private String role;
17+
private List<Object> content = new ArrayList<>();
18+
19+
BedrockMessage() {
20+
this("user");
21+
}
22+
23+
BedrockMessage(String role) {
24+
this(role, null);
25+
}
26+
27+
BedrockMessage(String role, List<Object> content) {
28+
this.role = role;
29+
if (content != null) {
30+
this.content = content;
31+
}
32+
}
33+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.engine.function_calling;
7+
8+
import java.util.List;
9+
import java.util.Map;
10+
11+
import org.opensearch.ml.common.output.model.ModelTensorOutput;
12+
13+
/**
14+
* A general LLM function calling interface.
15+
*/
16+
public interface FunctionCalling {
17+
18+
/**
19+
* Configure all parameters related to function calling.
20+
* @param params the parameters used to configure a request to LLM
21+
*/
22+
void configure(Map<String, String> params);
23+
24+
/**
25+
* Handle the response from LLM to get the function calling context.
26+
* @param modelTensorOutput the response from LLM
27+
* @param parameters some parameters
28+
* @return a list of tools with something like name, input, etc.
29+
*/
30+
List<Map<String, String>> handle(ModelTensorOutput modelTensorOutput, Map<String, String> parameters);
31+
32+
/**
33+
* According to results of tools to render a LLMMessage provided to LLM
34+
* @param toolResults results from tools
35+
* @return a LLMMessage containing tool results.
36+
*/
37+
List<LLMMessage> supply(List<Map<String, Object>> toolResults);
38+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.engine.function_calling;
7+
8+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_BEDROCK_CONVERSE_CLAUDE;
9+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_BEDROCK_CONVERSE_DEEPSEEK_R1;
10+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS;
11+
12+
import java.util.Locale;
13+
14+
import org.apache.commons.lang3.StringUtils;
15+
import org.opensearch.ml.common.exception.MLException;
16+
17+
public class FunctionCallingFactory {
18+
public static FunctionCalling create(String llmInterface) {
19+
if (StringUtils.isBlank(llmInterface)) {
20+
return null;
21+
}
22+
23+
switch (llmInterface.trim().toLowerCase(Locale.ROOT)) {
24+
case LLM_INTERFACE_BEDROCK_CONVERSE_CLAUDE:
25+
return new BedrockConverseFunctionCalling();
26+
case LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS:
27+
return new OpenaiV1ChatCompletionsFunctionCalling();
28+
case LLM_INTERFACE_BEDROCK_CONVERSE_DEEPSEEK_R1:
29+
return new BedrockConverseDeepseekR1FunctionCalling();
30+
default:
31+
throw new MLException(String.format("Unsupported llm interface: {}.", llmInterface));
32+
}
33+
}
34+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.engine.function_calling;
7+
8+
public interface LLMMessage {
9+
public String getRole();
10+
11+
public Object getContent();
12+
}

0 commit comments

Comments
 (0)