|
9 | 9 | import static org.junit.Assert.assertThrows;
|
10 | 10 | import static org.mockito.ArgumentMatchers.any;
|
11 | 11 | import static org.mockito.ArgumentMatchers.anyString;
|
| 12 | +import static org.mockito.ArgumentMatchers.argThat; |
12 | 13 | import static org.mockito.Mockito.doNothing;
|
13 | 14 | import static org.mockito.Mockito.mock;
|
14 | 15 | import static org.mockito.Mockito.verify;
|
15 | 16 | import static org.mockito.Mockito.when;
|
16 | 17 | import static org.opensearch.ml.common.CommonValue.MCP_CONNECTORS_FIELD;
|
17 | 18 | import static org.opensearch.ml.common.CommonValue.MCP_CONNECTOR_ID_FIELD;
|
| 19 | +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; |
18 | 20 | import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_FINISH_REASON_PATH;
|
19 | 21 | import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_FINISH_REASON_TOOL_USE;
|
20 | 22 | import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_GEN_INPUT;
|
| 23 | +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_EXCLUDE_PATH; |
21 | 24 | import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER;
|
22 | 25 | import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX;
|
23 | 26 | import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX;
|
| 27 | +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOLS; |
24 | 28 | import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALLS_PATH;
|
25 | 29 | import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALLS_TOOL_INPUT;
|
26 | 30 | import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALLS_TOOL_NAME;
|
27 | 31 | import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALL_ID;
|
28 | 32 | import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALL_ID_PATH;
|
29 | 33 | import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_FILTERS_FIELD;
|
| 34 | +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_TEMPLATE; |
30 | 35 | import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION;
|
31 | 36 | import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION_INPUT;
|
32 | 37 | import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CHAT_HISTORY;
|
|
81 | 86 | import org.opensearch.threadpool.ThreadPool;
|
82 | 87 | import org.opensearch.transport.client.Client;
|
83 | 88 |
|
| 89 | +import com.google.gson.JsonSyntaxException; |
| 90 | + |
84 | 91 | public class AgentUtilsTest extends MLStaticMockBase {
|
85 | 92 |
|
86 | 93 | @Mock
|
@@ -1197,6 +1204,48 @@ public void testParseLLMOutputWithDeepseekFormat() {
|
1197 | 1204 | Assert.assertTrue(output3.get(FINAL_ANSWER).contains("This is a test response"));
|
1198 | 1205 | }
|
1199 | 1206 |
|
| 1207 | + @Test |
| 1208 | + public void testAddToolsToFunctionCalling() { |
| 1209 | + Map<String, Tool> tools = new HashMap<>(); |
| 1210 | + tools.put("Tool1", tool1); |
| 1211 | + tools.put("Tool2", tool2); |
| 1212 | + |
| 1213 | + when(tool1.getName()).thenReturn("Tool1"); |
| 1214 | + when(tool1.getDescription()).thenReturn("Description of Tool1"); |
| 1215 | + when(tool1.getAttributes()).thenReturn(Map.of("param1", "value1")); |
| 1216 | + |
| 1217 | + when(tool2.getName()).thenReturn("Tool2"); |
| 1218 | + when(tool2.getDescription()).thenReturn("Description of Tool2"); |
| 1219 | + when(tool2.getAttributes()).thenReturn(Map.of("param2", "value2")); |
| 1220 | + |
| 1221 | + Map<String, String> parameters = new HashMap<>(); |
| 1222 | + String toolTemplate = "{\"name\": \"${tool.name}\", \"description\": \"${tool.description}\"}"; |
| 1223 | + parameters.put(TOOL_TEMPLATE, toolTemplate); |
| 1224 | + |
| 1225 | + List<String> inputTools = Arrays.asList("Tool1", "Tool2"); |
| 1226 | + String prompt = "test prompt"; |
| 1227 | + |
| 1228 | + String expectedTool1 = "{\"name\": \"Tool1\", \"description\": \"Description of Tool1\"}"; |
| 1229 | + String expectedTool2 = "{\"name\": \"Tool2\", \"description\": \"Description of Tool2\"}"; |
| 1230 | + String expectedTools = expectedTool1 + ", " + expectedTool2; |
| 1231 | + |
| 1232 | + AgentUtils.addToolsToFunctionCalling(tools, parameters, inputTools, prompt); |
| 1233 | + |
| 1234 | + assertEquals(expectedTools, parameters.get(TOOLS)); |
| 1235 | + } |
| 1236 | + |
| 1237 | + @Test |
| 1238 | + public void testAddToolsToFunctionCalling_ToolNotRegistered() { |
| 1239 | + Map<String, Tool> tools = new HashMap<>(); |
| 1240 | + tools.put("Tool1", tool1); |
| 1241 | + Map<String, String> parameters = new HashMap<>(); |
| 1242 | + parameters.put(TOOL_TEMPLATE, "template"); |
| 1243 | + List<String> inputTools = Arrays.asList("Tool1", "UnregisteredTool"); |
| 1244 | + String prompt = "test prompt"; |
| 1245 | + |
| 1246 | + assertThrows(IllegalArgumentException.class, () -> AgentUtils.addToolsToFunctionCalling(tools, parameters, inputTools, prompt)); |
| 1247 | + } |
| 1248 | + |
1200 | 1249 | private static MLToolSpec buildTool(String name) {
|
1201 | 1250 | return MLToolSpec.builder().type(McpSseTool.TYPE).name(name).description("mock").build();
|
1202 | 1251 | }
|
@@ -1362,4 +1411,198 @@ private void verifyConstructToolParams(String question, String actionInput, Cons
|
1362 | 1411 | Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput);
|
1363 | 1412 | verify.accept(toolParams);
|
1364 | 1413 | }
|
| 1414 | + |
| 1415 | + @Test |
| 1416 | + public void testParseLLMOutput_WithExcludePath() { |
| 1417 | + Map<String, String> parameters = new HashMap<>(); |
| 1418 | + parameters.put(LLM_RESPONSE_EXCLUDE_PATH, "[\"$.exclude_field\"]"); |
| 1419 | + |
| 1420 | + Map<String, Object> dataAsMap = new HashMap<>(); |
| 1421 | + dataAsMap.put("exclude_field", "should be excluded"); |
| 1422 | + dataAsMap.put("keep_field", "should be kept"); |
| 1423 | + |
| 1424 | + ModelTensorOutput modelTensorOutput = ModelTensorOutput |
| 1425 | + .builder() |
| 1426 | + .mlModelOutputs( |
| 1427 | + List |
| 1428 | + .of( |
| 1429 | + ModelTensors |
| 1430 | + .builder() |
| 1431 | + .mlModelTensors(List.of(ModelTensor.builder().name("response").dataAsMap(dataAsMap).build())) |
| 1432 | + .build() |
| 1433 | + ) |
| 1434 | + ) |
| 1435 | + .build(); |
| 1436 | + |
| 1437 | + Map<String, String> output = AgentUtils.parseLLMOutput(parameters, modelTensorOutput, null, Set.of(), new ArrayList<>()); |
| 1438 | + |
| 1439 | + Assert.assertTrue(output.containsKey(THOUGHT_RESPONSE)); |
| 1440 | + Assert.assertFalse(output.get(THOUGHT_RESPONSE).contains("exclude_field")); |
| 1441 | + Assert.assertTrue(output.get(THOUGHT_RESPONSE).contains("keep_field")); |
| 1442 | + } |
| 1443 | + |
| 1444 | + @Test |
| 1445 | + public void testParseLLMOutput_EmptyDataAsMap() { |
| 1446 | + Map<String, Object> dataAsMap = new HashMap<>(); |
| 1447 | + ModelTensorOutput modelTensorOutput = ModelTensorOutput |
| 1448 | + .builder() |
| 1449 | + .mlModelOutputs( |
| 1450 | + List |
| 1451 | + .of( |
| 1452 | + ModelTensors |
| 1453 | + .builder() |
| 1454 | + .mlModelTensors(List.of(ModelTensor.builder().name("response").dataAsMap(dataAsMap).build())) |
| 1455 | + .build() |
| 1456 | + ) |
| 1457 | + ) |
| 1458 | + .build(); |
| 1459 | + |
| 1460 | + Map<String, String> output = AgentUtils.parseLLMOutput(new HashMap<>(), modelTensorOutput, null, Set.of(), new ArrayList<>()); |
| 1461 | + |
| 1462 | + Assert.assertTrue(output.containsKey(THOUGHT_RESPONSE)); |
| 1463 | + Assert.assertEquals("{}", output.get(THOUGHT_RESPONSE)); |
| 1464 | + } |
| 1465 | + |
| 1466 | + @Test |
| 1467 | + public void testParseLLMOutput_ToolUse() { |
| 1468 | + Map<String, String> parameters = new HashMap<>(); |
| 1469 | + parameters.put(TOOL_CALLS_PATH, "$.tool_calls"); |
| 1470 | + parameters.put(TOOL_CALLS_TOOL_NAME, "name"); |
| 1471 | + parameters.put(TOOL_CALLS_TOOL_INPUT, "input"); |
| 1472 | + parameters.put(TOOL_CALL_ID_PATH, "id"); |
| 1473 | + parameters.put(LLM_RESPONSE_FILTER, "$.response"); |
| 1474 | + parameters.put(LLM_FINISH_REASON_PATH, "$.finish_reason"); |
| 1475 | + parameters.put(LLM_FINISH_REASON_TOOL_USE, "tool_use"); |
| 1476 | + |
| 1477 | + Map<String, Object> dataAsMap = new HashMap<>(); |
| 1478 | + dataAsMap.put("tool_calls", List.of(Map.of("name", "test_tool", "input", "test_input", "id", "test_id"))); |
| 1479 | + dataAsMap.put("response", "test response"); |
| 1480 | + dataAsMap.put("finish_reason", "tool_use"); |
| 1481 | + |
| 1482 | + ModelTensorOutput modelTensorOutput = ModelTensorOutput |
| 1483 | + .builder() |
| 1484 | + .mlModelOutputs( |
| 1485 | + List |
| 1486 | + .of( |
| 1487 | + ModelTensors |
| 1488 | + .builder() |
| 1489 | + .mlModelTensors(List.of(ModelTensor.builder().name("response").dataAsMap(dataAsMap).build())) |
| 1490 | + .build() |
| 1491 | + ) |
| 1492 | + ) |
| 1493 | + .build(); |
| 1494 | + |
| 1495 | + Map<String, String> output = AgentUtils.parseLLMOutput(parameters, modelTensorOutput, null, Set.of("test_tool"), new ArrayList<>()); |
| 1496 | + |
| 1497 | + Assert.assertEquals("test_tool", output.get(ACTION)); |
| 1498 | + Assert.assertEquals("test_input", output.get(ACTION_INPUT)); |
| 1499 | + Assert.assertEquals("test_id", output.get(TOOL_CALL_ID)); |
| 1500 | + } |
| 1501 | + |
| 1502 | + @Test |
| 1503 | + public void testRemoveJsonPath_WithStringPaths() { |
| 1504 | + Map<String, Object> json = new HashMap<>(); |
| 1505 | + json.put("field1", "value1"); |
| 1506 | + json.put("field2", "value2"); |
| 1507 | + json.put("nested", Map.of("field3", "value3")); |
| 1508 | + String excludePaths = "[\"$.field1\", \"$.nested.field3\"]"; |
| 1509 | + Map<String, ?> result = AgentUtils.removeJsonPath(json, excludePaths, false); |
| 1510 | + Assert.assertFalse(result.containsKey("field1")); |
| 1511 | + Assert.assertTrue(result.containsKey("field2")); |
| 1512 | + Assert.assertTrue(result.containsKey("nested")); |
| 1513 | + Assert.assertFalse(((Map<?, ?>) result.get("nested")).containsKey("field3")); |
| 1514 | + } |
| 1515 | + |
| 1516 | + @Test |
| 1517 | + public void testRemoveJsonPath_WithListPaths() { |
| 1518 | + Map<String, Object> json = new HashMap<>(); |
| 1519 | + json.put("field1", "value1"); |
| 1520 | + json.put("field2", "value2"); |
| 1521 | + json.put("nested", Map.of("field3", "value3")); |
| 1522 | + List<String> excludePaths = java.util.Arrays.asList("$.field1", "$.nested.field3"); |
| 1523 | + Map<String, ?> result = AgentUtils.removeJsonPath(json, excludePaths, false); |
| 1524 | + Assert.assertFalse(result.containsKey("field1")); |
| 1525 | + Assert.assertTrue(result.containsKey("field2")); |
| 1526 | + Assert.assertTrue(result.containsKey("nested")); |
| 1527 | + Assert.assertFalse(((Map<?, ?>) result.get("nested")).containsKey("field3")); |
| 1528 | + } |
| 1529 | + |
| 1530 | + @Test |
| 1531 | + public void testRemoveJsonPath_InPlace() { |
| 1532 | + Map<String, Object> json = new HashMap<>(); |
| 1533 | + json.put("field1", "value1"); |
| 1534 | + json.put("field2", "value2"); |
| 1535 | + json.put("nested", new HashMap<>(Map.of("field3", "value3"))); |
| 1536 | + List<String> excludePaths = java.util.Arrays.asList("$.field1", "$.nested.field3"); |
| 1537 | + Map<String, ?> result = AgentUtils.removeJsonPath(json, excludePaths, true); |
| 1538 | + Assert.assertFalse(json.containsKey("field1")); |
| 1539 | + Assert.assertTrue(json.containsKey("field2")); |
| 1540 | + Assert.assertTrue(json.containsKey("nested")); |
| 1541 | + Assert.assertFalse(((Map<?, ?>) json.get("nested")).containsKey("field3")); |
| 1542 | + Assert.assertSame(json, result); |
| 1543 | + } |
| 1544 | + |
| 1545 | + @Test |
| 1546 | + public void testRemoveJsonPath_WithInvalidJsonPaths() { |
| 1547 | + Map<String, Object> json = new HashMap<>(); |
| 1548 | + json.put("field1", "value1"); |
| 1549 | + String invalidJsonPaths = "invalid json"; |
| 1550 | + Assert.assertThrows(JsonSyntaxException.class, () -> AgentUtils.removeJsonPath(json, invalidJsonPaths, false)); |
| 1551 | + } |
| 1552 | + |
| 1553 | + @Test |
| 1554 | + public void testSubstitute() { |
| 1555 | + String template = "Hello ${parameters.name}! Welcome to ${parameters.place}."; |
| 1556 | + Map<String, String> params = new HashMap<>(); |
| 1557 | + params.put("name", "AI"); |
| 1558 | + params.put("place", "OpenSearch"); |
| 1559 | + String prefix = "${parameters."; |
| 1560 | + |
| 1561 | + String result = AgentUtils.substitute(template, params, prefix); |
| 1562 | + |
| 1563 | + Assert.assertEquals("Hello AI! Welcome to OpenSearch.", result); |
| 1564 | + } |
| 1565 | + |
| 1566 | + @Test |
| 1567 | + public void testCreateTool_Success() { |
| 1568 | + Map<String, Tool.Factory> toolFactories = new HashMap<>(); |
| 1569 | + Tool.Factory factory = mock(Tool.Factory.class); |
| 1570 | + Tool mockTool = mock(Tool.class); |
| 1571 | + when(factory.create(any())).thenReturn(mockTool); |
| 1572 | + toolFactories.put("test_tool", factory); |
| 1573 | + |
| 1574 | + MLToolSpec toolSpec = MLToolSpec |
| 1575 | + .builder() |
| 1576 | + .type("test_tool") |
| 1577 | + .name("TestTool") |
| 1578 | + .description("Original description") |
| 1579 | + .parameters(Map.of("param1", "value1")) |
| 1580 | + .runtimeResources(Map.of("resource1", "value2")) |
| 1581 | + .build(); |
| 1582 | + |
| 1583 | + Map<String, String> params = new HashMap<>(); |
| 1584 | + params.put("TestTool.param2", "value3"); |
| 1585 | + params.put("TestTool.description", "Custom description"); |
| 1586 | + |
| 1587 | + AgentUtils.createTool(toolFactories, params, toolSpec, "test_tenant"); |
| 1588 | + |
| 1589 | + verify(factory).create(argThat(toolParamsMap -> { |
| 1590 | + Map<String, Object> toolParams = (Map<String, Object>) toolParamsMap; |
| 1591 | + return toolParams.get("param1").equals("value1") |
| 1592 | + && toolParams.get("param2").equals("value3") |
| 1593 | + && toolParams.get("resource1").equals("value2") |
| 1594 | + && toolParams.get(TENANT_ID_FIELD).equals("test_tenant"); |
| 1595 | + })); |
| 1596 | + |
| 1597 | + verify(mockTool).setName("TestTool"); |
| 1598 | + verify(mockTool).setDescription("Custom description"); |
| 1599 | + } |
| 1600 | + |
| 1601 | + @Test |
| 1602 | + public void testCreateTool_ToolNotFound() { |
| 1603 | + Map<String, Tool.Factory> toolFactories = new HashMap<>(); |
| 1604 | + MLToolSpec toolSpec = MLToolSpec.builder().type("non_existent_tool").name("TestTool").build(); |
| 1605 | + |
| 1606 | + assertThrows(IllegalArgumentException.class, () -> AgentUtils.createTool(toolFactories, new HashMap<>(), toolSpec, "test_tenant")); |
| 1607 | + } |
1365 | 1608 | }
|
0 commit comments