Skip to content

Adds Json Parsing to nested object during update Query step in ML Inference Request processor #3856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,9 @@ private String updateQueryTemplate(String queryTemplate, Map<String, String> out
String newQueryField = outputMapEntry.getKey();
String modelOutputFieldName = outputMapEntry.getValue();
Object modelOutputValue = getModelOutputValue(mlOutput, modelOutputFieldName, ignoreMissing, fullResponsePath);
if (modelOutputValue instanceof Map) {
modelOutputValue = toJson(modelOutputValue);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the right place to fix the issue, the string convert to json string should happen before string substitution. So the next question is, should this toJson conversion only added when it's a map?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Map should be a catch all for many scenarios that I can think of at the top of my head.

I may need to test if this solution doesn't break the query from remote llm. I wonder if a list should also be supported as well for example a list of maps. then this instanceof check wouldnt work in that scenario.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a UT where a llm returns a query and a query rewrite occurs. This happened to succeed with the code change e.g.

        /**
         * {
         *   "inference_results" : [ {
         *     "output" : [ {
         *       "name" : "response",
         *       "dataAsMap" : {
         * 		  "content": [
         * 			"text": "{\"query\": \"match_all\" : {}}"
         *       }
         *     } ]
         *   } ]
         * }
         */

Off the top of my head I cant think of another scenario where this parsing issue occurs since I see that your UTs cover a lot of edge cases such as a list.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should cover more use cases in testing, because the model output can be various format,

I have two cases in mind that you can try add and verify in the UT

  1. list of map,
{
  "query": {
    "bool": {
      "must": [
        {
          "terms": {
            "categories": [
              {"name": "electronics", "id": 1},
              {"name": "computers", "id": 2},
              {"name": "phones", "id": 3}
            ]
          }
        }
      ]
    }
  }
}
  1. list of nested map (map of map),
{
  "query": {
    "bool": {
      "must": [
        {
          "nested": {
            "path": "products",
            "query": {
              "bool": {
                "must": [
                  {
                    "terms": {
                      "products.variants": [
                        {
                          "color": {
                            "primary": {"hex": "#FF0000", "name": "red"},
                            "secondary": {"hex": "#000000", "name": "black"}
                          },
                          "size": {
                            "dimensions": {"width": 10, "height": 20},
                            "label": {"us": "M", "eu": "38"}
                          }
                        },
                        {
                          "color": {
                            "primary": {"hex": "#0000FF", "name": "blue"},
                            "secondary": {"hex": "#FFFFFF", "name": "white"}
                          },
                          "size": {
                            "dimensions": {"width": 12, "height": 22},
                            "label": {"us": "L", "eu": "40"}
                          }
                        }
                      ]
                    }
                  }
                ]
              }
            }
          }
        }
      ]
    }
  }
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Scenario one list of Map I dont think this is a valid OpenSearch query consequently If it runs through the processor it fails the query rewrite

However I did try this, taken from docs :

query_template=
{
  "query": {
    "bool": {
      "must": [
       ${llm_response}
      ]
    }
  }
}


llm_response = 
 {
          "bool": {
            "should": [
              {
                "match": {
                  "text_entry": "love"
                }
              },
              {
                "match": {
                  "text": "hate"
                }
              }
            ]
          }
        },
        {
          "bool": {
            "should": [
              {
                "match": {
                  "text_entry": "life"
                }
              },
              {
                "match": {
                  "text": "grace"
                }
              }
            ]
          }
        }

valuesMap.put(newQueryField, modelOutputValue);
}
StringSubstitutor sub = new StringSubstitutor(valuesMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,13 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.RangeQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.index.query.functionscore.ScriptScoreQueryBuilder;
import org.opensearch.ingest.Processor;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
Expand All @@ -58,6 +61,8 @@
import org.opensearch.ml.searchext.MLInferenceRequestParameters;
import org.opensearch.ml.searchext.MLInferenceRequestParametersExtBuilder;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.script.Script;
import org.opensearch.script.ScriptType;
import org.opensearch.search.SearchModule;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.pipeline.PipelineProcessingContext;
Expand Down Expand Up @@ -1312,7 +1317,7 @@ public void onFailure(Exception e) {

/**
* Tests the successful rewriting of a complex nested array in query extension based on the model output.
* verify the pipelineConext is set from the extension
* verify the pipelineContext is set from the extension
* @throws Exception if an error occurs during the test
*/
public void testExecute_rewriteTermQueryReadAndWriteComplexNestedArrayToExtensionSuccess() throws Exception {
Expand Down Expand Up @@ -1499,6 +1504,207 @@ public void onFailure(Exception e) {

}

/**
* Tests ML Processor can return a sparse vector correctly when performing a rewrite query.
*
* This simulates a real world scenario where user has a neural sparse model and attempts to parse
* it by asserting FullResponsePath to true.
* @throws Exception when an error occurs on the test
*/
public void testExecute_rewriteTermQueryWithSparseVectorSuccess() throws Exception {
String modelInputField = "inputs";
String originalQueryField = "query.term.text.value";
String newQueryField = "vector";
String modelInferenceJsonPathInput = "$.inference_results[0].output[0].dataAsMap.response[0]";

String queryTemplate = "{\n"
+ " \"query\": {\n"
+ " \"script_score\": {\n"
+ " \"query\": {\n"
+ " \"match_all\": {}\n"
+ " },\n"
+ " \"script\": {\n"
+ " \"source\": \"return 1;\",\n"
+ " \"params\": {\n"
+ " \"query_tokens\": ${vector}\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ "}";

Map<String, Double> sparseVector = Map.of("this", 1.3123, "which", 0.2447, "here", 0.6674);

List<Map<String, String>> optionalInputMap = new ArrayList<>();
Map<String, String> input = new HashMap<>();
input.put(modelInputField, originalQueryField);
optionalInputMap.add(input);

List<Map<String, String>> optionalOutputMap = new ArrayList<>();
Map<String, String> output = new HashMap<>();
output.put(newQueryField, modelInferenceJsonPathInput);
optionalOutputMap.add(output);

MLInferenceSearchRequestProcessor requestProcessor = new MLInferenceSearchRequestProcessor(
"model1",
queryTemplate,
null,
null,
optionalInputMap,
optionalOutputMap,
null,
DEFAULT_MAX_PREDICTION_TASKS,
PROCESSOR_TAG,
DESCRIPTION,
false,
"remote",
true,
false,
"{ \"parameters\": ${ml_inference.parameters} }",
client,
TEST_XCONTENT_REGISTRY_FOR_QUERY
);

/**
* {
* "inference_results" : [ {
* "output" : [ {
* "name" : "response",
* "dataAsMap" : {
* "response" : [ {
* "this" : 1.3123,
* "which" : 0.2447,
* "here" : 0.6674
* } ]
* }
* } ]
* } ]
* }
*/
ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(Map.of("response", List.of(sparseVector))).build();
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();

doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build());
return null;
}).when(client).execute(any(), any(), any());

QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo");
SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery);
SearchRequest request = new SearchRequest().source(source);

