Skip to content

Commit 2da7d7a

Browse files
authored
fix error message when input map and output map length not match (#3730)
* fix error message Signed-off-by: Mingshi Liu <mingshl@amazon.com> * fix store array in ml inference search extension Signed-off-by: Mingshi Liu <mingshl@amazon.com> * add key checks Signed-off-by: Mingshi Liu <mingshl@amazon.com> --------- Signed-off-by: Mingshi Liu <mingshl@amazon.com>
1 parent b53a7e5 commit 2da7d7a

File tree

4 files changed

+285
-20
lines changed

4 files changed

+285
-20
lines changed

plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ public MLInferenceSearchRequestProcessor create(
707707
"when output_maps/optional_output_maps and input_maps/optional_input_maps are provided, their length needs to match. The input is in length of "
708708
+ combinedInputMaps.size()
709709
+ ", while output_maps is in the length of "
710-
+ combinedInputMaps.size()
710+
+ combinedOutputMaps.size()
711711
+ ". Please adjust mappings."
712712
);
713713
}

plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -626,29 +626,37 @@ public void onResponse(Map<Integer, MLOutput> multipleMLOutputs) {
626626
ignoreMissing,
627627
fullResponsePath
628628
);
629-
Object modelOutputValuePerDoc;
630-
if (modelOutputValue instanceof List
631-
&& ((List) modelOutputValue).size() == hitCountInPredictions.get(mappingIndex)) {
632-
Object valuePerDoc = ((List) modelOutputValue)
633-
.get(MapUtils.getCounter(writeOutputMapDocCounter, mappingIndex, modelOutputFieldName));
634-
modelOutputValuePerDoc = valuePerDoc;
635-
} else {
636-
modelOutputValuePerDoc = modelOutputValue;
637-
}
638629
// writing to search response extension
639630
if (newDocumentFieldName.startsWith(EXTENSION_PREFIX)) {
640631
Map<String, Object> params = ((MLInferenceSearchResponse) response).getParams();
641632
String paramsName = newDocumentFieldName.replaceFirst(EXTENSION_PREFIX + ".", "");
642633

643634
if (params != null) {
644-
params.put(paramsName, modelOutputValuePerDoc);
635+
params.put(paramsName, modelOutputValue);
645636
((MLInferenceSearchResponse) response).setParams(params);
646637
} else {
647638
Map<String, Object> newParams = new HashMap<>();
648-
newParams.put(paramsName, modelOutputValuePerDoc);
639+
newParams.put(paramsName, modelOutputValue);
649640
((MLInferenceSearchResponse) response).setParams(newParams);
650641
}
651642
} else {
643+
Object modelOutputValuePerDoc;
644+
if (hitCountInPredictions.containsKey(mappingIndex)) {
645+
if (modelOutputValue instanceof List
646+
&& ((List) modelOutputValue).size() == hitCountInPredictions.get(mappingIndex)
647+
&& !oneToOne) {
648+
Object valuePerDoc = ((List) modelOutputValue)
649+
.get(
650+
MapUtils
651+
.getCounter(writeOutputMapDocCounter, mappingIndex, modelOutputFieldName)
652+
);
653+
modelOutputValuePerDoc = valuePerDoc;
654+
} else {
655+
modelOutputValuePerDoc = modelOutputValue;
656+
}
657+
} else {
658+
modelOutputValuePerDoc = modelOutputValue;
659+
}
652660
// writing to search response hits
653661
if (sourceAsMap.containsKey(newDocumentFieldName)) {
654662
if (override) {
@@ -902,12 +910,15 @@ public MLInferenceSearchResponseProcessor create(
902910
+ ". Please reduce the size of input_map or optional_input_map or increase max_prediction_tasks."
903911
);
904912
}
905-
if (combinedOutputMaps != null && combinedInputMaps != null && combinedOutputMaps.size() != combinedInputMaps.size()) {
913+
914+
if (!CollectionUtils.isEmpty(combinedOutputMaps)
915+
&& !CollectionUtils.isEmpty(combinedInputMaps)
916+
&& combinedOutputMaps.size() != combinedInputMaps.size()) {
906917
throw new IllegalArgumentException(
907918
"when output_maps/optional_output_maps and input_maps/optional_input_maps are provided, their length needs to match. The input is in length of "
908919
+ combinedInputMaps.size()
909920
+ ", while output_maps is in the length of "
910-
+ combinedInputMaps.size()
921+
+ combinedOutputMaps.size()
911922
+ ". Please adjust mappings."
912923
);
913924
}

plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2250,7 +2250,7 @@ public void testOutputMapsExceedInputMaps() throws Exception {
22502250
} catch (IllegalArgumentException e) {
22512251
assertEquals(
22522252
e.getMessage(),
2253-
("when output_maps/optional_output_maps and input_maps/optional_input_maps are provided, their length needs to match. The input is in length of 2, while output_maps is in the length of 2. Please adjust mappings.")
2253+
("when output_maps/optional_output_maps and input_maps/optional_input_maps are provided, their length needs to match. The input is in length of 2, while output_maps is in the length of 3. Please adjust mappings.")
22542254
);
22552255
}
22562256
}

plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java

Lines changed: 259 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,126 @@ public void testProcessResponseSuccessWriteToExt() throws Exception {
10361036
@Override
10371037
public void onResponse(SearchResponse newSearchResponse) {
10381038
assertEquals(newSearchResponse.getHits().getHits().length, 5);
1039+
MLInferenceSearchResponse mLInferenceSearchResponse = (MLInferenceSearchResponse) newSearchResponse;
1040+
String resultsInResponse = (String) mLInferenceSearchResponse.getParams().get("llm_response");
1041+
assertEquals("there is 1 value", resultsInResponse);
1042+
}
1043+
1044+
@Override
1045+
public void onFailure(Exception e) {
1046+
throw new RuntimeException(e);
1047+
}
1048+
1049+
};
1050+
responseProcessor.processResponseAsync(request, response, responseContext, listener);
1051+
verify(client, times(1)).execute(any(), any(), any());
1052+
}
1053+
1054+
/**
1055+
* Tests the successful processing of a response with a single pair of input and output mappings.
1056+
* read the query text into model config with query extensions
1057+
* read the prediction outcome as array and store in search extension
1058+
* @throws Exception if an error occurs during the test
1059+
*/
1060+
@Test
1061+
public void testProcessResponseSuccessArrayWriteToExt() throws Exception {
1062+
String documentField = "text";
1063+
String modelInputField = "context";
1064+
List<Map<String, String>> inputMap = new ArrayList<>();
1065+
Map<String, String> input = new HashMap<>();
1066+
input.put(modelInputField, documentField);
1067+
inputMap.add(input);
1068+
1069+
String newDocumentField = "ext.ml_inference.results";
1070+
String modelOutputField = "results[*].document.text";
1071+
List<Map<String, String>> outputMap = new ArrayList<>();
1072+
Map<String, String> output = new HashMap<>();
1073+
output.put(newDocumentField, modelOutputField);
1074+
outputMap.add(output);
1075+
Map<String, String> modelConfig = new HashMap<>();
1076+
modelConfig.put("query", "positive review");
1077+
MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor(
1078+
"model1",
1079+
inputMap,
1080+
outputMap,
1081+
optionalInputMaps,
1082+
optionalOutputMaps,
1083+
modelConfig,
1084+
DEFAULT_MAX_PREDICTION_TASKS,
1085+
PROCESSOR_TAG,
1086+
DESCRIPTION,
1087+
false,
1088+
"remote",
1089+
false,
1090+
false,
1091+
false,
1092+
"{ \"parameters\": ${ml_inference.parameters} }",
1093+
client,
1094+
TEST_XCONTENT_REGISTRY_FOR_QUERY,
1095+
false
1096+
);
1097+
1098+
SearchRequest request = getSearchRequest();
1099+
String fieldName = "text";
1100+
SearchResponse response = getSearchResponse(5, true, fieldName);
1101+
1102+
Map<String, Object> inferenceResultMap = new HashMap<>();
1103+
1104+
Map<String, Object> doc1 = new HashMap<>();
1105+
Map<String, Object> doc1Text = new HashMap<>();
1106+
doc1Text.put("text", "value1");
1107+
doc1.put("document", doc1Text);
1108+
doc1.put("index", 0.0);
1109+
doc1.put("relevance_score", 2.6480842E-5);
1110+
1111+
Map<String, Object> doc2 = new HashMap<>();
1112+
Map<String, Object> doc2Text = new HashMap<>();
1113+
doc2Text.put("text", "value5");
1114+
doc2.put("document", doc2Text);
1115+
doc2.put("index", 4.0);
1116+
doc2.put("relevance_score", 2.5071593E-5);
1117+
1118+
Map<String, Object> doc3 = new HashMap<>();
1119+
Map<String, Object> doc3Text = new HashMap<>();
1120+
doc3Text.put("text", "value4");
1121+
doc3.put("document", doc3Text);
1122+
doc3.put("index", 3.0);
1123+
doc3.put("relevance_score", 2.373734E-5);
1124+
1125+
Map<String, Object> doc4 = new HashMap<>();
1126+
Map<String, Object> doc4Text = new HashMap<>();
1127+
doc4Text.put("text", "value2");
1128+
doc4.put("document", doc4Text);
1129+
doc4.put("index", 1.0);
1130+
doc4.put("relevance_score", 2.1112483E-5);
1131+
1132+
Map<String, Object> doc5 = new HashMap<>();
1133+
Map<String, Object> doc5Text = new HashMap<>();
1134+
doc5Text.put("text", "value3");
1135+
doc5.put("document", doc5Text);
1136+
doc5.put("index", 2.0);
1137+
doc5.put("relevance_score", 1.6187581E-5);
1138+
1139+
inferenceResultMap.put("results", Arrays.asList(doc1, doc2, doc3, doc4, doc5));
1140+
1141+
ModelTensor modelTensor = ModelTensor.builder().dataAsMap(inferenceResultMap).build();
1142+
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
1143+
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
1144+
1145+
doAnswer(invocation -> {
1146+
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
1147+
actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build());
1148+
return null;
1149+
}).when(client).execute(any(), any(), any());
1150+
1151+
ActionListener<SearchResponse> listener = new ActionListener<>() {
1152+
@Override
1153+
public void onResponse(SearchResponse newSearchResponse) {
1154+
assertEquals(newSearchResponse.getHits().getHits().length, 5);
1155+
MLInferenceSearchResponse mLInferenceSearchResponse = (MLInferenceSearchResponse) newSearchResponse;
1156+
List<Map<String, Object>> results = (List<Map<String, Object>>) inferenceResultMap.get("results");
1157+
List<String> resultsInResponse = (List<String>) mLInferenceSearchResponse.getParams().get("results");
1158+
assertEquals(results.size(), resultsInResponse.size());
10391159
}
10401160

10411161
@Override
@@ -5166,17 +5286,151 @@ public void testOutputMapsExceedInputMaps() throws Exception {
51665286
} catch (IllegalArgumentException e) {
51675287
assertEquals(
51685288
e.getMessage(),
5169-
"when output_maps/optional_output_maps and input_maps/optional_input_maps are provided, their length needs to match. The input is in length of 2, while output_maps is in the length of 2. Please adjust mappings."
5289+
"when output_maps/optional_output_maps and input_maps/optional_input_maps are provided, their length needs to match. The input is in length of 2, while output_maps is in the length of 3. Please adjust mappings."
51705290
);
51715291

51725292
}
51735293
}
51745294

51755295
/**
5176-
* Tests the creation of the MLInferenceSearchResponseProcessor with optional fields.
5177-
*
5178-
* @throws Exception if an error occurs during the test
5179-
*/
5296+
* Tests the case where only the input maps are provided in the configuration.
5297+
*
5298+
* @throws Exception if an error occurs during the test
5299+
*/
5300+
public void testOnlyInputMapsProvided() throws Exception {
5301+
Map<String, Object> config = new HashMap<>();
5302+
config.put(MODEL_ID, "model2");
5303+
List<Map<String, String>> inputMap = new ArrayList<>();
5304+
Map<String, String> input0 = new HashMap<>();
5305+
input0.put("inputs", "text");
5306+
inputMap.add(input0);
5307+
Map<String, String> input1 = new HashMap<>();
5308+
input1.put("inputs", "hashtag");
5309+
inputMap.add(input1);
5310+
config.put(INPUT_MAP, inputMap);
5311+
config.put(MAX_PREDICTION_TASKS, 2);
5312+
String processorTag = randomAlphaOfLength(10);
5313+
5314+
factory.create(Collections.emptyMap(), processorTag, null, false, config, null);
5315+
}
5316+
5317+
/**
5318+
* Tests the case where the input maps and empty output map are provided in the configuration.
5319+
*
5320+
* @throws Exception if an error occurs during the test
5321+
*/
5322+
public void testInputMapsEmptyOutputMapProvided() throws Exception {
5323+
Map<String, Object> config = new HashMap<>();
5324+
config.put(MODEL_ID, "model2");
5325+
List<Map<String, String>> inputMap = new ArrayList<>();
5326+
Map<String, String> input0 = new HashMap<>();
5327+
input0.put("inputs", "text");
5328+
inputMap.add(input0);
5329+
Map<String, String> input1 = new HashMap<>();
5330+
input1.put("inputs", "hashtag");
5331+
inputMap.add(input1);
5332+
config.put(INPUT_MAP, inputMap);
5333+
config.put(MAX_PREDICTION_TASKS, 2);
5334+
String processorTag = randomAlphaOfLength(10);
5335+
5336+
List<Map<String, String>> outputMap = new ArrayList<>();
5337+
config.put(OUTPUT_MAP, outputMap);
5338+
5339+
factory.create(Collections.emptyMap(), processorTag, null, false, config, null);
5340+
}
5341+
5342+
/**
5343+
* Tests the case where only the Optional input maps are provided in the configuration.
5344+
*
5345+
* @throws Exception if an error occurs during the test
5346+
*/
5347+
public void testOnlyOptionalInputMapsProvided() throws Exception {
5348+
Map<String, Object> config = new HashMap<>();
5349+
config.put(MODEL_ID, "model2");
5350+
List<Map<String, String>> inputMap = new ArrayList<>();
5351+
Map<String, String> input0 = new HashMap<>();
5352+
input0.put("inputs", "text");
5353+
inputMap.add(input0);
5354+
Map<String, String> input1 = new HashMap<>();
5355+
input1.put("inputs", "hashtag");
5356+
inputMap.add(input1);
5357+
config.put(OPTIONAL_INPUT_MAP, inputMap);
5358+
config.put(MAX_PREDICTION_TASKS, 2);
5359+
String processorTag = randomAlphaOfLength(10);
5360+
5361+
factory.create(Collections.emptyMap(), processorTag, null, false, config, null);
5362+
5363+
}
5364+
5365+
/**
5366+
* Tests the case where only the Optional input maps are provided in the configuration.
5367+
*
5368+
* @throws Exception if an error occurs during the test
5369+
*/
5370+
public void testOnlyOptionalInputMapsEmptyOptionalOutputProvided() throws Exception {
5371+
Map<String, Object> config = new HashMap<>();
5372+
config.put(MODEL_ID, "model2");
5373+
List<Map<String, String>> inputMap = new ArrayList<>();
5374+
Map<String, String> input0 = new HashMap<>();
5375+
input0.put("inputs", "text");
5376+
inputMap.add(input0);
5377+
Map<String, String> input1 = new HashMap<>();
5378+
input1.put("inputs", "hashtag");
5379+
inputMap.add(input1);
5380+
config.put(OPTIONAL_INPUT_MAP, inputMap);
5381+
config.put(MAX_PREDICTION_TASKS, 2);
5382+
String processorTag = randomAlphaOfLength(10);
5383+
List<Map<String, String>> outputMap = new ArrayList<>();
5384+
config.put(OPTIONAL_OUTPUT_MAP, outputMap);
5385+
factory.create(Collections.emptyMap(), processorTag, null, false, config, null);
5386+
5387+
}
5388+
5389+
/**
5390+
* Tests the case where only the output maps are provided in the configuration.
5391+
*
5392+
* @throws Exception if an error occurs during the test
5393+
*/
5394+
public void testOnlyOutputMapsProvided() throws Exception {
5395+
Map<String, Object> config = new HashMap<>();
5396+
config.put(MODEL_ID, "model2");
5397+
List<Map<String, String>> outputMap = new ArrayList<>();
5398+
Map<String, String> output = new HashMap<>();
5399+
output.put("text_embedding", "$.inference_results[0].output[0].data");
5400+
outputMap.add(output);
5401+
config.put(OUTPUT_MAP, outputMap);
5402+
config.put(MAX_PREDICTION_TASKS, 2);
5403+
String processorTag = randomAlphaOfLength(10);
5404+
5405+
factory.create(Collections.emptyMap(), processorTag, null, false, config, null);
5406+
}
5407+
5408+
/**
5409+
* Tests the case where only the output maps are provided in the configuration.
5410+
*
5411+
* @throws Exception if an error occurs during the test
5412+
*/
5413+
public void testOnlyOutputMapsEmptyInputProvided() throws Exception {
5414+
Map<String, Object> config = new HashMap<>();
5415+
config.put(MODEL_ID, "model2");
5416+
List<Map<String, String>> inputMap = new ArrayList<>();
5417+
List<Map<String, String>> outputMap = new ArrayList<>();
5418+
Map<String, String> output = new HashMap<>();
5419+
output.put("text_embedding", "$.inference_results[0].output[0].data");
5420+
outputMap.add(output);
5421+
config.put(INPUT_MAP, inputMap);
5422+
config.put(OUTPUT_MAP, outputMap);
5423+
config.put(MAX_PREDICTION_TASKS, 2);
5424+
String processorTag = randomAlphaOfLength(10);
5425+
5426+
factory.create(Collections.emptyMap(), processorTag, null, false, config, null);
5427+
}
5428+
5429+
/**
5430+
* Tests the creation of the MLInferenceSearchResponseProcessor with optional fields.
5431+
*
5432+
* @throws Exception if an error occurs during the test
5433+
*/
51805434
public void testCreateOptionalFields() throws Exception {
51815435
Map<String, Object> config = new HashMap<>();
51825436
config.put(MODEL_ID, "model2");

0 commit comments

Comments
 (0)