Skip to content

Commit 08f0a88

Browse files
rithin-pullela-awsmingshlZhangxunmt
authored
Add custom SSE endpoint for the MCP Client (#3891)
* Add custom SSE endpoint for the MCP Client Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> * Address comment, use sse_endpoint as field name Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> --------- Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> Co-authored-by: Mingshi Liu <mingshl@amazon.com> Co-authored-by: Xun Zhang <xunzh@amazon.com>
1 parent 561db0f commit 08f0a88

File tree

5 files changed

+58
-9
lines changed

5 files changed

+58
-9
lines changed

common/src/main/java/org/opensearch/ml/common/CommonValue.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ public class CommonValue {
100100
public static final String MCP_TOOLS_FIELD = "tools";
101101
public static final String MCP_CONNECTORS_FIELD = "mcp_connectors";
102102
public static final String MCP_CONNECTOR_ID_FIELD = "mcp_connector_id";
103+
public static final String MCP_DEFAULT_SSE_ENDPOINT = "/sse";
104+
public static final String SSE_ENDPOINT_FILED = "sse_endpoint";
103105

104106
// TOOL Constants
105107
public static final String TOOL_INPUT_SCHEMA_FIELD = "input_schema";

common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ public abstract class AbstractConnector implements Connector {
3737
public static final String DESCRIPTION_FIELD = "description";
3838
public static final String PROTOCOL_FIELD = "protocol";
3939
public static final String ACTIONS_FIELD = "actions";
40-
public static final String CREDENTIAL_FIELD = "credential";
4140
public static final String PARAMETERS_FIELD = "parameters";
4241
public static final String CREATED_TIME_FIELD = "created_time";
4342
public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time";

common/src/main/java/org/opensearch/ml/common/connector/McpConnector.java

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
import static org.opensearch.ml.common.CommonValue.LAST_UPDATED_TIME_FIELD;
1717
import static org.opensearch.ml.common.CommonValue.NAME_FIELD;
1818
import static org.opensearch.ml.common.CommonValue.OWNER_FIELD;
19+
import static org.opensearch.ml.common.CommonValue.PARAMETERS_FIELD;
1920
import static org.opensearch.ml.common.CommonValue.PROTOCOL_FIELD;
2021
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
2122
import static org.opensearch.ml.common.CommonValue.URL_FIELD;
23+
import static org.opensearch.ml.common.CommonValue.VERSION_3_1_0;
2224
import static org.opensearch.ml.common.CommonValue.VERSION_FIELD;
2325
import static org.opensearch.ml.common.connector.ConnectorProtocols.MCP_SSE;
2426
import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol;
@@ -35,6 +37,7 @@
3537
import java.util.regex.Pattern;
3638

3739
import org.apache.commons.text.StringSubstitutor;
40+
import org.opensearch.Version;
3841
import org.opensearch.common.io.stream.BytesStreamOutput;
3942
import org.opensearch.commons.authuser.User;
4043
import org.opensearch.core.common.io.stream.StreamInput;
@@ -66,6 +69,7 @@ public class McpConnector implements Connector {
6669

6770
protected Map<String, String> credential;
6871
protected Map<String, String> decryptedHeaders;
72+
protected Map<String, String> parameters;
6973
@Setter
7074
protected Map<String, String> decryptedCredential;
7175
@Setter
@@ -101,7 +105,8 @@ public McpConnector(
101105
ConnectorClientConfig connectorClientConfig,
102106
String tenantId,
103107
String url,
104-
Map<String, String> headers
108+
Map<String, String> headers,
109+
Map<String, String> parameters
105110
) {
106111
validateProtocol(protocol);
107112
this.name = name;
@@ -116,6 +121,7 @@ public McpConnector(
116121
this.tenantId = tenantId;
117122
this.url = url;
118123
this.headers = headers;
124+
this.parameters = parameters;
119125
}
120126

121127
public McpConnector(String protocol, XContentParser parser) throws IOException {
@@ -175,6 +181,10 @@ public McpConnector(String protocol, XContentParser parser) throws IOException {
175181
headers = new HashMap<>();
176182
headers.putAll(parser.mapStrings());
177183
break;
184+
case PARAMETERS_FIELD:
185+
parameters = new HashMap<>();
186+
parameters.putAll(parser.mapStrings());
187+
break;
178188
default:
179189
parser.skipChildren();
180190
break;
@@ -229,6 +239,7 @@ public McpConnector(StreamInput input) throws IOException {
229239
}
230240

231241
private void parseFromStream(StreamInput input) throws IOException {
242+
Version streamInputVersion = input.getVersion();
232243
this.name = input.readOptionalString();
233244
this.version = input.readOptionalString();
234245
this.description = input.readOptionalString();
@@ -252,7 +263,11 @@ private void parseFromStream(StreamInput input) throws IOException {
252263
if (input.readBoolean()) {
253264
this.headers = input.readMap(s -> s.readString(), s -> s.readString());
254265
}
255-
266+
if (streamInputVersion.onOrAfter(VERSION_3_1_0)) {
267+
if (input.readBoolean()) {
268+
this.parameters = input.readMap(s -> s.readString(), s -> s.readString());
269+
}
270+
}
256271
}
257272

258273
@Override
@@ -264,6 +279,7 @@ public void removeCredential() {
264279

265280
@Override
266281
public void writeTo(StreamOutput out) throws IOException {
282+
Version streamOutputVersion = out.getVersion();
267283
out.writeString(protocol);
268284
out.writeOptionalString(name);
269285
out.writeOptionalString(version);
@@ -305,7 +321,14 @@ public void writeTo(StreamOutput out) throws IOException {
305321
} else {
306322
out.writeBoolean(false);
307323
}
308-
324+
if (streamOutputVersion.onOrAfter(VERSION_3_1_0)) {
325+
if (parameters != null) {
326+
out.writeBoolean(true);
327+
out.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeString);
328+
} else {
329+
out.writeBoolean(false);
330+
}
331+
}
309332
}
310333

311334
@Override
@@ -341,6 +364,9 @@ public void update(MLCreateConnectorInput updateContent, BiFunction<String, Stri
341364
if (updateContent.getHeaders() != null) {
342365
this.headers = updateContent.getHeaders();
343366
}
367+
if (updateContent.getParameters() != null) {
368+
this.parameters = updateContent.getParameters();
369+
}
344370
}
345371

346372
@Override
@@ -393,6 +419,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
393419
if (headers != null) {
394420
builder.field(HEADERS_FIELD, headers);
395421
}
422+
if (parameters != null) {
423+
builder.field(PARAMETERS_FIELD, parameters);
424+
}
396425
builder.endObject();
397426
return builder;
398427
}
@@ -415,7 +444,7 @@ public void validateConnectorURL(List<String> urlRegexes) {
415444

416445
@Override
417446
public Map<String, String> getParameters() {
418-
throw new UnsupportedOperationException("Not implemented.");
447+
return parameters;
419448
}
420449

421450
@Override

common/src/test/java/org/opensearch/ml/common/connector/McpConnectorTest.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public class McpConnectorTest {
4444
BiFunction<String, String, String> decryptFunction;
4545

4646
String TEST_CONNECTOR_JSON_STRING =
47-
"{\"name\":\"test_mcp_connector_name\",\"version\":\"1\",\"description\":\"this is a test mcp connector\",\"protocol\":\"mcp_sse\",\"credential\":{\"key\":\"test_key_value\"},\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\",\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000,\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"},\"url\":\"https://test.com\",\"headers\":{\"api_key\":\"${credential.key}\"}}";
47+
"{\"name\":\"test_mcp_connector_name\",\"version\":\"1\",\"description\":\"this is a test mcp connector\",\"protocol\":\"mcp_sse\",\"credential\":{\"key\":\"test_key_value\"},\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\",\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000,\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"},\"url\":\"https://test.com\",\"headers\":{\"api_key\":\"${credential.key}\"},\"parameters\":{\"sse_endpoint\":\"/custom/sse\"}}";
4848

4949
@Before
5050
public void setUp() {
@@ -100,6 +100,7 @@ public void constructor_Parser() throws IOException {
100100
Assert.assertEquals("mcp_sse", connector.getProtocol());
101101
Assert.assertEquals(AccessMode.PUBLIC, connector.getAccess());
102102
Assert.assertEquals("https://test.com", connector.getUrl());
103+
Assert.assertEquals("/custom/sse", connector.getParameters().get("sse_endpoint"));
103104
connector.decrypt(PREDICT.name(), decryptFunction, null);
104105
Map<String, String> decryptedCredential = connector.getDecryptedCredential();
105106
Assert.assertEquals(1, decryptedCredential.size());
@@ -197,6 +198,8 @@ public void testUpdate() {
197198
Map<String, String> updatedHeaders = new HashMap<>();
198199
updatedHeaders.put("new_header", "new_header_value");
199200
updatedHeaders.put("updated_api_key", "${credential.new_key}"); // Referencing new credential key
201+
Map<String, String> updatedParameters = new HashMap<>();
202+
updatedParameters.put("sse_endpoint", "/updated/sse");
200203

201204
MLCreateConnectorInput updateInput = MLCreateConnectorInput
202205
.builder()
@@ -209,6 +212,7 @@ public void testUpdate() {
209212
.connectorClientConfig(updatedClientConfig)
210213
.url(updatedUrl)
211214
.headers(updatedHeaders)
215+
.parameters(updatedParameters)
212216
.protocol(MCP_SSE)
213217
.build();
214218

@@ -220,6 +224,7 @@ public void testUpdate() {
220224
Assert.assertEquals(updatedDescription, connector.getDescription());
221225
Assert.assertEquals(updatedVersion, connector.getVersion());
222226
Assert.assertEquals(MCP_SSE, connector.getProtocol()); // Should not change if not provided
227+
Assert.assertEquals(updatedParameters, connector.getParameters());
223228
Assert.assertEquals(updatedBackendRoles, connector.getBackendRoles());
224229
Assert.assertEquals(updatedAccessMode, connector.getAccess());
225230
Assert.assertEquals(updatedClientConfig, connector.getConnectorClientConfig());
@@ -254,6 +259,9 @@ public static McpConnector createMcpConnector() {
254259
Map<String, String> headers = new HashMap<>();
255260
headers.put("api_key", "${credential.key}");
256261

262+
Map<String, String> parameters = new HashMap<>();
263+
parameters.put("sse_endpoint", "/custom/sse");
264+
257265
ConnectorClientConfig clientConfig = new ConnectorClientConfig(30, 30000, 30000, 10, 10, -1, RetryBackoffPolicy.CONSTANT);
258266

259267
return McpConnector
@@ -268,6 +276,7 @@ public static McpConnector createMcpConnector() {
268276
.connectorClientConfig(clientConfig)
269277
.url("https://test.com")
270278
.headers(headers)
279+
.parameters(parameters)
271280
.build();
272281
}
273282
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutor.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55

66
package org.opensearch.ml.engine.algorithms.remote;
77

8+
import static org.opensearch.ml.common.CommonValue.MCP_DEFAULT_SSE_ENDPOINT;
89
import static org.opensearch.ml.common.CommonValue.MCP_SYNC_CLIENT;
910
import static org.opensearch.ml.common.CommonValue.MCP_TOOLS_FIELD;
1011
import static org.opensearch.ml.common.CommonValue.MCP_TOOL_DESCRIPTION_FIELD;
1112
import static org.opensearch.ml.common.CommonValue.MCP_TOOL_INPUT_SCHEMA_FIELD;
1213
import static org.opensearch.ml.common.CommonValue.MCP_TOOL_NAME_FIELD;
14+
import static org.opensearch.ml.common.CommonValue.SSE_ENDPOINT_FILED;
1315
import static org.opensearch.ml.common.CommonValue.TOOL_INPUT_SCHEMA_FIELD;
1416
import static org.opensearch.ml.common.connector.ConnectorProtocols.MCP_SSE;
1517

@@ -73,6 +75,9 @@ public McpConnectorExecutor(Connector connector) {
7375

7476
public List<MLToolSpec> getMcpToolSpecs() {
7577
String mcpServerUrl = connector.getUrl();
78+
String sseEndpoint = connector.getParameters() != null && connector.getParameters().containsKey(SSE_ENDPOINT_FILED)
79+
? connector.getParameters().get(SSE_ENDPOINT_FILED)
80+
: MCP_DEFAULT_SSE_ENDPOINT;
7681
if (mcpServerUrl == null) {
7782
return Collections.emptyList();
7883
}
@@ -88,9 +93,14 @@ public List<MLToolSpec> getMcpToolSpecs() {
8893
};
8994

9095
// Create transport
91-
McpClientTransport transport = HttpClientSseClientTransport.builder(mcpServerUrl).customizeClient(clientBuilder -> {
92-
clientBuilder.connectTimeout(connectionTimeout);
93-
}).customizeRequest(headerConfig).build();
96+
McpClientTransport transport = HttpClientSseClientTransport
97+
.builder(mcpServerUrl)
98+
.sseEndpoint(sseEndpoint)
99+
.customizeClient(clientBuilder -> {
100+
clientBuilder.connectTimeout(connectionTimeout);
101+
})
102+
.customizeRequest(headerConfig)
103+
.build();
94104

95105
// Create and initialize client
96106
McpSyncClient client = McpClient

0 commit comments

Comments
 (0)