Skip to content

Commit 73129c6

Browse files
Add pre and post process functions for Bedrock Rerank API #3254 (#3339) (#3456) (#3500)
* Add pre and post process functions for Bedrock Rerank API #3254 Signed-off-by: tkykenmt <tkykenmto+github.com@gmail.com> * modify format using spotlessApply Signed-off-by: tkykenmt <tkykenmto+github.com@gmail.com> * Fix on validation/converting scores #3339 Signed-off-by: tkykenmt <tkykenmto+github.com@gmail.com> * Fix on method name of test case for list of maps data #3339 Signed-off-by: tkykenmt <tkykenmto+github.com@gmail.com> * remove unnecessary cast #3339 Signed-off-by: tkykenmt <tkykenmto+github.com@gmail.com> --------- Signed-off-by: tkykenmt <tkykenmto+github.com@gmail.com> (cherry picked from commit d05674a) Co-authored-by: Takayuki Enomoto <4161768+tkykenmt@users.noreply.github.com> (cherry picked from commit a29627e) Co-authored-by: opensearch-trigger-bot[bot] <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com>
1 parent 22f258b commit 73129c6

File tree

6 files changed

+323
-0
lines changed

6 files changed

+323
-0
lines changed

common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import org.opensearch.ml.common.connector.functions.postprocess.BedrockBatchJobArnPostProcessFunction;
1414
import org.opensearch.ml.common.connector.functions.postprocess.BedrockEmbeddingPostProcessFunction;
15+
import org.opensearch.ml.common.connector.functions.postprocess.BedrockRerankPostProcessFunction;
1516
import org.opensearch.ml.common.connector.functions.postprocess.CohereRerankPostProcessFunction;
1617
import org.opensearch.ml.common.connector.functions.postprocess.EmbeddingPostProcessFunction;
1718
import org.opensearch.ml.common.output.model.ModelTensor;
@@ -23,6 +24,7 @@ public class MLPostProcessFunction {
2324
public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding";
2425
public static final String BEDROCK_BATCH_JOB_ARN = "connector.post_process.bedrock.batch_job_arn";
2526
public static final String COHERE_RERANK = "connector.post_process.cohere.rerank";
27+
public static final String BEDROCK_RERANK = "connector.post_process.bedrock.rerank";
2628
public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding";
2729
public static final String DEFAULT_RERANK = "connector.post_process.default.rerank";
2830

@@ -35,19 +37,22 @@ public class MLPostProcessFunction {
3537
BedrockEmbeddingPostProcessFunction bedrockEmbeddingPostProcessFunction = new BedrockEmbeddingPostProcessFunction();
3638
BedrockBatchJobArnPostProcessFunction batchJobArnPostProcessFunction = new BedrockBatchJobArnPostProcessFunction();
3739
CohereRerankPostProcessFunction cohereRerankPostProcessFunction = new CohereRerankPostProcessFunction();
40+
BedrockRerankPostProcessFunction bedrockRerankPostProcessFunction = new BedrockRerankPostProcessFunction();
3841
JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding");
3942
JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings");
4043
JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]");
4144
JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding");
4245
JSON_PATH_EXPRESSION.put(BEDROCK_BATCH_JOB_ARN, "$");
4346
JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results");
47+
JSON_PATH_EXPRESSION.put(BEDROCK_RERANK, "$.results");
4448
JSON_PATH_EXPRESSION.put(DEFAULT_RERANK, "$[*]");
4549
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, embeddingPostProcessFunction);
4650
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction);
4751
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, embeddingPostProcessFunction);
4852
POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, bedrockEmbeddingPostProcessFunction);
4953
POST_PROCESS_FUNCTIONS.put(BEDROCK_BATCH_JOB_ARN, batchJobArnPostProcessFunction);
5054
POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction);
55+
POST_PROCESS_FUNCTIONS.put(BEDROCK_RERANK, bedrockRerankPostProcessFunction);
5156
POST_PROCESS_FUNCTIONS.put(DEFAULT_RERANK, cohereRerankPostProcessFunction);
5257
}
5358

common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import java.util.function.Function;
1111

