Skip to content

Commit 092a8b3

Browse files
Add tests for MCP feature
Signed-off-by: rithin-pullela-aws <rithinp@amazon.com>
1 parent c3b1fd6 commit 092a8b3

File tree

4 files changed

+500
-0
lines changed

4 files changed

+500
-0
lines changed
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
package org.opensearch.ml.common.connector;
2+
3+
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
4+
import static org.opensearch.ml.common.connector.ConnectorProtocols.MCP_SSE;
5+
import static org.opensearch.ml.common.connector.RetryBackoffPolicy.CONSTANT;
6+
7+
import java.io.IOException;
8+
import java.util.Arrays;
9+
import java.util.Collections;
10+
import java.util.HashMap;
11+
import java.util.List;
12+
import java.util.Locale;
13+
import java.util.Map;
14+
import java.util.function.BiFunction;
15+
16+
import org.junit.Assert;
17+
import org.junit.Before;
18+
import org.junit.Rule;
19+
import org.junit.Test;
20+
import org.junit.rules.ExpectedException;
21+
import org.opensearch.common.io.stream.BytesStreamOutput;
22+
import org.opensearch.common.settings.Settings;
23+
import org.opensearch.common.xcontent.XContentFactory;
24+
import org.opensearch.common.xcontent.XContentType;
25+
import org.opensearch.core.xcontent.NamedXContentRegistry;
26+
import org.opensearch.core.xcontent.ToXContent;
27+
import org.opensearch.core.xcontent.XContentBuilder;
28+
import org.opensearch.core.xcontent.XContentParser;
29+
import org.opensearch.ml.common.AccessMode;
30+
import org.opensearch.ml.common.TestHelper;
31+
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
32+
import org.opensearch.search.SearchModule;
33+
34+
public class McpConnectorTest {
35+
@Rule
36+
public ExpectedException exceptionRule = ExpectedException.none();
37+
38+
BiFunction<String, String, String> encryptFunction;
39+
BiFunction<String, String, String> decryptFunction;
40+
41+
String TEST_CONNECTOR_JSON_STRING =
42+
"{\"name\":\"test_mcp_connector_name\",\"version\":\"1\",\"description\":\"this is a test mcp connector\",\"protocol\":\"mcp_sse\",\"credential\":{\"key\":\"test_key_value\"},\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\",\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000,\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"},\"url\":\"https://test.com\",\"headers\":{\"api_key\":\"${credential.key}\"}}";
43+
44+
@Before
45+
public void setUp() {
46+
encryptFunction = (s, v) -> "encrypted: " + s.toLowerCase(Locale.ROOT);
47+
decryptFunction = (s, v) -> "decrypted: " + s.toUpperCase(Locale.ROOT);
48+
}
49+
50+
@Test
51+
public void constructor_InvalidProtocol() {
52+
exceptionRule.expect(IllegalArgumentException.class);
53+
exceptionRule.expectMessage("Unsupported connector protocol. Please use one of [aws_sigv4, http, mcp_sse]");
54+
55+
McpConnector.builder().protocol("wrong protocol").build();
56+
}
57+
58+
@Test
59+
public void writeTo() throws IOException {
60+
McpConnector connector = createMcpConnector();
61+
62+
BytesStreamOutput output = new BytesStreamOutput();
63+
connector.writeTo(output);
64+
65+
McpConnector connector2 = new McpConnector(output.bytes().streamInput());
66+
Assert.assertEquals(connector, connector2);
67+
}
68+
69+
@Test
70+
public void toXContent() throws IOException {
71+
McpConnector connector = createMcpConnector();
72+
73+
XContentBuilder builder = XContentFactory.jsonBuilder();
74+
connector.toXContent(builder, ToXContent.EMPTY_PARAMS);
75+
String content = TestHelper.xContentBuilderToString(builder);
76+
77+
Assert.assertEquals(TEST_CONNECTOR_JSON_STRING, content);
78+
}
79+
80+
@Test
81+
public void constructor_Parser() throws IOException {
82+
XContentParser parser = XContentType.JSON
83+
.xContent()
84+
.createParser(
85+
new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()),
86+
null,
87+
TEST_CONNECTOR_JSON_STRING
88+
);
89+
parser.nextToken();
90+
91+
McpConnector connector = new McpConnector("mcp_sse", parser);
92+
Assert.assertEquals("test_mcp_connector_name", connector.getName());
93+
Assert.assertEquals("1", connector.getVersion());
94+
Assert.assertEquals("this is a test mcp connector", connector.getDescription());
95+
Assert.assertEquals("mcp_sse", connector.getProtocol());
96+
Assert.assertEquals(AccessMode.PUBLIC, connector.getAccess());
97+
Assert.assertEquals("https://test.com", connector.getUrl());
98+
connector.decrypt(PREDICT.name(), decryptFunction, null);
99+
Map<String, String> decryptedCredential = connector.getDecryptedCredential();
100+
Assert.assertEquals(1, decryptedCredential.size());
101+
Assert.assertEquals("decrypted: TEST_KEY_VALUE", decryptedCredential.get("key"));
102+
Assert.assertNotNull(connector.getDecryptedHeaders());
103+
Assert.assertEquals(1, connector.getDecryptedHeaders().size());
104+
Assert.assertEquals("decrypted: TEST_KEY_VALUE", connector.getDecryptedHeaders().get("api_key"));
105+
}
106+
107+
@Test
108+
public void cloneConnector() {
109+
McpConnector connector = createMcpConnector();
110+
Connector connector2 = connector.cloneConnector();
111+
Assert.assertEquals(connector, connector2);
112+
}
113+
114+
@Test
115+
public void decrypt() {
116+
McpConnector connector = createMcpConnector();
117+
connector.decrypt("", decryptFunction, null);
118+
Map<String, String> decryptedCredential = connector.getDecryptedCredential();
119+
Assert.assertEquals(1, decryptedCredential.size());
120+
Assert.assertEquals("decrypted: TEST_KEY_VALUE", decryptedCredential.get("key"));
121+
Assert.assertNotNull(connector.getDecryptedHeaders());
122+
Assert.assertEquals(1, connector.getDecryptedHeaders().size());
123+
Assert.assertEquals("decrypted: TEST_KEY_VALUE", connector.getDecryptedHeaders().get("api_key"));
124+
125+
connector.removeCredential();
126+
Assert.assertNull(connector.getCredential());
127+
Assert.assertNull(connector.getDecryptedCredential());
128+
Assert.assertNull(connector.getDecryptedHeaders());
129+
}
130+
131+
@Test
132+
public void encrypt() {
133+
McpConnector connector = createMcpConnector();
134+
connector.encrypt(encryptFunction, null);
135+
Map<String, String> credential = connector.getCredential();
136+
Assert.assertEquals(1, credential.size());
137+
Assert.assertEquals("encrypted: test_key_value", credential.get("key"));
138+
139+
connector.removeCredential();
140+
Assert.assertNull(connector.getCredential());
141+
Assert.assertNull(connector.getDecryptedCredential());
142+
Assert.assertNull(connector.getDecryptedHeaders());
143+
}
144+
145+
@Test
146+
public void validateConnectorURL_Invalid() {
147+
exceptionRule.expect(IllegalArgumentException.class);
148+
exceptionRule.expectMessage("Connector URL is not matching the trusted connector endpoint regex");
149+
McpConnector connector = createMcpConnector();
150+
connector
151+
.validateConnectorURL(
152+
Arrays
153+
.asList(
154+
"^https://runtime\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
155+
"^https://api\\.openai\\.com/.*$",
156+
"^https://api\\.cohere\\.ai/.*$",
157+
"^https://bedrock-agent-runtime\\\\..*[a-z0-9-]\\\\.amazonaws\\\\.com/.*$"
158+
)
159+
);
160+
}
161+
162+
@Test
163+
public void validateConnectorURL() {
164+
McpConnector connector = createMcpConnector();
165+
connector
166+
.validateConnectorURL(
167+
Arrays
168+
.asList(
169+
"^https://runtime\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
170+
"^https://api\\.openai\\.com/.*$",
171+
"^https://bedrock-agent-runtime\\\\..*[a-z0-9-]\\\\.amazonaws\\\\.com/.*$",
172+
"^" + connector.getUrl()
173+
)
174+
);
175+
}
176+
177+
@Test
178+
public void testUpdate() {
179+
McpConnector connector = createMcpConnector();
180+
Map<String, String> initialCredential = new HashMap<>(connector.getCredential());
181+
182+
// Create update content
183+
String updatedName = "updated_name";
184+
String updatedDescription = "updated description";
185+
String updatedVersion = "2";
186+
Map<String, String> updatedCredential = new HashMap<>();
187+
updatedCredential.put("new_key", "new_value");
188+
List<String> updatedBackendRoles = List.of("role3", "role4");
189+
AccessMode updatedAccessMode = AccessMode.PRIVATE;
190+
ConnectorClientConfig updatedClientConfig = new ConnectorClientConfig(40, 40000, 40000, 20, 20, 5, CONSTANT);
191+
String updatedUrl = "https://updated.test.com";
192+
Map<String, String> updatedHeaders = new HashMap<>();
193+
updatedHeaders.put("new_header", "new_header_value");
194+
updatedHeaders.put("updated_api_key", "${credential.new_key}"); // Referencing new credential key
195+
196+
MLCreateConnectorInput updateInput = MLCreateConnectorInput.builder()
197+
.name(updatedName)
198+
.description(updatedDescription)
199+
.version(updatedVersion)
200+
.credential(updatedCredential)
201+
.backendRoles(updatedBackendRoles)
202+
.access(updatedAccessMode)
203+
.connectorClientConfig(updatedClientConfig)
204+
.url(updatedUrl)
205+
.headers(updatedHeaders)
206+
.protocol(MCP_SSE)
207+
.build();
208+
209+
// Call the update method
210+
connector.update(updateInput, encryptFunction);
211+
212+
// Assertions
213+
Assert.assertEquals(updatedName, connector.getName());
214+
Assert.assertEquals(updatedDescription, connector.getDescription());
215+
Assert.assertEquals(updatedVersion, connector.getVersion());
216+
Assert.assertEquals(MCP_SSE, connector.getProtocol()); // Should not change if not provided
217+
Assert.assertEquals(updatedBackendRoles, connector.getBackendRoles());
218+
Assert.assertEquals(updatedAccessMode, connector.getAccess());
219+
Assert.assertEquals(updatedClientConfig, connector.getConnectorClientConfig());
220+
Assert.assertEquals(updatedUrl, connector.getUrl());
221+
Assert.assertEquals(updatedHeaders, connector.getHeaders());
222+
223+
// Check encrypted credentials
224+
Map<String, String> currentCredential = connector.getCredential();
225+
Assert.assertNotNull(currentCredential);
226+
Assert.assertEquals(1, currentCredential.size()); // Should replace old credentials
227+
Assert.assertEquals("encrypted: new_value", currentCredential.get("new_key"));
228+
Assert.assertNotEquals(initialCredential, currentCredential);
229+
230+
// Check decrypted credentials and headers (need to explicitly decrypt after update)
231+
connector.decrypt("", decryptFunction, null); // Use decrypt function from setUp
232+
Map<String, String> decryptedCredential = connector.getDecryptedCredential();
233+
Assert.assertNotNull(decryptedCredential);
234+
Assert.assertEquals(1, decryptedCredential.size());
235+
Assert.assertEquals("decrypted: ENCRYPTED: NEW_VALUE", decryptedCredential.get("new_key")); // Uses the decrypt function logic
236+
237+
Map<String, String> decryptedHeaders = connector.getDecryptedHeaders();
238+
Assert.assertNotNull(decryptedHeaders);
239+
Assert.assertEquals(2, decryptedHeaders.size());
240+
Assert.assertEquals("new_header_value", decryptedHeaders.get("new_header"));
241+
Assert.assertEquals("decrypted: ENCRYPTED: NEW_VALUE", decryptedHeaders.get("updated_api_key")); // Check header substitution
242+
}
243+
244+
public static McpConnector createMcpConnector() {
245+
Map<String, String> credential = new HashMap<>();
246+
credential.put("key", "test_key_value");
247+
248+
Map<String, String> headers = new HashMap<>();
249+
headers.put("api_key", "${credential.key}");
250+
251+
ConnectorClientConfig clientConfig = new ConnectorClientConfig(30, 30000, 30000, 10, 10, -1, RetryBackoffPolicy.CONSTANT);
252+
253+
return McpConnector
254+
.builder()
255+
.name("test_mcp_connector_name")
256+
.version("1")
257+
.description("this is a test mcp connector")
258+
.protocol(MCP_SSE)
259+
.credential(credential)
260+
.backendRoles(List.of("role1", "role2"))
261+
.accessMode(AccessMode.PUBLIC)
262+
.connectorClientConfig(clientConfig)
263+
.url("https://test.com")
264+
.headers(headers)
265+
.build();
266+
}
267+
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import static org.junit.Assert.assertEquals;
99
import static org.junit.Assert.assertThrows;
10+
import static org.mockito.Mockito.mock;
11+
import static org.mockito.Mockito.verify;
1012
import static org.mockito.Mockito.when;
1113
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_FINISH_REASON_PATH;
1214
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_FINISH_REASON_TOOL_USE;
@@ -44,17 +46,32 @@
4446
import org.junit.Test;
4547
import org.mockito.Mock;
4648
import org.mockito.MockitoAnnotations;
49+
import org.opensearch.core.action.ActionListener;
4750
import org.opensearch.ml.common.agent.MLToolSpec;
4851
import org.opensearch.ml.common.output.model.ModelTensor;
4952
import org.opensearch.ml.common.output.model.ModelTensorOutput;
5053
import org.opensearch.ml.common.output.model.ModelTensors;
5154
import org.opensearch.ml.common.spi.tools.Tool;
55+
import org.opensearch.ml.common.agent.MLAgent;
56+
import org.opensearch.ml.engine.encryptor.Encryptor;
57+
import org.opensearch.remote.metadata.client.SdkClient;
58+
import org.opensearch.transport.client.Client;
59+
5260

