Skip to content

Commit 887c90a

Browse files
authored
support customized message endpoint and addressing comments (#3810)
* support customized message endpoint and addressing comments Signed-off-by: zane-neo <zaniu@amazon.com> * fix UT failures Signed-off-by: zane-neo <zaniu@amazon.com> * add files to jacoco exception Signed-off-by: zane-neo <zaniu@amazon.com> * fix tool name issue and optimize register tool api Signed-off-by: zane-neo <zaniu@amazon.com> * fix schema not parsed correctly issue and NPE when parameters is null Signed-off-by: zane-neo <zaniu@amazon.com> * fix failure UT Signed-off-by: zane-neo <zaniu@amazon.com> --------- Signed-off-by: zane-neo <zaniu@amazon.com>
1 parent 913b033 commit 887c90a

File tree

19 files changed

+207
-38
lines changed

19 files changed

+207
-38
lines changed

common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/message/MLMcpMessageRequest.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import java.io.IOException;
1111
import java.io.UncheckedIOException;
1212

13+
import org.apache.commons.lang3.StringUtils;
1314
import org.opensearch.action.ActionRequest;
1415
import org.opensearch.action.ActionRequestValidationException;
1516
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
@@ -35,11 +36,17 @@ public MLMcpMessageRequest(StreamInput in) throws IOException {
3536
this.nodeId = in.readString();
3637
this.sessionId = in.readString();
3738
this.requestBody = in.readString();
39+
if (StringUtils.isEmpty(nodeId) || StringUtils.isEmpty(sessionId) || StringUtils.isEmpty(requestBody)) {
40+
throw new IllegalStateException("nodeId, sessionId and requestBody must not be null");
41+
}
3842
}
3943

4044
@Builder
4145
public MLMcpMessageRequest(String nodeId, String sessionId, String requestBody) {
4246
super();
47+
if (StringUtils.isEmpty(nodeId) || StringUtils.isEmpty(sessionId) || StringUtils.isEmpty(requestBody)) {
48+
throw new IllegalStateException("nodeId, sessionId and requestBody must not be null");
49+
}
4350
this.nodeId = nodeId;
4451
this.sessionId = sessionId;
4552
this.requestBody = requestBody;

common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/register/MLMcpToolsRegisterNodesRequest.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
/*
2+
*
3+
* * Copyright OpenSearch Contributors
4+
* * SPDX-License-Identifier: Apache-2.0
5+
*
6+
*/
7+
18
package org.opensearch.ml.common.transport.mcpserver.requests.register;
29

310
import java.io.ByteArrayInputStream;

common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/register/McpTool.java

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,13 @@
3131
@Data
3232
public class McpTool implements ToXContentObject, Writeable {
3333
private static final String TYPE_FIELD = "type";
34+
private static final String NAME_FIELD = "name";
3435
private static final String DESCRIPTION_FIELD = "description";
3536
private static final String PARAMS_FIELD = "parameters";
3637
private static final String ATTRIBUTES_FIELD = "attributes";
3738
public static final String SCHEMA_FIELD = "input_schema";
3839
private final String type;
40+
private String name;
3941
private final String description;
4042
private Map<String, Object> parameters;
4143
private Map<String, Object> attributes;
@@ -46,6 +48,7 @@ public McpTool(StreamInput streamInput) throws IOException {
4648
if (type == null) {
4749
throw new IllegalArgumentException(TYPE_NOT_SHOWN_EXCEPTION_MESSAGE);
4850
}
51+
name = streamInput.readOptionalString();
4952
description = streamInput.readOptionalString();
5053
if (streamInput.readBoolean()) {
5154
parameters = streamInput.readMap(StreamInput::readString, StreamInput::readGenericValue);
@@ -55,10 +58,8 @@ public McpTool(StreamInput streamInput) throws IOException {
5558
}
5659
}
5760

58-
public McpTool(String type, String description, Map<String, Object> parameters, Map<String, Object> attributes) {
59-
if (type == null) {
60-
throw new IllegalArgumentException(TYPE_NOT_SHOWN_EXCEPTION_MESSAGE);
61-
}
61+
public McpTool(String name, String type, String description, Map<String, Object> parameters, Map<String, Object> attributes) {
62+
this.name = name;
6263
this.type = type;
6364
this.description = description;
6465
this.parameters = parameters;
@@ -67,9 +68,10 @@ public McpTool(String type, String description, Map<String, Object> parameters,
6768

6869
public static McpTool parse(XContentParser parser) throws IOException {
6970
String type = null;
71+
String name = null;
7072
String description = null;
7173
Map<String, Object> params = null;
72-
Map<String, Object> schema = null;
74+
Map<String, Object> attrubutes = null;
7375
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
7476
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
7577
String fieldName = parser.currentName();
@@ -79,14 +81,17 @@ public static McpTool parse(XContentParser parser) throws IOException {
7981
case TYPE_FIELD:
8082
type = parser.text();
8183
break;
84+
case NAME_FIELD:
85+
name = parser.text();
86+
break;
8287
case DESCRIPTION_FIELD:
8388
description = parser.text();
8489
break;
8590
case PARAMS_FIELD:
8691
params = parser.map();
8792
break;
88-
case SCHEMA_FIELD:
89-
schema = parser.map();
93+
case ATTRIBUTES_FIELD:
94+
attrubutes = parser.map();
9095
break;
9196
default:
9297
parser.skipChildren();
@@ -96,12 +101,13 @@ public static McpTool parse(XContentParser parser) throws IOException {
96101
if (type == null) {
97102
throw new IllegalArgumentException(TYPE_NOT_SHOWN_EXCEPTION_MESSAGE);
98103
}
99-
return new McpTool(type, description, params, schema);
104+
return new McpTool(name, type, description, params, attrubutes);
100105
}
101106

102107
@Override
103108
public void writeTo(StreamOutput streamOutput) throws IOException {
104109
streamOutput.writeString(type);
110+
streamOutput.writeOptionalString(name);
105111
streamOutput.writeOptionalString(description);
106112
if (parameters != null) {
107113
streamOutput.writeBoolean(true);
@@ -122,6 +128,9 @@ public void writeTo(StreamOutput streamOutput) throws IOException {
122128
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params xcontentParams) throws IOException {
123129
builder.startObject();
124130
builder.field(TYPE_FIELD, type);
131+
if (name != null) {
132+
builder.field(NAME_FIELD, name);
133+
}
125134
if (description != null) {
126135
builder.field(DESCRIPTION_FIELD, description);
127136
}

common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/register/McpTools.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
/*
2+
*
3+
* * Copyright OpenSearch Contributors
4+
* * SPDX-License-Identifier: Apache-2.0
5+
*
6+
*/
7+
18
package org.opensearch.ml.common.transport.mcpserver.requests.register;
29

310
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/remove/MLMcpToolsRemoveNodeRequest.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ public class MLMcpToolsRemoveNodeRequest extends TransportRequest {
2020
private List<String> tools;
2121

2222
public MLMcpToolsRemoveNodeRequest(StreamInput in) throws IOException {
23+
super(in);
2324
if (in.readBoolean()) {
2425
this.tools = in.readList(StreamInput::readString);
2526
}
@@ -32,6 +33,7 @@ public MLMcpToolsRemoveNodeRequest(List<String> tools) {
3233

3334
@Override
3435
public void writeTo(StreamOutput out) throws IOException {
36+
super.writeTo(out);
3537
if (tools != null) {
3638
out.writeBoolean(true);
3739
out.writeStringArray(tools.toArray(new String[0]));

common/src/main/java/org/opensearch/ml/common/transport/mcpserver/responses/register/MLMcpRegisterNodesResponse.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
5656
for (FailedNodeException failedNodeException : failures()) {
5757
builder.startObject(failedNodeException.nodeId());
5858
builder.field("error");
59-
builder.value(failedNodeException.getMessage());
59+
builder.value(failedNodeException.getRootCause().getMessage());
6060
builder.endObject();
6161
}
6262
builder.endObject();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
*
3+
* * Copyright OpenSearch Contributors
4+
* * SPDX-License-Identifier: Apache-2.0
5+
*
6+
*/
7+
8+
package org.opensearch.ml.common.transport.mcpserver.requests.message;
9+
10+
import static org.junit.Assert.*;
11+
12+
import java.io.IOException;
13+
import java.io.UncheckedIOException;
14+
15+
import org.junit.Test;
16+
import org.opensearch.common.io.stream.BytesStreamOutput;
17+
import org.opensearch.core.common.io.stream.StreamInput;
18+
import org.opensearch.core.common.io.stream.StreamOutput;
19+
import org.opensearch.transport.TransportRequest;
20+
21+
public class MLMcpMessageRequestTest {
22+
23+
private final String testNodeId = "node-001";
24+
private final String testSessionId = "session-2023";
25+
private final String testRequestBody = "{ \"query\": { \"match_all\": {} } }";
26+
27+
@Test
28+
public void testBuilderPattern() {
29+
MLMcpMessageRequest request = MLMcpMessageRequest
30+
.builder()
31+
.nodeId(testNodeId)
32+
.sessionId(testSessionId)
33+
.requestBody(testRequestBody)
34+
.build();
35+
36+
assertEquals(testNodeId, request.getNodeId());
37+
assertEquals(testSessionId, request.getSessionId());
38+
assertEquals(testRequestBody, request.getRequestBody());
39+
}
40+
41+
@Test
42+
public void testStreamSerialization() throws IOException {
43+
MLMcpMessageRequest original = buildRequest();
44+
45+
BytesStreamOutput output = new BytesStreamOutput();
46+
original.writeTo(output);
47+
48+
StreamInput input = output.bytes().streamInput();
49+
MLMcpMessageRequest deserialized = new MLMcpMessageRequest(input);
50+
51+
assertEquals(original.getNodeId(), deserialized.getNodeId());
52+
assertEquals(original.getSessionId(), deserialized.getSessionId());
53+
assertEquals(original.getRequestBody(), deserialized.getRequestBody());
54+
}
55+
56+
@Test
57+
public void testFromActionRequestSameType() {
58+
MLMcpMessageRequest original = buildRequest();
59+
MLMcpMessageRequest converted = MLMcpMessageRequest.fromActionRequest(original);
60+
assertSame(original, converted);
61+
}
62+
63+
@Test
64+
public void testFromActionRequestDifferentType() throws IOException {
65+
TransportRequest transportRequest = new TransportRequest() {
66+
@Override
67+
public void writeTo(StreamOutput out) throws IOException {
68+
buildRequest().writeTo(out);
69+
}
70+
};
71+
72+
MLMcpMessageRequest converted = MLMcpMessageRequest.fromActionRequest(transportRequest);
73+
74+
assertEquals(testNodeId, converted.getNodeId());
75+
assertEquals(testRequestBody, converted.getRequestBody());
76+
}
77+
78+
@Test(expected = UncheckedIOException.class)
79+
public void testFromActionRequestIOException() {
80+
TransportRequest faultyRequest = new TransportRequest() {
81+
@Override
82+
public void writeTo(StreamOutput out) throws IOException {
83+
throw new IOException("IO exception");
84+
}
85+
};
86+
MLMcpMessageRequest.fromActionRequest(faultyRequest);
87+
}
88+
89+
@Test
90+
public void testValidationSuccess() {
91+
MLMcpMessageRequest request = buildRequest();
92+
assertNull(request.validate());
93+
}
94+
95+
@Test(expected = IllegalStateException.class)
96+
public void testEmptyFieldsHandling() throws IOException {
97+
MLMcpMessageRequest request = MLMcpMessageRequest.builder().nodeId("").sessionId("").requestBody("").build();
98+
99+
BytesStreamOutput output = new BytesStreamOutput();
100+
request.writeTo(output);
101+
new MLMcpMessageRequest(output.bytes().streamInput());
102+
}
103+
104+
private MLMcpMessageRequest buildRequest() {
105+
return MLMcpMessageRequest.builder().nodeId(testNodeId).sessionId(testSessionId).requestBody(testRequestBody).build();
106+
}
107+
108+
}

common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/register/MLMcpToolsRegisterNodeRequestTest.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ public void setUp() {
3636
Collections
3737
.singletonList(
3838
new McpTool(
39+
null,
3940
"test_tool",
4041
"Sample tool",
4142
Collections.singletonMap("param", "value"),

common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/register/MLMcpToolsRegisterNodesRequestTest.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ public class MLMcpToolsRegisterNodesRequestTest {
3131
@Before
3232
public void setup() {
3333
sampleTools = new McpTools(
34-
Arrays.asList(new McpTool("metric_analyzer", "System monitoring tool", Map.of("interval", "60s"), Map.of("type", "object"))),
34+
Arrays
35+
.asList(
36+
new McpTool(null, "metric_analyzer", "System monitoring tool", Map.of("interval", "60s"), Map.of("type", "object"))
37+
),
3538
null,
3639
null
3740
);

common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/register/McpToolTest.java

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ public class McpToolTest {
4646

4747
@Before
4848
public void setUp() {
49-
mcptool = new McpTool(toolName, description, params, schema);
49+
mcptool = new McpTool(toolName, toolName, description, params, schema);
5050
}
5151

5252
@Test
@@ -60,7 +60,7 @@ public void testConstructor_Success() {
6060
@Test
6161
public void testParse_AllFields() throws Exception {
6262
String jsonStr = "{\"type\":\"stock_tool\",\"description\":\"Stock data tool\","
63-
+ "\"parameters\":{\"exchange\":\"NYSE\"},\"input_schema\":{\"properties\":{\"symbol\":{\"type\":\"string\"}}}}";
63+
+ "\"parameters\":{\"exchange\":\"NYSE\"},\"attributes\": {\"input_schema\":{\"properties\":{\"symbol\":{\"type\":\"string\"}}}}}";
6464

6565
XContentParser parser = XContentType.JSON
6666
.xContent()
@@ -75,7 +75,7 @@ public void testParse_AllFields() throws Exception {
7575
assertEquals("stock_tool", parsed.getType());
7676
assertEquals("Stock data tool", parsed.getDescription());
7777
assertEquals(Collections.singletonMap("exchange", "NYSE"), parsed.getParameters());
78-
assertTrue(parsed.getAttributes().containsKey("properties"));
78+
assertTrue(parsed.getAttributes().containsKey("input_schema"));
7979
}
8080

8181
@Test
@@ -109,7 +109,7 @@ public void testToXContent_AllFields() throws Exception {
109109

110110
@Test
111111
public void testToXContent_MinimalFields() throws Exception {
112-
McpTool minimalTool = new McpTool("minimal_tool", null, null, null);
112+
McpTool minimalTool = new McpTool(null, "minimal_tool", null, null, null);
113113
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
114114
minimalTool.toXContent(builder, ToXContent.EMPTY_PARAMS);
115115
String jsonStr = builder.toString();
@@ -136,7 +136,7 @@ public void testStreamInputOutput_Success() throws IOException {
136136

137137
@Test
138138
public void testStreamInputOutput_WithNullFields() throws IOException {
139-
McpTool toolWithNulls = new McpTool("null_tool", null, null, null);
139+
McpTool toolWithNulls = new McpTool(null, "null_tool", null, null, null);
140140
BytesStreamOutput output = new BytesStreamOutput();
141141
toolWithNulls.writeTo(output);
142142

@@ -151,25 +151,22 @@ public void testStreamInputOutput_WithNullFields() throws IOException {
151151

152152
@Test
153153
public void testComplexParameters() throws Exception {
154-
// 测试嵌套参数结构
155154
Map<String, Object> complexParams = new HashMap<>();
156155
complexParams.put("config", Collections.singletonMap("timeout", 30));
157156

158157
Map<String, Object> complexSchema = new HashMap<>();
159158
complexSchema.put("type", "object");
160159
complexSchema.put("properties", Collections.singletonMap("location", Collections.singletonMap("type", "string")));
161160

162-
McpTool complexTool = new McpTool("complex_tool", null, complexParams, complexSchema);
161+
McpTool complexTool = new McpTool(null, "complex_tool", null, complexParams, complexSchema);
163162

164-
// 序列化测试
165163
BytesStreamOutput output = new BytesStreamOutput();
166164
complexTool.writeTo(output);
167165
McpTool parsed = new McpTool(output.bytes().streamInput());
168166

169167
assertEquals(complexParams, parsed.getParameters());
170168
assertEquals(complexSchema, parsed.getAttributes());
171169

172-
// XContent测试
173170
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
174171
complexTool.toXContent(builder, ToXContent.EMPTY_PARAMS);
175172
String jsonStr = builder.toString();

0 commit comments

Comments
 (0)