Skip to content

[Backport 3.0] Adds Json Parsing to nested object during update Query step in ML Inference Request processor #3867

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,9 @@ private String updateQueryTemplate(String queryTemplate, Map<String, String> out
String newQueryField = outputMapEntry.getKey();
String modelOutputFieldName = outputMapEntry.getValue();
Object modelOutputValue = getModelOutputValue(mlOutput, modelOutputFieldName, ignoreMissing, fullResponsePath);
if (modelOutputValue instanceof Map) {
modelOutputValue = toJson(modelOutputValue);
}
valuesMap.put(newQueryField, modelOutputValue);
}
StringSubstitutor sub = new StringSubstitutor(valuesMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,13 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.RangeQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.index.query.functionscore.ScriptScoreQueryBuilder;
import org.opensearch.ingest.Processor;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
Expand All @@ -58,6 +61,8 @@
import org.opensearch.ml.searchext.MLInferenceRequestParameters;
import org.opensearch.ml.searchext.MLInferenceRequestParametersExtBuilder;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.script.Script;
import org.opensearch.script.ScriptType;
import org.opensearch.search.SearchModule;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.pipeline.PipelineProcessingContext;
Expand Down Expand Up @@ -1312,7 +1317,7 @@ public void onFailure(Exception e) {

/**
* Tests the successful rewriting of a complex nested array in query extension based on the model output.
* verify the pipelineConext is set from the extension
* verify the pipelineContext is set from the extension
* @throws Exception if an error occurs during the test
*/
public void testExecute_rewriteTermQueryReadAndWriteComplexNestedArrayToExtensionSuccess() throws Exception {
Expand Down Expand Up @@ -1499,6 +1504,207 @@ public void onFailure(Exception e) {

}

/**
* Tests ML Processor can return a sparse vector correctly when performing a rewrite query.
*
* This simulates a real world scenario where user has a neural sparse model and attempts to parse
* it by asserting FullResponsePath to true.
* @throws Exception when an error occurs on the test
*/
public void testExecute_rewriteTermQueryWithSparseVectorSuccess() throws Exception {
String modelInputField = "inputs";
String originalQueryField = "query.term.text.value";
String newQueryField = "vector";
String modelInferenceJsonPathInput = "$.inference_results[0].output[0].dataAsMap.response[0]";

String queryTemplate = "{\n"
+ " \"query\": {\n"
+ " \"script_score\": {\n"
+ " \"query\": {\n"
+ " \"match_all\": {}\n"
+ " },\n"
+ " \"script\": {\n"
+ " \"source\": \"return 1;\",\n"
+ " \"params\": {\n"
+ " \"query_tokens\": ${vector}\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ "}";

Map<String, Double> sparseVector = Map.of("this", 1.3123, "which", 0.2447, "here", 0.6674);

List<Map<String, String>> optionalInputMap = new ArrayList<>();
Map<String, String> input = new HashMap<>();
input.put(modelInputField, originalQueryField);
optionalInputMap.add(input);

List<Map<String, String>> optionalOutputMap = new ArrayList<>();
Map<String, String> output = new HashMap<>();
output.put(newQueryField, modelInferenceJsonPathInput);
optionalOutputMap.add(output);

MLInferenceSearchRequestProcessor requestProcessor = new MLInferenceSearchRequestProcessor(
"model1",
queryTemplate,
null,
null,
optionalInputMap,
optionalOutputMap,
null,
DEFAULT_MAX_PREDICTION_TASKS,
PROCESSOR_TAG,
DESCRIPTION,
false,
"remote",
true,
false,
"{ \"parameters\": ${ml_inference.parameters} }",
client,
TEST_XCONTENT_REGISTRY_FOR_QUERY
);

/**
* {
* "inference_results" : [ {
* "output" : [ {
* "name" : "response",
* "dataAsMap" : {
* "response" : [ {
* "this" : 1.3123,
* "which" : 0.2447,
* "here" : 0.6674
* } ]
* }
* } ]
* } ]
* }
*/
ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(Map.of("response", List.of(sparseVector))).build();
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();

doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build());
return null;
}).when(client).execute(any(), any(), any());

QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo");
SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery);
SearchRequest request = new SearchRequest().source(source);

ActionListener<SearchRequest> Listener = new ActionListener<>() {
@Override
public void onResponse(SearchRequest newSearchRequest) {
Script script = new Script(ScriptType.INLINE, "painless", "return 1;", Map.of("query_tokens", sparseVector));

ScriptScoreQueryBuilder expectedQuery = new ScriptScoreQueryBuilder(QueryBuilders.matchAllQuery(), script);
assertEquals(expectedQuery, newSearchRequest.source().query());
}

@Override
public void onFailure(Exception e) {
throw new RuntimeException("Failed in executing processRequestAsync.", e);
}
};

requestProcessor.processRequestAsync(request, requestContext, Listener);
}

/**
* Tests ML Processor can return a OpenSearch Query correctly when performing a rewrite query.
*
* This simulates a real world scenario where user has a llm return a OpenSearch Query to help them generate a new
* query based on the context given in the prompt.
*
* @throws Exception when an error occurs on the test
*/
public void testExecute_rewriteTermQueryWithNewQuerySuccess() throws Exception {
String modelInputField = "inputs";
String originalQueryField = "query.term.text.value";
String newQueryField = "llm_query";
String modelInferenceJsonPathInput = "$.inference_results[0].output[0].dataAsMap.content[0].text";

String queryTemplate = "${llm_query}";

String llmQuery = "{\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " }\n" + "}";
Map content = Map.of("content", List.of(Map.of("text", llmQuery)));

List<Map<String, String>> optionalInputMap = new ArrayList<>();
Map<String, String> input = new HashMap<>();
input.put(modelInputField, originalQueryField);
optionalInputMap.add(input);

List<Map<String, String>> optionalOutputMap = new ArrayList<>();
Map<String, String> output = new HashMap<>();
output.put(newQueryField, modelInferenceJsonPathInput);
optionalOutputMap.add(output);

MLInferenceSearchRequestProcessor requestProcessor = new MLInferenceSearchRequestProcessor(
"model1",
queryTemplate,
null,
null,
optionalInputMap,
optionalOutputMap,
null,
DEFAULT_MAX_PREDICTION_TASKS,
PROCESSOR_TAG,
DESCRIPTION,
false,
"remote",
true,
false,
"{ \"parameters\": ${ml_inference.parameters} }",
client,
TEST_XCONTENT_REGISTRY_FOR_QUERY
);

/*
* {
* "inference_results" : [ {
* "output" : [ {
* "name" : "response",
* "dataAsMap" : {
* "content": [
* "text": "{\"query\": \"match_all\" : {}}"
* }
* } ]
* } ]
* }
*/
ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(content).build();
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();

doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build());
return null;
}).when(client).execute(any(), any(), any());

QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo");
SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery);
SearchRequest sampleRequest = new SearchRequest().source(source);

ActionListener<SearchRequest> Listener = new ActionListener<>() {
@Override
public void onResponse(SearchRequest newSearchRequest) {
MatchAllQueryBuilder expectedQuery = new MatchAllQueryBuilder();
assertEquals(expectedQuery, newSearchRequest.source().query());
}

@Override
public void onFailure(Exception e) {
throw new RuntimeException("Failed in executing processRequestAsync.", e);
}
};

requestProcessor.processRequestAsync(sampleRequest, requestContext, Listener);
}

/**
* Tests when there are two optional input fields
* but only the second optional input is present in the query
Expand Down
Loading