Skip to content

Commit 46db551

Browse files
committed
add Unit Test for queryTemplate change
Signed-off-by: Brian Flores <iflorbri@amazon.com>
1 parent 5608c63 commit 46db551

File tree

2 files changed

+137
-22
lines changed

2 files changed

+137
-22
lines changed

plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,23 @@
44
*/
55
package org.opensearch.ml.processor;
66

7-
import com.google.gson.Gson;
8-
import com.jayway.jsonpath.Configuration;
9-
import com.jayway.jsonpath.JsonPath;
10-
import com.jayway.jsonpath.Option;
11-
import com.jayway.jsonpath.PathNotFoundException;
12-
import com.jayway.jsonpath.ReadContext;
13-
import lombok.Getter;
7+
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
8+
import static org.opensearch.ml.common.utils.StringUtils.toJson;
9+
import static org.opensearch.ml.processor.InferenceProcessorAttributes.INPUT_MAP;
10+
import static org.opensearch.ml.processor.InferenceProcessorAttributes.MAX_PREDICTION_TASKS;
11+
import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_CONFIG;
12+
import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_ID;
13+
import static org.opensearch.ml.processor.InferenceProcessorAttributes.OUTPUT_MAP;
14+
import static org.opensearch.ml.processor.ModelExecutor.combineMaps;
15+
16+
import java.io.IOException;
17+
import java.util.Collection;
18+
import java.util.HashMap;
19+
import java.util.HashSet;
20+
import java.util.List;
21+
import java.util.Map;
22+
import java.util.Set;
23+
1424
import org.apache.commons.text.StringSubstitutor;
1525
import org.apache.logging.log4j.LogManager;
1626
import org.apache.logging.log4j.Logger;
@@ -36,22 +46,14 @@
3646
import org.opensearch.search.pipeline.SearchRequestProcessor;
3747
import org.opensearch.transport.client.Client;
3848

39-
import java.io.IOException;
40-
import java.util.Collection;
41-
import java.util.HashMap;
42-
import java.util.HashSet;
43-
import java.util.List;
44-
import java.util.Map;
45-
import java.util.Set;
49+
import com.google.gson.Gson;
50+
import com.jayway.jsonpath.Configuration;
51+
import com.jayway.jsonpath.JsonPath;
52+
import com.jayway.jsonpath.Option;
53+
import com.jayway.jsonpath.PathNotFoundException;
54+
import com.jayway.jsonpath.ReadContext;
4655

47-
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
48-
import static org.opensearch.ml.common.utils.StringUtils.toJson;
49-
import static org.opensearch.ml.processor.InferenceProcessorAttributes.INPUT_MAP;
50-
import static org.opensearch.ml.processor.InferenceProcessorAttributes.MAX_PREDICTION_TASKS;
51-
import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_CONFIG;
52-
import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_ID;
53-
import static org.opensearch.ml.processor.InferenceProcessorAttributes.OUTPUT_MAP;
54-
import static org.opensearch.ml.processor.ModelExecutor.combineMaps;
56+
import lombok.Getter;
5557

