Skip to content

Commit b53a7e5

Browse files
Onboard MCP (#3721)
* Onboard MCP Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> * Handle failing tests Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> * Fix empty resources issue Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> * Address Comments Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> * Add Authorization support via headers Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> * Add MCP tools to plan and execute agent Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> * Resolve failing test by using correct constant Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> * Address comments Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> * Address comments Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> * Apply spotless Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> --------- Signed-off-by: rithin-pullela-aws <rithinp@amazon.com>
1 parent 9ffe7ef commit b53a7e5

31 files changed

+1273
-47
lines changed

common/src/main/java/org/opensearch/ml/common/CommonValue.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,31 @@ public class CommonValue {
7070
public static final Version VERSION_2_18_0 = Version.fromString("2.18.0");
7171
public static final Version VERSION_2_19_0 = Version.fromString("2.19.0");
7272
public static final Version VERSION_3_0_0 = Version.fromString("3.0.0");
73+
74+
// Connector Constants
75+
public static final String NAME_FIELD = "name";
76+
public static final String VERSION_FIELD = "version";
77+
public static final String DESCRIPTION_FIELD = "description";
78+
public static final String PROTOCOL_FIELD = "protocol";
79+
public static final String CREDENTIAL_FIELD = "credential";
80+
public static final String PARAMETERS_FIELD = "parameters";
81+
public static final String CREATED_TIME_FIELD = "created_time";
82+
public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time";
83+
public static final String BACKEND_ROLES_FIELD = "backend_roles";
84+
public static final String OWNER_FIELD = "owner";
85+
public static final String ACCESS_FIELD = "access";
86+
public static final String CLIENT_CONFIG_FIELD = "client_config";
87+
public static final String URL_FIELD = "url";
88+
public static final String HEADERS_FIELD = "headers";
89+
90+
// MCP Constants
91+
public static final String MCP_TOOL_NAME_FIELD = "name";
92+
public static final String MCP_TOOL_DESCRIPTION_FIELD = "description";
93+
public static final String MCP_TOOL_INPUT_SCHEMA_FIELD = "inputSchema";
94+
public static final String MCP_SYNC_CLIENT = "mcp_sync_client";
95+
public static final String MCP_EXECUTOR_SERVICE = "mcp_executor_service";
96+
public static final String MCP_TOOLS_FIELD = "tools";
97+
98+
// TOOL Constants
99+
public static final String TOOL_INPUT_SCHEMA_FIELD = "input_schema";
73100
}

common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
1313

1414
import java.io.IOException;
15+
import java.util.HashMap;
1516
import java.util.Map;
1617

1718
import org.opensearch.Version;
@@ -39,6 +40,7 @@ public class MLToolSpec implements ToXContentObject {
3940
public static final String ATTRIBUTES_FIELD = "attributes";
4041
public static final String INCLUDE_OUTPUT_IN_AGENT_RESPONSE = "include_output_in_agent_response";
4142
public static final String CONFIG_FIELD = "config";
43+
public static final String RUN_TIME_RESOURCES_FIELD = "runtime_resources";
4244

4345
private String type;
4446
private String name;
@@ -49,6 +51,7 @@ public class MLToolSpec implements ToXContentObject {
4951
private Map<String, String> configMap;
5052
@Setter
5153
private String tenantId;
54+
private Map<String, Object> runtimeResources;
5255

5356
@Builder(toBuilder = true)
5457
public MLToolSpec(
@@ -59,7 +62,8 @@ public MLToolSpec(
5962
Map<String, String> attributes,
6063
boolean includeOutputInAgentResponse,
6164
Map<String, String> configMap,
62-
String tenantId
65+
String tenantId,
66+
Map<String, Object> runtimeResources
6367
) {
6468
if (type == null) {
6569
throw new IllegalArgumentException("tool type is null");
@@ -72,6 +76,7 @@ public MLToolSpec(
7276
this.includeOutputInAgentResponse = includeOutputInAgentResponse;
7377
this.configMap = configMap;
7478
this.tenantId = tenantId;
79+
this.runtimeResources = runtimeResources;
7580
}
7681

7782
public MLToolSpec(StreamInput input) throws IOException {
@@ -87,8 +92,13 @@ public MLToolSpec(StreamInput input) throws IOException {
8792
configMap = input.readMap(StreamInput::readString, StreamInput::readOptionalString);
8893
}
8994
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
90-
if (input.getVersion().onOrAfter(VERSION_3_0_0) && input.available() > 0 && input.readBoolean()) {
91-
attributes = input.readMap(StreamInput::readString, StreamInput::readOptionalString);
95+
if (input.getVersion().onOrAfter(VERSION_3_0_0)) {
96+
if (input.available() > 0 && input.readBoolean()) {
97+
attributes = input.readMap(StreamInput::readString, StreamInput::readOptionalString);
98+
}
99+
if (input.available() > 0 && input.readBoolean()) {
100+
runtimeResources = input.readMap(StreamInput::readString, StreamInput::readGenericValue);
101+
}
92102
}
93103
}
94104

@@ -122,6 +132,12 @@ public void writeTo(StreamOutput out) throws IOException {
122132
} else {
123133
out.writeBoolean(false);
124134
}
135+
if (runtimeResources != null && !runtimeResources.isEmpty()) {
136+
out.writeBoolean(true);
137+
out.writeMap(runtimeResources, StreamOutput::writeString, StreamOutput::writeGenericValue);
138+
} else {
139+
out.writeBoolean(false);
140+
}
125141
}
126142
}
127143

@@ -150,6 +166,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
150166
if (tenantId != null) {
151167
builder.field(TENANT_ID_FIELD, tenantId);
152168
}
169+
if (runtimeResources != null && !runtimeResources.isEmpty()) {
170+
builder.field(RUN_TIME_RESOURCES_FIELD, runtimeResources);
171+
}
153172
builder.endObject();
154173
return builder;
155174
}
@@ -163,6 +182,7 @@ public static MLToolSpec parse(XContentParser parser) throws IOException {
163182
boolean includeOutputInAgentResponse = false;
164183
Map<String, String> configMap = null;
165184
String tenantId = null;
185+
Map<String, Object> runtimeResources = null;
166186

167187
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
168188
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -194,6 +214,8 @@ public static MLToolSpec parse(XContentParser parser) throws IOException {
194214
case TENANT_ID_FIELD:
195215
tenantId = parser.textOrNull();
196216
break;
217+
case RUN_TIME_RESOURCES_FIELD:
218+
runtimeResources = parser.map();
197219
default:
198220
parser.skipChildren();
199221
break;
@@ -209,10 +231,22 @@ public static MLToolSpec parse(XContentParser parser) throws IOException {
209231
.includeOutputInAgentResponse(includeOutputInAgentResponse)
210232
.configMap(configMap)
211233
.tenantId(tenantId)
234+
.runtimeResources(runtimeResources)
212235
.build();
213236
}
214237

215238
public static MLToolSpec fromStream(StreamInput in) throws IOException {
216239
return new MLToolSpec(in);
217240
}
241+
242+
public void addRuntimeResource(String key, Object value) {
243+
if (this.runtimeResources == null) {
244+
this.runtimeResources = new HashMap<>();
245+
}
246+
this.runtimeResources.put(key, value);
247+
}
248+
249+
public Object getRuntimeResource(String key) {
250+
return this.runtimeResources.get(key);
251+
}
218252
}

common/src/main/java/org/opensearch/ml/common/connector/ConnectorProtocols.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ public class ConnectorProtocols {
1212

1313
public static final String HTTP = "http";
1414
public static final String AWS_SIGV4 = "aws_sigv4";
15+
public static final String MCP_SSE = "mcp_sse";
1516

16-
public static final List<String> VALID_PROTOCOLS = Arrays.asList(AWS_SIGV4, HTTP);
17+
public static final List<String> VALID_PROTOCOLS = Arrays.asList(AWS_SIGV4, HTTP, MCP_SSE);
1718

1819
public static void validateProtocol(String protocol) {
1920
if (protocol == null) {

0 commit comments

Comments
 (0)