Skip to content

Commit 91d8cf1

Browse files
brianf-awsakolarkunnu
authored andcommitted
Adds Json Parsing to nested object during update Query step in ML Inference Request processor (opensearch-project#3856)
* adds json parsing to modelOutputValue during query rewrite Signed-off-by: Brian Flores <iflorbri@amazon.com> * add Unit Test for queryTemplate change Signed-off-by: Brian Flores <iflorbri@amazon.com> * refactors and adds new UT Signed-off-by: Brian Flores <iflorbri@amazon.com> * apply spotless Signed-off-by: Brian Flores <iflorbri@amazon.com> * replace tab with spaces on multi-line comment Signed-off-by: Brian Flores <iflorbri@amazon.com> --------- Signed-off-by: Brian Flores <iflorbri@amazon.com> Signed-off-by: Abdul Muneer Kolarkunnu <muneer.kolarkunnu@netapp.com>
1 parent 6651fb6 commit 91d8cf1

File tree

2 files changed

+210
-1
lines changed

2 files changed

+210
-1
lines changed

plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,9 @@ private String updateQueryTemplate(String queryTemplate, Map<String, String> out
360360
String newQueryField = outputMapEntry.getKey();
361361
String modelOutputFieldName = outputMapEntry.getValue();
362362
Object modelOutputValue = getModelOutputValue(mlOutput, modelOutputFieldName, ignoreMissing, fullResponsePath);
363+
if (modelOutputValue instanceof Map) {
364+
modelOutputValue = toJson(modelOutputValue);
365+
}
363366
valuesMap.put(newQueryField, modelOutputValue);
364367
}
365368
StringSubstitutor sub = new StringSubstitutor(valuesMap);

plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java

Lines changed: 207 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,13 @@
4141
import org.opensearch.core.action.ActionListener;
4242
import org.opensearch.core.xcontent.NamedXContentRegistry;
4343
import org.opensearch.core.xcontent.XContentParser;
44+
import org.opensearch.index.query.MatchAllQueryBuilder;
4445
import org.opensearch.index.query.QueryBuilder;
46+
import org.opensearch.index.query.QueryBuilders;
4547
import org.opensearch.index.query.RangeQueryBuilder;
4648
import org.opensearch.index.query.TermQueryBuilder;
4749
import org.opensearch.index.query.TermsQueryBuilder;
50+
import org.opensearch.index.query.functionscore.ScriptScoreQueryBuilder;
4851
import org.opensearch.ingest.Processor;
4952
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
5053
import org.opensearch.ml.common.input.MLInput;
@@ -58,6 +61,8 @@
5861
import org.opensearch.ml.searchext.MLInferenceRequestParameters;
5962
import org.opensearch.ml.searchext.MLInferenceRequestParametersExtBuilder;
6063
import org.opensearch.plugins.SearchPlugin;
64+
import org.opensearch.script.Script;
65+
import org.opensearch.script.ScriptType;
6166
import org.opensearch.search.SearchModule;
6267
import org.opensearch.search.builder.SearchSourceBuilder;
6368
import org.opensearch.search.pipeline.PipelineProcessingContext;
@@ -1312,7 +1317,7 @@ public void onFailure(Exception e) {
13121317

13131318
/**
13141319
* Tests the successful rewriting of a complex nested array in query extension based on the model output.
1315-
* verify the pipelineConext is set from the extension
1320+
* verify the pipelineContext is set from the extension
13161321
* @throws Exception if an error occurs during the test
13171322
*/
13181323
public void testExecute_rewriteTermQueryReadAndWriteComplexNestedArrayToExtensionSuccess() throws Exception {
@@ -1499,6 +1504,207 @@ public void onFailure(Exception e) {
14991504

15001505
}
15011506

1507+
/**
1508+
* Tests ML Processor can return a sparse vector correctly when performing a rewrite query.
1509+
*
1510+
* This simulates a real world scenario where user has a neural sparse model and attempts to parse
1511+
* it by asserting FullResponsePath to true.
1512+
* @throws Exception when an error occurs on the test
1513+
*/
1514+
public void testExecute_rewriteTermQueryWithSparseVectorSuccess() throws Exception {
1515+
String modelInputField = "inputs";
1516+
String originalQueryField = "query.term.text.value";
1517+
String newQueryField = "vector";
1518+
String modelInferenceJsonPathInput = "$.inference_results[0].output[0].dataAsMap.response[0]";
1519+
1520+
String queryTemplate = "{\n"
1521+
+ " \"query\": {\n"
1522+
+ " \"script_score\": {\n"
1523+
+ " \"query\": {\n"
1524+
+ " \"match_all\": {}\n"
1525+
+ " },\n"
1526+
+ " \"script\": {\n"
1527+
+ " \"source\": \"return 1;\",\n"
1528+
+ " \"params\": {\n"
1529+
+ " \"query_tokens\": ${vector}\n"
1530+
+ " }\n"
1531+
+ " }\n"
1532+
+ " }\n"
1533+
+ " }\n"
1534+
+ "}";
1535+
1536+
Map<String, Double> sparseVector = Map.of("this", 1.3123, "which", 0.2447, "here", 0.6674);
1537+
1538+
List<Map<String, String>> optionalInputMap = new ArrayList<>();
1539+
Map<String, String> input = new HashMap<>();
1540+
input.put(modelInputField, originalQueryField);
1541+
optionalInputMap.add(input);
1542+
1543+
List<Map<String, String>> optionalOutputMap = new ArrayList<>();
1544+
Map<String, String> output = new HashMap<>();
1545+
output.put(newQueryField, modelInferenceJsonPathInput);
1546+
optionalOutputMap.add(output);
1547+
1548+
MLInferenceSearchRequestProcessor requestProcessor = new MLInferenceSearchRequestProcessor(
1549+
"model1",
1550+
queryTemplate,
1551+
null,
1552+
null,
1553+
optionalInputMap,
1554+
optionalOutputMap,
1555+
null,
1556+
DEFAULT_MAX_PREDICTION_TASKS,
1557+
PROCESSOR_TAG,
1558+
DESCRIPTION,
1559+
false,
1560+
"remote",
1561+
true,
1562+
false,
1563+
"{ \"parameters\": ${ml_inference.parameters} }",
1564+
client,
1565+
TEST_XCONTENT_REGISTRY_FOR_QUERY
1566+
);
1567+
1568+
/**
1569+
* {
1570+
* "inference_results" : [ {
1571+
* "output" : [ {
1572+
* "name" : "response",
1573+
* "dataAsMap" : {
1574+
* "response" : [ {
1575+
* "this" : 1.3123,
1576+
* "which" : 0.2447,
1577+
* "here" : 0.6674
1578+
* } ]
1579+
* }
1580+
* } ]
1581+
* } ]
1582+
* }
1583+
*/
1584+
ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(Map.of("response", List.of(sparseVector))).build();
1585+
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
1586+
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
1587+
1588+
doAnswer(invocation -> {
1589+
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
1590+
actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build());
1591+
return null;
1592+
}).when(client).execute(any(), any(), any());
1593+
1594+
QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo");
1595+
SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery);
1596+
SearchRequest request = new SearchRequest().source(source);
1597+
1598+
ActionListener<SearchRequest> Listener = new ActionListener<>() {
1599+
@Override
1600+
public void onResponse(SearchRequest newSearchRequest) {
1601+
Script script = new Script(ScriptType.INLINE, "painless", "return 1;", Map.of("query_tokens", sparseVector));
1602+
1603+
ScriptScoreQueryBuilder expectedQuery = new ScriptScoreQueryBuilder(QueryBuilders.matchAllQuery(), script);
1604+
assertEquals(expectedQuery, newSearchRequest.source().query());
1605+
}
1606+
1607+
@Override
1608+
public void onFailure(Exception e) {
1609+
throw new RuntimeException("Failed in executing processRequestAsync.", e);
1610+
}
1611+
};
1612+
1613+
requestProcessor.processRequestAsync(request, requestContext, Listener);
1614+
}
1615+
1616+
/**
1617+
* Tests ML Processor can return a OpenSearch Query correctly when performing a rewrite query.
1618+
*
1619+
* This simulates a real world scenario where user has a llm return a OpenSearch Query to help them generate a new
1620+
* query based on the context given in the prompt.
1621+
*
1622+
* @throws Exception when an error occurs on the test
1623+
*/
1624+
public void testExecute_rewriteTermQueryWithNewQuerySuccess() throws Exception {
1625+
String modelInputField = "inputs";
1626+
String originalQueryField = "query.term.text.value";
1627+
String newQueryField = "llm_query";
1628+
String modelInferenceJsonPathInput = "$.inference_results[0].output[0].dataAsMap.content[0].text";
1629+
1630+
String queryTemplate = "${llm_query}";
1631+
1632+
String llmQuery = "{\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " }\n" + "}";
1633+
Map content = Map.of("content", List.of(Map.of("text", llmQuery)));
1634+
1635+
List<Map<String, String>> optionalInputMap = new ArrayList<>();
1636+
Map<String, String> input = new HashMap<>();
1637+
input.put(modelInputField, originalQueryField);
1638+
optionalInputMap.add(input);
1639+
1640+
List<Map<String, String>> optionalOutputMap = new ArrayList<>();
1641+
Map<String, String> output = new HashMap<>();
1642+
output.put(newQueryField, modelInferenceJsonPathInput);
1643+
optionalOutputMap.add(output);
1644+
1645+
MLInferenceSearchRequestProcessor requestProcessor = new MLInferenceSearchRequestProcessor(
1646+
"model1",
1647+
queryTemplate,
1648+
null,
1649+
null,
1650+
optionalInputMap,
1651+
optionalOutputMap,
1652+
null,
1653+
DEFAULT_MAX_PREDICTION_TASKS,
1654+
PROCESSOR_TAG,
1655+
DESCRIPTION,
1656+
false,
1657+
"remote",
1658+
true,
1659+
false,
1660+
"{ \"parameters\": ${ml_inference.parameters} }",
1661+
client,
1662+
TEST_XCONTENT_REGISTRY_FOR_QUERY
1663+
);
1664+
1665+
/*
1666+
* {
1667+
* "inference_results" : [ {
1668+
* "output" : [ {
1669+
* "name" : "response",
1670+
* "dataAsMap" : {
1671+
* "content": [
1672+
* "text": "{\"query\": \"match_all\" : {}}"
1673+
* }
1674+
* } ]
1675+
* } ]
1676+
* }
1677+
*/
1678+
ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(content).build();
1679+
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
1680+
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
1681+
1682+
doAnswer(invocation -> {
1683+
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
1684+
actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build());
1685+
return null;
1686+
}).when(client).execute(any(), any(), any());
1687+
1688+
QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo");
1689+
SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery);
1690+
SearchRequest sampleRequest = new SearchRequest().source(source);
1691+
1692+
ActionListener<SearchRequest> Listener = new ActionListener<>() {
1693+
@Override
1694+
public void onResponse(SearchRequest newSearchRequest) {
1695+
MatchAllQueryBuilder expectedQuery = new MatchAllQueryBuilder();
1696+
assertEquals(expectedQuery, newSearchRequest.source().query());
1697+
}
1698+
1699+
@Override
1700+
public void onFailure(Exception e) {
1701+
throw new RuntimeException("Failed in executing processRequestAsync.", e);
1702+
}
1703+
};
1704+
1705+
requestProcessor.processRequestAsync(sampleRequest, requestContext, Listener);
1706+
}
1707+
15021708
/**
15031709
* Tests when there are two optional input fields
15041710
* but only the second optional input is present in the query

0 commit comments

Comments
 (0)