-
Notifications
You must be signed in to change notification settings - Fork 158
Adds Json Parsing to nested object during update Query step in ML Inference Request processor #3856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+210
−1
Merged
Changes from 4 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
5608c63
adds json parsing to modelOutputValue during query rewrite
brianf-aws 46db551
add Unit Test for queryTemplate change
brianf-aws f7d0f24
refactors and adds new UT
brianf-aws de23341
apply spotless
brianf-aws 29c6745
replace tab with spaces on multi-line comment
brianf-aws File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,10 +41,13 @@ | |
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
import org.opensearch.index.query.MatchAllQueryBuilder; | ||
import org.opensearch.index.query.QueryBuilder; | ||
import org.opensearch.index.query.QueryBuilders; | ||
import org.opensearch.index.query.RangeQueryBuilder; | ||
import org.opensearch.index.query.TermQueryBuilder; | ||
import org.opensearch.index.query.TermsQueryBuilder; | ||
import org.opensearch.index.query.functionscore.ScriptScoreQueryBuilder; | ||
import org.opensearch.ingest.Processor; | ||
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; | ||
import org.opensearch.ml.common.input.MLInput; | ||
|
@@ -58,6 +61,8 @@ | |
import org.opensearch.ml.searchext.MLInferenceRequestParameters; | ||
import org.opensearch.ml.searchext.MLInferenceRequestParametersExtBuilder; | ||
import org.opensearch.plugins.SearchPlugin; | ||
import org.opensearch.script.Script; | ||
import org.opensearch.script.ScriptType; | ||
import org.opensearch.search.SearchModule; | ||
import org.opensearch.search.builder.SearchSourceBuilder; | ||
import org.opensearch.search.pipeline.PipelineProcessingContext; | ||
|
@@ -1312,7 +1317,7 @@ public void onFailure(Exception e) { | |
|
||
/** | ||
* Tests the successful rewriting of a complex nested array in query extension based on the model output. | ||
* verify the pipelineConext is set from the extension | ||
* verify the pipelineContext is set from the extension | ||
* @throws Exception if an error occurs during the test | ||
*/ | ||
public void testExecute_rewriteTermQueryReadAndWriteComplexNestedArrayToExtensionSuccess() throws Exception { | ||
|
@@ -1499,6 +1504,207 @@ public void onFailure(Exception e) { | |
|
||
} | ||
|
||
/** | ||
* Tests ML Processor can return a sparse vector correctly when performing a rewrite query. | ||
* | ||
* This simulates a real world scenario where user has a neural sparse model and attempts to parse | ||
* it by asserting FullResponsePath to true. | ||
* @throws Exception when an error occurs on the test | ||
*/ | ||
public void testExecute_rewriteTermQueryWithSparseVectorSuccess() throws Exception { | ||
String modelInputField = "inputs"; | ||
String originalQueryField = "query.term.text.value"; | ||
String newQueryField = "vector"; | ||
String modelInferenceJsonPathInput = "$.inference_results[0].output[0].dataAsMap.response[0]"; | ||
|
||
String queryTemplate = "{\n" | ||
+ " \"query\": {\n" | ||
+ " \"script_score\": {\n" | ||
+ " \"query\": {\n" | ||
+ " \"match_all\": {}\n" | ||
+ " },\n" | ||
+ " \"script\": {\n" | ||
+ " \"source\": \"return 1;\",\n" | ||
+ " \"params\": {\n" | ||
+ " \"query_tokens\": ${vector}\n" | ||
+ " }\n" | ||
+ " }\n" | ||
+ " }\n" | ||
+ " }\n" | ||
+ "}"; | ||
|
||
Map<String, Double> sparseVector = Map.of("this", 1.3123, "which", 0.2447, "here", 0.6674); | ||
|
||
List<Map<String, String>> optionalInputMap = new ArrayList<>(); | ||
Map<String, String> input = new HashMap<>(); | ||
input.put(modelInputField, originalQueryField); | ||
optionalInputMap.add(input); | ||
|
||
List<Map<String, String>> optionalOutputMap = new ArrayList<>(); | ||
Map<String, String> output = new HashMap<>(); | ||
output.put(newQueryField, modelInferenceJsonPathInput); | ||
optionalOutputMap.add(output); | ||
|
||
MLInferenceSearchRequestProcessor requestProcessor = new MLInferenceSearchRequestProcessor( | ||
"model1", | ||
queryTemplate, | ||
null, | ||
null, | ||
optionalInputMap, | ||
optionalOutputMap, | ||
null, | ||
DEFAULT_MAX_PREDICTION_TASKS, | ||
PROCESSOR_TAG, | ||
DESCRIPTION, | ||
false, | ||
"remote", | ||
true, | ||
false, | ||
"{ \"parameters\": ${ml_inference.parameters} }", | ||
client, | ||
TEST_XCONTENT_REGISTRY_FOR_QUERY | ||
); | ||
|
||
/** | ||
* { | ||
* "inference_results" : [ { | ||
* "output" : [ { | ||
* "name" : "response", | ||
* "dataAsMap" : { | ||
* "response" : [ { | ||
* "this" : 1.3123, | ||
* "which" : 0.2447, | ||
* "here" : 0.6674 | ||
* } ] | ||
* } | ||
* } ] | ||
* } ] | ||
* } | ||
*/ | ||
ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(Map.of("response", List.of(sparseVector))).build(); | ||
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); | ||
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); | ||
|
||
doAnswer(invocation -> { | ||
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2); | ||
actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); | ||
return null; | ||
}).when(client).execute(any(), any(), any()); | ||
|
||
QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); | ||
SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); | ||
SearchRequest request = new SearchRequest().source(source); | ||
|
||
ActionListener<SearchRequest> Listener = new ActionListener<>() { | ||
@Override | ||
public void onResponse(SearchRequest newSearchRequest) { | ||
Script script = new Script(ScriptType.INLINE, "painless", "return 1;", Map.of("query_tokens", sparseVector)); | ||
|
||
ScriptScoreQueryBuilder expectedQuery = new ScriptScoreQueryBuilder(QueryBuilders.matchAllQuery(), script); | ||
assertEquals(expectedQuery, newSearchRequest.source().query()); | ||
} | ||
|
||
@Override | ||
public void onFailure(Exception e) { | ||
throw new RuntimeException("Failed in executing processRequestAsync.", e); | ||
} | ||
}; | ||
|
||
requestProcessor.processRequestAsync(request, requestContext, Listener); | ||
} | ||
|
||
/** | ||
* Tests ML Processor can return a OpenSearch Query correctly when performing a rewrite query. | ||
* | ||
* This simulates a real world scenario where user has a llm return a OpenSearch Query to help them generate a new | ||
* query based on the context given in the prompt. | ||
* | ||
* @throws Exception when an error occurs on the test | ||
*/ | ||
public void testExecute_rewriteTermQueryWithNewQuerySuccess() throws Exception { | ||
String modelInputField = "inputs"; | ||
String originalQueryField = "query.term.text.value"; | ||
String newQueryField = "llm_query"; | ||
String modelInferenceJsonPathInput = "$.inference_results[0].output[0].dataAsMap.content[0].text"; | ||
|
||
String queryTemplate = "${llm_query}"; | ||
|
||
String llmQuery = "{\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " }\n" + "}"; | ||
Map content = Map.of("content", List.of(Map.of("text", llmQuery))); | ||
|
||
List<Map<String, String>> optionalInputMap = new ArrayList<>(); | ||
Map<String, String> input = new HashMap<>(); | ||
input.put(modelInputField, originalQueryField); | ||
optionalInputMap.add(input); | ||
|
||
List<Map<String, String>> optionalOutputMap = new ArrayList<>(); | ||
Map<String, String> output = new HashMap<>(); | ||
output.put(newQueryField, modelInferenceJsonPathInput); | ||
optionalOutputMap.add(output); | ||
|
||
MLInferenceSearchRequestProcessor requestProcessor = new MLInferenceSearchRequestProcessor( | ||
"model1", | ||
queryTemplate, | ||
null, | ||
null, | ||
optionalInputMap, | ||
optionalOutputMap, | ||
null, | ||
DEFAULT_MAX_PREDICTION_TASKS, | ||
PROCESSOR_TAG, | ||
DESCRIPTION, | ||
false, | ||
"remote", | ||
true, | ||
false, | ||
"{ \"parameters\": ${ml_inference.parameters} }", | ||
client, | ||
TEST_XCONTENT_REGISTRY_FOR_QUERY | ||
); | ||
|
||
/** | ||
* { | ||
* "inference_results" : [ { | ||
* "output" : [ { | ||
* "name" : "response", | ||
* "dataAsMap" : { | ||
* "content": [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure why its displayed out of place on github but looks normal in intelij |
||
* "text": "{\"query\": \"match_all\" : {}}" | ||
* } | ||
* } ] | ||
* } ] | ||
* } | ||
*/ | ||
ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(content).build(); | ||
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); | ||
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); | ||
|
||
doAnswer(invocation -> { | ||
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2); | ||
actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); | ||
return null; | ||
}).when(client).execute(any(), any(), any()); | ||
|
||
QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); | ||
SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); | ||
SearchRequest sampleRequest = new SearchRequest().source(source); | ||
|
||
ActionListener<SearchRequest> Listener = new ActionListener<>() { | ||
@Override | ||
public void onResponse(SearchRequest newSearchRequest) { | ||
MatchAllQueryBuilder expectedQuery = new MatchAllQueryBuilder(); | ||
assertEquals(expectedQuery, newSearchRequest.source().query()); | ||
} | ||
|
||
@Override | ||
public void onFailure(Exception e) { | ||
throw new RuntimeException("Failed in executing processRequestAsync.", e); | ||
} | ||
}; | ||
|
||
requestProcessor.processRequestAsync(sampleRequest, requestContext, Listener); | ||
} | ||
|
||
/** | ||
* Tests when there are two optional input fields | ||
* but only the second optional input is present in the query | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is the right place to fix the issue, the string convert to json string should happen before string substitution. So the next question is, should this toJson conversion only added when it's a map?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Map should be a catch all for many scenarios that I can think of at the top of my head.
I may need to test if this solution doesn't break the query from remote llm. I wonder if a list should also be supported as well for example a list of maps. then this instanceof check wouldnt work in that scenario.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made a UT where a llm returns a query and a query rewrite occurs. This happened to succeed with the code change e.g.
Off the top of my head I cant think of another scenario where this parsing issue occurs since I see that your UTs cover a lot of edge cases such as a list.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should cover more use cases in testing, because the model output can be various format,
I have two cases in mind that you can try add and verify in the UT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For Scenario one list of Map I dont think this is a valid OpenSearch query consequently If it runs through the processor it fails the query rewrite
However I did try this, taken from docs :