66
66
import org .opensearch .remote .metadata .client .SdkClient ;
67
67
import org .opensearch .transport .client .Client ;
68
68
69
+ import com .google .common .annotations .VisibleForTesting ;
69
70
import com .jayway .jsonpath .JsonPath ;
70
71
71
72
import joptsimple .internal .Strings ;
@@ -154,7 +155,8 @@ public MLPlanExecuteAndReflectAgentRunner(
154
155
this .plannerWithHistoryPromptTemplate = DEFAULT_PLANNER_WITH_HISTORY_PROMPT_TEMPLATE ;
155
156
}
156
157
157
- private void setupPromptParameters (Map <String , String > params ) {
158
+ @ VisibleForTesting
159
+ void setupPromptParameters (Map <String , String > params ) {
158
160
// populated depending on whether LLM is asked to plan or re-evaluate
159
161
// removed here, so that error is thrown in case this field is not populated
160
162
params .remove (PROMPT_FIELD );
@@ -203,22 +205,26 @@ private void setupPromptParameters(Map<String, String> params) {
203
205
}
204
206
}
205
207
206
- private void usePlannerPromptTemplate (Map <String , String > params ) {
208
+ @ VisibleForTesting
209
+ void usePlannerPromptTemplate (Map <String , String > params ) {
207
210
params .put (PROMPT_TEMPLATE_FIELD , this .plannerPromptTemplate );
208
211
populatePrompt (params );
209
212
}
210
213
211
- private void useReflectPromptTemplate (Map <String , String > params ) {
214
+ @ VisibleForTesting
215
+ void useReflectPromptTemplate (Map <String , String > params ) {
212
216
params .put (PROMPT_TEMPLATE_FIELD , this .reflectPromptTemplate );
213
217
populatePrompt (params );
214
218
}
215
219
216
- private void usePlannerWithHistoryPromptTemplate (Map <String , String > params ) {
220
+ @ VisibleForTesting
221
+ void usePlannerWithHistoryPromptTemplate (Map <String , String > params ) {
217
222
params .put (PROMPT_TEMPLATE_FIELD , this .plannerWithHistoryPromptTemplate );
218
223
populatePrompt (params );
219
224
}
220
225
221
- private void populatePrompt (Map <String , String > allParams ) {
226
+ @ VisibleForTesting
227
+ void populatePrompt (Map <String , String > allParams ) {
222
228
String promptTemplate = allParams .get (PROMPT_TEMPLATE_FIELD );
223
229
StringSubstitutor promptSubstitutor = new StringSubstitutor (allParams , "${parameters." , "}" );
224
230
String prompt = promptSubstitutor .replace (promptTemplate );
@@ -475,7 +481,8 @@ private void executePlanningLoop(
475
481
client .execute (MLPredictionTaskAction .INSTANCE , request , planListener );
476
482
}
477
483
478
- private Map <String , String > parseLLMOutput (Map <String , String > allParams , ModelTensorOutput modelTensorOutput ) {
484
+ @ VisibleForTesting
485
+ Map <String , String > parseLLMOutput (Map <String , String > allParams , ModelTensorOutput modelTensorOutput ) {
479
486
Map <String , String > modelOutput = new HashMap <>();
480
487
Map <String , ?> dataAsMap = modelTensorOutput .getMlModelOutputs ().getFirst ().getMlModelTensors ().getFirst ().getDataAsMap ();
481
488
String llmResponse ;
@@ -513,7 +520,8 @@ private Map<String, String> parseLLMOutput(Map<String, String> allParams, ModelT
513
520
return modelOutput ;
514
521
}
515
522
516
- private String extractJsonFromMarkdown (String response ) {
523
+ @ VisibleForTesting
524
+ String extractJsonFromMarkdown (String response ) {
517
525
response = response .trim ();
518
526
if (response .contains ("```json" )) {
519
527
response = response .substring (response .indexOf ("```json" ) + "```json" .length ());
@@ -535,7 +543,8 @@ private String extractJsonFromMarkdown(String response) {
535
543
return response ;
536
544
}
537
545
538
- private void addToolsToPrompt (Map <String , Tool > tools , Map <String , String > allParams ) {
546
+ @ VisibleForTesting
547
+ void addToolsToPrompt (Map <String , Tool > tools , Map <String , String > allParams ) {
539
548
StringBuilder toolsPrompt = new StringBuilder ("In this environment, you have access to the below tools: \n " );
540
549
for (Map .Entry <String , Tool > entry : tools .entrySet ()) {
541
550
String toolName = entry .getKey ();
@@ -548,11 +557,13 @@ private void addToolsToPrompt(Map<String, Tool> tools, Map<String, String> allPa
548
557
cleanUpResource (tools );
549
558
}
550
559
551
- private void addSteps (List <String > steps , Map <String , String > allParams , String field ) {
560
+ @ VisibleForTesting
561
+ void addSteps (List <String > steps , Map <String , String > allParams , String field ) {
552
562
allParams .put (field , String .join (", " , steps ));
553
563
}
554
564
555
- private void saveAndReturnFinalResult (
565
+ @ VisibleForTesting
566
+ void saveAndReturnFinalResult (
556
567
ConversationIndexMemory memory ,
557
568
String parentInteractionId ,
558
569
String reactAgentMemoryId ,
@@ -591,7 +602,8 @@ private void saveAndReturnFinalResult(
591
602
}));
592
603
}
593
604
594
- private static List <ModelTensors > createModelTensors (
605
+ @ VisibleForTesting
606
+ static List <ModelTensors > createModelTensors (
595
607
String sessionId ,
596
608
String parentInteractionId ,
597
609
String reactAgentMemoryId ,
0 commit comments