Skip to content

Commit 91a28c4

Browse files
[Bug] ListTools call does not return tool attributes (#3785) (#3790)
(cherry picked from commit ed4f09f) Co-authored-by: Pavan Yekbote <pybot@amazon.com>
1 parent 6a1c7eb commit 91a28c4

File tree

10 files changed

+70
-4
lines changed

10 files changed

+70
-4
lines changed

common/src/main/java/org/opensearch/ml/common/ToolMetadata.java

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
88

99
import java.io.IOException;
10+
import java.util.Map;
1011

12+
import org.opensearch.Version;
1113
import org.opensearch.core.common.io.stream.StreamInput;
1214
import org.opensearch.core.common.io.stream.StreamOutput;
1315
import org.opensearch.core.common.io.stream.Writeable;
@@ -25,6 +27,9 @@ public class ToolMetadata implements ToXContentObject, Writeable {
2527
public static final String TOOL_DESCRIPTION_FIELD = "description";
2628
public static final String TOOL_TYPE_FIELD = "type";
2729
public static final String TOOL_VERSION_FIELD = "version";
30+
public static final String TOOL_ATTRIBUTES_FIELD = "attributes";
31+
32+
private static final Version MINIMUM_VERSION_FOR_TOOL_ATTRIBUTES = Version.V_3_0_0;
2833

2934
@Getter
3035
private String name;
@@ -34,27 +39,43 @@ public class ToolMetadata implements ToXContentObject, Writeable {
3439
private String type;
3540
@Getter
3641
private String version;
42+
@Getter
43+
private Map<String, Object> attributes;
3744

3845
@Builder(toBuilder = true)
39-
public ToolMetadata(String name, String description, String type, String version) {
46+
public ToolMetadata(String name, String description, String type, String version, Map<String, Object> attributes) {
4047
this.name = name;
4148
this.description = description;
4249
this.type = type;
4350
this.version = version;
51+
this.attributes = attributes;
4452
}
4553

4654
public ToolMetadata(StreamInput input) throws IOException {
55+
Version byteStreamVersion = input.getVersion();
4756
name = input.readString();
4857
description = input.readString();
4958
type = input.readString();
5059
version = input.readOptionalString();
60+
if (byteStreamVersion.onOrAfter(MINIMUM_VERSION_FOR_TOOL_ATTRIBUTES) && input.readBoolean()) {
61+
attributes = input.readMap(StreamInput::readString, StreamInput::readGenericValue);
62+
}
5163
}
5264

5365
public void writeTo(StreamOutput output) throws IOException {
66+
Version byteStreamVersion = output.getVersion();
5467
output.writeString(name);
5568
output.writeString(description);
5669
output.writeString(type);
5770
output.writeOptionalString(version);
71+
if (byteStreamVersion.onOrAfter(MINIMUM_VERSION_FOR_TOOL_ATTRIBUTES)) {
72+
if (attributes != null) {
73+
output.writeBoolean(true);
74+
output.writeMap(attributes, StreamOutput::writeString, StreamOutput::writeGenericValue);
75+
} else {
76+
output.writeBoolean(false);
77+
}
78+
}
5879
}
5980

6081
@Override
@@ -70,6 +91,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
7091
builder.field(TOOL_TYPE_FIELD, type);
7192
}
7293
builder.field(TOOL_VERSION_FIELD, version != null ? version : "undefined");
94+
if (attributes != null) {
95+
builder.field(TOOL_ATTRIBUTES_FIELD, attributes);
96+
}
7397
builder.endObject();
7498
return builder;
7599
}
@@ -79,6 +103,7 @@ public static ToolMetadata parse(XContentParser parser) throws IOException {
79103
String description = null;
80104
String type = null;
81105
String version = null;
106+
Map<String, Object> attributes = null;
82107

83108
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
84109
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -97,12 +122,14 @@ public static ToolMetadata parse(XContentParser parser) throws IOException {
97122
break;
98123
case TOOL_VERSION_FIELD:
99124
version = parser.text();
125+
case TOOL_ATTRIBUTES_FIELD:
126+
attributes = parser.map();
100127
default:
101128
parser.skipChildren();
102129
break;
103130
}
104131
}
105-
return ToolMetadata.builder().name(name).description(description).type(type).version(version).build();
132+
return ToolMetadata.builder().name(name).description(description).type(type).version(version).attributes(attributes).build();
106133
}
107134

108135
public static ToolMetadata fromStream(StreamInput in) throws IOException {

common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolsListResponse.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Pa
5555
xContentBuilder.field(ToolMetadata.TOOL_TYPE_FIELD, toolMetadata.getType());
5656
xContentBuilder
5757
.field(ToolMetadata.TOOL_VERSION_FIELD, toolMetadata.getVersion() != null ? toolMetadata.getVersion() : "undefined");
58+
if (toolMetadata.getAttributes() != null) {
59+
xContentBuilder.field(ToolMetadata.TOOL_ATTRIBUTES_FIELD, toolMetadata.getAttributes());
60+
}
5861
xContentBuilder.endObject();
5962
}
6063
xContentBuilder.endArray();

common/src/test/java/org/opensearch/ml/common/ToolMetadataTests.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ public void setUp() {
3434
.description("Use this tool to calculate any math problem.")
3535
.type("MathTool")
3636
.version("test")
37+
.attributes(null)
3738
.build();
3839

3940
function = parser -> {
@@ -51,8 +52,8 @@ public void toXContent() throws IOException {
5152
toolMetadata.toXContent(builder, EMPTY_PARAMS);
5253
String toolMetadataString = TestHelper.xContentBuilderToString(builder);
5354
assertEquals(
54-
toolMetadataString,
55-
"{\"name\":\"MathTool\",\"description\":\"Use this tool to calculate any math problem.\",\"type\":\"MathTool\",\"version\":\"test\"}"
55+
"{\"name\":\"MathTool\",\"description\":\"Use this tool to calculate any math problem.\",\"type\":\"MathTool\",\"version\":\"test\"}",
56+
toolMetadataString
5657
);
5758
}
5859

@@ -92,5 +93,6 @@ private void readInputStream(ToolMetadata toolMetadata) throws IOException {
9293
assertEquals(toolMetadata.getDescription(), parsedToolMetadata.getDescription());
9394
assertEquals(toolMetadata.getType(), parsedToolMetadata.getType());
9495
assertEquals(toolMetadata.getVersion(), parsedToolMetadata.getVersion());
96+
assertEquals(toolMetadata.getAttributes(), parsedToolMetadata.getAttributes());
9597
}
9698
}

common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolGetResponseTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ public void setUp() {
3333
.description("Use this tool to calculate any math problem.")
3434
.type("MathTool")
3535
.version(null)
36+
.attributes(null)
3637
.build();
3738

3839
mlToolGetResponse = MLToolGetResponse.builder().toolMetadata(toolMetadata).build();

common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolsListResponseTests.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,15 @@ public void setUp() {
3737
.description("Useful when you need to use this tool to search general knowledge on wikipedia.")
3838
.type("SearchWikipediaTool")
3939
.version(null)
40+
.attributes(null)
4041
.build();
4142
ToolMetadata toolMetadata = ToolMetadata
4243
.builder()
4344
.name("MathTool")
4445
.description("Use this tool to calculate any math problem.")
4546
.type("MathTool")
4647
.version("test")
48+
.attributes(null)
4749
.build();
4850

4951
toolMetadataList.add(searchWikipediaTool);

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/IndexMappingTool.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ public class IndexMappingTool implements Tool {
5353
+ "\"items\":{\"type\":\"string\"}}},"
5454
+ "\"required\":[\"index\"],"
5555
+ "\"additionalProperties\":false}";
56+
public static final Map<String, Object> DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA, STRICT_FIELD, true);
5657

5758
@Setter
5859
@Getter
@@ -228,5 +229,10 @@ public String getDefaultType() {
228229
public String getDefaultVersion() {
229230
return null;
230231
}
232+
233+
@Override
234+
public Map<String, Object> getDefaultAttributes() {
235+
return DEFAULT_ATTRIBUTES;
236+
}
231237
}
232238
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ListIndexTool.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ public class ListIndexTool implements Tool {
8585
+ "\"description\":\"OpenSearch index name list, separated by comma. "
8686
+ "for example: [\\\"index1\\\", \\\"index2\\\"], use empty array [] to list all indices in the cluster\"}},"
8787
+ "\"additionalProperties\":false}";
88+
public static final Map<String, Object> DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA, STRICT_FIELD, false);
8889

8990
@Setter
9091
@Getter
@@ -478,6 +479,11 @@ public String getDefaultType() {
478479
public String getDefaultVersion() {
479480
return null;
480481
}
482+
483+
@Override
484+
public Map<String, Object> getDefaultAttributes() {
485+
return DEFAULT_ATTRIBUTES;
486+
}
481487
}
482488

483489
private Table getTableWithHeader() {

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ public class SearchIndexTool implements Tool {
6565

6666
private static final Gson GSON = new GsonBuilder().serializeSpecialFloatingPointValues().create();
6767

68+
public static final Map<String, Object> DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA, STRICT_FIELD, false);
69+
6870
private String name = TYPE;
6971
private Map<String, Object> attributes;
7072
private String description = DEFAULT_DESCRIPTION;
@@ -211,5 +213,10 @@ public String getDefaultType() {
211213
public String getDefaultVersion() {
212214
return null;
213215
}
216+
217+
@Override
218+
public Map<String, Object> getDefaultAttributes() {
219+
return DEFAULT_ATTRIBUTES;
220+
}
214221
}
215222
}

plugin/src/main/java/org/opensearch/ml/rest/RestMLListToolsAction.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ MLToolsListRequest getRequest(RestRequest request) throws IOException {
7676
.description(value.getDefaultDescription())
7777
.type(value.getDefaultType())
7878
.version(value.getDefaultVersion())
79+
.attributes(value.getDefaultAttributes())
7980
.build()
8081
)
8182
);

spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
package org.opensearch.ml.common.spi.tools;
77

88
import org.opensearch.core.action.ActionListener;
9+
10+
import java.util.Collections;
11+
import java.util.HashMap;
912
import java.util.Map;
1013

1114
/**
@@ -129,5 +132,13 @@ interface Factory<T extends Tool> {
129132
* @return the default tool version
130133
*/
131134
String getDefaultVersion();
135+
136+
/**
137+
* Get the default attributes of this tool
138+
* @return the default attributes
139+
*/
140+
default Map<String, Object> getDefaultAttributes() {
141+
return null;
142+
}
132143
}
133144
}

0 commit comments

Comments
 (0)