ActionListener<SearchRequest> Listener = new ActionListener<>() {
@Override
public void onResponse(SearchRequest newSearchRequest) {
Script script = new Script(ScriptType.INLINE, "painless", "return 1;", Map.of("query_tokens", sparseVector));

ScriptScoreQueryBuilder expectedQuery = new ScriptScoreQueryBuilder(QueryBuilders.matchAllQuery(), script);
assertEquals(expectedQuery, newSearchRequest.source().query());
}

@Override
public void onFailure(Exception e) {
throw new RuntimeException("Failed in executing processRequestAsync.", e);
}
};

requestProcessor.processRequestAsync(request, requestContext, Listener);
}

/**
* Tests ML Processor can return a OpenSearch Query correctly when performing a rewrite query.
*
* This simulates a real world scenario where user has a llm return a OpenSearch Query to help them generate a new
* query based on the context given in the prompt.
*
* @throws Exception when an error occurs on the test
*/
public void testExecute_rewriteTermQueryWithNewQuerySuccess() throws Exception {
String modelInputField = "inputs";
String originalQueryField = "query.term.text.value";
String newQueryField = "llm_query";
String modelInferenceJsonPathInput = "$.inference_results[0].output[0].dataAsMap.content[0].text";

String queryTemplate = "${llm_query}";

String llmQuery = "{\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " }\n" + "}";
Map content = Map.of("content", List.of(Map.of("text", llmQuery)));

List<Map<String, String>> optionalInputMap = new ArrayList<>();
Map<String, String> input = new HashMap<>();
input.put(modelInputField, originalQueryField);
optionalInputMap.add(input);

List<Map<String, String>> optionalOutputMap = new ArrayList<>();
Map<String, String> output = new HashMap<>();
output.put(newQueryField, modelInferenceJsonPathInput);
optionalOutputMap.add(output);

MLInferenceSearchRequestProcessor requestProcessor = new MLInferenceSearchRequestProcessor(
"model1",
queryTemplate,
null,
null,
optionalInputMap,
optionalOutputMap,
null,
DEFAULT_MAX_PREDICTION_TASKS,
PROCESSOR_TAG,
DESCRIPTION,
false,
"remote",
true,
false,
"{ \"parameters\": ${ml_inference.parameters} }",
client,
TEST_XCONTENT_REGISTRY_FOR_QUERY
);

/**
* {
* "inference_results" : [ {
* "output" : [ {
* "name" : "response",
* "dataAsMap" : {
* "content": [
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why its displayed out of place on github but looks normal in intelij

* "text": "{\"query\": \"match_all\" : {}}"
* }
* } ]
* } ]
* }
*/
ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(content).build();
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();

doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build());
return null;
}).when(client).execute(any(), any(), any());

QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo");
SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery);
SearchRequest sampleRequest = new SearchRequest().source(source);

ActionListener<SearchRequest> Listener = new ActionListener<>() {
@Override
public void onResponse(SearchRequest newSearchRequest) {
MatchAllQueryBuilder expectedQuery = new MatchAllQueryBuilder();
assertEquals(expectedQuery, newSearchRequest.source().query());
}

@Override
public void onFailure(Exception e) {
throw new RuntimeException("Failed in executing processRequestAsync.", e);
}
};

requestProcessor.processRequestAsync(sampleRequest, requestContext, Listener);
}

/**
* Tests when there are two optional input fields
* but only the second optional input is present in the query
Expand Down
Loading