Skip to content

Commit f7d0f24

Browse files
committed
refactors and adds new UT
Signed-off-by: Brian Flores <iflorbri@amazon.com>
1 parent 46db551 commit f7d0f24

File tree

2 files changed

+104
-3
lines changed

2 files changed

+104
-3
lines changed

plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
import org.opensearch.search.pipeline.SearchRequestProcessor;
4747
import org.opensearch.transport.client.Client;
4848

49-
import com.google.gson.Gson;
5049
import com.jayway.jsonpath.Configuration;
5150
import com.jayway.jsonpath.JsonPath;
5251
import com.jayway.jsonpath.Option;
@@ -362,7 +361,7 @@ private String updateQueryTemplate(String queryTemplate, Map<String, String> out
362361
String modelOutputFieldName = outputMapEntry.getValue();
363362
Object modelOutputValue = getModelOutputValue(mlOutput, modelOutputFieldName, ignoreMissing, fullResponsePath);
364363
if (modelOutputValue instanceof Map) {
365-
modelOutputValue = new Gson().toJson(modelOutputValue);
364+
modelOutputValue = toJson(modelOutputValue);
366365
}
367366
valuesMap.put(newQueryField, modelOutputValue);
368367
}

plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import org.opensearch.core.action.ActionListener;
4242
import org.opensearch.core.xcontent.NamedXContentRegistry;
4343
import org.opensearch.core.xcontent.XContentParser;
44+
import org.opensearch.index.query.MatchAllQueryBuilder;
4445
import org.opensearch.index.query.QueryBuilder;
4546
import org.opensearch.index.query.QueryBuilders;
4647
import org.opensearch.index.query.RangeQueryBuilder;
@@ -1316,7 +1317,7 @@ public void onFailure(Exception e) {
13161317

13171318
/**
13181319
* Tests the successful rewriting of a complex nested array in query extension based on the model output.
1319-
* verify the pipelineConext is set from the extension
1320+
* verify the pipelineContext is set from the extension
13201321
* @throws Exception if an error occurs during the test
13211322
*/
13221323
public void testExecute_rewriteTermQueryReadAndWriteComplexNestedArrayToExtensionSuccess() throws Exception {
@@ -1612,6 +1613,107 @@ public void onFailure(Exception e) {
16121613
requestProcessor.processRequestAsync(request, requestContext, Listener);
16131614
}
16141615

1616+
/**
1617+
* Tests ML Processor can return a OpenSearch Query correctly when performing a rewrite query.
1618+
*
1619+
* This simulates a real world scenario where user has a llm return a OpenSearch Query to help them generate a new
1620+
* query based on the context given in the prompt.
1621+
*
1622+
* @throws Exception when an error occurs on the test
1623+
*/
1624+
public void testExecute_rewriteTermQueryWithNewQuerySuccess() throws Exception {
1625+
String modelInputField = "inputs";
1626+
String originalQueryField = "query.term.text.value";
1627+
String newQueryField = "llm_query";
1628+
String modelInferenceJsonPathInput = "$.inference_results[0].output[0].dataAsMap.content[0].text";
1629+
1630+
String queryTemplate = "${llm_query}";
1631+
1632+
String llmQuery = "{\n" +
1633+
" \"query\": {\n" +
1634+
" \"match_all\": {}\n" +
1635+
" }\n" +
1636+
"}";
1637+
Map content = Map.of("content", List.of(Map.of("text", llmQuery)));
1638+
1639+
1640+
List<Map<String, String>> optionalInputMap = new ArrayList<>();
1641+
Map<String, String> input = new HashMap<>();
1642+
input.put(modelInputField, originalQueryField);
1643+
optionalInputMap.add(input);
1644+
1645+
List<Map<String, String>> optionalOutputMap = new ArrayList<>();
1646+
Map<String, String> output = new HashMap<>();
1647+
output.put(newQueryField, modelInferenceJsonPathInput);
1648+
optionalOutputMap.add(output);
1649+
1650+
MLInferenceSearchRequestProcessor requestProcessor = new MLInferenceSearchRequestProcessor(
1651+
"model1",
1652+
queryTemplate,
1653+
null,
1654+
null,
1655+
optionalInputMap,
1656+
optionalOutputMap,
1657+
null,
1658+
DEFAULT_MAX_PREDICTION_TASKS,
1659+
PROCESSOR_TAG,
1660+
DESCRIPTION,
1661+
false,
1662+
"remote",
1663+
true,
1664+
false,
1665+
"{ \"parameters\": ${ml_inference.parameters} }",
1666+
client,
1667+
TEST_XCONTENT_REGISTRY_FOR_QUERY
1668+
);
1669+
1670+
/**
1671+
* {
1672+
* "inference_results" : [ {
1673+
* "output" : [ {
1674+
* "name" : "response",
1675+
* "dataAsMap" : {
1676+
* "content": [
1677+
* "text": "{\"query\": \"match_all\" : {}}"
1678+
* }
1679+
* } ]
1680+
* } ]
1681+
* }
1682+
*/
1683+
ModelTensor modelTensor = ModelTensor
1684+
.builder()
1685+
.name("response")
1686+
.dataAsMap(content)
1687+
.build();
1688+
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
1689+
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
1690+
1691+
doAnswer(invocation -> {
1692+
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
1693+
actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build());
1694+
return null;
1695+
}).when(client).execute(any(), any(), any());
1696+
1697+
QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo");
1698+
SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery);
1699+
SearchRequest sampleRequest = new SearchRequest().source(source);
1700+
1701+
ActionListener<SearchRequest> Listener = new ActionListener<>() {
1702+
@Override
1703+
public void onResponse(SearchRequest newSearchRequest) {
1704+
MatchAllQueryBuilder expectedQuery = new MatchAllQueryBuilder();
1705+
assertEquals(expectedQuery, newSearchRequest.source().query());
1706+
}
1707+
1708+
@Override
1709+
public void onFailure(Exception e) {
1710+
throw new RuntimeException("Failed in executing processRequestAsync.", e);
1711+
}
1712+
};
1713+
1714+
requestProcessor.processRequestAsync(sampleRequest, requestContext, Listener);
1715+
}
1716+
16151717
/**
16161718
* Tests when there are two optional input fields
16171719
* but only the second optional input is present in the query

0 commit comments

Comments
 (0)