|
42 | 42 | import org.opensearch.core.xcontent.NamedXContentRegistry;
|
43 | 43 | import org.opensearch.core.xcontent.XContentParser;
|
44 | 44 | import org.opensearch.index.query.QueryBuilder;
|
| 45 | +import org.opensearch.index.query.QueryBuilders; |
45 | 46 | import org.opensearch.index.query.RangeQueryBuilder;
|
46 | 47 | import org.opensearch.index.query.TermQueryBuilder;
|
47 | 48 | import org.opensearch.index.query.TermsQueryBuilder;
|
| 49 | +import org.opensearch.index.query.functionscore.ScriptScoreQueryBuilder; |
48 | 50 | import org.opensearch.ingest.Processor;
|
49 | 51 | import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
|
50 | 52 | import org.opensearch.ml.common.input.MLInput;
|
|
58 | 60 | import org.opensearch.ml.searchext.MLInferenceRequestParameters;
|
59 | 61 | import org.opensearch.ml.searchext.MLInferenceRequestParametersExtBuilder;
|
60 | 62 | import org.opensearch.plugins.SearchPlugin;
|
| 63 | +import org.opensearch.script.Script; |
| 64 | +import org.opensearch.script.ScriptType; |
61 | 65 | import org.opensearch.search.SearchModule;
|
62 | 66 | import org.opensearch.search.builder.SearchSourceBuilder;
|
63 | 67 | import org.opensearch.search.pipeline.PipelineProcessingContext;
|
@@ -1499,6 +1503,115 @@ public void onFailure(Exception e) {
|
1499 | 1503 |
|
1500 | 1504 | }
|
1501 | 1505 |
|
| 1506 | + /** |
| 1507 | + * Tests ML Processor can return a sparse vector correctly when performing a rewrite query. |
| 1508 | + * |
| 1509 | + * This simulates a real world scenario where user has a neural sparse model and attempts to parse |
| 1510 | + * it by asserting FullResponsePath to true. |
| 1511 | + * @throws Exception when an error occurs on the test |
| 1512 | + */ |
| 1513 | + public void testExecute_rewriteTermQueryWithSparseVectorSuccess() throws Exception { |
| 1514 | + String modelInputField = "inputs"; |
| 1515 | + String originalQueryField = "query.term.text.value"; |
| 1516 | + String newQueryField = "vector"; |
| 1517 | + String modelInferenceJsonPathInput = "$.inference_results[0].output[0].dataAsMap.response[0]"; |
| 1518 | + |
| 1519 | + String queryTemplate = "{\n" |
| 1520 | + + " \"query\": {\n" |
| 1521 | + + " \"script_score\": {\n" |
| 1522 | + + " \"query\": {\n" |
| 1523 | + + " \"match_all\": {}\n" |
| 1524 | + + " },\n" |
| 1525 | + + " \"script\": {\n" |
| 1526 | + + " \"source\": \"return 1;\",\n" |
| 1527 | + + " \"params\": {\n" |
| 1528 | + + " \"query_tokens\": ${vector}\n" |
| 1529 | + + " }\n" |
| 1530 | + + " }\n" |
| 1531 | + + " }\n" |
| 1532 | + + " }\n" |
| 1533 | + + "}"; |
| 1534 | + |
| 1535 | + Map<String, Double> sparseVector = Map.of("this", 1.3123, "which", 0.2447, "here", 0.6674); |
| 1536 | + |
| 1537 | + List<Map<String, String>> optionalInputMap = new ArrayList<>(); |
| 1538 | + Map<String, String> input = new HashMap<>(); |
| 1539 | + input.put(modelInputField, originalQueryField); |
| 1540 | + optionalInputMap.add(input); |
| 1541 | + |
| 1542 | + List<Map<String, String>> optionalOutputMap = new ArrayList<>(); |
| 1543 | + Map<String, String> output = new HashMap<>(); |
| 1544 | + output.put(newQueryField, modelInferenceJsonPathInput); |
| 1545 | + optionalOutputMap.add(output); |
| 1546 | + |
| 1547 | + MLInferenceSearchRequestProcessor requestProcessor = new MLInferenceSearchRequestProcessor( |
| 1548 | + "model1", |
| 1549 | + queryTemplate, |
| 1550 | + null, |
| 1551 | + null, |
| 1552 | + optionalInputMap, |
| 1553 | + optionalOutputMap, |
| 1554 | + null, |
| 1555 | + DEFAULT_MAX_PREDICTION_TASKS, |
| 1556 | + PROCESSOR_TAG, |
| 1557 | + DESCRIPTION, |
| 1558 | + false, |
| 1559 | + "remote", |
| 1560 | + true, |
| 1561 | + false, |
| 1562 | + "{ \"parameters\": ${ml_inference.parameters} }", |
| 1563 | + client, |
| 1564 | + TEST_XCONTENT_REGISTRY_FOR_QUERY |
| 1565 | + ); |
| 1566 | + |
| 1567 | + /** |
| 1568 | + * { |
| 1569 | + * "inference_results" : [ { |
| 1570 | + * "output" : [ { |
| 1571 | + * "name" : "response", |
| 1572 | + * "dataAsMap" : { |
| 1573 | + * "response" : [ { |
| 1574 | + * "this" : 1.3123, |
| 1575 | + * "which" : 0.2447, |
| 1576 | + * "here" : 0.6674 |
| 1577 | + * } ] |
| 1578 | + * } |
| 1579 | + * } ] |
| 1580 | + * } ] |
| 1581 | + * } |
| 1582 | + */ |
| 1583 | + ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(Map.of("response", List.of(sparseVector))).build(); |
| 1584 | + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); |
| 1585 | + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); |
| 1586 | + |
| 1587 | + doAnswer(invocation -> { |
| 1588 | + ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2); |
| 1589 | + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); |
| 1590 | + return null; |
| 1591 | + }).when(client).execute(any(), any(), any()); |
| 1592 | + |
| 1593 | + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); |
| 1594 | + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); |
| 1595 | + SearchRequest request = new SearchRequest().source(source); |
| 1596 | + |
| 1597 | + ActionListener<SearchRequest> Listener = new ActionListener<>() { |
| 1598 | + @Override |
| 1599 | + public void onResponse(SearchRequest newSearchRequest) { |
| 1600 | + Script script = new Script(ScriptType.INLINE, "painless", "return 1;", Map.of("query_tokens", sparseVector)); |
| 1601 | + |
| 1602 | + ScriptScoreQueryBuilder expectedQuery = new ScriptScoreQueryBuilder(QueryBuilders.matchAllQuery(), script); |
| 1603 | + assertEquals(expectedQuery, newSearchRequest.source().query()); |
| 1604 | + } |
| 1605 | + |
| 1606 | + @Override |
| 1607 | + public void onFailure(Exception e) { |
| 1608 | + throw new RuntimeException("Failed in executing processRequestAsync.", e); |
| 1609 | + } |
| 1610 | + }; |
| 1611 | + |
| 1612 | + requestProcessor.processRequestAsync(request, requestContext, Listener); |
| 1613 | + } |
| 1614 | + |
1502 | 1615 | /**
|
1503 | 1616 | * Tests when there are two optional input fields
|
1504 | 1617 | * but only the second optional input is present in the query
|
|
0 commit comments