Skip to content

Commit 9ffe7ef

Browse files
authored
[PlanExecuteReflect Agent] Feature: support custom prompts from user (#3731)
1 parent 2ef82b6 commit 9ffe7ef

File tree

2 files changed

+63
-28
lines changed

2 files changed

+63
-28
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs;
1717
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.LLM_INTERFACE;
1818
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.saveTraceData;
19-
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.PLANNER_PROMPT;
20-
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.PLANNER_PROMPT_TEMPLATE;
21-
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.PLANNER_WITH_HISTORY_PROMPT_TEMPLATE;
19+
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.DEFAULT_PLANNER_PROMPT;
20+
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.DEFAULT_PLANNER_PROMPT_TEMPLATE;
21+
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.DEFAULT_PLANNER_WITH_HISTORY_PROMPT_TEMPLATE;
22+
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.DEFAULT_REFLECT_PROMPT;
23+
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.DEFAULT_REFLECT_PROMPT_TEMPLATE;
2224
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.PLAN_EXECUTE_REFLECT_RESPONSE_FORMAT;
23-
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.REFLECT_PROMPT;
24-
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.REFLECT_PROMPT_TEMPLATE;
2525

2626
import java.util.ArrayList;
2727
import java.util.Arrays;
@@ -75,8 +75,15 @@ public class MLPlanExecuteAndReflectAgentRunner implements MLAgentRunner {
7575
private final Map<String, Tool.Factory> toolFactories;
7676
private final Map<String, Memory.Factory> memoryFactoryMap;
7777

78+
// prompts
79+
private String plannerPrompt;
80+
private String plannerPromptTemplate;
81+
private String reflectPrompt;
82+
private String reflectPromptTemplate;
83+
private String plannerWithHistoryPromptTemplate;
84+
7885
// defaults
79-
private static final String DEFAULT_DEEP_RESEARCH_SYSTEM_PROMPT = "Always respond in JSON format.";
86+
private static final String DEFAULT_SYSTEM_PROMPT = "Always respond in JSON format.";
8087
private static final String DEFAULT_REACT_SYSTEM_PROMPT = "You are a helpful assistant.";
8188
private static final String DEFAULT_NO_ESCAPE_PARAMS = "tool_configs,_tools";
8289
private static final String DEFAULT_MAX_STEPS_EXECUTED = "20";
@@ -89,8 +96,8 @@ public class MLPlanExecuteAndReflectAgentRunner implements MLAgentRunner {
8996
public static final String STEPS_FIELD = "steps";
9097
public static final String COMPLETED_STEPS_FIELD = "completed_steps";
9198
public static final String PLANNER_PROMPT_FIELD = "planner_prompt";
92-
public static final String REVAL_PROMPT_FIELD = "reval_prompt";
93-
public static final String DEEP_RESEARCH_RESPONSE_FORMAT_FIELD = "deep_research_response_format";
99+
public static final String REFLECT_PROMPT_FIELD = "reflect_prompt";
100+
public static final String PLAN_EXECUTE_REFLECT_RESPONSE_FORMAT_FIELD = "plan_execute_reflect_response_format";
94101
public static final String PROMPT_TEMPLATE_FIELD = "prompt_template";
95102
public static final String SYSTEM_PROMPT_FIELD = "system_prompt";
96103
public static final String QUESTION_FIELD = "question";
@@ -104,6 +111,9 @@ public class MLPlanExecuteAndReflectAgentRunner implements MLAgentRunner {
104111
public static final String NO_ESCAPE_PARAMS_FIELD = "no_escape_params";
105112
public static final String DEFAULT_PROMPT_TOOLS_FIELD = "tools_prompt";
106113
public static final String MAX_STEPS_EXECUTED_FIELD = "max_steps";
114+
public static final String PLANNER_PROMPT_TEMPLATE_FIELD = "planner_prompt_template";
115+
public static final String REFLECT_PROMPT_TEMPLATE_FIELD = "reflect_prompt_template";
116+
public static final String PLANNER_WITH_HISTORY_TEMPLATE_FIELD = "planner_with_history_template";
107117

108118
public MLPlanExecuteAndReflectAgentRunner(
109119
Client client,
@@ -119,6 +129,11 @@ public MLPlanExecuteAndReflectAgentRunner(
119129
this.xContentRegistry = registry;
120130
this.toolFactories = toolFactories;
121131
this.memoryFactoryMap = memoryFactoryMap;
132+
this.plannerPrompt = DEFAULT_PLANNER_PROMPT;
133+
this.plannerPromptTemplate = DEFAULT_PLANNER_PROMPT_TEMPLATE;
134+
this.reflectPrompt = DEFAULT_REFLECT_PROMPT;
135+
this.reflectPromptTemplate = DEFAULT_REFLECT_PROMPT_TEMPLATE;
136+
this.plannerWithHistoryPromptTemplate = DEFAULT_PLANNER_WITH_HISTORY_PROMPT_TEMPLATE;
122137
}
123138

124139
private void setupPromptParameters(Map<String, String> params) {
@@ -130,11 +145,31 @@ private void setupPromptParameters(Map<String, String> params) {
130145
params.put(USER_PROMPT_FIELD, userPrompt);
131146

132147
String userSystemPrompt = params.getOrDefault(SYSTEM_PROMPT_FIELD, "");
133-
params.put(SYSTEM_PROMPT_FIELD, userSystemPrompt + DEFAULT_DEEP_RESEARCH_SYSTEM_PROMPT);
148+
params.put(SYSTEM_PROMPT_FIELD, userSystemPrompt + DEFAULT_SYSTEM_PROMPT);
149+
150+
if (params.get(PLANNER_PROMPT_FIELD) != null) {
151+
this.plannerPrompt = params.get(PLANNER_PROMPT_FIELD);
152+
}
153+
params.put(PLANNER_PROMPT_FIELD, this.plannerPrompt);
154+
155+
if (params.get(PLANNER_PROMPT_TEMPLATE_FIELD) != null) {
156+
this.plannerPromptTemplate = params.get(PLANNER_PROMPT_TEMPLATE_FIELD);
157+
}
158+
159+
if (params.get(REFLECT_PROMPT_FIELD) != null) {
160+
this.reflectPrompt = params.get(REFLECT_PROMPT_FIELD);
161+
}
162+
params.put(REFLECT_PROMPT_FIELD, this.reflectPrompt);
163+
164+
if (params.get(REFLECT_PROMPT_TEMPLATE_FIELD) != null) {
165+
this.reflectPromptTemplate = params.get(REFLECT_PROMPT_TEMPLATE_FIELD);
166+
}
167+
168+
if (params.get(PLANNER_WITH_HISTORY_TEMPLATE_FIELD) != null) {
169+
this.plannerWithHistoryPromptTemplate = params.get(PLANNER_WITH_HISTORY_TEMPLATE_FIELD);
170+
}
134171

135-
params.put(PLANNER_PROMPT_FIELD, PLANNER_PROMPT);
136-
params.put(REVAL_PROMPT_FIELD, REFLECT_PROMPT);
137-
params.put(DEEP_RESEARCH_RESPONSE_FORMAT_FIELD, PLAN_EXECUTE_REFLECT_RESPONSE_FORMAT);
172+
params.put(PLAN_EXECUTE_REFLECT_RESPONSE_FORMAT_FIELD, PLAN_EXECUTE_REFLECT_RESPONSE_FORMAT);
138173

139174
params.put(NO_ESCAPE_PARAMS_FIELD, DEFAULT_NO_ESCAPE_PARAMS);
140175

@@ -153,17 +188,17 @@ private void setupPromptParameters(Map<String, String> params) {
153188
}
154189

155190
private void usePlannerPromptTemplate(Map<String, String> params) {
156-
params.put(PROMPT_TEMPLATE_FIELD, PLANNER_PROMPT_TEMPLATE);
191+
params.put(PROMPT_TEMPLATE_FIELD, this.plannerPromptTemplate);
157192
populatePrompt(params);
158193
}
159194

160-
private void useRevalPromptTemplate(Map<String, String> params) {
161-
params.put(PROMPT_TEMPLATE_FIELD, REFLECT_PROMPT_TEMPLATE);
195+
private void useReflectPromptTemplate(Map<String, String> params) {
196+
params.put(PROMPT_TEMPLATE_FIELD, this.reflectPromptTemplate);
162197
populatePrompt(params);
163198
}
164199

165200
private void usePlannerWithHistoryPromptTemplate(Map<String, String> params) {
166-
params.put(PROMPT_TEMPLATE_FIELD, PLANNER_WITH_HISTORY_PROMPT_TEMPLATE);
201+
params.put(PROMPT_TEMPLATE_FIELD, this.plannerWithHistoryPromptTemplate);
167202
populatePrompt(params);
168203
}
169204

@@ -366,7 +401,7 @@ private void executePlanningLoop(
366401

367402
addSteps(completedSteps, allParams, COMPLETED_STEPS_FIELD);
368403

369-
useRevalPromptTemplate(allParams);
404+
useReflectPromptTemplate(allParams);
370405

371406
executePlanningLoop(
372407
llm,

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
package org.opensearch.ml.engine.algorithms.agent;
22

33
import static org.opensearch.ml.engine.algorithms.agent.MLPlanExecuteAndReflectAgentRunner.COMPLETED_STEPS_FIELD;
4-
import static org.opensearch.ml.engine.algorithms.agent.MLPlanExecuteAndReflectAgentRunner.DEEP_RESEARCH_RESPONSE_FORMAT_FIELD;
54
import static org.opensearch.ml.engine.algorithms.agent.MLPlanExecuteAndReflectAgentRunner.DEFAULT_PROMPT_TOOLS_FIELD;
65
import static org.opensearch.ml.engine.algorithms.agent.MLPlanExecuteAndReflectAgentRunner.PLANNER_PROMPT_FIELD;
7-
import static org.opensearch.ml.engine.algorithms.agent.MLPlanExecuteAndReflectAgentRunner.REVAL_PROMPT_FIELD;
6+
import static org.opensearch.ml.engine.algorithms.agent.MLPlanExecuteAndReflectAgentRunner.PLAN_EXECUTE_REFLECT_RESPONSE_FORMAT_FIELD;
7+
import static org.opensearch.ml.engine.algorithms.agent.MLPlanExecuteAndReflectAgentRunner.REFLECT_PROMPT_FIELD;
88
import static org.opensearch.ml.engine.algorithms.agent.MLPlanExecuteAndReflectAgentRunner.STEPS_FIELD;
99
import static org.opensearch.ml.engine.algorithms.agent.MLPlanExecuteAndReflectAgentRunner.USER_PROMPT_FIELD;
1010

@@ -23,17 +23,17 @@ public class PromptTemplate {
2323
public static final String CHAT_HISTORY_PREFIX =
2424
"Human:CONVERSATION HISTORY WITH AI ASSISTANT\n----------------------------\nBelow is Chat History between Human and AI which sorted by time with asc order:\n";
2525

26-
public static final String PLANNER_PROMPT_TEMPLATE = "${parameters."
26+
public static final String DEFAULT_PLANNER_PROMPT_TEMPLATE = "${parameters."
2727
+ PLANNER_PROMPT_FIELD
2828
+ "} \n"
2929
+ "Objective: ${parameters."
3030
+ USER_PROMPT_FIELD
3131
+ "} \n\n"
3232
+ "${parameters."
33-
+ DEEP_RESEARCH_RESPONSE_FORMAT_FIELD
33+
+ PLAN_EXECUTE_REFLECT_RESPONSE_FORMAT_FIELD
3434
+ "}";
3535

36-
public static final String REFLECT_PROMPT_TEMPLATE = "${parameters."
36+
public static final String DEFAULT_REFLECT_PROMPT_TEMPLATE = "${parameters."
3737
+ PLANNER_PROMPT_FIELD
3838
+ "} \n\n"
3939
+ "Objective: ${parameters."
@@ -46,13 +46,13 @@ public class PromptTemplate {
4646
+ COMPLETED_STEPS_FIELD
4747
+ "}] \n\n"
4848
+ "${parameters."
49-
+ REVAL_PROMPT_FIELD
49+
+ REFLECT_PROMPT_FIELD
5050
+ "} \n\n"
5151
+ "${parameters."
52-
+ DEEP_RESEARCH_RESPONSE_FORMAT_FIELD
52+
+ PLAN_EXECUTE_REFLECT_RESPONSE_FORMAT_FIELD
5353
+ "}";
5454

55-
public static final String PLANNER_WITH_HISTORY_PROMPT_TEMPLATE = "${parameters."
55+
public static final String DEFAULT_PLANNER_WITH_HISTORY_PROMPT_TEMPLATE = "${parameters."
5656
+ PLANNER_PROMPT_FIELD
5757
+ "} \n"
5858
+ "Objective: ${parameters."
@@ -62,13 +62,13 @@ public class PromptTemplate {
6262
+ COMPLETED_STEPS_FIELD
6363
+ "}] \n\n"
6464
+ "${parameters."
65-
+ DEEP_RESEARCH_RESPONSE_FORMAT_FIELD
65+
+ PLAN_EXECUTE_REFLECT_RESPONSE_FORMAT_FIELD
6666
+ "}";
6767

68-
public static final String PLANNER_PROMPT =
68+
public static final String DEFAULT_PLANNER_PROMPT =
6969
"For the given objective, come up with a simple step by step plan. This plan should involve individual tasks, that if executed correctly will yield the correct answer. Do not add any superfluous steps. The result of the final step should be the final answer. Make sure that each step has all the information needed - do not skip steps. At all costs, do not execute the steps. You will be told when to execute the steps.";
7070

71-
public static final String REFLECT_PROMPT =
71+
public static final String DEFAULT_REFLECT_PROMPT =
7272
"Update your plan accordingly. If no more steps are needed and you can return to the user, then respond with that. Otherwise, fill out the plan. Only add steps to the plan that still NEED to be done. Do not return previously done steps as part of the plan. Please follow the below response format.";
7373

7474
public static final String PLAN_EXECUTE_REFLECT_RESPONSE_FORMAT = "${parameters."

0 commit comments

Comments
 (0)