From e2dee346695ded0a955ad51a7c32747a0231cb87 Mon Sep 17 00:00:00 2001 From: Brian Flores Date: Thu, 22 May 2025 12:18:15 -0700 Subject: [PATCH] Adds Json Parsing to nested object during update Query step in ML Inference Request processor (#3856) * adds json parsing to modelOutputValue during query rewrite Signed-off-by: Brian Flores * add Unit Test for queryTemplate change Signed-off-by: Brian Flores * refactors and adds new UT Signed-off-by: Brian Flores * apply spotless Signed-off-by: Brian Flores * replace tab with spaces on multi-line comment Signed-off-by: Brian Flores --------- Signed-off-by: Brian Flores (cherry picked from commit 532284e3bb5ec11fa35550edc87848a27b226b08) --- .../MLInferenceSearchRequestProcessor.java | 3 + ...LInferenceSearchRequestProcessorTests.java | 208 +++++++++++++++++- 2 files changed, 210 insertions(+), 1 deletion(-) diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java index 8973082873..2a90e96cdf 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java @@ -360,6 +360,9 @@ private String updateQueryTemplate(String queryTemplate, Map out String newQueryField = outputMapEntry.getKey(); String modelOutputFieldName = outputMapEntry.getValue(); Object modelOutputValue = getModelOutputValue(mlOutput, modelOutputFieldName, ignoreMissing, fullResponsePath); + if (modelOutputValue instanceof Map) { + modelOutputValue = toJson(modelOutputValue); + } valuesMap.put(newQueryField, modelOutputValue); } StringSubstitutor sub = new StringSubstitutor(valuesMap); diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java index 0151b8b18e..50b4bea86a 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java @@ -42,10 +42,13 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.RangeQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.index.query.TermsQueryBuilder; +import org.opensearch.index.query.functionscore.ScriptScoreQueryBuilder; import org.opensearch.ingest.Processor; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; @@ -59,6 +62,8 @@ import org.opensearch.ml.searchext.MLInferenceRequestParameters; import org.opensearch.ml.searchext.MLInferenceRequestParametersExtBuilder; import org.opensearch.plugins.SearchPlugin; +import org.opensearch.script.Script; +import org.opensearch.script.ScriptType; import org.opensearch.search.SearchModule; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.pipeline.PipelineProcessingContext; @@ -1312,7 +1317,7 @@ public void onFailure(Exception e) { /** * Tests the successful rewriting of a complex nested array in query extension based on the model output. - * verify the pipelineConext is set from the extension + * verify the pipelineContext is set from the extension * @throws Exception if an error occurs during the test */ public void testExecute_rewriteTermQueryReadAndWriteComplexNestedArrayToExtensionSuccess() throws Exception { @@ -1499,6 +1504,207 @@ public void onFailure(Exception e) { } + /** + * Tests ML Processor can return a sparse vector correctly when performing a rewrite query. + * + * This simulates a real world scenario where user has a neural sparse model and attempts to parse + * it by asserting FullResponsePath to true. + * @throws Exception when an error occurs on the test + */ + public void testExecute_rewriteTermQueryWithSparseVectorSuccess() throws Exception { + String modelInputField = "inputs"; + String originalQueryField = "query.term.text.value"; + String newQueryField = "vector"; + String modelInferenceJsonPathInput = "$.inference_results[0].output[0].dataAsMap.response[0]"; + + String queryTemplate = "{\n" + + " \"query\": {\n" + + " \"script_score\": {\n" + + " \"query\": {\n" + + " \"match_all\": {}\n" + + " },\n" + + " \"script\": {\n" + + " \"source\": \"return 1;\",\n" + + " \"params\": {\n" + + " \"query_tokens\": ${vector}\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + + Map sparseVector = Map.of("this", 1.3123, "which", 0.2447, "here", 0.6674); + + List> optionalInputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put(modelInputField, originalQueryField); + optionalInputMap.add(input); + + List> optionalOutputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newQueryField, modelInferenceJsonPathInput); + optionalOutputMap.add(output); + + MLInferenceSearchRequestProcessor requestProcessor = new MLInferenceSearchRequestProcessor( + "model1", + queryTemplate, + null, + null, + optionalInputMap, + optionalOutputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + true, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY + ); + + /** + * { + * "inference_results" : [ { + * "output" : [ { + * "name" : "response", + * "dataAsMap" : { + * "response" : [ { + * "this" : 1.3123, + * "which" : 0.2447, + * "here" : 0.6674 + * } ] + * } + * } ] + * } ] + * } + */ + ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(Map.of("response", List.of(sparseVector))).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); + SearchRequest request = new SearchRequest().source(source); + + ActionListener Listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + Script script = new Script(ScriptType.INLINE, "painless", "return 1;", Map.of("query_tokens", sparseVector)); + + ScriptScoreQueryBuilder expectedQuery = new ScriptScoreQueryBuilder(QueryBuilders.matchAllQuery(), script); + assertEquals(expectedQuery, newSearchRequest.source().query()); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("Failed in executing processRequestAsync.", e); + } + }; + + requestProcessor.processRequestAsync(request, requestContext, Listener); + } + + /** + * Tests ML Processor can return a OpenSearch Query correctly when performing a rewrite query. + * + * This simulates a real world scenario where user has a llm return a OpenSearch Query to help them generate a new + * query based on the context given in the prompt. + * + * @throws Exception when an error occurs on the test + */ + public void testExecute_rewriteTermQueryWithNewQuerySuccess() throws Exception { + String modelInputField = "inputs"; + String originalQueryField = "query.term.text.value"; + String newQueryField = "llm_query"; + String modelInferenceJsonPathInput = "$.inference_results[0].output[0].dataAsMap.content[0].text"; + + String queryTemplate = "${llm_query}"; + + String llmQuery = "{\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " }\n" + "}"; + Map content = Map.of("content", List.of(Map.of("text", llmQuery))); + + List> optionalInputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put(modelInputField, originalQueryField); + optionalInputMap.add(input); + + List> optionalOutputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newQueryField, modelInferenceJsonPathInput); + optionalOutputMap.add(output); + + MLInferenceSearchRequestProcessor requestProcessor = new MLInferenceSearchRequestProcessor( + "model1", + queryTemplate, + null, + null, + optionalInputMap, + optionalOutputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + true, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY + ); + + /* + * { + * "inference_results" : [ { + * "output" : [ { + * "name" : "response", + * "dataAsMap" : { + * "content": [ + * "text": "{\"query\": \"match_all\" : {}}" + * } + * } ] + * } ] + * } + */ + ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(content).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); + SearchRequest sampleRequest = new SearchRequest().source(source); + + ActionListener Listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + MatchAllQueryBuilder expectedQuery = new MatchAllQueryBuilder(); + assertEquals(expectedQuery, newSearchRequest.source().query()); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("Failed in executing processRequestAsync.", e); + } + }; + + requestProcessor.processRequestAsync(sampleRequest, requestContext, Listener); + } + /** * Tests when there are two optional input fields * but only the second optional input is present in the query