1212
import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction;
13+
import org.opensearch.ml.common.connector.functions.preprocess.BedrockRerankPreProcessFunction;
1314
import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction;
1415
import org.opensearch.ml.common.connector.functions.preprocess.CohereMultiModalEmbeddingPreProcessFunction;
1516
import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction;
@@ -28,6 +29,7 @@ public class MLPreProcessFunction {
2829
public static final String TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.multimodal_embedding";
2930
public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding";
3031
public static final String TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT = "connector.pre_process.cohere.rerank";
32+
public static final String TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT = "connector.pre_process.bedrock.rerank";
3133
public static final String TEXT_SIMILARITY_TO_DEFAULT_INPUT = "connector.pre_process.default.rerank";
3234

3335
public static final String PROCESS_REMOTE_INFERENCE_INPUT = "pre_process_function.process_remote_inference_input";
@@ -38,6 +40,7 @@ public class MLPreProcessFunction {
3840
OpenAIEmbeddingPreProcessFunction openAIEmbeddingPreProcessFunction = new OpenAIEmbeddingPreProcessFunction();
3941
BedrockEmbeddingPreProcessFunction bedrockEmbeddingPreProcessFunction = new BedrockEmbeddingPreProcessFunction();
4042
CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction();
43+
BedrockRerankPreProcessFunction bedrockRerankPreProcessFunction = new BedrockRerankPreProcessFunction();
4144
MultiModalConnectorPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalConnectorPreProcessFunction();
4245
CohereMultiModalEmbeddingPreProcessFunction cohereMultiModalEmbeddingPreProcessFunction =
4346
new CohereMultiModalEmbeddingPreProcessFunction();
@@ -49,6 +52,7 @@ public class MLPreProcessFunction {
4952
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction);
5053
PRE_PROCESS_FUNCTIONS.put(TEXT_SIMILARITY_TO_DEFAULT_INPUT, cohereRerankPreProcessFunction);
5154
PRE_PROCESS_FUNCTIONS.put(TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT, cohereRerankPreProcessFunction);
55+
PRE_PROCESS_FUNCTIONS.put(TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT, bedrockRerankPreProcessFunction);
5256
}
5357

5458
public static boolean contains(String functionName) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.connector.functions.postprocess;
7+
8+
import java.math.BigDecimal;
9+
import java.util.ArrayList;
10+
import java.util.List;
11+
import java.util.Map;
12+
13+
import org.opensearch.ml.common.output.model.MLResultDataType;
14+
import org.opensearch.ml.common.output.model.ModelTensor;
15+
16+
public class BedrockRerankPostProcessFunction extends ConnectorPostProcessFunction<List<Map<String, Object>>> {
17+
18+
@Override
19+
public void validate(Object input) {
20+
21+
if (!(input instanceof List)) {
22+
throw new IllegalArgumentException("Post process function input is not a List.");
23+
}
24+
25+
List<?> outerList = (List<?>) input;
26+
27+
if (outerList.isEmpty()) {
28+
throw new IllegalArgumentException("Post process function input is empty.");
29+
}
30+
31+
for (Object item : outerList) {
32+
if (!(item instanceof Map)) {
33+
throw new IllegalArgumentException("Rerank result is not a Map.");
34+
}
35+
36+
Map<?, ?> innerMap = (Map<?, ?>) item;
37+
38+
if (innerMap.isEmpty()) {
39+
throw new IllegalArgumentException("Rerank result is empty.");
40+
}
41+
42+
if (!innerMap.containsKey("index") || !innerMap.containsKey("relevanceScore")) {
43+
throw new IllegalArgumentException("Rerank result should have both index and relevanceScore.");
44+
}
45+
46+
if (!(innerMap.get("relevanceScore") instanceof BigDecimal || innerMap.get("relevanceScore") instanceof Double)) {
47+
throw new IllegalArgumentException("relevanceScore is not BigDecimal or Double.");
48+
}
49+
}
50+
}
51+
52+
@Override
53+
public List<ModelTensor> process(List<Map<String, Object>> rerankResults) {
54+
List<ModelTensor> modelTensors = new ArrayList<>();
55+
56+
if (!rerankResults.isEmpty()) {
57+
Double[] scores = new Double[rerankResults.size()];
58+
for (Map rerankResult : rerankResults) {
59+
Integer index = (Integer) rerankResult.get("index");
60+
Object relevanceScore = rerankResult.get("relevanceScore");
61+
if (relevanceScore instanceof BigDecimal) {
62+
scores[index] = ((BigDecimal) relevanceScore).doubleValue();
63+
} else if (relevanceScore instanceof Double) {
64+
scores[index] = (Double) relevanceScore;
65+
}
66+
}
67+
for (Double score : scores) {
68+
modelTensors
69+
.add(
70+
ModelTensor
71+
.builder()
72+
.name("similarity")
73+
.shape(new long[] { 1 })
74+
.data(new Number[] { score })
75+
.dataType(MLResultDataType.FLOAT32)
76+
.build()
77+
);
78+
}
79+
}
80+
return modelTensors;
81+
}
82+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.connector.functions.preprocess;
7+
8+
import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;
9+
10+
import java.util.ArrayList;
11+
import java.util.List;
12+
import java.util.Map;
13+
14+
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
15+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
16+
import org.opensearch.ml.common.input.MLInput;
17+
18+
public class BedrockRerankPreProcessFunction extends ConnectorPreProcessFunction {
19+
20+
public BedrockRerankPreProcessFunction() {
21+
this.returnDirectlyForRemoteInferenceInput = true;
22+
}
23+
24+
@Override
25+
public void validate(MLInput mlInput) {
26+
27+
if (mlInput.getInputDataset() == null) {
28+
throw new IllegalArgumentException("Input dataset cannot be null.");
29+
}
30+
31+
if (!(mlInput.getInputDataset() instanceof TextSimilarityInputDataSet)) {
32+
throw new IllegalArgumentException("This pre_process_function can only support TextSimilarityInputDataSet");
33+
}
34+
}
35+
36+
@Override
37+
public RemoteInferenceInputDataSet process(MLInput mlInput) {
38+
TextSimilarityInputDataSet inputData = (TextSimilarityInputDataSet) mlInput.getInputDataset();
39+
String queryText = inputData.getQueryText();
40+
List<String> textDocs = inputData.getTextDocs();
41+
42+
List<Map<String, Object>> queries = new ArrayList<Map<String, Object>>();
43+
queries.add(Map.of("textQuery", Map.of("text", queryText), "type", "TEXT"));
44+
45+
List<Map<String, Object>> sources = new ArrayList<Map<String, Object>>();
46+
inputData.getTextDocs().forEach(textDoc -> {
47+
sources.add(Map.of("inlineDocumentSource", Map.of("textDocument", Map.of("text", textDoc), "type", "TEXT"), "type", "INLINE"));
48+
});
49+
50+
Map<String, Object> processedResult = Map
51+
.of("parameters", Map.of("queries", queries, "sources", sources, "numberOfResults", textDocs.size()));
52+
53+
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build();
54+
}
55+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.connector.functions.postprocess;
7+
8+
import static org.junit.Assert.assertEquals;
9+
10+
import java.util.Arrays;
11+
import java.util.List;
12+
import java.util.Map;
13+
14+
import org.junit.Before;
15+
import org.junit.Rule;
16+
import org.junit.Test;
17+
import org.junit.rules.ExpectedException;
18+
import org.opensearch.ml.common.output.model.ModelTensor;
19+
20+
public class BedrockRerankPostProcessFunctionTest {
21+
@Rule
22+
public ExpectedException exceptionRule = ExpectedException.none();
23+
24+
BedrockRerankPostProcessFunction function;
25+
26+
@Before
27+
public void setUp() {
28+
function = new BedrockRerankPostProcessFunction();
29+
}
30+
31+
@Test
32+
public void process_WrongInput_NotList() {
33+
exceptionRule.expect(IllegalArgumentException.class);
34+
exceptionRule.expectMessage("Post process function input is not a List.");
35+
function.apply("abc");
36+
}
37+
38+
@Test
39+
public void process_EmptyInput() {
40+
exceptionRule.expect(IllegalArgumentException.class);
41+
exceptionRule.expectMessage("Post process function input is empty.");
42+
function.apply(Arrays.asList());
43+
}
44+
45+
@Test
46+
public void process_WrongInput_NotCorrectListOfMapsFormat() {
47+
exceptionRule.expect(IllegalArgumentException.class);
48+
exceptionRule.expectMessage("Rerank result is not a Map.");
49+
function.apply(Arrays.asList("abc"));
50+
}
51+
52+
@Test
53+
public void process_EmptyMapInput() {
54+
exceptionRule.expect(IllegalArgumentException.class);
55+
exceptionRule.expectMessage("Rerank result is empty.");
56+
function.apply(Arrays.asList(Map.of()));
57+
}
58+
59+
@Test
60+
public void process_WrongInput_NotCorrectMap() {
61+
exceptionRule.expect(IllegalArgumentException.class);
62+
exceptionRule.expectMessage("Rerank result should have both index and relevanceScore.");
63+
List<Map<String, Object>> rerankResults = List
64+
.of(
65+
Map.of("index", 2, "relevanceScore", 0.7711548805236816),
66+
Map.of("index", 0, "relevanceScore", 0.0025114635936915874),
67+
Map.of("index", 1, "relevanceScore", 2.4876489987946115e-05),
68+
Map.of("test1", "value1")
69+
);
70+
function.apply(rerankResults);
71+
}
72+
73+
@Test
74+
public void process_WrongInput_NotCorrectRelevanceScore() {
75+
exceptionRule.expect(IllegalArgumentException.class);
76+
exceptionRule.expectMessage("relevanceScore is not BigDecimal or Double.");
77+
List<Map<String, Object>> rerankResults = List
78+
.of(
79+
Map.of("index", 2, "relevanceScore", 0.7711548805236816),
80+
Map.of("index", 0, "relevanceScore", 0.0025114635936915874),
81+
Map.of("index", 1, "relevanceScore", 2.4876489987946115e-05),
82+
Map.of("index", 3, "relevanceScore", "value1")
83+
);
84+
function.apply(rerankResults);
85+
}
86+
87+
@Test
88+
public void process_CorrectInput() {
89+
List<Map<String, Object>> rerankResults = List
90+
.of(
91+
Map.of("index", 2, "relevanceScore", 0.7711548805236816),
92+
Map.of("index", 0, "relevanceScore", 0.0025114635936915874),
93+
Map.of("index", 1, "relevanceScore", 2.4876489987946115e-05),
94+
Map.of("index", 3, "relevanceScore", 6.339210358419223e-06)
95+
);
96+
List<ModelTensor> result = function.apply(rerankResults);
97+
assertEquals(4, result.size());
98+
assertEquals(1, result.get(0).getData().length);
99+
assertEquals(0.0025114635936915874, result.get(0).getData()[0]);
100+
assertEquals(2.4876489987946115e-05, result.get(1).getData()[0]);
101+
assertEquals(0.7711548805236816, result.get(2).getData()[0]);
102+
assertEquals(6.339210358419223e-06, result.get(3).getData()[0]);
103+
}
104+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.connector.functions.preprocess;
7+
8+
import static org.junit.Assert.assertEquals;
9+
import static org.junit.Assert.assertTrue;
10+
11+
import java.util.Arrays;
12+
13+
import org.json.JSONArray;
14+
import org.junit.Before;
15+
import org.junit.Rule;
16+
import org.junit.Test;
17+
import org.junit.rules.ExpectedException;
18+
import org.opensearch.ml.common.FunctionName;
19+
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
20+
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
21+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
22+
import org.opensearch.ml.common.input.MLInput;
23+
24+
public class BedrockRerankPreProcessFunctionTest {
25+
@Rule
26+
public ExpectedException exceptionRule = ExpectedException.none();
27+
28+
BedrockRerankPreProcessFunction function;
29+
30+
TextSimilarityInputDataSet textSimilarityInputDataSet;
31+
TextDocsInputDataSet textDocsInputDataSet;
32+
33+
@Before
34+
public void setUp() {
35+
function = new BedrockRerankPreProcessFunction();
36+
textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build();
37+
textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build();
38+
}
39+
40+
@Test
41+
public void process_NullInput() {
42+
exceptionRule.expect(IllegalArgumentException.class);
43+
exceptionRule.expectMessage("Preprocess function input can't be null");
44+
function.apply(null);
45+
}
46+
47+
@Test
48+
public void process_WrongInput() {
49+
exceptionRule.expect(IllegalArgumentException.class);
50+
exceptionRule.expectMessage("This pre_process_function can only support TextSimilarityInputDataSet");
51+
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
52+
function.apply(mlInput);
53+
}
54+
55+
@Test
56+
public void process_CorrectInput() {
57+
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build();
58+
RemoteInferenceInputDataSet dataSet = function.apply(mlInput);
59+
assertEquals(3, dataSet.getParameters().size());
60+
61+
JSONArray expectedSources = new JSONArray(
62+
"[{\"type\": \"INLINE\", \"inlineDocumentSource\": {\"type\": \"TEXT\", \"textDocument\": {\"text\": \"hello\"}}}]"
63+
);
64+
JSONArray actualSources = new JSONArray(dataSet.getParameters().get("sources"));
65+
assertTrue(expectedSources.getJSONObject(0).similar(actualSources.getJSONObject(0)));
66+
67+
JSONArray expectedQueries = new JSONArray("[{\"textQuery\": {\"text\": \"test\"}, \"type\": \"TEXT\"}]");
68+
JSONArray actualQueries = new JSONArray(dataSet.getParameters().get("queries"));
69+
assertTrue(expectedQueries.getJSONObject(0).similar(actualQueries.getJSONObject(0)));
70+
71+
assertEquals("1", dataSet.getParameters().get("numberOfResults"));
72+
}
73+
}

0 commit comments

Comments
 (0)