|
41 | 41 | import org.opensearch.core.action.ActionListener;
|
42 | 42 | import org.opensearch.core.xcontent.NamedXContentRegistry;
|
43 | 43 | import org.opensearch.core.xcontent.XContentParser;
|
| 44 | +import org.opensearch.index.query.MatchAllQueryBuilder; |
44 | 45 | import org.opensearch.index.query.QueryBuilder;
|
| 46 | +import org.opensearch.index.query.QueryBuilders; |
45 | 47 | import org.opensearch.index.query.RangeQueryBuilder;
|
46 | 48 | import org.opensearch.index.query.TermQueryBuilder;
|
47 | 49 | import org.opensearch.index.query.TermsQueryBuilder;
|
| 50 | +import org.opensearch.index.query.functionscore.ScriptScoreQueryBuilder; |
48 | 51 | import org.opensearch.ingest.Processor;
|
49 | 52 | import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
|
50 | 53 | import org.opensearch.ml.common.input.MLInput;
|
|
58 | 61 | import org.opensearch.ml.searchext.MLInferenceRequestParameters;
|
59 | 62 | import org.opensearch.ml.searchext.MLInferenceRequestParametersExtBuilder;
|
60 | 63 | import org.opensearch.plugins.SearchPlugin;
|
| 64 | +import org.opensearch.script.Script; |
| 65 | +import org.opensearch.script.ScriptType; |
61 | 66 | import org.opensearch.search.SearchModule;
|
62 | 67 | import org.opensearch.search.builder.SearchSourceBuilder;
|
63 | 68 | import org.opensearch.search.pipeline.PipelineProcessingContext;
|
@@ -1312,7 +1317,7 @@ public void onFailure(Exception e) {
|
1312 | 1317 |
|
1313 | 1318 | /**
|
1314 | 1319 | * Tests the successful rewriting of a complex nested array in query extension based on the model output.
|
1315 |
| - * verify the pipelineConext is set from the extension |
| 1320 | + * verify the pipelineContext is set from the extension |
1316 | 1321 | * @throws Exception if an error occurs during the test
|
1317 | 1322 | */
|
1318 | 1323 | public void testExecute_rewriteTermQueryReadAndWriteComplexNestedArrayToExtensionSuccess() throws Exception {
|
@@ -1499,6 +1504,207 @@ public void onFailure(Exception e) {
|
1499 | 1504 |
|
1500 | 1505 | }
|
1501 | 1506 |
|
| 1507 | + /** |
| 1508 | + * Tests ML Processor can return a sparse vector correctly when performing a rewrite query. |
| 1509 | + * |
| 1510 | + * This simulates a real world scenario where user has a neural sparse model and attempts to parse |
| 1511 | + * it by asserting FullResponsePath to true. |
| 1512 | + * @throws Exception when an error occurs on the test |
| 1513 | + */ |
| 1514 | + public void testExecute_rewriteTermQueryWithSparseVectorSuccess() throws Exception { |
| 1515 | + String modelInputField = "inputs"; |
| 1516 | + String originalQueryField = "query.term.text.value"; |
| 1517 | + String newQueryField = "vector"; |
| 1518 | + String modelInferenceJsonPathInput = "$.inference_results[0].output[0].dataAsMap.response[0]"; |
| 1519 | + |
| 1520 | + String queryTemplate = "{\n" |
| 1521 | + + " \"query\": {\n" |
| 1522 | + + " \"script_score\": {\n" |
| 1523 | + + " \"query\": {\n" |
| 1524 | + + " \"match_all\": {}\n" |
| 1525 | + + " },\n" |
| 1526 | + + " \"script\": {\n" |
| 1527 | + + " \"source\": \"return 1;\",\n" |
| 1528 | + + " \"params\": {\n" |
| 1529 | + + " \"query_tokens\": ${vector}\n" |
| 1530 | + + " }\n" |
| 1531 | + + " }\n" |
| 1532 | + + " }\n" |
| 1533 | + + " }\n" |
| 1534 | + + "}"; |
| 1535 | + |
| 1536 | + Map<String, Double> sparseVector = Map.of("this", 1.3123, "which", 0.2447, "here", 0.6674); |
| 1537 | + |
| 1538 | + List<Map<String, String>> optionalInputMap = new ArrayList<>(); |
| 1539 | + Map<String, String> input = new HashMap<>(); |
| 1540 | + input.put(modelInputField, originalQueryField); |
| 1541 | + optionalInputMap.add(input); |
| 1542 | + |
| 1543 | + List<Map<String, String>> optionalOutputMap = new ArrayList<>(); |
| 1544 | + Map<String, String> output = new HashMap<>(); |
| 1545 | + output.put(newQueryField, modelInferenceJsonPathInput); |
| 1546 | + optionalOutputMap.add(output); |
| 1547 | + |
| 1548 | + MLInferenceSearchRequestProcessor requestProcessor = new MLInferenceSearchRequestProcessor( |
| 1549 | + "model1", |
| 1550 | + queryTemplate, |
| 1551 | + null, |
| 1552 | + null, |
| 1553 | + optionalInputMap, |
| 1554 | + optionalOutputMap, |
| 1555 | + null, |
| 1556 | + DEFAULT_MAX_PREDICTION_TASKS, |
| 1557 | + PROCESSOR_TAG, |
| 1558 | + DESCRIPTION, |
| 1559 | + false, |
| 1560 | + "remote", |
| 1561 | + true, |
| 1562 | + false, |
| 1563 | + "{ \"parameters\": ${ml_inference.parameters} }", |
| 1564 | + client, |
| 1565 | + TEST_XCONTENT_REGISTRY_FOR_QUERY |
| 1566 | + ); |
| 1567 | + |
| 1568 | + /** |
| 1569 | + * { |
| 1570 | + * "inference_results" : [ { |
| 1571 | + * "output" : [ { |
| 1572 | + * "name" : "response", |
| 1573 | + * "dataAsMap" : { |
| 1574 | + * "response" : [ { |
| 1575 | + * "this" : 1.3123, |
| 1576 | + * "which" : 0.2447, |
| 1577 | + * "here" : 0.6674 |
| 1578 | + * } ] |
| 1579 | + * } |
| 1580 | + * } ] |
| 1581 | + * } ] |
| 1582 | + * } |
| 1583 | + */ |
| 1584 | + ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(Map.of("response", List.of(sparseVector))).build(); |
| 1585 | + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); |
| 1586 | + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); |
| 1587 | + |
| 1588 | + doAnswer(invocation -> { |
| 1589 | + ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2); |
| 1590 | + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); |
| 1591 | + return null; |
| 1592 | + }).when(client).execute(any(), any(), any()); |
| 1593 | + |
| 1594 | + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); |
| 1595 | + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); |
| 1596 | + SearchRequest request = new SearchRequest().source(source); |
| 1597 | + |
| 1598 | + ActionListener<SearchRequest> Listener = new ActionListener<>() { |
| 1599 | + @Override |
| 1600 | + public void onResponse(SearchRequest newSearchRequest) { |
| 1601 | + Script script = new Script(ScriptType.INLINE, "painless", "return 1;", Map.of("query_tokens", sparseVector)); |
| 1602 | + |
| 1603 | + ScriptScoreQueryBuilder expectedQuery = new ScriptScoreQueryBuilder(QueryBuilders.matchAllQuery(), script); |
| 1604 | + assertEquals(expectedQuery, newSearchRequest.source().query()); |
| 1605 | + } |
| 1606 | + |
| 1607 | + @Override |
| 1608 | + public void onFailure(Exception e) { |
| 1609 | + throw new RuntimeException("Failed in executing processRequestAsync.", e); |
| 1610 | + } |
| 1611 | + }; |
| 1612 | + |
| 1613 | + requestProcessor.processRequestAsync(request, requestContext, Listener); |
| 1614 | + } |
| 1615 | + |
| 1616 | + /** |
| 1617 | + * Tests ML Processor can return a OpenSearch Query correctly when performing a rewrite query. |
| 1618 | + * |
| 1619 | + * This simulates a real world scenario where user has a llm return a OpenSearch Query to help them generate a new |
| 1620 | + * query based on the context given in the prompt. |
| 1621 | + * |
| 1622 | + * @throws Exception when an error occurs on the test |
| 1623 | + */ |
| 1624 | + public void testExecute_rewriteTermQueryWithNewQuerySuccess() throws Exception { |
| 1625 | + String modelInputField = "inputs"; |
| 1626 | + String originalQueryField = "query.term.text.value"; |
| 1627 | + String newQueryField = "llm_query"; |
| 1628 | + String modelInferenceJsonPathInput = "$.inference_results[0].output[0].dataAsMap.content[0].text"; |
| 1629 | + |
| 1630 | + String queryTemplate = "${llm_query}"; |
| 1631 | + |
| 1632 | + String llmQuery = "{\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " }\n" + "}"; |
| 1633 | + Map content = Map.of("content", List.of(Map.of("text", llmQuery))); |
| 1634 | + |
| 1635 | + List<Map<String, String>> optionalInputMap = new ArrayList<>(); |
| 1636 | + Map<String, String> input = new HashMap<>(); |
| 1637 | + input.put(modelInputField, originalQueryField); |
| 1638 | + optionalInputMap.add(input); |
| 1639 | + |
| 1640 | + List<Map<String, String>> optionalOutputMap = new ArrayList<>(); |
| 1641 | + Map<String, String> output = new HashMap<>(); |
| 1642 | + output.put(newQueryField, modelInferenceJsonPathInput); |
| 1643 | + optionalOutputMap.add(output); |
| 1644 | + |
| 1645 | + MLInferenceSearchRequestProcessor requestProcessor = new MLInferenceSearchRequestProcessor( |
| 1646 | + "model1", |
| 1647 | + queryTemplate, |
| 1648 | + null, |
| 1649 | + null, |
| 1650 | + optionalInputMap, |
| 1651 | + optionalOutputMap, |
| 1652 | + null, |
| 1653 | + DEFAULT_MAX_PREDICTION_TASKS, |
| 1654 | + PROCESSOR_TAG, |
| 1655 | + DESCRIPTION, |
| 1656 | + false, |
| 1657 | + "remote", |
| 1658 | + true, |
| 1659 | + false, |
| 1660 | + "{ \"parameters\": ${ml_inference.parameters} }", |
| 1661 | + client, |
| 1662 | + TEST_XCONTENT_REGISTRY_FOR_QUERY |
| 1663 | + ); |
| 1664 | + |
| 1665 | + /* |
| 1666 | + * { |
| 1667 | + * "inference_results" : [ { |
| 1668 | + * "output" : [ { |
| 1669 | + * "name" : "response", |
| 1670 | + * "dataAsMap" : { |
| 1671 | + * "content": [ |
| 1672 | + * "text": "{\"query\": \"match_all\" : {}}" |
| 1673 | + * } |
| 1674 | + * } ] |
| 1675 | + * } ] |
| 1676 | + * } |
| 1677 | + */ |
| 1678 | + ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(content).build(); |
| 1679 | + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); |
| 1680 | + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); |
| 1681 | + |
| 1682 | + doAnswer(invocation -> { |
| 1683 | + ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2); |
| 1684 | + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); |
| 1685 | + return null; |
| 1686 | + }).when(client).execute(any(), any(), any()); |
| 1687 | + |
| 1688 | + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); |
| 1689 | + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); |
| 1690 | + SearchRequest sampleRequest = new SearchRequest().source(source); |
| 1691 | + |
| 1692 | + ActionListener<SearchRequest> Listener = new ActionListener<>() { |
| 1693 | + @Override |
| 1694 | + public void onResponse(SearchRequest newSearchRequest) { |
| 1695 | + MatchAllQueryBuilder expectedQuery = new MatchAllQueryBuilder(); |
| 1696 | + assertEquals(expectedQuery, newSearchRequest.source().query()); |
| 1697 | + } |
| 1698 | + |
| 1699 | + @Override |
| 1700 | + public void onFailure(Exception e) { |
| 1701 | + throw new RuntimeException("Failed in executing processRequestAsync.", e); |
| 1702 | + } |
| 1703 | + }; |
| 1704 | + |
| 1705 | + requestProcessor.processRequestAsync(sampleRequest, requestContext, Listener); |
| 1706 | + } |
| 1707 | + |
1502 | 1708 | /**
|
1503 | 1709 | * Tests when there are two optional input fields
|
1504 | 1710 | * but only the second optional input is present in the query
|
|
0 commit comments