Skip to content

Commit fdbe3b4

Browse files
authored
change SearchIndexTool arguments parsing logic (#3883)
Signed-off-by: zane-neo <zaniu@amazon.com>
1 parent b57c776 commit fdbe3b4

File tree

2 files changed

+47
-11
lines changed

2 files changed

+47
-11
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
import static org.opensearch.ml.common.CommonValue.*;
99

1010
import java.io.IOException;
11-
import java.security.AccessController;
12-
import java.security.PrivilegedExceptionAction;
1311
import java.util.HashMap;
1412
import java.util.Map;
1513
import java.util.Objects;
@@ -96,7 +94,12 @@ public String getVersion() {
9694

9795
@Override
9896
public boolean validate(Map<String, String> parameters) {
99-
return parameters != null && parameters.containsKey(INPUT_FIELD) && parameters.get(INPUT_FIELD) != null;
97+
if (parameters == null || parameters.isEmpty()) {
98+
return false;
99+
}
100+
boolean argumentsFromInput = parameters.containsKey(INPUT_FIELD) && parameters.get(INPUT_FIELD) != null;
101+
boolean argumentsFromParameters = parameters.containsKey(INDEX_FIELD) && parameters.containsKey(QUERY_FIELD);
102+
return argumentsFromInput || argumentsFromParameters;
100103
}
101104

102105
private SearchRequest getSearchRequest(String index, String query) throws IOException {
@@ -120,8 +123,16 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
120123
try {
121124
String input = parameters.get(INPUT_FIELD);
122125
JsonObject jsonObject = GSON.fromJson(input, JsonObject.class);
123-
String index = Optional.ofNullable(jsonObject).map(x -> x.get(INDEX_FIELD)).map(JsonElement::getAsString).orElse(null);
124-
String query = Optional.ofNullable(jsonObject).map(x -> x.get(QUERY_FIELD)).map(JsonElement::toString).orElse(null);
126+
String index = Optional
127+
.ofNullable(jsonObject)
128+
.map(x -> x.get(INDEX_FIELD))
129+
.map(JsonElement::getAsString)
130+
.orElse(parameters.getOrDefault(INDEX_FIELD, null));
131+
String query = Optional
132+
.ofNullable(jsonObject)
133+
.map(x -> x.get(QUERY_FIELD))
134+
.map(JsonElement::toString)
135+
.orElse(parameters.getOrDefault(QUERY_FIELD, null));
125136
if (index == null || query == null) {
126137
listener.onFailure(new IllegalArgumentException("SearchIndexTool's two parameter: index and query are required!"));
127138
return;
@@ -134,10 +145,8 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
134145
if (hits != null && hits.length > 0) {
135146
StringBuilder contextBuilder = new StringBuilder();
136147
for (SearchHit hit : hits) {
137-
String doc = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> {
138-
Map<String, Object> docContent = processResponse(hit);
139-
return GSON.toJson(docContent);
140-
});
148+
Map<String, Object> docContent = processResponse(hit);
149+
String doc = GSON.toJson(docContent);
141150
contextBuilder.append(doc).append("\n");
142151
}
143152
listener.onResponse((T) contextBuilder.toString());

ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/SearchIndexToolTests.java

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,24 @@ public void testDefaultAttributes() {
9595

9696
@Test
9797
@SneakyThrows
98-
public void testValidate() {
98+
public void testValidateWithInputKey() {
9999
Map<String, String> parameters = Map.of("input", "{}");
100100
assertTrue(mockedSearchIndexTool.validate(parameters));
101101
}
102102

103+
@Test
104+
@SneakyThrows
105+
public void testValidateWithActualKeys() {
106+
Map<String, String> parameters = Map
107+
.of(
108+
SearchIndexTool.INDEX_FIELD,
109+
"test-index",
110+
SearchIndexTool.QUERY_FIELD,
111+
"{\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " }\n" + "}"
112+
);
113+
assertTrue(mockedSearchIndexTool.validate(parameters));
114+
}
115+
103116
@Test
104117
@SneakyThrows
105118
public void testValidateWithEmptyInput() {
@@ -108,14 +121,28 @@ public void testValidateWithEmptyInput() {
108121
}
109122

110123
@Test
111-
public void testRunWithNormalIndex() {
124+
@SneakyThrows
125+
public void testValidateWithNullInput() {
126+
assertFalse(mockedSearchIndexTool.validate(null));
127+
}
128+
129+
@Test
130+
public void testRunWithInputKey() {
112131
String inputString = "{\"index\": \"test-index\", \"query\": {\"query\": {\"match_all\": {}}}}";
113132
Map<String, String> parameters = Map.of("input", inputString);
114133
mockedSearchIndexTool.run(parameters, null);
115134
Mockito.verify(client, times(1)).search(any(), any());
116135
Mockito.verify(client, Mockito.never()).execute(any(), any(), any());
117136
}
118137

138+
@Test
139+
public void testRunWithActualKeys() {
140+
Map<String, String> parameters = Map.of("index", "test-index", "query", "{\"query\": {\"match_all\": {}}}");
141+
mockedSearchIndexTool.run(parameters, null);
142+
Mockito.verify(client, times(1)).search(any(), any());
143+
Mockito.verify(client, Mockito.never()).execute(any(), any(), any());
144+
}
145+
119146
@Test
120147
public void testRunWithConnectorIndex() {
121148
String inputString = "{\"index\": \".plugins-ml-connector\", \"query\": {\"query\": {\"match_all\": {}}}}";

0 commit comments

Comments
 (0)