5361
public class AgentUtilsTest {
5462

5563
@Mock
5664
private Tool tool1, tool2;
5765

66+
@Mock
67+
private MLAgent mlAgent;
68+
@Mock
69+
private Client client;
70+
@Mock
71+
private SdkClient sdkClient;
72+
@Mock
73+
private Encryptor encryptor;
74+
5875
private Map<String, Map<String, String>> llmResponseExpectedParseResults;
5976

6077
private String responseForAction = "---------------------\n{\n "
@@ -1152,6 +1169,16 @@ public void testParseLLMOutputWithDeepseekFormat() {
11521169
Assert.assertTrue(output3.get(FINAL_ANSWER).contains("This is a test response"));
11531170
}
11541171

1172+
@Test
1173+
public void testGetMcpToolSpecs_NoMcpJsonConfig() {
1174+
when(mlAgent.getParameters()).thenReturn(null);
1175+
1176+
ActionListener<List<MLToolSpec>> listener = mock(ActionListener.class);
1177+
AgentUtils.getMcpToolSpecs(mlAgent, client, sdkClient, encryptor, listener);
1178+
1179+
verify(listener).onResponse(Collections.emptyList());
1180+
}
1181+
11551182
private void verifyConstructToolParams(String question, String actionInput, Consumer<Map<String, String>> verify) {
11561183
Map<String, Tool> tools = Map.of("tool1", tool1);
11571184
Map<String, MLToolSpec> toolSpecMap = Map

0 commit comments

Comments
 (0)