From 5608c63093e06c55e4c929c569349a37a8d94fb1 Mon Sep 17 00:00:00 2001 From: Brian Flores Date: Fri, 16 May 2025 15:57:12 -0700 Subject: [PATCH 1/5] adds json parsing to modelOutputValue during query rewrite Signed-off-by: Brian Flores --- .../MLInferenceSearchRequestProcessor.java | 48 ++++++++++--------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java index 02b9c331c7..3a43399f79 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java @@ -4,23 +4,13 @@ */ package org.opensearch.ml.processor; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.utils.StringUtils.toJson; -import static org.opensearch.ml.processor.InferenceProcessorAttributes.INPUT_MAP; -import static org.opensearch.ml.processor.InferenceProcessorAttributes.MAX_PREDICTION_TASKS; -import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_CONFIG; -import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_ID; -import static org.opensearch.ml.processor.InferenceProcessorAttributes.OUTPUT_MAP; -import static org.opensearch.ml.processor.ModelExecutor.combineMaps; - -import java.io.IOException; -import java.util.Collection; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - +import com.google.gson.Gson; +import com.jayway.jsonpath.Configuration; +import com.jayway.jsonpath.JsonPath; +import com.jayway.jsonpath.Option; +import com.jayway.jsonpath.PathNotFoundException; +import com.jayway.jsonpath.ReadContext; +import lombok.Getter; import org.apache.commons.text.StringSubstitutor; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -46,13 +36,22 @@ import org.opensearch.search.pipeline.SearchRequestProcessor; import org.opensearch.transport.client.Client; -import com.jayway.jsonpath.Configuration; -import com.jayway.jsonpath.JsonPath; -import com.jayway.jsonpath.Option; -import com.jayway.jsonpath.PathNotFoundException; -import com.jayway.jsonpath.ReadContext; +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; -import lombok.Getter; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.toJson; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.INPUT_MAP; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.MAX_PREDICTION_TASKS; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_CONFIG; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_ID; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.OUTPUT_MAP; +import static org.opensearch.ml.processor.ModelExecutor.combineMaps; /** * MLInferenceSearchRequestProcessor requires a modelId string to call model inferences @@ -360,6 +359,9 @@ private String updateQueryTemplate(String queryTemplate, Map out String newQueryField = outputMapEntry.getKey(); String modelOutputFieldName = outputMapEntry.getValue(); Object modelOutputValue = getModelOutputValue(mlOutput, modelOutputFieldName, ignoreMissing, fullResponsePath); + if (modelOutputValue instanceof Map) { + modelOutputValue = new Gson().toJson(modelOutputValue); + } valuesMap.put(newQueryField, modelOutputValue); } StringSubstitutor sub = new StringSubstitutor(valuesMap); From 46db551664e677dd9677ab67096591462060b6ae Mon Sep 17 00:00:00 2001 From: Brian Flores Date: Fri, 16 May 2025 17:56:49 -0700 Subject: [PATCH 2/5] add Unit Test for queryTemplate change Signed-off-by: Brian Flores --- .../MLInferenceSearchRequestProcessor.java | 46 +++---- ...LInferenceSearchRequestProcessorTests.java | 113 ++++++++++++++++++ 2 files changed, 137 insertions(+), 22 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java index 3a43399f79..053d92725a 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java @@ -4,13 +4,23 @@ */ package org.opensearch.ml.processor; -import com.google.gson.Gson; -import com.jayway.jsonpath.Configuration; -import com.jayway.jsonpath.JsonPath; -import com.jayway.jsonpath.Option; -import com.jayway.jsonpath.PathNotFoundException; -import com.jayway.jsonpath.ReadContext; -import lombok.Getter; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.toJson; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.INPUT_MAP; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.MAX_PREDICTION_TASKS; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_CONFIG; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_ID; +import static org.opensearch.ml.processor.InferenceProcessorAttributes.OUTPUT_MAP; +import static org.opensearch.ml.processor.ModelExecutor.combineMaps; + +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + import org.apache.commons.text.StringSubstitutor; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -36,22 +46,14 @@ import org.opensearch.search.pipeline.SearchRequestProcessor; import org.opensearch.transport.client.Client; -import java.io.IOException; -import java.util.Collection; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; +import com.google.gson.Gson; +import com.jayway.jsonpath.Configuration; +import com.jayway.jsonpath.JsonPath; +import com.jayway.jsonpath.Option; +import com.jayway.jsonpath.PathNotFoundException; +import com.jayway.jsonpath.ReadContext; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.utils.StringUtils.toJson; -import static org.opensearch.ml.processor.InferenceProcessorAttributes.INPUT_MAP; -import static org.opensearch.ml.processor.InferenceProcessorAttributes.MAX_PREDICTION_TASKS; -import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_CONFIG; -import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_ID; -import static org.opensearch.ml.processor.InferenceProcessorAttributes.OUTPUT_MAP; -import static org.opensearch.ml.processor.ModelExecutor.combineMaps; +import lombok.Getter; /** * MLInferenceSearchRequestProcessor requires a modelId string to call model inferences diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java index e41b4b1a3d..8cce79c4a3 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java @@ -42,9 +42,11 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.RangeQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.index.query.TermsQueryBuilder; +import org.opensearch.index.query.functionscore.ScriptScoreQueryBuilder; import org.opensearch.ingest.Processor; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; @@ -58,6 +60,8 @@ import org.opensearch.ml.searchext.MLInferenceRequestParameters; import org.opensearch.ml.searchext.MLInferenceRequestParametersExtBuilder; import org.opensearch.plugins.SearchPlugin; +import org.opensearch.script.Script; +import org.opensearch.script.ScriptType; import org.opensearch.search.SearchModule; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.pipeline.PipelineProcessingContext; @@ -1499,6 +1503,115 @@ public void onFailure(Exception e) { } + /** + * Tests ML Processor can return a sparse vector correctly when performing a rewrite query. + * + * This simulates a real world scenario where user has a neural sparse model and attempts to parse + * it by asserting FullResponsePath to true. + * @throws Exception when an error occurs on the test + */ + public void testExecute_rewriteTermQueryWithSparseVectorSuccess() throws Exception { + String modelInputField = "inputs"; + String originalQueryField = "query.term.text.value"; + String newQueryField = "vector"; + String modelInferenceJsonPathInput = "$.inference_results[0].output[0].dataAsMap.response[0]"; + + String queryTemplate = "{\n" + + " \"query\": {\n" + + " \"script_score\": {\n" + + " \"query\": {\n" + + " \"match_all\": {}\n" + + " },\n" + + " \"script\": {\n" + + " \"source\": \"return 1;\",\n" + + " \"params\": {\n" + + " \"query_tokens\": ${vector}\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + + Map sparseVector = Map.of("this", 1.3123, "which", 0.2447, "here", 0.6674); + + List> optionalInputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put(modelInputField, originalQueryField); + optionalInputMap.add(input); + + List> optionalOutputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newQueryField, modelInferenceJsonPathInput); + optionalOutputMap.add(output); + + MLInferenceSearchRequestProcessor requestProcessor = new MLInferenceSearchRequestProcessor( + "model1", + queryTemplate, + null, + null, + optionalInputMap, + optionalOutputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + true, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY + ); + + /** + * { + * "inference_results" : [ { + * "output" : [ { + * "name" : "response", + * "dataAsMap" : { + * "response" : [ { + * "this" : 1.3123, + * "which" : 0.2447, + * "here" : 0.6674 + * } ] + * } + * } ] + * } ] + * } + */ + ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(Map.of("response", List.of(sparseVector))).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); + SearchRequest request = new SearchRequest().source(source); + + ActionListener Listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + Script script = new Script(ScriptType.INLINE, "painless", "return 1;", Map.of("query_tokens", sparseVector)); + + ScriptScoreQueryBuilder expectedQuery = new ScriptScoreQueryBuilder(QueryBuilders.matchAllQuery(), script); + assertEquals(expectedQuery, newSearchRequest.source().query()); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("Failed in executing processRequestAsync.", e); + } + }; + + requestProcessor.processRequestAsync(request, requestContext, Listener); + } + /** * Tests when there are two optional input fields * but only the second optional input is present in the query From f7d0f24a2f569b9e6141bbb4fa83fe904cd3584d Mon Sep 17 00:00:00 2001 From: Brian Flores Date: Wed, 21 May 2025 11:17:45 -0700 Subject: [PATCH 3/5] refactors and adds new UT Signed-off-by: Brian Flores --- .../MLInferenceSearchRequestProcessor.java | 3 +- ...LInferenceSearchRequestProcessorTests.java | 104 +++++++++++++++++- 2 files changed, 104 insertions(+), 3 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java index 053d92725a..9d19214c34 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java @@ -46,7 +46,6 @@ import org.opensearch.search.pipeline.SearchRequestProcessor; import org.opensearch.transport.client.Client; -import com.google.gson.Gson; import com.jayway.jsonpath.Configuration; import com.jayway.jsonpath.JsonPath; import com.jayway.jsonpath.Option; @@ -362,7 +361,7 @@ private String updateQueryTemplate(String queryTemplate, Map out String modelOutputFieldName = outputMapEntry.getValue(); Object modelOutputValue = getModelOutputValue(mlOutput, modelOutputFieldName, ignoreMissing, fullResponsePath); if (modelOutputValue instanceof Map) { - modelOutputValue = new Gson().toJson(modelOutputValue); + modelOutputValue = toJson(modelOutputValue); } valuesMap.put(newQueryField, modelOutputValue); } diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java index 8cce79c4a3..6d1d2f5cdf 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java @@ -41,6 +41,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.RangeQueryBuilder; @@ -1316,7 +1317,7 @@ public void onFailure(Exception e) { /** * Tests the successful rewriting of a complex nested array in query extension based on the model output. - * verify the pipelineConext is set from the extension + * verify the pipelineContext is set from the extension * @throws Exception if an error occurs during the test */ public void testExecute_rewriteTermQueryReadAndWriteComplexNestedArrayToExtensionSuccess() throws Exception { @@ -1612,6 +1613,107 @@ public void onFailure(Exception e) { requestProcessor.processRequestAsync(request, requestContext, Listener); } + /** + * Tests ML Processor can return a OpenSearch Query correctly when performing a rewrite query. + * + * This simulates a real world scenario where user has a llm return a OpenSearch Query to help them generate a new + * query based on the context given in the prompt. + * + * @throws Exception when an error occurs on the test + */ + public void testExecute_rewriteTermQueryWithNewQuerySuccess() throws Exception { + String modelInputField = "inputs"; + String originalQueryField = "query.term.text.value"; + String newQueryField = "llm_query"; + String modelInferenceJsonPathInput = "$.inference_results[0].output[0].dataAsMap.content[0].text"; + + String queryTemplate = "${llm_query}"; + + String llmQuery = "{\n" + + " \"query\": {\n" + + " \"match_all\": {}\n" + + " }\n" + + "}"; + Map content = Map.of("content", List.of(Map.of("text", llmQuery))); + + + List> optionalInputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put(modelInputField, originalQueryField); + optionalInputMap.add(input); + + List> optionalOutputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newQueryField, modelInferenceJsonPathInput); + optionalOutputMap.add(output); + + MLInferenceSearchRequestProcessor requestProcessor = new MLInferenceSearchRequestProcessor( + "model1", + queryTemplate, + null, + null, + optionalInputMap, + optionalOutputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + true, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY + ); + + /** + * { + * "inference_results" : [ { + * "output" : [ { + * "name" : "response", + * "dataAsMap" : { + * "content": [ + * "text": "{\"query\": \"match_all\" : {}}" + * } + * } ] + * } ] + * } + */ + ModelTensor modelTensor = ModelTensor + .builder() + .name("response") + .dataAsMap(content) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); + SearchRequest sampleRequest = new SearchRequest().source(source); + + ActionListener Listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + MatchAllQueryBuilder expectedQuery = new MatchAllQueryBuilder(); + assertEquals(expectedQuery, newSearchRequest.source().query()); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("Failed in executing processRequestAsync.", e); + } + }; + + requestProcessor.processRequestAsync(sampleRequest, requestContext, Listener); + } + /** * Tests when there are two optional input fields * but only the second optional input is present in the query From de23341b034f6c6974ffda2cfb56b716869bfb4f Mon Sep 17 00:00:00 2001 From: Brian Flores Date: Wed, 21 May 2025 13:30:47 -0700 Subject: [PATCH 4/5] apply spotless Signed-off-by: Brian Flores --- ...LInferenceSearchRequestProcessorTests.java | 47 ++++++++----------- 1 file changed, 19 insertions(+), 28 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java index 6d1d2f5cdf..4ff22bcad3 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java @@ -1629,14 +1629,9 @@ public void testExecute_rewriteTermQueryWithNewQuerySuccess() throws Exception { String queryTemplate = "${llm_query}"; - String llmQuery = "{\n" + - " \"query\": {\n" + - " \"match_all\": {}\n" + - " }\n" + - "}"; + String llmQuery = "{\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " }\n" + "}"; Map content = Map.of("content", List.of(Map.of("text", llmQuery))); - List> optionalInputMap = new ArrayList<>(); Map input = new HashMap<>(); input.put(modelInputField, originalQueryField); @@ -1648,23 +1643,23 @@ public void testExecute_rewriteTermQueryWithNewQuerySuccess() throws Exception { optionalOutputMap.add(output); MLInferenceSearchRequestProcessor requestProcessor = new MLInferenceSearchRequestProcessor( - "model1", - queryTemplate, - null, - null, - optionalInputMap, - optionalOutputMap, - null, - DEFAULT_MAX_PREDICTION_TASKS, - PROCESSOR_TAG, - DESCRIPTION, - false, - "remote", - true, - false, - "{ \"parameters\": ${ml_inference.parameters} }", - client, - TEST_XCONTENT_REGISTRY_FOR_QUERY + "model1", + queryTemplate, + null, + null, + optionalInputMap, + optionalOutputMap, + null, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + true, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY ); /** @@ -1680,11 +1675,7 @@ public void testExecute_rewriteTermQueryWithNewQuerySuccess() throws Exception { * } ] * } */ - ModelTensor modelTensor = ModelTensor - .builder() - .name("response") - .dataAsMap(content) - .build(); + ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(content).build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); From 29c6745f66a1bfe0551ba4d94e09dec635e172b1 Mon Sep 17 00:00:00 2001 From: Brian Flores Date: Wed, 21 May 2025 14:53:29 -0700 Subject: [PATCH 5/5] replace tab with spaces on multi-line comment Signed-off-by: Brian Flores --- .../processor/MLInferenceSearchRequestProcessorTests.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java index 4ff22bcad3..ea0646b02d 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java @@ -1662,14 +1662,14 @@ public void testExecute_rewriteTermQueryWithNewQuerySuccess() throws Exception { TEST_XCONTENT_REGISTRY_FOR_QUERY ); - /** + /* * { * "inference_results" : [ { * "output" : [ { * "name" : "response", * "dataAsMap" : { - * "content": [ - * "text": "{\"query\": \"match_all\" : {}}" + * "content": [ + * "text": "{\"query\": \"match_all\" : {}}" * } * } ] * } ]