5658
/**
5759
* MLInferenceSearchRequestProcessor requires a modelId string to call model inferences

plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,11 @@
4242
import org.opensearch.core.xcontent.NamedXContentRegistry;
4343
import org.opensearch.core.xcontent.XContentParser;
4444
import org.opensearch.index.query.QueryBuilder;
45+
import org.opensearch.index.query.QueryBuilders;
4546
import org.opensearch.index.query.RangeQueryBuilder;
4647
import org.opensearch.index.query.TermQueryBuilder;
4748
import org.opensearch.index.query.TermsQueryBuilder;
49+
import org.opensearch.index.query.functionscore.ScriptScoreQueryBuilder;
4850
import org.opensearch.ingest.Processor;
4951
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
5052
import org.opensearch.ml.common.input.MLInput;
@@ -58,6 +60,8 @@
5860
import org.opensearch.ml.searchext.MLInferenceRequestParameters;
5961
import org.opensearch.ml.searchext.MLInferenceRequestParametersExtBuilder;
6062
import org.opensearch.plugins.SearchPlugin;
63+
import org.opensearch.script.Script;
64+
import org.opensearch.script.ScriptType;
6165
import org.opensearch.search.SearchModule;
6266
import org.opensearch.search.builder.SearchSourceBuilder;
6367
import org.opensearch.search.pipeline.PipelineProcessingContext;
@@ -1499,6 +1503,115 @@ public void onFailure(Exception e) {
14991503

15001504
}
15011505

1506+
/**
1507+
* Tests ML Processor can return a sparse vector correctly when performing a rewrite query.
1508+
*
1509+
* This simulates a real world scenario where user has a neural sparse model and attempts to parse
1510+
* it by asserting FullResponsePath to true.
1511+
* @throws Exception when an error occurs on the test
1512+
*/
1513+
public void testExecute_rewriteTermQueryWithSparseVectorSuccess() throws Exception {
1514+
String modelInputField = "inputs";
1515+
String originalQueryField = "query.term.text.value";
1516+
String newQueryField = "vector";
1517+
String modelInferenceJsonPathInput = "$.inference_results[0].output[0].dataAsMap.response[0]";
1518+
1519+
String queryTemplate = "{\n"
1520+
+ " \"query\": {\n"
1521+
+ " \"script_score\": {\n"
1522+
+ " \"query\": {\n"
1523+
+ " \"match_all\": {}\n"
1524+
+ " },\n"
1525+
+ " \"script\": {\n"
1526+
+ " \"source\": \"return 1;\",\n"
1527+
+ " \"params\": {\n"
1528+
+ " \"query_tokens\": ${vector}\n"
1529+
+ " }\n"
1530+
+ " }\n"
1531+
+ " }\n"
1532+
+ " }\n"
1533+
+ "}";
1534+
1535+
Map<String, Double> sparseVector = Map.of("this", 1.3123, "which", 0.2447, "here", 0.6674);
1536+
1537+
List<Map<String, String>> optionalInputMap = new ArrayList<>();
1538+
Map<String, String> input = new HashMap<>();
1539+
input.put(modelInputField, originalQueryField);
1540+
optionalInputMap.add(input);
1541+
1542+
List<Map<String, String>> optionalOutputMap = new ArrayList<>();
1543+
Map<String, String> output = new HashMap<>();
1544+
output.put(newQueryField, modelInferenceJsonPathInput);
1545+
optionalOutputMap.add(output);
1546+
1547+
MLInferenceSearchRequestProcessor requestProcessor = new MLInferenceSearchRequestProcessor(
1548+
"model1",
1549+
queryTemplate,
1550+
null,
1551+
null,
1552+
optionalInputMap,
1553+
optionalOutputMap,
1554+
null,
1555+
DEFAULT_MAX_PREDICTION_TASKS,
1556+
PROCESSOR_TAG,
1557+
DESCRIPTION,
1558+
false,
1559+
"remote",
1560+
true,
1561+
false,
1562+
"{ \"parameters\": ${ml_inference.parameters} }",
1563+
client,
1564+
TEST_XCONTENT_REGISTRY_FOR_QUERY
1565+
);
1566+
1567+
/**
1568+
* {
1569+
* "inference_results" : [ {
1570+
* "output" : [ {
1571+
* "name" : "response",
1572+
* "dataAsMap" : {
1573+
* "response" : [ {
1574+
* "this" : 1.3123,
1575+
* "which" : 0.2447,
1576+
* "here" : 0.6674
1577+
* } ]
1578+
* }
1579+
* } ]
1580+
* } ]
1581+
* }
1582+
*/
1583+
ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(Map.of("response", List.of(sparseVector))).build();
1584+
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
1585+
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
1586+
1587+
doAnswer(invocation -> {
1588+
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
1589+
actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build());
1590+
return null;
1591+
}).when(client).execute(any(), any(), any());
1592+
1593+
QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo");
1594+
SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery);
1595+
SearchRequest request = new SearchRequest().source(source);
1596+
1597+
ActionListener<SearchRequest> Listener = new ActionListener<>() {
1598+
@Override
1599+
public void onResponse(SearchRequest newSearchRequest) {
1600+
Script script = new Script(ScriptType.INLINE, "painless", "return 1;", Map.of("query_tokens", sparseVector));
1601+
1602+
ScriptScoreQueryBuilder expectedQuery = new ScriptScoreQueryBuilder(QueryBuilders.matchAllQuery(), script);
1603+
assertEquals(expectedQuery, newSearchRequest.source().query());
1604+
}
1605+
1606+
@Override
1607+
public void onFailure(Exception e) {
1608+
throw new RuntimeException("Failed in executing processRequestAsync.", e);
1609+
}
1610+
};
1611+
1612+
requestProcessor.processRequestAsync(request, requestContext, Listener);
1613+
}
1614+
15021615
/**
15031616
* Tests when there are two optional input fields
15041617
* but only the second optional input is present in the query

0 commit comments

Comments
 (0)