From 092a8b375ba326a06fd65e12e1bdd35ebabd5d5c Mon Sep 17 00:00:00 2001 From: rithin-pullela-aws Date: Mon, 28 Apr 2025 11:46:03 -0700 Subject: [PATCH 1/4] Add tests for MCP feature Signed-off-by: rithin-pullela-aws --- .../ml/common/connector/McpConnectorTest.java | 267 ++++++++++++++++++ .../algorithms/agent/AgentUtilsTest.java | 27 ++ .../remote/McpConnectorExecutorTest.java | 92 ++++++ .../ml/engine/tools/McpSseToolTests.java | 114 ++++++++ 4 files changed, 500 insertions(+) create mode 100644 common/src/test/java/org/opensearch/ml/common/connector/McpConnectorTest.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutorTest.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/McpSseToolTests.java diff --git a/common/src/test/java/org/opensearch/ml/common/connector/McpConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/McpConnectorTest.java new file mode 100644 index 0000000000..103ac0f7c5 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/connector/McpConnectorTest.java @@ -0,0 +1,267 @@ +package org.opensearch.ml.common.connector; + +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; +import static org.opensearch.ml.common.connector.ConnectorProtocols.MCP_SSE; +import static org.opensearch.ml.common.connector.RetryBackoffPolicy.CONSTANT; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.function.BiFunction; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; +import org.opensearch.search.SearchModule; + +public class McpConnectorTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + BiFunction encryptFunction; + BiFunction decryptFunction; + + String TEST_CONNECTOR_JSON_STRING = + "{\"name\":\"test_mcp_connector_name\",\"version\":\"1\",\"description\":\"this is a test mcp connector\",\"protocol\":\"mcp_sse\",\"credential\":{\"key\":\"test_key_value\"},\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\",\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000,\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"},\"url\":\"https://test.com\",\"headers\":{\"api_key\":\"${credential.key}\"}}"; + + @Before + public void setUp() { + encryptFunction = (s, v) -> "encrypted: " + s.toLowerCase(Locale.ROOT); + decryptFunction = (s, v) -> "decrypted: " + s.toUpperCase(Locale.ROOT); + } + + @Test + public void constructor_InvalidProtocol() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Unsupported connector protocol. Please use one of [aws_sigv4, http, mcp_sse]"); + + McpConnector.builder().protocol("wrong protocol").build(); + } + + @Test + public void writeTo() throws IOException { + McpConnector connector = createMcpConnector(); + + BytesStreamOutput output = new BytesStreamOutput(); + connector.writeTo(output); + + McpConnector connector2 = new McpConnector(output.bytes().streamInput()); + Assert.assertEquals(connector, connector2); + } + + @Test + public void toXContent() throws IOException { + McpConnector connector = createMcpConnector(); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + connector.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + Assert.assertEquals(TEST_CONNECTOR_JSON_STRING, content); + } + + @Test + public void constructor_Parser() throws IOException { + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + TEST_CONNECTOR_JSON_STRING + ); + parser.nextToken(); + + McpConnector connector = new McpConnector("mcp_sse", parser); + Assert.assertEquals("test_mcp_connector_name", connector.getName()); + Assert.assertEquals("1", connector.getVersion()); + Assert.assertEquals("this is a test mcp connector", connector.getDescription()); + Assert.assertEquals("mcp_sse", connector.getProtocol()); + Assert.assertEquals(AccessMode.PUBLIC, connector.getAccess()); + Assert.assertEquals("https://test.com", connector.getUrl()); + connector.decrypt(PREDICT.name(), decryptFunction, null); + Map decryptedCredential = connector.getDecryptedCredential(); + Assert.assertEquals(1, decryptedCredential.size()); + Assert.assertEquals("decrypted: TEST_KEY_VALUE", decryptedCredential.get("key")); + Assert.assertNotNull(connector.getDecryptedHeaders()); + Assert.assertEquals(1, connector.getDecryptedHeaders().size()); + Assert.assertEquals("decrypted: TEST_KEY_VALUE", connector.getDecryptedHeaders().get("api_key")); + } + + @Test + public void cloneConnector() { + McpConnector connector = createMcpConnector(); + Connector connector2 = connector.cloneConnector(); + Assert.assertEquals(connector, connector2); + } + + @Test + public void decrypt() { + McpConnector connector = createMcpConnector(); + connector.decrypt("", decryptFunction, null); + Map decryptedCredential = connector.getDecryptedCredential(); + Assert.assertEquals(1, decryptedCredential.size()); + Assert.assertEquals("decrypted: TEST_KEY_VALUE", decryptedCredential.get("key")); + Assert.assertNotNull(connector.getDecryptedHeaders()); + Assert.assertEquals(1, connector.getDecryptedHeaders().size()); + Assert.assertEquals("decrypted: TEST_KEY_VALUE", connector.getDecryptedHeaders().get("api_key")); + + connector.removeCredential(); + Assert.assertNull(connector.getCredential()); + Assert.assertNull(connector.getDecryptedCredential()); + Assert.assertNull(connector.getDecryptedHeaders()); + } + + @Test + public void encrypt() { + McpConnector connector = createMcpConnector(); + connector.encrypt(encryptFunction, null); + Map credential = connector.getCredential(); + Assert.assertEquals(1, credential.size()); + Assert.assertEquals("encrypted: test_key_value", credential.get("key")); + + connector.removeCredential(); + Assert.assertNull(connector.getCredential()); + Assert.assertNull(connector.getDecryptedCredential()); + Assert.assertNull(connector.getDecryptedHeaders()); + } + + @Test + public void validateConnectorURL_Invalid() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Connector URL is not matching the trusted connector endpoint regex"); + McpConnector connector = createMcpConnector(); + connector + .validateConnectorURL( + Arrays + .asList( + "^https://runtime\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$", + "^https://api\\.openai\\.com/.*$", + "^https://api\\.cohere\\.ai/.*$", + "^https://bedrock-agent-runtime\\\\..*[a-z0-9-]\\\\.amazonaws\\\\.com/.*$" + ) + ); + } + + @Test + public void validateConnectorURL() { + McpConnector connector = createMcpConnector(); + connector + .validateConnectorURL( + Arrays + .asList( + "^https://runtime\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$", + "^https://api\\.openai\\.com/.*$", + "^https://bedrock-agent-runtime\\\\..*[a-z0-9-]\\\\.amazonaws\\\\.com/.*$", + "^" + connector.getUrl() + ) + ); + } + + @Test + public void testUpdate() { + McpConnector connector = createMcpConnector(); + Map initialCredential = new HashMap<>(connector.getCredential()); + + // Create update content + String updatedName = "updated_name"; + String updatedDescription = "updated description"; + String updatedVersion = "2"; + Map updatedCredential = new HashMap<>(); + updatedCredential.put("new_key", "new_value"); + List updatedBackendRoles = List.of("role3", "role4"); + AccessMode updatedAccessMode = AccessMode.PRIVATE; + ConnectorClientConfig updatedClientConfig = new ConnectorClientConfig(40, 40000, 40000, 20, 20, 5, CONSTANT); + String updatedUrl = "https://updated.test.com"; + Map updatedHeaders = new HashMap<>(); + updatedHeaders.put("new_header", "new_header_value"); + updatedHeaders.put("updated_api_key", "${credential.new_key}"); // Referencing new credential key + + MLCreateConnectorInput updateInput = MLCreateConnectorInput.builder() + .name(updatedName) + .description(updatedDescription) + .version(updatedVersion) + .credential(updatedCredential) + .backendRoles(updatedBackendRoles) + .access(updatedAccessMode) + .connectorClientConfig(updatedClientConfig) + .url(updatedUrl) + .headers(updatedHeaders) + .protocol(MCP_SSE) + .build(); + + // Call the update method + connector.update(updateInput, encryptFunction); + + // Assertions + Assert.assertEquals(updatedName, connector.getName()); + Assert.assertEquals(updatedDescription, connector.getDescription()); + Assert.assertEquals(updatedVersion, connector.getVersion()); + Assert.assertEquals(MCP_SSE, connector.getProtocol()); // Should not change if not provided + Assert.assertEquals(updatedBackendRoles, connector.getBackendRoles()); + Assert.assertEquals(updatedAccessMode, connector.getAccess()); + Assert.assertEquals(updatedClientConfig, connector.getConnectorClientConfig()); + Assert.assertEquals(updatedUrl, connector.getUrl()); + Assert.assertEquals(updatedHeaders, connector.getHeaders()); + + // Check encrypted credentials + Map currentCredential = connector.getCredential(); + Assert.assertNotNull(currentCredential); + Assert.assertEquals(1, currentCredential.size()); // Should replace old credentials + Assert.assertEquals("encrypted: new_value", currentCredential.get("new_key")); + Assert.assertNotEquals(initialCredential, currentCredential); + + // Check decrypted credentials and headers (need to explicitly decrypt after update) + connector.decrypt("", decryptFunction, null); // Use decrypt function from setUp + Map decryptedCredential = connector.getDecryptedCredential(); + Assert.assertNotNull(decryptedCredential); + Assert.assertEquals(1, decryptedCredential.size()); + Assert.assertEquals("decrypted: ENCRYPTED: NEW_VALUE", decryptedCredential.get("new_key")); // Uses the decrypt function logic + + Map decryptedHeaders = connector.getDecryptedHeaders(); + Assert.assertNotNull(decryptedHeaders); + Assert.assertEquals(2, decryptedHeaders.size()); + Assert.assertEquals("new_header_value", decryptedHeaders.get("new_header")); + Assert.assertEquals("decrypted: ENCRYPTED: NEW_VALUE", decryptedHeaders.get("updated_api_key")); // Check header substitution + } + + public static McpConnector createMcpConnector() { + Map credential = new HashMap<>(); + credential.put("key", "test_key_value"); + + Map headers = new HashMap<>(); + headers.put("api_key", "${credential.key}"); + + ConnectorClientConfig clientConfig = new ConnectorClientConfig(30, 30000, 30000, 10, 10, -1, RetryBackoffPolicy.CONSTANT); + + return McpConnector + .builder() + .name("test_mcp_connector_name") + .version("1") + .description("this is a test mcp connector") + .protocol(MCP_SSE) + .credential(credential) + .backendRoles(List.of("role1", "role2")) + .accessMode(AccessMode.PUBLIC) + .connectorClientConfig(clientConfig) + .url("https://test.com") + .headers(headers) + .build(); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java index e56cb71559..2b8487d1ec 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java @@ -7,6 +7,8 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_FINISH_REASON_PATH; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_FINISH_REASON_TOOL_USE; @@ -44,17 +46,32 @@ import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.engine.encryptor.Encryptor; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.transport.client.Client; + public class AgentUtilsTest { @Mock private Tool tool1, tool2; + @Mock + private MLAgent mlAgent; + @Mock + private Client client; + @Mock + private SdkClient sdkClient; + @Mock + private Encryptor encryptor; + private Map> llmResponseExpectedParseResults; private String responseForAction = "---------------------\n{\n " @@ -1152,6 +1169,16 @@ public void testParseLLMOutputWithDeepseekFormat() { Assert.assertTrue(output3.get(FINAL_ANSWER).contains("This is a test response")); } + @Test + public void testGetMcpToolSpecs_NoMcpJsonConfig() { + when(mlAgent.getParameters()).thenReturn(null); + + ActionListener> listener = mock(ActionListener.class); + AgentUtils.getMcpToolSpecs(mlAgent, client, sdkClient, encryptor, listener); + + verify(listener).onResponse(Collections.emptyList()); + } + private void verifyConstructToolParams(String question, String actionInput, Consumer> verify) { Map tools = Map.of("tool1", tool1); Map toolSpecMap = Map diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutorTest.java new file mode 100644 index 0000000000..66808cb9bd --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutorTest.java @@ -0,0 +1,92 @@ +package org.opensearch.ml.engine.algorithms.remote; + +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.List; +import java.util.Map; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.MockitoAnnotations; +import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.connector.McpConnector; +import org.opensearch.ml.engine.MLStaticMockBase; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; + +public class McpConnectorExecutorTest extends MLStaticMockBase { + + @Mock + private McpConnector mockConnector; + @Mock + private McpSyncClient mcpClient; + @Mock + private McpClient.SyncSpec builder; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + Map decryptedHeaders = Map.of("Authorization", "Bearer secret-token"); + + when(mockConnector.getUrl()).thenReturn("http://random-url"); + when(mockConnector.getDecryptedHeaders()).thenReturn(decryptedHeaders); + + /* ---------- stub the fluent builder chain ------------------------ */ + when(builder.requestTimeout(any())).thenReturn(builder); + when(builder.capabilities(any())).thenReturn(builder); + when(builder.build()).thenReturn(mcpClient); + } + + @Test + public void getMcpToolSpecs_returnsExpectedSpecs() { + + String inputSchemaJSON = + "{\"type\":\"object\",\"properties\":{\"state\":{\"title\":\"State\",\"type\":\"string\"}},\"required\":[\"state\"],\"additionalProperties\":false}"; + + McpSchema.Tool tool = new McpSchema.Tool("tool1", "desc1", inputSchemaJSON); + McpSchema.ListToolsResult mockTools = new McpSchema.ListToolsResult(List.of(tool), null); + + when(mcpClient.listTools()).thenReturn(mockTools); + when(mcpClient.initialize()).thenReturn(null); + + try (MockedStatic mocked = mockStatic(McpClient.class)) { + mocked.when(() -> McpClient.sync(any(McpClientTransport.class))).thenReturn(builder); + McpConnectorExecutor exec = new McpConnectorExecutor(mockConnector); + List specs = exec.getMcpToolSpecs(); + + Assert.assertEquals(1, specs.size()); + MLToolSpec spec = specs.get(0); + Assert.assertEquals("tool1", spec.getName()); + Assert.assertEquals("desc1", spec.getDescription()); + Assert.assertEquals(inputSchemaJSON, spec.getAttributes().get("input_schema")); + Assert.assertSame(mcpClient, spec.getRuntimeResources().get("mcp_sync_client")); + mocked.verify(() -> McpClient.sync(any(McpClientTransport.class))); + verify(builder, times(1)).build(); + verify(mcpClient, times(1)).initialize(); + verify(mcpClient, times(1)).listTools(); + } + } + + @Test + public void getMcpToolSpecs_throwsOnInitError() { + + when(mcpClient.initialize()).thenThrow(new RuntimeException("Error initializing")); + try (MockedStatic mocked = mockStatic(McpClient.class)) { + mocked.when(() -> McpClient.sync(any(McpClientTransport.class))).thenReturn(builder); + McpConnectorExecutor exec = new McpConnectorExecutor(mockConnector); + + assertThrows(RuntimeException.class, () -> exec.getMcpToolSpecs()); + } + } + +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/McpSseToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/McpSseToolTests.java new file mode 100644 index 0000000000..b1b76615e6 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/McpSseToolTests.java @@ -0,0 +1,114 @@ +package org.opensearch.ml.engine.tools; + +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.MCP_SYNC_CLIENT; + +import java.util.Collections; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.spi.tools.Tool; + +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.spec.McpSchema; + +public class McpSseToolTests { + + @Mock + private McpSyncClient mcpSyncClient; + + @Mock + private ActionListener listener; + + private Tool tool; + private Map validParams; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + // Initialize the tool with the mocked client + tool = McpSseTool.Factory.getInstance().create( + Map.of(MCP_SYNC_CLIENT, mcpSyncClient) + ); + validParams = Map.of("input", "{\"foo\":\"bar\"}"); + } + + @Test + public void testRunSuccess() { + // Arrange: create a CallToolResult wrapping a JSON string + McpSchema.CallToolResult result = new McpSchema.CallToolResult("{\"foo\":\"bar\"}", false); + when(mcpSyncClient.callTool(any(McpSchema.CallToolRequest.class))) + .thenReturn(result); + + // Act + tool.run(validParams, listener); + + // Assert: ensure onResponse is called with the JSON string + verify(listener).onResponse( + "[{\"text\":\"{\\\"foo\\\":\\\"bar\\\"}\"}]" + ); + verify(listener, never()).onFailure(any()); + } + + @Test + public void testRunInvalidJsonInput() { + // Passing a non-JSON string should trigger failure in parsing + Map badParams = Map.of("input", "not-json"); + tool.run(badParams, listener); + + verify(listener).onFailure(any(Exception.class)); + verify(listener, never()).onResponse(any()); + } + + @Test + public void testRunClientThrows() { + // Simulate the MCP client throwing an exception + when(mcpSyncClient.callTool(any())).thenThrow(new RuntimeException("client error")); + + tool.run(validParams, listener); + + verify(listener).onFailure(any(RuntimeException.class)); + verify(listener, never()).onResponse(any()); + } + + @Test + public void testRunMissingInputParam() { + // No "input" key in parameters should also be caught + tool.run(Collections.emptyMap(), listener); + + verify(listener).onFailure(any(Exception.class)); + verify(listener, never()).onResponse(any()); + } + + @Test + public void testValidateAndMetadata() { + // validate + assertTrue(tool.validate(validParams)); + assertFalse(tool.validate(Collections.emptyMap())); + // metadata + assertEquals(McpSseTool.TYPE, tool.getName()); + assertEquals(McpSseTool.TYPE, tool.getType()); + assertNull(tool.getVersion()); + assertEquals(McpSseTool.DEFAULT_DESCRIPTION, tool.getDescription()); + } + + @Test + public void testFactoryDefaults() { + McpSseTool.Factory factory = McpSseTool.Factory.getInstance(); + assertEquals(McpSseTool.DEFAULT_DESCRIPTION, factory.getDefaultDescription()); + assertEquals(McpSseTool.TYPE, factory.getDefaultType()); + assertNull(factory.getDefaultVersion()); + assertTrue(factory.getAllModelKeys().isEmpty()); + } +} From a3383d2003b957368f575b3108684de85d0c7434 Mon Sep 17 00:00:00 2001 From: rithin-pullela-aws Date: Wed, 14 May 2025 14:44:44 -0700 Subject: [PATCH 2/4] Add more UTs Signed-off-by: rithin-pullela-aws --- .../ml/common/connector/McpConnectorTest.java | 25 +-- .../algorithms/agent/AgentUtilsTest.java | 180 +++++++++++++++++- .../remote/McpConnectorExecutorTest.java | 2 +- .../ml/engine/tools/McpSseToolTests.java | 15 +- 4 files changed, 196 insertions(+), 26 deletions(-) diff --git a/common/src/test/java/org/opensearch/ml/common/connector/McpConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/McpConnectorTest.java index 103ac0f7c5..df13014cdb 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/McpConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/McpConnectorTest.java @@ -193,18 +193,19 @@ public void testUpdate() { updatedHeaders.put("new_header", "new_header_value"); updatedHeaders.put("updated_api_key", "${credential.new_key}"); // Referencing new credential key - MLCreateConnectorInput updateInput = MLCreateConnectorInput.builder() - .name(updatedName) - .description(updatedDescription) - .version(updatedVersion) - .credential(updatedCredential) - .backendRoles(updatedBackendRoles) - .access(updatedAccessMode) - .connectorClientConfig(updatedClientConfig) - .url(updatedUrl) - .headers(updatedHeaders) - .protocol(MCP_SSE) - .build(); + MLCreateConnectorInput updateInput = MLCreateConnectorInput + .builder() + .name(updatedName) + .description(updatedDescription) + .version(updatedVersion) + .credential(updatedCredential) + .backendRoles(updatedBackendRoles) + .access(updatedAccessMode) + .connectorClientConfig(updatedClientConfig) + .url(updatedUrl) + .headers(updatedHeaders) + .protocol(MCP_SSE) + .build(); // Call the update method connector.update(updateInput, encryptFunction); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java index 2b8487d1ec..310c80a944 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java @@ -7,9 +7,14 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.MCP_CONNECTORS_FIELD; +import static org.opensearch.ml.common.CommonValue.MCP_CONNECTOR_ID_FIELD; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_FINISH_REASON_PATH; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_FINISH_REASON_TOOL_USE; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_GEN_INPUT; @@ -21,6 +26,7 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALLS_TOOL_NAME; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALL_ID; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_CALL_ID_PATH; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_FILTERS_FIELD; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION_INPUT; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CHAT_HISTORY; @@ -38,27 +44,44 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.CompletionStage; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; import java.util.function.Consumer; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; +import org.mockito.MockedStatic; import org.mockito.MockitoAnnotations; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.connector.AwsConnector; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.McpConnector; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.spi.tools.Tool; -import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.engine.MLEngineClassLoader; +import org.opensearch.ml.engine.MLStaticMockBase; +import org.opensearch.ml.engine.algorithms.remote.McpConnectorExecutor; import org.opensearch.ml.engine.encryptor.Encryptor; +import org.opensearch.ml.engine.tools.McpSseTool; +import org.opensearch.remote.metadata.client.GetDataObjectRequest; +import org.opensearch.remote.metadata.client.GetDataObjectResponse; import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.client.Client; - -public class AgentUtilsTest { +public class AgentUtilsTest extends MLStaticMockBase { @Mock private Tool tool1, tool2; @@ -72,6 +95,11 @@ public class AgentUtilsTest { @Mock private Encryptor encryptor; + ThreadContext threadContext; + + @Mock + ThreadPool threadPool; + private Map> llmResponseExpectedParseResults; private String responseForAction = "---------------------\n{\n " @@ -1169,6 +1197,49 @@ public void testParseLLMOutputWithDeepseekFormat() { Assert.assertTrue(output3.get(FINAL_ANSWER).contains("This is a test response")); } + private static MLToolSpec tool(String name) { + return MLToolSpec.builder().type(McpSseTool.TYPE).name(name).description("mock").build(); + } + + private void stubGetConnector() { + threadContext = new ThreadContext(Settings.builder().build()); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + when(sdkClient.getDataObjectAsync(any(GetDataObjectRequest.class))).thenAnswer(inv -> { + String json = "{\"_index\":\"i\",\"_id\":\"j\",\"found\":true,\"_source\":{}}"; + XContentParser parser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, null, json); + + GetDataObjectResponse resp = mock(GetDataObjectResponse.class); + when(resp.parser()).thenReturn(parser); + + CompletionStage stage = mock(CompletionStage.class); + when(stage.whenComplete(any())).thenAnswer(cbInv -> { + BiConsumer cb = cbInv.getArgument(0); + cb.accept(resp, null); + return stage; + }); + + return stage; + }); + } + + // create + register a mock McpConnector with Connector.createConnector + private void mockMcpConnector(MockedStatic connectorStatic) { + McpConnector mockConnector = mock(McpConnector.class); + when(mockConnector.getProtocol()).thenReturn("mcp_sse"); + doNothing().when(mockConnector).decrypt(anyString(), any(), anyString()); + connectorStatic.when(() -> Connector.createConnector(any(XContentParser.class))).thenReturn(mockConnector); + } + + // create a mock MLAgent with connector-config JSON + private MLAgent mockAgent(String json, String tenant) { + MLAgent mockAgent = mock(MLAgent.class); + when(mockAgent.getParameters()).thenReturn(Map.of(MCP_CONNECTORS_FIELD, json)); + when(mockAgent.getTenantId()).thenReturn(tenant); + return mockAgent; + } + @Test public void testGetMcpToolSpecs_NoMcpJsonConfig() { when(mlAgent.getParameters()).thenReturn(null); @@ -1179,6 +1250,109 @@ public void testGetMcpToolSpecs_NoMcpJsonConfig() { verify(listener).onResponse(Collections.emptyList()); } + @Test + public void testGetMcpToolSpecs_SingleConnectorSuccess() throws Exception { + stubGetConnector(); + List expected = List.of(tool("Demo")); + + try ( + MockedStatic connStatic = mockStatic(Connector.class); + MockedStatic loadStatic = mockStatic(MLEngineClassLoader.class) + ) { + // mock McpConnector, McpConnectorExecutor, agent, and listener + mockMcpConnector(connStatic); + McpConnectorExecutor exec = mock(McpConnectorExecutor.class); + when(exec.getMcpToolSpecs()).thenReturn(expected); + loadStatic.when(() -> MLEngineClassLoader.initInstance(anyString(), any(), any())).thenReturn(exec); + + MLAgent mlAgent = mockAgent("[{\"" + MCP_CONNECTOR_ID_FIELD + "\":\"c1\"}]", "tenant"); + ActionListener> listener = mock(ActionListener.class); + + // run and verify + AgentUtils.getMcpToolSpecs(mlAgent, client, sdkClient, null, listener); + verify(listener).onResponse(expected); + } + } + + @Test + public void testGetMcpToolSpecs_ToolFilterApplied() throws Exception { + stubGetConnector(); + List repo = List.of(tool("FilterTool"), tool("TempTool")); + List expected = List.of(tool("FilterTool")); + + try ( + MockedStatic connStatic = mockStatic(Connector.class); + MockedStatic loadStatic = mockStatic(MLEngineClassLoader.class) + ) { + // mock McpConnector, McpConnectorExecutor, agent, and listener + mockMcpConnector(connStatic); + + McpConnectorExecutor exec = mock(McpConnectorExecutor.class); + when(exec.getMcpToolSpecs()).thenReturn(repo); + loadStatic.when(() -> MLEngineClassLoader.initInstance(anyString(), any(), any())).thenReturn(exec); + + String mcpJsonConfig = "[{\"" + + MCP_CONNECTOR_ID_FIELD + + "\":\"c1\",\"" + + TOOL_FILTERS_FIELD + + "\":[\"^Filter.*\", \"SecondDemoFilter\"]}]"; + MLAgent agent = mockAgent(mcpJsonConfig, "tenant"); + + ActionListener> listener = mock(ActionListener.class); + + // run and verify + AgentUtils.getMcpToolSpecs(agent, client, sdkClient, null, listener); + verify(listener).onResponse(expected); + } + } + + @Test + public void testGetMcpToolSpecs_MultipleConnectorsMerged() throws Exception { + stubGetConnector(); // now safe + + List aTools = List.of(tool("A1")); + List bTools = List.of(tool("B1"), tool("B2")); + List expected = new ArrayList<>(); + expected.addAll(aTools); + expected.addAll(bTools); + + try ( + MockedStatic connStatic = mockStatic(Connector.class); + MockedStatic loadStatic = mockStatic(MLEngineClassLoader.class) + ) { + // mock McpConnector, McpConnectorExecutor, agent, and listener + mockMcpConnector(connStatic); + + McpConnectorExecutor exec = mock(McpConnectorExecutor.class); + when(exec.getMcpToolSpecs()).thenReturn(aTools, bTools); + loadStatic.when(() -> MLEngineClassLoader.initInstance(anyString(), any(), any())).thenReturn(exec); + + String mcpJsonConfig = "[{\"" + MCP_CONNECTOR_ID_FIELD + "\":\"A\"}," + "{\"" + MCP_CONNECTOR_ID_FIELD + "\":\"B\"}]"; + MLAgent agent = mockAgent(mcpJsonConfig, "tenant"); + + ActionListener> listener = mock(ActionListener.class); + + // run and verify + AgentUtils.getMcpToolSpecs(agent, client, sdkClient, null, listener); + verify(listener).onResponse(expected); + } + } + + @Test + public void testGetMcpToolSpecs_NonMcpConnectorReturnsEmpty() throws Exception { + stubGetConnector(); + try (MockedStatic connStatic = mockStatic(Connector.class)) { + + connStatic.when(() -> Connector.createConnector(any(XContentParser.class))).thenReturn(mock(AwsConnector.class)); + MLAgent agent = mockAgent("[{\"" + MCP_CONNECTOR_ID_FIELD + "\":\"c1\"}]", "tenant"); + + ActionListener> listener = mock(ActionListener.class); + AgentUtils.getMcpToolSpecs(agent, client, sdkClient, null, listener); + + verify(listener).onResponse(Collections.emptyList()); + } + } + private void verifyConstructToolParams(String question, String actionInput, Consumer> verify) { Map tools = Map.of("tool1", tool1); Map toolSpecMap = Map diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutorTest.java index 66808cb9bd..596c218021 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutorTest.java @@ -51,7 +51,7 @@ public void setUp() { public void getMcpToolSpecs_returnsExpectedSpecs() { String inputSchemaJSON = - "{\"type\":\"object\",\"properties\":{\"state\":{\"title\":\"State\",\"type\":\"string\"}},\"required\":[\"state\"],\"additionalProperties\":false}"; + "{\"type\":\"object\",\"properties\":{\"state\":{\"title\":\"State\",\"type\":\"string\"}},\"required\":[\"state\"],\"additionalProperties\":false}"; McpSchema.Tool tool = new McpSchema.Tool("tool1", "desc1", inputSchemaJSON); McpSchema.ListToolsResult mockTools = new McpSchema.ListToolsResult(List.of(tool), null); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/McpSseToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/McpSseToolTests.java index b1b76615e6..7c050f9a45 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/McpSseToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/McpSseToolTests.java @@ -1,9 +1,9 @@ package org.opensearch.ml.engine.tools; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; @@ -38,9 +38,7 @@ public class McpSseToolTests { public void setup() { MockitoAnnotations.openMocks(this); // Initialize the tool with the mocked client - tool = McpSseTool.Factory.getInstance().create( - Map.of(MCP_SYNC_CLIENT, mcpSyncClient) - ); + tool = McpSseTool.Factory.getInstance().create(Map.of(MCP_SYNC_CLIENT, mcpSyncClient)); validParams = Map.of("input", "{\"foo\":\"bar\"}"); } @@ -48,16 +46,13 @@ public void setup() { public void testRunSuccess() { // Arrange: create a CallToolResult wrapping a JSON string McpSchema.CallToolResult result = new McpSchema.CallToolResult("{\"foo\":\"bar\"}", false); - when(mcpSyncClient.callTool(any(McpSchema.CallToolRequest.class))) - .thenReturn(result); + when(mcpSyncClient.callTool(any(McpSchema.CallToolRequest.class))).thenReturn(result); // Act tool.run(validParams, listener); // Assert: ensure onResponse is called with the JSON string - verify(listener).onResponse( - "[{\"text\":\"{\\\"foo\\\":\\\"bar\\\"}\"}]" - ); + verify(listener).onResponse("[{\"text\":\"{\\\"foo\\\":\\\"bar\\\"}\"}]"); verify(listener, never()).onFailure(any()); } From a5ab589d6e0719c122e5fdcd9d23c3cfe686a474 Mon Sep 17 00:00:00 2001 From: rithin-pullela-aws Date: Wed, 14 May 2025 15:01:11 -0700 Subject: [PATCH 3/4] cleanup code Signed-off-by: rithin-pullela-aws --- .../ml/engine/algorithms/agent/AgentUtilsTest.java | 4 ++-- .../org/opensearch/ml/engine/tools/McpSseToolTests.java | 9 ++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java index 310c80a944..5e164238b5 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java @@ -95,10 +95,10 @@ public class AgentUtilsTest extends MLStaticMockBase { @Mock private Encryptor encryptor; - ThreadContext threadContext; + private ThreadContext threadContext; @Mock - ThreadPool threadPool; + private ThreadPool threadPool; private Map> llmResponseExpectedParseResults; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/McpSseToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/McpSseToolTests.java index 7c050f9a45..2990aaff15 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/McpSseToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/McpSseToolTests.java @@ -37,21 +37,20 @@ public class McpSseToolTests { @Before public void setup() { MockitoAnnotations.openMocks(this); - // Initialize the tool with the mocked client + // Initialize the tool with the mocked mcp client tool = McpSseTool.Factory.getInstance().create(Map.of(MCP_SYNC_CLIENT, mcpSyncClient)); validParams = Map.of("input", "{\"foo\":\"bar\"}"); } @Test public void testRunSuccess() { - // Arrange: create a CallToolResult wrapping a JSON string + // create a CallToolResult wrapping a JSON string McpSchema.CallToolResult result = new McpSchema.CallToolResult("{\"foo\":\"bar\"}", false); when(mcpSyncClient.callTool(any(McpSchema.CallToolRequest.class))).thenReturn(result); - // Act tool.run(validParams, listener); - // Assert: ensure onResponse is called with the JSON string + // Assert verify(listener).onResponse("[{\"text\":\"{\\\"foo\\\":\\\"bar\\\"}\"}]"); verify(listener, never()).onFailure(any()); } @@ -79,7 +78,7 @@ public void testRunClientThrows() { @Test public void testRunMissingInputParam() { - // No "input" key in parameters should also be caught + // No "input" key in parameters should be caught tool.run(Collections.emptyMap(), listener); verify(listener).onFailure(any(Exception.class)); From 9d0b9a751b5551eac287073080f2532e9d3bf50b Mon Sep 17 00:00:00 2001 From: rithin-pullela-aws Date: Wed, 14 May 2025 18:20:40 -0700 Subject: [PATCH 4/4] Add license header, rename helper function Signed-off-by: rithin-pullela-aws --- .../ml/common/connector/McpConnectorTest.java | 5 +++++ .../ml/engine/algorithms/agent/AgentUtilsTest.java | 12 ++++++------ .../algorithms/remote/McpConnectorExecutorTest.java | 5 +++++ .../opensearch/ml/engine/tools/McpSseToolTests.java | 5 +++++ 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/common/src/test/java/org/opensearch/ml/common/connector/McpConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/McpConnectorTest.java index df13014cdb..e77d6ebe44 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/McpConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/McpConnectorTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.ml.common.connector; import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java index 5e164238b5..0a39383e17 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java @@ -1197,7 +1197,7 @@ public void testParseLLMOutputWithDeepseekFormat() { Assert.assertTrue(output3.get(FINAL_ANSWER).contains("This is a test response")); } - private static MLToolSpec tool(String name) { + private static MLToolSpec buildTool(String name) { return MLToolSpec.builder().type(McpSseTool.TYPE).name(name).description("mock").build(); } @@ -1253,7 +1253,7 @@ public void testGetMcpToolSpecs_NoMcpJsonConfig() { @Test public void testGetMcpToolSpecs_SingleConnectorSuccess() throws Exception { stubGetConnector(); - List expected = List.of(tool("Demo")); + List expected = List.of(buildTool("Demo")); try ( MockedStatic connStatic = mockStatic(Connector.class); @@ -1277,8 +1277,8 @@ public void testGetMcpToolSpecs_SingleConnectorSuccess() throws Exception { @Test public void testGetMcpToolSpecs_ToolFilterApplied() throws Exception { stubGetConnector(); - List repo = List.of(tool("FilterTool"), tool("TempTool")); - List expected = List.of(tool("FilterTool")); + List repo = List.of(buildTool("FilterTool"), buildTool("TempTool")); + List expected = List.of(buildTool("FilterTool")); try ( MockedStatic connStatic = mockStatic(Connector.class); @@ -1310,8 +1310,8 @@ public void testGetMcpToolSpecs_ToolFilterApplied() throws Exception { public void testGetMcpToolSpecs_MultipleConnectorsMerged() throws Exception { stubGetConnector(); // now safe - List aTools = List.of(tool("A1")); - List bTools = List.of(tool("B1"), tool("B2")); + List aTools = List.of(buildTool("A1")); + List bTools = List.of(buildTool("B1"), buildTool("B2")); List expected = new ArrayList<>(); expected.addAll(aTools); expected.addAll(bTools); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutorTest.java index 596c218021..bf15071646 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutorTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.ml.engine.algorithms.remote; import static org.junit.Assert.assertThrows; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/McpSseToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/McpSseToolTests.java index 2990aaff15..b45bc09b36 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/McpSseToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/McpSseToolTests.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.ml.engine.tools; import static org.junit.Assert.assertEquals;