Skip to content

Commit 310f556

Browse files
authored
[Code Quality] Adding test cases for PlanExecuteReflect Agent (#3778)
* feat: adding UTs for plan execute reflect agent Signed-off-by: Pavan Yekbote <pybot@amazon.com> * wip Signed-off-by: Pavan Yekbote <pybot@amazon.com> * feat: add more test cases for per agent Signed-off-by: Pavan Yekbote <pybot@amazon.com> * fix: saveAndReturnFinalResult testcase Signed-off-by: Pavan Yekbote <pybot@amazon.com> * feat: adding test cases for agentutils, connectorutils, transportregisteragent Signed-off-by: Pavan Yekbote <pybot@amazon.com> * chore: remove max steps test post rebase Signed-off-by: Pavan Yekbote <pybot@amazon.com> * chore: add todo for max_steps reached Signed-off-by: Pavan Yekbote <pybot@amazon.com> --------- Signed-off-by: Pavan Yekbote <pybot@amazon.com>
1 parent 8fe39de commit 310f556

File tree

5 files changed

+950
-11
lines changed

5 files changed

+950
-11
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
import org.opensearch.remote.metadata.client.SdkClient;
6767
import org.opensearch.transport.client.Client;
6868

69+
import com.google.common.annotations.VisibleForTesting;
6970
import com.jayway.jsonpath.JsonPath;
7071

7172
import joptsimple.internal.Strings;
@@ -154,7 +155,8 @@ public MLPlanExecuteAndReflectAgentRunner(
154155
this.plannerWithHistoryPromptTemplate = DEFAULT_PLANNER_WITH_HISTORY_PROMPT_TEMPLATE;
155156
}
156157

157-
private void setupPromptParameters(Map<String, String> params) {
158+
@VisibleForTesting
159+
void setupPromptParameters(Map<String, String> params) {
158160
// populated depending on whether LLM is asked to plan or re-evaluate
159161
// removed here, so that error is thrown in case this field is not populated
160162
params.remove(PROMPT_FIELD);
@@ -203,22 +205,26 @@ private void setupPromptParameters(Map<String, String> params) {
203205
}
204206
}
205207

206-
private void usePlannerPromptTemplate(Map<String, String> params) {
208+
@VisibleForTesting
209+
void usePlannerPromptTemplate(Map<String, String> params) {
207210
params.put(PROMPT_TEMPLATE_FIELD, this.plannerPromptTemplate);
208211
populatePrompt(params);
209212
}
210213

211-
private void useReflectPromptTemplate(Map<String, String> params) {
214+
@VisibleForTesting
215+
void useReflectPromptTemplate(Map<String, String> params) {
212216
params.put(PROMPT_TEMPLATE_FIELD, this.reflectPromptTemplate);
213217
populatePrompt(params);
214218
}
215219

216-
private void usePlannerWithHistoryPromptTemplate(Map<String, String> params) {
220+
@VisibleForTesting
221+
void usePlannerWithHistoryPromptTemplate(Map<String, String> params) {
217222
params.put(PROMPT_TEMPLATE_FIELD, this.plannerWithHistoryPromptTemplate);
218223
populatePrompt(params);
219224
}
220225

221-
private void populatePrompt(Map<String, String> allParams) {
226+
@VisibleForTesting
227+
void populatePrompt(Map<String, String> allParams) {
222228
String promptTemplate = allParams.get(PROMPT_TEMPLATE_FIELD);
223229
StringSubstitutor promptSubstitutor = new StringSubstitutor(allParams, "${parameters.", "}");
224230
String prompt = promptSubstitutor.replace(promptTemplate);
@@ -475,7 +481,8 @@ private void executePlanningLoop(
475481
client.execute(MLPredictionTaskAction.INSTANCE, request, planListener);
476482
}
477483

478-
private Map<String, String> parseLLMOutput(Map<String, String> allParams, ModelTensorOutput modelTensorOutput) {
484+
@VisibleForTesting
485+
Map<String, String> parseLLMOutput(Map<String, String> allParams, ModelTensorOutput modelTensorOutput) {
479486
Map<String, String> modelOutput = new HashMap<>();
480487
Map<String, ?> dataAsMap = modelTensorOutput.getMlModelOutputs().getFirst().getMlModelTensors().getFirst().getDataAsMap();
481488
String llmResponse;
@@ -513,7 +520,8 @@ private Map<String, String> parseLLMOutput(Map<String, String> allParams, ModelT
513520
return modelOutput;
514521
}
515522

516-
private String extractJsonFromMarkdown(String response) {
523+
@VisibleForTesting
524+
String extractJsonFromMarkdown(String response) {
517525
response = response.trim();
518526
if (response.contains("```json")) {
519527
response = response.substring(response.indexOf("```json") + "```json".length());
@@ -530,7 +538,8 @@ private String extractJsonFromMarkdown(String response) {
530538
return response;
531539
}
532540

533-
private void addToolsToPrompt(Map<String, Tool> tools, Map<String, String> allParams) {
541+
@VisibleForTesting
542+
void addToolsToPrompt(Map<String, Tool> tools, Map<String, String> allParams) {
534543
StringBuilder toolsPrompt = new StringBuilder("In this environment, you have access to the below tools: \n");
535544
for (Map.Entry<String, Tool> entry : tools.entrySet()) {
536545
String toolName = entry.getKey();
@@ -543,11 +552,13 @@ private void addToolsToPrompt(Map<String, Tool> tools, Map<String, String> allPa
543552
cleanUpResource(tools);
544553
}
545554

546-
private void addSteps(List<String> steps, Map<String, String> allParams, String field) {
555+
@VisibleForTesting
556+
void addSteps(List<String> steps, Map<String, String> allParams, String field) {
547557
allParams.put(field, String.join(", ", steps));
548558
}
549559

550-
private void saveAndReturnFinalResult(
560+
@VisibleForTesting
561+
void saveAndReturnFinalResult(
551562
ConversationIndexMemory memory,
552563
String parentInteractionId,
553564
String reactAgentMemoryId,
@@ -586,7 +597,8 @@ private void saveAndReturnFinalResult(
586597
}));
587598
}
588599

589-
private static List<ModelTensors> createModelTensors(
600+
@VisibleForTesting
601+
static List<ModelTensors> createModelTensors(
590602
String sessionId,
591603
String parentInteractionId,
592604
String reactAgentMemoryId,

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,29 @@
99
import static org.junit.Assert.assertThrows;
1010
import static org.mockito.ArgumentMatchers.any;
1111
import static org.mockito.ArgumentMatchers.anyString;
12+
import static org.mockito.ArgumentMatchers.argThat;
1213
import static org.mockito.Mockito.doNothing;
1314
import static org.mockito.Mockito.mock;
1415
import static org.mockito.Mockito.verify;
1516
import static org.mockito.Mockito.when;
1617
import static org.opensearch.ml.common.CommonValue.MCP_CONNECTORS_FIELD;
1718
import static org.opensearch.ml.common.CommonValue.MCP_CONNECTOR_ID_FIELD;
19+
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
1820
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_FINISH_REASON_PATH;
1921
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_FINISH_REASON_TOOL_USE;
2022
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_GEN_INPUT;
23+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_EXCLUDE_PATH;
2124
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER;
2225
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX;
2326
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX;
27+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOLS;
2428
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALLS_PATH;
2529
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALLS_TOOL_INPUT;
2630
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALLS_TOOL_NAME;
2731
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALL_ID;
2832
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALL_ID_PATH;
2933
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_FILTERS_FIELD;
34+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_TEMPLATE;
3035
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION;
3136
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION_INPUT;
3237
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CHAT_HISTORY;
@@ -81,6 +86,8 @@
8186
import org.opensearch.threadpool.ThreadPool;
8287
import org.opensearch.transport.client.Client;
8388

89+
import com.google.gson.JsonSyntaxException;
90+
8491
public class AgentUtilsTest extends MLStaticMockBase {
8592

8693
@Mock
@@ -1197,6 +1204,48 @@ public void testParseLLMOutputWithDeepseekFormat() {
11971204
Assert.assertTrue(output3.get(FINAL_ANSWER).contains("This is a test response"));
11981205
}
11991206

1207+
@Test
1208+
public void testAddToolsToFunctionCalling() {
1209+
Map<String, Tool> tools = new HashMap<>();
1210+
tools.put("Tool1", tool1);
1211+
tools.put("Tool2", tool2);
1212+
1213+
when(tool1.getName()).thenReturn("Tool1");
1214+
when(tool1.getDescription()).thenReturn("Description of Tool1");
1215+
when(tool1.getAttributes()).thenReturn(Map.of("param1", "value1"));
1216+
1217+
when(tool2.getName()).thenReturn("Tool2");
1218+
when(tool2.getDescription()).thenReturn("Description of Tool2");
1219+
when(tool2.getAttributes()).thenReturn(Map.of("param2", "value2"));
1220+
1221+
Map<String, String> parameters = new HashMap<>();
1222+
String toolTemplate = "{\"name\": \"${tool.name}\", \"description\": \"${tool.description}\"}";
1223+
parameters.put(TOOL_TEMPLATE, toolTemplate);
1224+
1225+
List<String> inputTools = Arrays.asList("Tool1", "Tool2");
1226+
String prompt = "test prompt";
1227+
1228+
String expectedTool1 = "{\"name\": \"Tool1\", \"description\": \"Description of Tool1\"}";
1229+
String expectedTool2 = "{\"name\": \"Tool2\", \"description\": \"Description of Tool2\"}";
1230+
String expectedTools = expectedTool1 + ", " + expectedTool2;
1231+
1232+
AgentUtils.addToolsToFunctionCalling(tools, parameters, inputTools, prompt);
1233+
1234+
assertEquals(expectedTools, parameters.get(TOOLS));
1235+
}
1236+
1237+
@Test
1238+
public void testAddToolsToFunctionCalling_ToolNotRegistered() {
1239+
Map<String, Tool> tools = new HashMap<>();
1240+
tools.put("Tool1", tool1);
1241+
Map<String, String> parameters = new HashMap<>();
1242+
parameters.put(TOOL_TEMPLATE, "template");
1243+
List<String> inputTools = Arrays.asList("Tool1", "UnregisteredTool");
1244+
String prompt = "test prompt";
1245+
1246+
assertThrows(IllegalArgumentException.class, () -> AgentUtils.addToolsToFunctionCalling(tools, parameters, inputTools, prompt));
1247+
}
1248+
12001249
private static MLToolSpec buildTool(String name) {
12011250
return MLToolSpec.builder().type(McpSseTool.TYPE).name(name).description("mock").build();
12021251
}
@@ -1362,4 +1411,198 @@ private void verifyConstructToolParams(String question, String actionInput, Cons
13621411
Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput);
13631412
verify.accept(toolParams);
13641413
}
1414+
1415+
@Test
1416+
public void testParseLLMOutput_WithExcludePath() {
1417+
Map<String, String> parameters = new HashMap<>();
1418+
parameters.put(LLM_RESPONSE_EXCLUDE_PATH, "[\"$.exclude_field\"]");
1419+
1420+
Map<String, Object> dataAsMap = new HashMap<>();
1421+
dataAsMap.put("exclude_field", "should be excluded");
1422+
dataAsMap.put("keep_field", "should be kept");
1423+
1424+
ModelTensorOutput modelTensorOutput = ModelTensorOutput
1425+
.builder()
1426+
.mlModelOutputs(
1427+
List
1428+
.of(
1429+
ModelTensors
1430+
.builder()
1431+
.mlModelTensors(List.of(ModelTensor.builder().name("response").dataAsMap(dataAsMap).build()))
1432+
.build()
1433+
)
1434+
)
1435+
.build();
1436+
1437+
Map<String, String> output = AgentUtils.parseLLMOutput(parameters, modelTensorOutput, null, Set.of(), new ArrayList<>());
1438+
1439+
Assert.assertTrue(output.containsKey(THOUGHT_RESPONSE));
1440+
Assert.assertFalse(output.get(THOUGHT_RESPONSE).contains("exclude_field"));
1441+
Assert.assertTrue(output.get(THOUGHT_RESPONSE).contains("keep_field"));
1442+
}
1443+
1444+
@Test
1445+
public void testParseLLMOutput_EmptyDataAsMap() {
1446+
Map<String, Object> dataAsMap = new HashMap<>();
1447+
ModelTensorOutput modelTensorOutput = ModelTensorOutput
1448+
.builder()
1449+
.mlModelOutputs(
1450+
List
1451+
.of(
1452+
ModelTensors
1453+
.builder()
1454+
.mlModelTensors(List.of(ModelTensor.builder().name("response").dataAsMap(dataAsMap).build()))
1455+
.build()
1456+
)
1457+
)
1458+
.build();
1459+
1460+
Map<String, String> output = AgentUtils.parseLLMOutput(new HashMap<>(), modelTensorOutput, null, Set.of(), new ArrayList<>());
1461+
1462+
Assert.assertTrue(output.containsKey(THOUGHT_RESPONSE));
1463+
Assert.assertEquals("{}", output.get(THOUGHT_RESPONSE));
1464+
}
1465+
1466+
@Test
1467+
public void testParseLLMOutput_ToolUse() {
1468+
Map<String, String> parameters = new HashMap<>();
1469+
parameters.put(TOOL_CALLS_PATH, "$.tool_calls");
1470+
parameters.put(TOOL_CALLS_TOOL_NAME, "name");
1471+
parameters.put(TOOL_CALLS_TOOL_INPUT, "input");
1472+
parameters.put(TOOL_CALL_ID_PATH, "id");
1473+
parameters.put(LLM_RESPONSE_FILTER, "$.response");
1474+
parameters.put(LLM_FINISH_REASON_PATH, "$.finish_reason");
1475+
parameters.put(LLM_FINISH_REASON_TOOL_USE, "tool_use");
1476+
1477+
Map<String, Object> dataAsMap = new HashMap<>();
1478+
dataAsMap.put("tool_calls", List.of(Map.of("name", "test_tool", "input", "test_input", "id", "test_id")));
1479+
dataAsMap.put("response", "test response");
1480+
dataAsMap.put("finish_reason", "tool_use");
1481+
1482+
ModelTensorOutput modelTensorOutput = ModelTensorOutput
1483+
.builder()
1484+
.mlModelOutputs(
1485+
List
1486+
.of(
1487+
ModelTensors
1488+
.builder()
1489+
.mlModelTensors(List.of(ModelTensor.builder().name("response").dataAsMap(dataAsMap).build()))
1490+
.build()
1491+
)
1492+
)
1493+
.build();
1494+
1495+
Map<String, String> output = AgentUtils.parseLLMOutput(parameters, modelTensorOutput, null, Set.of("test_tool"), new ArrayList<>());
1496+
1497+
Assert.assertEquals("test_tool", output.get(ACTION));
1498+
Assert.assertEquals("test_input", output.get(ACTION_INPUT));
1499+
Assert.assertEquals("test_id", output.get(TOOL_CALL_ID));
1500+
}
1501+
1502+
@Test
1503+
public void testRemoveJsonPath_WithStringPaths() {
1504+
Map<String, Object> json = new HashMap<>();
1505+
json.put("field1", "value1");
1506+
json.put("field2", "value2");
1507+
json.put("nested", Map.of("field3", "value3"));
1508+
String excludePaths = "[\"$.field1\", \"$.nested.field3\"]";
1509+
Map<String, ?> result = AgentUtils.removeJsonPath(json, excludePaths, false);
1510+
Assert.assertFalse(result.containsKey("field1"));
1511+
Assert.assertTrue(result.containsKey("field2"));
1512+
Assert.assertTrue(result.containsKey("nested"));
1513+
Assert.assertFalse(((Map<?, ?>) result.get("nested")).containsKey("field3"));
1514+
}
1515+
1516+
@Test
1517+
public void testRemoveJsonPath_WithListPaths() {
1518+
Map<String, Object> json = new HashMap<>();
1519+
json.put("field1", "value1");
1520+
json.put("field2", "value2");
1521+
json.put("nested", Map.of("field3", "value3"));
1522+
List<String> excludePaths = java.util.Arrays.asList("$.field1", "$.nested.field3");
1523+
Map<String, ?> result = AgentUtils.removeJsonPath(json, excludePaths, false);
1524+
Assert.assertFalse(result.containsKey("field1"));
1525+
Assert.assertTrue(result.containsKey("field2"));
1526+
Assert.assertTrue(result.containsKey("nested"));
1527+
Assert.assertFalse(((Map<?, ?>) result.get("nested")).containsKey("field3"));
1528+
}
1529+
1530+
@Test
1531+
public void testRemoveJsonPath_InPlace() {
1532+
Map<String, Object> json = new HashMap<>();
1533+
json.put("field1", "value1");
1534+
json.put("field2", "value2");
1535+
json.put("nested", new HashMap<>(Map.of("field3", "value3")));
1536+
List<String> excludePaths = java.util.Arrays.asList("$.field1", "$.nested.field3");
1537+
Map<String, ?> result = AgentUtils.removeJsonPath(json, excludePaths, true);
1538+
Assert.assertFalse(json.containsKey("field1"));
1539+
Assert.assertTrue(json.containsKey("field2"));
1540+
Assert.assertTrue(json.containsKey("nested"));
1541+
Assert.assertFalse(((Map<?, ?>) json.get("nested")).containsKey("field3"));
1542+
Assert.assertSame(json, result);
1543+
}
1544+
1545+
@Test
1546+
public void testRemoveJsonPath_WithInvalidJsonPaths() {
1547+
Map<String, Object> json = new HashMap<>();
1548+
json.put("field1", "value1");
1549+
String invalidJsonPaths = "invalid json";
1550+
Assert.assertThrows(JsonSyntaxException.class, () -> AgentUtils.removeJsonPath(json, invalidJsonPaths, false));
1551+
}
1552+
1553+
@Test
1554+
public void testSubstitute() {
1555+
String template = "Hello ${parameters.name}! Welcome to ${parameters.place}.";
1556+
Map<String, String> params = new HashMap<>();
1557+
params.put("name", "AI");
1558+
params.put("place", "OpenSearch");
1559+
String prefix = "${parameters.";
1560+
1561+
String result = AgentUtils.substitute(template, params, prefix);
1562+
1563+
Assert.assertEquals("Hello AI! Welcome to OpenSearch.", result);
1564+
}
1565+
1566+
@Test
1567+
public void testCreateTool_Success() {
1568+
Map<String, Tool.Factory> toolFactories = new HashMap<>();
1569+
Tool.Factory factory = mock(Tool.Factory.class);
1570+
Tool mockTool = mock(Tool.class);
1571+
when(factory.create(any())).thenReturn(mockTool);
1572+
toolFactories.put("test_tool", factory);
1573+
1574+
MLToolSpec toolSpec = MLToolSpec
1575+
.builder()
1576+
.type("test_tool")
1577+
.name("TestTool")
1578+
.description("Original description")
1579+
.parameters(Map.of("param1", "value1"))
1580+
.runtimeResources(Map.of("resource1", "value2"))
1581+
.build();
1582+
1583+
Map<String, String> params = new HashMap<>();
1584+
params.put("TestTool.param2", "value3");
1585+
params.put("TestTool.description", "Custom description");
1586+
1587+
AgentUtils.createTool(toolFactories, params, toolSpec, "test_tenant");
1588+
1589+
verify(factory).create(argThat(toolParamsMap -> {
1590+
Map<String, Object> toolParams = (Map<String, Object>) toolParamsMap;
1591+
return toolParams.get("param1").equals("value1")
1592+
&& toolParams.get("param2").equals("value3")
1593+
&& toolParams.get("resource1").equals("value2")
1594+
&& toolParams.get(TENANT_ID_FIELD).equals("test_tenant");
1595+
}));
1596+
1597+
verify(mockTool).setName("TestTool");
1598+
verify(mockTool).setDescription("Custom description");
1599+
}
1600+
1601+
@Test
1602+
public void testCreateTool_ToolNotFound() {
1603+
Map<String, Tool.Factory> toolFactories = new HashMap<>();
1604+
MLToolSpec toolSpec = MLToolSpec.builder().type("non_existent_tool").name("TestTool").build();
1605+
1606+
assertThrows(IllegalArgumentException.class, () -> AgentUtils.createTool(toolFactories, new HashMap<>(), toolSpec, "test_tenant"));
1607+
}
13651608
}

0 commit comments

Comments
 (0)