Skip to content

Commit 969dc3d

Browse files
authored
Fix MCP server tool bug (opensearch-project#3912)
* Fix mcp tools list and remove bugs and add ITs Signed-off-by: zane-neo <zaniu@amazon.com> * Fix failure ITs Signed-off-by: zane-neo <zaniu@amazon.com> --------- Signed-off-by: zane-neo <zaniu@amazon.com>
1 parent a67b4da commit 969dc3d

File tree

8 files changed

+361
-4
lines changed

8 files changed

+361
-4
lines changed

common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/register/McpToolRegisterInput.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ public static McpToolRegisterInput parse(XContentParser parser) throws IOExcepti
7272
name = parser.text();
7373
break;
7474
case DESCRIPTION_FIELD:
75-
description = parser.text();
75+
description = parser.textOrNull();
7676
break;
7777
case PARAMS_FIELD:
7878
params = parser.map();

common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/remove/MLMcpToolsRemoveNodesRequest.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
import java.io.IOException;
1515
import java.io.UncheckedIOException;
1616
import java.util.ArrayList;
17+
import java.util.Arrays;
1718
import java.util.List;
1819

20+
import org.apache.commons.lang3.StringUtils;
1921
import org.opensearch.action.ActionRequest;
2022
import org.opensearch.action.ActionRequestValidationException;
2123
import org.opensearch.action.support.nodes.BaseNodesRequest;
@@ -63,7 +65,8 @@ public static MLMcpToolsRemoveNodesRequest parse(XContentParser parser, String[]
6365
List<String> tools = new ArrayList<>();
6466
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.nextToken(), parser);
6567
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
66-
tools.add(parser.text());
68+
String[] toolNames = StringUtils.split(parser.text(), ",");
69+
Arrays.stream(toolNames).forEach(x -> tools.add(StringUtils.trim(x)));
6770
}
6871
return new MLMcpToolsRemoveNodesRequest(allNodeIds, tools);
6972
}

common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/remove/MLMcpToolsRemoveNodesRequestTest.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ public void writeTo(StreamOutput out) throws IOException {
116116

117117
@Test
118118
public void testParse_AllFields() throws Exception {
119-
String jsonStr = "[\n" + " \"MyListIndexTool2\"\n" + "]";
119+
String jsonStr = "[\n" + " \"GoogleSearchTool1, GoogleSearchTool2\"\n" + "]";
120120

121121
XContentParser parser = XContentType.JSON
122122
.xContent()
@@ -127,6 +127,8 @@ public void testParse_AllFields() throws Exception {
127127
);
128128

129129
MLMcpToolsRemoveNodesRequest parsed = MLMcpToolsRemoveNodesRequest.parse(parser, new String[] { "nodeId" });
130-
assertEquals(1, parsed.getMcpTools().size());
130+
assertEquals(2, parsed.getMcpTools().size());
131+
assertEquals("GoogleSearchTool1", parsed.getMcpTools().getFirst());
132+
assertEquals("GoogleSearchTool2", parsed.getMcpTools().getLast());
131133
}
132134
}

plugin/src/main/java/org/opensearch/ml/action/mcpserver/TransportMcpToolsRegisterAction.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.opensearch.action.index.IndexRequest;
2828
import org.opensearch.action.support.ActionFilters;
2929
import org.opensearch.action.support.HandledTransportAction;
30+
import org.opensearch.action.support.WriteRequest;
3031
import org.opensearch.cluster.service.ClusterService;
3132
import org.opensearch.common.inject.Inject;
3233
import org.opensearch.common.util.concurrent.ThreadContext;
@@ -209,6 +210,7 @@ private void indexMcpTools(
209210
});
210211

