@@ -1036,6 +1036,126 @@ public void testProcessResponseSuccessWriteToExt() throws Exception {
1036
1036
@ Override
1037
1037
public void onResponse (SearchResponse newSearchResponse ) {
1038
1038
assertEquals (newSearchResponse .getHits ().getHits ().length , 5 );
1039
+ MLInferenceSearchResponse mLInferenceSearchResponse = (MLInferenceSearchResponse ) newSearchResponse ;
1040
+ String resultsInResponse = (String ) mLInferenceSearchResponse .getParams ().get ("llm_response" );
1041
+ assertEquals ("there is 1 value" , resultsInResponse );
1042
+ }
1043
+
1044
+ @ Override
1045
+ public void onFailure (Exception e ) {
1046
+ throw new RuntimeException (e );
1047
+ }
1048
+
1049
+ };
1050
+ responseProcessor .processResponseAsync (request , response , responseContext , listener );
1051
+ verify (client , times (1 )).execute (any (), any (), any ());
1052
+ }
1053
+
1054
+ /**
1055
+ * Tests the successful processing of a response with a single pair of input and output mappings.
1056
+ * read the query text into model config with query extensions
1057
+ * read the prediction outcome as array and store in search extension
1058
+ * @throws Exception if an error occurs during the test
1059
+ */
1060
+ @ Test
1061
+ public void testProcessResponseSuccessArrayWriteToExt () throws Exception {
1062
+ String documentField = "text" ;
1063
+ String modelInputField = "context" ;
1064
+ List <Map <String , String >> inputMap = new ArrayList <>();
1065
+ Map <String , String > input = new HashMap <>();
1066
+ input .put (modelInputField , documentField );
1067
+ inputMap .add (input );
1068
+
1069
+ String newDocumentField = "ext.ml_inference.results" ;
1070
+ String modelOutputField = "results[*].document.text" ;
1071
+ List <Map <String , String >> outputMap = new ArrayList <>();
1072
+ Map <String , String > output = new HashMap <>();
1073
+ output .put (newDocumentField , modelOutputField );
1074
+ outputMap .add (output );
1075
+ Map <String , String > modelConfig = new HashMap <>();
1076
+ modelConfig .put ("query" , "positive review" );
1077
+ MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor (
1078
+ "model1" ,
1079
+ inputMap ,
1080
+ outputMap ,
1081
+ optionalInputMaps ,
1082
+ optionalOutputMaps ,
1083
+ modelConfig ,
1084
+ DEFAULT_MAX_PREDICTION_TASKS ,
1085
+ PROCESSOR_TAG ,
1086
+ DESCRIPTION ,
1087
+ false ,
1088
+ "remote" ,
1089
+ false ,
1090
+ false ,
1091
+ false ,
1092
+ "{ \" parameters\" : ${ml_inference.parameters} }" ,
1093
+ client ,
1094
+ TEST_XCONTENT_REGISTRY_FOR_QUERY ,
1095
+ false
1096
+ );
1097
+
1098
+ SearchRequest request = getSearchRequest ();
1099
+ String fieldName = "text" ;
1100
+ SearchResponse response = getSearchResponse (5 , true , fieldName );
1101
+
1102
+ Map <String , Object > inferenceResultMap = new HashMap <>();
1103
+
1104
+ Map <String , Object > doc1 = new HashMap <>();
1105
+ Map <String , Object > doc1Text = new HashMap <>();
1106
+ doc1Text .put ("text" , "value1" );
1107
+ doc1 .put ("document" , doc1Text );
1108
+ doc1 .put ("index" , 0.0 );
1109
+ doc1 .put ("relevance_score" , 2.6480842E-5 );
1110
+
1111
+ Map <String , Object > doc2 = new HashMap <>();
1112
+ Map <String , Object > doc2Text = new HashMap <>();
1113
+ doc2Text .put ("text" , "value5" );
1114
+ doc2 .put ("document" , doc2Text );
1115
+ doc2 .put ("index" , 4.0 );
1116
+ doc2 .put ("relevance_score" , 2.5071593E-5 );
1117
+
1118
+ Map <String , Object > doc3 = new HashMap <>();
1119
+ Map <String , Object > doc3Text = new HashMap <>();
1120
+ doc3Text .put ("text" , "value4" );
1121
+ doc3 .put ("document" , doc3Text );
1122
+ doc3 .put ("index" , 3.0 );
1123
+ doc3 .put ("relevance_score" , 2.373734E-5 );
1124
+
1125
+ Map <String , Object > doc4 = new HashMap <>();
1126
+ Map <String , Object > doc4Text = new HashMap <>();
1127
+ doc4Text .put ("text" , "value2" );
1128
+ doc4 .put ("document" , doc4Text );
1129
+ doc4 .put ("index" , 1.0 );
1130
+ doc4 .put ("relevance_score" , 2.1112483E-5 );
1131
+
1132
+ Map <String , Object > doc5 = new HashMap <>();
1133
+ Map <String , Object > doc5Text = new HashMap <>();
1134
+ doc5Text .put ("text" , "value3" );
1135
+ doc5 .put ("document" , doc5Text );
1136
+ doc5 .put ("index" , 2.0 );
1137
+ doc5 .put ("relevance_score" , 1.6187581E-5 );
1138
+
1139
+ inferenceResultMap .put ("results" , Arrays .asList (doc1 , doc2 , doc3 , doc4 , doc5 ));
1140
+
1141
+ ModelTensor modelTensor = ModelTensor .builder ().dataAsMap (inferenceResultMap ).build ();
1142
+ ModelTensors modelTensors = ModelTensors .builder ().mlModelTensors (Arrays .asList (modelTensor )).build ();
1143
+ ModelTensorOutput mlModelTensorOutput = ModelTensorOutput .builder ().mlModelOutputs (Arrays .asList (modelTensors )).build ();
1144
+
1145
+ doAnswer (invocation -> {
1146
+ ActionListener <MLTaskResponse > actionListener = invocation .getArgument (2 );
1147
+ actionListener .onResponse (MLTaskResponse .builder ().output (mlModelTensorOutput ).build ());
1148
+ return null ;
1149
+ }).when (client ).execute (any (), any (), any ());
1150
+
1151
+ ActionListener <SearchResponse > listener = new ActionListener <>() {
1152
+ @ Override
1153
+ public void onResponse (SearchResponse newSearchResponse ) {
1154
+ assertEquals (newSearchResponse .getHits ().getHits ().length , 5 );
1155
+ MLInferenceSearchResponse mLInferenceSearchResponse = (MLInferenceSearchResponse ) newSearchResponse ;
1156
+ List <Map <String , Object >> results = (List <Map <String , Object >>) inferenceResultMap .get ("results" );
1157
+ List <String > resultsInResponse = (List <String >) mLInferenceSearchResponse .getParams ().get ("results" );
1158
+ assertEquals (results .size (), resultsInResponse .size ());
1039
1159
}
1040
1160
1041
1161
@ Override
@@ -5166,17 +5286,151 @@ public void testOutputMapsExceedInputMaps() throws Exception {
5166
5286
} catch (IllegalArgumentException e ) {
5167
5287
assertEquals (
5168
5288
e .getMessage (),
5169
- "when output_maps/optional_output_maps and input_maps/optional_input_maps are provided, their length needs to match. The input is in length of 2, while output_maps is in the length of 2 . Please adjust mappings."
5289
+ "when output_maps/optional_output_maps and input_maps/optional_input_maps are provided, their length needs to match. The input is in length of 2, while output_maps is in the length of 3 . Please adjust mappings."
5170
5290
);
5171
5291
5172
5292
}
5173
5293
}
5174
5294
5175
5295
/**
5176
- * Tests the creation of the MLInferenceSearchResponseProcessor with optional fields.
5177
- *
5178
- * @throws Exception if an error occurs during the test
5179
- */
5296
+ * Tests the case where only the input maps are provided in the configuration.
5297
+ *
5298
+ * @throws Exception if an error occurs during the test
5299
+ */
5300
+ public void testOnlyInputMapsProvided () throws Exception {
5301
+ Map <String , Object > config = new HashMap <>();
5302
+ config .put (MODEL_ID , "model2" );
5303
+ List <Map <String , String >> inputMap = new ArrayList <>();
5304
+ Map <String , String > input0 = new HashMap <>();
5305
+ input0 .put ("inputs" , "text" );
5306
+ inputMap .add (input0 );
5307
+ Map <String , String > input1 = new HashMap <>();
5308
+ input1 .put ("inputs" , "hashtag" );
5309
+ inputMap .add (input1 );
5310
+ config .put (INPUT_MAP , inputMap );
5311
+ config .put (MAX_PREDICTION_TASKS , 2 );
5312
+ String processorTag = randomAlphaOfLength (10 );
5313
+
5314
+ factory .create (Collections .emptyMap (), processorTag , null , false , config , null );
5315
+ }
5316
+
5317
+ /**
5318
+ * Tests the case where the input maps and empty output map are provided in the configuration.
5319
+ *
5320
+ * @throws Exception if an error occurs during the test
5321
+ */
5322
+ public void testInputMapsEmptyOutputMapProvided () throws Exception {
5323
+ Map <String , Object > config = new HashMap <>();
5324
+ config .put (MODEL_ID , "model2" );
5325
+ List <Map <String , String >> inputMap = new ArrayList <>();
5326
+ Map <String , String > input0 = new HashMap <>();
5327
+ input0 .put ("inputs" , "text" );
5328
+ inputMap .add (input0 );
5329
+ Map <String , String > input1 = new HashMap <>();
5330
+ input1 .put ("inputs" , "hashtag" );
5331
+ inputMap .add (input1 );
5332
+ config .put (INPUT_MAP , inputMap );
5333
+ config .put (MAX_PREDICTION_TASKS , 2 );
5334
+ String processorTag = randomAlphaOfLength (10 );
5335
+
5336
+ List <Map <String , String >> outputMap = new ArrayList <>();
5337
+ config .put (OUTPUT_MAP , outputMap );
5338
+
5339
+ factory .create (Collections .emptyMap (), processorTag , null , false , config , null );
5340
+ }
5341
+
5342
+ /**
5343
+ * Tests the case where only the Optional input maps are provided in the configuration.
5344
+ *
5345
+ * @throws Exception if an error occurs during the test
5346
+ */
5347
+ public void testOnlyOptionalInputMapsProvided () throws Exception {
5348
+ Map <String , Object > config = new HashMap <>();
5349
+ config .put (MODEL_ID , "model2" );
5350
+ List <Map <String , String >> inputMap = new ArrayList <>();
5351
+ Map <String , String > input0 = new HashMap <>();
5352
+ input0 .put ("inputs" , "text" );
5353
+ inputMap .add (input0 );
5354
+ Map <String , String > input1 = new HashMap <>();
5355
+ input1 .put ("inputs" , "hashtag" );
5356
+ inputMap .add (input1 );
5357
+ config .put (OPTIONAL_INPUT_MAP , inputMap );
5358
+ config .put (MAX_PREDICTION_TASKS , 2 );
5359
+ String processorTag = randomAlphaOfLength (10 );
5360
+
5361
+ factory .create (Collections .emptyMap (), processorTag , null , false , config , null );
5362
+
5363
+ }
5364
+
5365
+ /**
5366
+ * Tests the case where only the Optional input maps are provided in the configuration.
5367
+ *
5368
+ * @throws Exception if an error occurs during the test
5369
+ */
5370
+ public void testOnlyOptionalInputMapsEmptyOptionalOutputProvided () throws Exception {
5371
+ Map <String , Object > config = new HashMap <>();
5372
+ config .put (MODEL_ID , "model2" );
5373
+ List <Map <String , String >> inputMap = new ArrayList <>();
5374
+ Map <String , String > input0 = new HashMap <>();
5375
+ input0 .put ("inputs" , "text" );
5376
+ inputMap .add (input0 );
5377
+ Map <String , String > input1 = new HashMap <>();
5378
+ input1 .put ("inputs" , "hashtag" );
5379
+ inputMap .add (input1 );
5380
+ config .put (OPTIONAL_INPUT_MAP , inputMap );
5381
+ config .put (MAX_PREDICTION_TASKS , 2 );
5382
+ String processorTag = randomAlphaOfLength (10 );
5383
+ List <Map <String , String >> outputMap = new ArrayList <>();
5384
+ config .put (OPTIONAL_OUTPUT_MAP , outputMap );
5385
+ factory .create (Collections .emptyMap (), processorTag , null , false , config , null );
5386
+
5387
+ }
5388
+
5389
+ /**
5390
+ * Tests the case where only the output maps are provided in the configuration.
5391
+ *
5392
+ * @throws Exception if an error occurs during the test
5393
+ */
5394
+ public void testOnlyOutputMapsProvided () throws Exception {
5395
+ Map <String , Object > config = new HashMap <>();
5396
+ config .put (MODEL_ID , "model2" );
5397
+ List <Map <String , String >> outputMap = new ArrayList <>();
5398
+ Map <String , String > output = new HashMap <>();
5399
+ output .put ("text_embedding" , "$.inference_results[0].output[0].data" );
5400
+ outputMap .add (output );
5401
+ config .put (OUTPUT_MAP , outputMap );
5402
+ config .put (MAX_PREDICTION_TASKS , 2 );
5403
+ String processorTag = randomAlphaOfLength (10 );
5404
+
5405
+ factory .create (Collections .emptyMap (), processorTag , null , false , config , null );
5406
+ }
5407
+
5408
+ /**
5409
+ * Tests the case where only the output maps are provided in the configuration.
5410
+ *
5411
+ * @throws Exception if an error occurs during the test
5412
+ */
5413
+ public void testOnlyOutputMapsEmptyInputProvided () throws Exception {
5414
+ Map <String , Object > config = new HashMap <>();
5415
+ config .put (MODEL_ID , "model2" );
5416
+ List <Map <String , String >> inputMap = new ArrayList <>();
5417
+ List <Map <String , String >> outputMap = new ArrayList <>();
5418
+ Map <String , String > output = new HashMap <>();
5419
+ output .put ("text_embedding" , "$.inference_results[0].output[0].data" );
5420
+ outputMap .add (output );
5421
+ config .put (INPUT_MAP , inputMap );
5422
+ config .put (OUTPUT_MAP , outputMap );
5423
+ config .put (MAX_PREDICTION_TASKS , 2 );
5424
+ String processorTag = randomAlphaOfLength (10 );
5425
+
5426
+ factory .create (Collections .emptyMap (), processorTag , null , false , config , null );
5427
+ }
5428
+
5429
+ /**
5430
+ * Tests the creation of the MLInferenceSearchResponseProcessor with optional fields.
5431
+ *
5432
+ * @throws Exception if an error occurs during the test
5433
+ */
5180
5434
public void testCreateOptionalFields () throws Exception {
5181
5435
Map <String , Object > config = new HashMap <>();
5182
5436
config .put (MODEL_ID , "model2" );
0 commit comments