Skip to content

Commit 50bcbda

Browse files
authored
Add error handling for plan&execute agent (#3845)
1 parent 563120b commit 50bcbda

File tree

5 files changed

+89
-2
lines changed

5 files changed

+89
-2
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,13 @@ private static void runTool(
608608
);
609609
nextStepListener
610610
.onResponse(
611-
String.format(Locale.ROOT, "Failed to run the tool %s with the error message %s.", finalAction, e.getMessage())
611+
String
612+
.format(
613+
Locale.ROOT,
614+
"Failed to run the tool %s with the error message %s.",
615+
finalAction,
616+
e.getMessage().replaceAll("\\n", "\n")
617+
)
612618
);
613619
});
614620
if (tools.get(action) instanceof MLModelTool) {

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,11 @@ String extractJsonFromMarkdown(String response) {
555555
if (response.contains("```")) {
556556
response = response.substring(0, response.lastIndexOf("```"));
557557
}
558+
} else {
559+
// extract content from {} block
560+
if (response.contains("{") && response.contains("}")) {
561+
response = response.substring(response.indexOf("{"), response.lastIndexOf("}") + 1);
562+
}
558563
}
559564

560565
response = response.trim();

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/IndexMappingTool.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,12 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
9595
try {
9696
List<String> indexList = new ArrayList<>();
9797
if (StringUtils.isNotBlank(parameters.get("index"))) {
98-
indexList = gson.fromJson(parameters.get("index"), List.class);
98+
try {
99+
indexList = gson.fromJson(parameters.get("index"), List.class);
100+
} catch (Exception e) {
101+
// sometimes the input comes from LLM is not a json string, it might a single value of index name
102+
indexList.add(parameters.get("index"));
103+
}
99104
}
100105

101106
if (indexList.isEmpty()) {

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,13 @@ public void testExtractJsonFromMarkdown() {
451451
assertEquals("{\"key\":\"value\"}", result);
452452
}
453453

454+
@Test
455+
public void testExtractJsonFromMarkdownWithoutJsonPrefix() {
456+
String markdown = "This is the json output {\"key\":\"value\"}\n";
457+
String result = mlPlanExecuteAndReflectAgentRunner.extractJsonFromMarkdown(markdown);
458+
assertEquals("{\"key\":\"value\"}", result);
459+
}
460+
454461
@Test
455462
public void testAddToolsToPrompt() {
456463
Map<String, String> testParams = new HashMap<>();

ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/IndexMappingToolTests.java

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ public class IndexMappingToolTests {
5454
private GetIndexResponse getIndexResponse;
5555

5656
private Map<String, String> indexParams;
57+
private Map<String, String> indexParamsWithRawIndexName;
5758
private Map<String, String> otherParams;
5859
private Map<String, String> emptyParams;
5960

@@ -67,6 +68,7 @@ public void setup() {
6768
IndexMappingTool.Factory.getInstance().init(client);
6869

6970
indexParams = Map.of("index", "[\"foo\"]");
71+
indexParamsWithRawIndexName = Map.of("index", "foo");
7072
otherParams = Map.of("other", "[\"bar\"]");
7173
emptyParams = Collections.emptyMap();
7274
}
@@ -175,6 +177,68 @@ public void testRunAsyncIndexMapping() throws Exception {
175177
assertTrue(responseList.contains("test.int.setting=123"));
176178
}
177179

180+
@Test
181+
public void testRunWithRawIndexNameInput() throws Exception {
182+
String indexName = "foo";
183+
184+
@SuppressWarnings("unchecked")
185+
ArgumentCaptor<ActionListener<GetIndexResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);
186+
doNothing().when(indicesAdminClient).getIndex(any(), actionListenerCaptor.capture());
187+
188+
when(getIndexResponse.indices()).thenReturn(new String[] { indexName });
189+
Settings settings = Settings.builder().put("test.boolean.setting", false).put("test.int.setting", 123).build();
190+
when(getIndexResponse.settings()).thenReturn(Map.of(indexName, settings));
191+
String source = """
192+
{
193+
"foo": {
194+
"mappings": {
195+
"year": {
196+
"full_name": "year",
197+
"mapping": {
198+
"year": {
199+
"type": "text"
200+
}
201+
}
202+
},
203+
"age": {
204+
"full_name": "age",
205+
"mapping": {
206+
"age": {
207+
"type": "integer"
208+
}
209+
}
210+
}
211+
}
212+
}
213+
}""";
214+
MappingMetadata mapping = new MappingMetadata(indexName, XContentHelper.convertToMap(JsonXContent.jsonXContent, source, true));
215+
when(getIndexResponse.mappings()).thenReturn(Map.of(indexName, mapping));
216+
217+
// Now make the call
218+
Tool tool = IndexMappingTool.Factory.getInstance().create(Collections.emptyMap());
219+
final CompletableFuture<String> future = new CompletableFuture<>();
220+
ActionListener<String> listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); });
221+
222+
tool.run(indexParamsWithRawIndexName, listener);
223+
actionListenerCaptor.getValue().onResponse(getIndexResponse);
224+
225+
future.orTimeout(10, TimeUnit.SECONDS).join();
226+
String response = future.get();
227+
List<String> responseList = Arrays.asList(response.trim().split("\\n"));
228+
229+
assertTrue(responseList.contains("index: foo"));
230+
231+
assertTrue(responseList.contains("mappings:"));
232+
assertTrue(
233+
responseList
234+
.contains("mappings={year={full_name=year, mapping={year={type=text}}}, age={full_name=age, mapping={age={type=integer}}}}")
235+
);
236+
237+
assertTrue(responseList.contains("settings:"));
238+
assertTrue(responseList.contains("test.boolean.setting=false"));
239+
assertTrue(responseList.contains("test.int.setting=123"));
240+
}
241+
178242
@Test
179243
public void testTool() {
180244
Factory instance = IndexMappingTool.Factory.getInstance();

0 commit comments

Comments
 (0)