211212
BulkRequest bulkRequest = new BulkRequest();
213+
bulkRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
212214
for (McpToolRegisterInput mcpTool : registerNodesRequest.getMcpTools()) {
213215
IndexRequest indexRequest = new IndexRequest(MLIndex.MCP_TOOLS.getIndexName());
214216
// Set opType to create to avoid race condition when creating tools with same name.

plugin/src/main/java/org/opensearch/ml/action/mcpserver/TransportMcpToolsRemoveAction.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.opensearch.action.delete.DeleteRequest;
2020
import org.opensearch.action.support.ActionFilters;
2121
import org.opensearch.action.support.HandledTransportAction;
22+
import org.opensearch.action.support.WriteRequest;
2223
import org.opensearch.cluster.service.ClusterService;
2324
import org.opensearch.common.inject.Inject;
2425
import org.opensearch.common.util.concurrent.ThreadContext;
@@ -163,6 +164,7 @@ private void bulkDeleteMcpTools(
163164
restoreListener.onFailure(e);
164165
});
165166
BulkRequest bulkRequest = new BulkRequest();
167+
bulkRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
166168
for (String name : foundTools) {
167169
DeleteRequest deleteRequest = new DeleteRequest(MLIndex.MCP_TOOLS.getIndexName(), name);
168170
bulkRequest.add(deleteRequest);
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Copyright 2023 Aryn
3+
* Copyright OpenSearch Contributors
4+
* SPDX-License-Identifier: Apache-2.0
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
package org.opensearch.ml.rest.mcpserver;
19+
20+
import java.io.IOException;
21+
22+
import org.apache.hc.core5.http.HttpEntity;
23+
import org.apache.hc.core5.http.HttpHeaders;
24+
import org.apache.hc.core5.http.message.BasicHeader;
25+
import org.junit.Before;
26+
import org.opensearch.client.Response;
27+
import org.opensearch.core.rest.RestStatus;
28+
import org.opensearch.ml.common.settings.MLCommonsSettings;
29+
import org.opensearch.ml.rest.MLCommonsRestTestCase;
30+
import org.opensearch.ml.utils.TestHelper;
31+
32+
import com.google.common.collect.ImmutableList;
33+
34+
public class RestMcpToolsRegisterActionIT extends MLCommonsRestTestCase {
35+
36+
@Before
37+
public void setupFeatureSettings() throws IOException {
38+
Response response = TestHelper
39+
.makeRequest(
40+
client(),
41+
"PUT",
42+
"_cluster/settings",
43+
null,
44+
"{\"persistent\":{\"" + MLCommonsSettings.ML_COMMONS_MCP_SERVER_ENABLED.getKey() + "\":true}}",
45+
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, ""))
46+
);
47+
assertEquals(200, response.getStatusLine().getStatusCode());
48+
}
49+
50+
public void testRegisterMcpTools() throws IOException {
51+
String requestBody =
52+
"""
53+
{
54+
"tools": [
55+
{
56+
"name": "ListIndexTool",
57+
"type": "ListIndexTool",
58+
"description": "initial description",
59+
"attributes": {
60+
"input_schema": {
61+
"type": "object",
62+
"properties": {
63+
"indices": {
64+
"type": "array",
65+
"items": {
66+
"type": "string"
67+
},
68+
"description": "OpenSearch index name list, separated by comma. for example: [\\"index1\\", \\"index2\\"], use empty array [] to list all indices in the cluster"
69+
}
70+
},
71+
"additionalProperties": false
72+
}
73+
}
74+
}
75+
]
76+
}
77+
""";
78+
Response response = TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/mcp/tools/_register", null, requestBody, null);
79+
assert (response != null);
80+
assert (TestHelper.restStatus(response) == RestStatus.OK);
81+
HttpEntity httpEntity = response.getEntity();
82+
String entityString = TestHelper.httpEntityToString(httpEntity);
83+
assertTrue(entityString.contains("created"));
84+
85+
Response listResponse = TestHelper.makeRequest(client(), "GET", "/_plugins/_ml/mcp/tools/_list", null, "", null);
86+
assert (listResponse != null);
87+
assert (TestHelper.restStatus(listResponse) == RestStatus.OK);
88+
HttpEntity listHttpEntity = listResponse.getEntity();
89+
String listEntityString = TestHelper.httpEntityToString(listHttpEntity);
90+
assertTrue(listEntityString.contains("ListIndexTool"));
91+
assertTrue(listEntityString.contains("initial description"));
92+
}
93+
}
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/*
2+
* Copyright 2023 Aryn
3+
* Copyright OpenSearch Contributors
4+
* SPDX-License-Identifier: Apache-2.0
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
package org.opensearch.ml.rest.mcpserver;
19+
20+
import java.io.IOException;
21+
22+
import org.apache.hc.core5.http.HttpEntity;
23+
import org.apache.hc.core5.http.HttpHeaders;
24+
import org.apache.hc.core5.http.message.BasicHeader;
25+
import org.junit.Before;
26+
import org.opensearch.client.Response;
27+
import org.opensearch.core.rest.RestStatus;
28+
import org.opensearch.ml.common.settings.MLCommonsSettings;
29+
import org.opensearch.ml.rest.MLCommonsRestTestCase;
30+
import org.opensearch.ml.utils.TestHelper;
31+
32+
import com.google.common.collect.ImmutableList;
33+
34+
public class RestMcpToolsRemoveActionIT extends MLCommonsRestTestCase {
35+
36+
@Before
37+
public void setupFeatureSettings() throws IOException {
38+
Response response = TestHelper
39+
.makeRequest(
40+
client(),
41+
"PUT",
42+
"_cluster/settings",
43+
null,
44+
"{\"persistent\":{\"" + MLCommonsSettings.ML_COMMONS_MCP_SERVER_ENABLED.getKey() + "\":true}}",
45+
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, ""))
46+
);
47+
assertEquals(200, response.getStatusLine().getStatusCode());
48+
}
49+
50+
public void testRemoveMcpTools() throws IOException {
51+
String registerRequestBody =
52+
"""
53+
{
54+
"tools": [
55+
{
56+
"name": "ListIndexTool1",
57+
"type": "ListIndexTool",
58+
"description": "initial description",
59+
"attributes": {
60+
"input_schema": {
61+
"type": "object",
62+
"properties": {
63+
"indices": {
64+
"type": "array",
65+
"items": {
66+
"type": "string"
67+
},
68+
"description": "OpenSearch index name list, separated by comma. for example: [\\"index1\\", \\"index2\\"], use empty array [] to list all indices in the cluster"
69+
}
70+
},
71+
"additionalProperties": false
72+
}
73+
}
74+
},
75+
{
76+
"name": "ListIndexTool2",
77+
"type": "ListIndexTool",
78+
"description": "initial description",
79+
"attributes": {
80+
"input_schema": {
81+
"type": "object",
82+
"properties": {
83+
"indices": {
84+
"type": "array",
85+
"items": {
86+
"type": "string"
87+
},
88+
"description": "OpenSearch index name list, separated by comma. for example: [\\"index1\\", \\"index2\\"], use empty array [] to list all indices in the cluster"
89+
}
90+
},
91+
"additionalProperties": false
92+
}
93+
}
94+
}
95+
]
96+
}
97+
""";
98+
Response registerResponse = TestHelper
99+
.makeRequest(client(), "POST", "/_plugins/_ml/mcp/tools/_register", null, registerRequestBody, null);
100+
assert (registerResponse != null);
101+
assert (TestHelper.restStatus(registerResponse) == RestStatus.OK);
102+
HttpEntity registerResponseEntity = registerResponse.getEntity();
103+
String registerResString = TestHelper.httpEntityToString(registerResponseEntity);
104+
assertTrue(registerResString.contains("created"));
105+
106+
String removeRequestBody = """
107+
[
108+
"ListIndexTool1, ListIndexTool2"
109+
]
110+
""";
111+
Response removeResponse = TestHelper
112+
.makeRequest(client(), "POST", "/_plugins/_ml/mcp/tools/_remove", null, removeRequestBody, null);
113+
assert (removeResponse != null);
114+
assert (TestHelper.restStatus(removeResponse) == RestStatus.OK);
115+
HttpEntity removeResponseEntity = removeResponse.getEntity();
116+
String removeResString = TestHelper.httpEntityToString(removeResponseEntity);
117+
assertTrue(removeResString.contains("removed"));
118+
119+
Response listResponse = TestHelper.makeRequest(client(), "GET", "/_plugins/_ml/mcp/tools/_list", null, "", null);
120+
assert (listResponse != null);
121+
assert (TestHelper.restStatus(listResponse) == RestStatus.OK);
122+
HttpEntity listHttpEntity = listResponse.getEntity();
123+
String listEntityString = TestHelper.httpEntityToString(listHttpEntity);
124+
assertFalse(listEntityString.contains("ListIndexTool1"));
125+
assertFalse(listEntityString.contains("ListIndexTool2"));
126+
}
127+
}

0 commit comments

Comments
 (0)