Skip to content

Commit a866cec

Browse files
authored
[Agent] PlanExecuteReflect: Return memory early to track progress (#3884)
* feat: return executor agent memory id when available Signed-off-by: Pavan Yekbote <pybot@amazon.com> * feat: modify task state and return parent interaction id Signed-off-by: Pavan Yekbote <pybot@amazon.com> * chore: add license header Signed-off-by: Pavan Yekbote <pybot@amazon.com> * refactor: remove redundant try-catch Signed-off-by: Pavan Yekbote <pybot@amazon.com> * test: add test case to validate contents of update and ensure updatemltask is called Signed-off-by: Pavan Yekbote <pybot@amazon.com> * spotless Signed-off-by: Pavan Yekbote <pybot@amazon.com> * test: adding test case for MLTaskUtilsTest Signed-off-by: Pavan Yekbote <pybot@amazon.com> * fix: taskupdate state logic and test cases Signed-off-by: Pavan Yekbote <pybot@amazon.com> * spotless Signed-off-by: Pavan Yekbote <pybot@amazon.com> * test: add test case for empty task and thrown exception Signed-off-by: Pavan Yekbote <pybot@amazon.com> --------- Signed-off-by: Pavan Yekbote <pybot@amazon.com>
1 parent 2a9baeb commit a866cec

File tree

7 files changed

+415
-35
lines changed

7 files changed

+415
-35
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.ml.common.utils;
6+
7+
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;
8+
import static org.opensearch.ml.common.MLTask.LAST_UPDATE_TIME_FIELD;
9+
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
10+
11+
import java.time.Instant;
12+
import java.util.HashMap;
13+
import java.util.Map;
14+
15+
import org.opensearch.action.support.WriteRequest;
16+
import org.opensearch.action.update.UpdateRequest;
17+
import org.opensearch.action.update.UpdateResponse;
18+
import org.opensearch.common.util.concurrent.ThreadContext;
19+
import org.opensearch.core.action.ActionListener;
20+
import org.opensearch.ml.common.MLTaskState;
21+
import org.opensearch.transport.client.Client;
22+
23+
import com.google.common.collect.ImmutableSet;
24+
25+
import lombok.extern.log4j.Log4j2;
26+
27+
@Log4j2
28+
public class MLTaskUtils {
29+
30+
public static final ImmutableSet<MLTaskState> TASK_DONE_STATES = ImmutableSet
31+
.of(MLTaskState.COMPLETED, MLTaskState.COMPLETED_WITH_ERROR, MLTaskState.FAILED, MLTaskState.CANCELLED);
32+
33+
/**
34+
* Updates an ML task document directly in the ML task index.
35+
* This method performs validation on the input parameters and updates the task with the provided fields.
36+
* It automatically adds a timestamp for the last update time.
37+
* For tasks that are being marked as done (completed, failed, etc.), it enables retry on conflict.
38+
*
39+
* @param taskId The ID of the ML task to update
40+
* @param updatedFields Map containing the fields to update in the ML task document
41+
* @param client The OpenSearch client to use for the update operation
42+
* @param listener ActionListener to handle the response or failure of the update operation
43+
* @throws IllegalArgumentException if taskId is null/empty, updatedFields is null/empty, or if the state field contains an invalid MLTaskState
44+
*/
45+
public static void updateMLTaskDirectly(
46+
String taskId,
47+
Map<String, Object> updatedFields,
48+
Client client,
49+
ActionListener<UpdateResponse> listener
50+
) {
51+
if (taskId == null || taskId.isEmpty()) {
52+
listener.onFailure(new IllegalArgumentException("Task ID is null or empty"));
53+
return;
54+
}
55+
56+
if (updatedFields == null || updatedFields.isEmpty()) {
57+
listener.onFailure(new IllegalArgumentException("Updated fields is null or empty"));
58+
return;
59+
}
60+
61+
if (updatedFields.containsKey(STATE_FIELD) && !(updatedFields.get(STATE_FIELD) instanceof MLTaskState)) {
62+
listener.onFailure(new IllegalArgumentException("Invalid task state"));
63+
return;
64+
}
65+
66+
UpdateRequest updateRequest = new UpdateRequest(ML_TASK_INDEX, taskId);
67+
68+
Map<String, Object> updatedContent = new HashMap<>(updatedFields);
69+
updatedContent.put(LAST_UPDATE_TIME_FIELD, Instant.now().toEpochMilli());
70+
updateRequest.doc(updatedContent);
71+
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
72+
if (updatedFields.containsKey(STATE_FIELD) && TASK_DONE_STATES.contains((MLTaskState) updatedFields.get(STATE_FIELD))) {
73+
updateRequest.retryOnConflict(3);
74+
}
75+
76+
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
77+
client.update(updateRequest, ActionListener.runBefore(listener, context::restore));
78+
} catch (Exception e) {
79+
log.error("Failed to update ML task {}", taskId, e);
80+
listener.onFailure(e);
81+
}
82+
}
83+
}
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.utils;
7+
8+
import static org.mockito.ArgumentMatchers.any;
9+
import static org.mockito.Mockito.doAnswer;
10+
import static org.mockito.Mockito.mock;
11+
import static org.mockito.Mockito.verify;
12+
import static org.mockito.Mockito.when;
13+
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;
14+
15+
import java.util.HashMap;
16+
import java.util.Map;
17+
18+
import org.junit.Before;
19+
import org.junit.Test;
20+
import org.junit.runner.RunWith;
21+
import org.mockito.junit.MockitoJUnitRunner;
22+
import org.opensearch.action.DocWriteResponse;
23+
import org.opensearch.action.update.UpdateRequest;
24+
import org.opensearch.action.update.UpdateResponse;
25+
import org.opensearch.common.settings.Settings;
26+
import org.opensearch.common.util.concurrent.ThreadContext;
27+
import org.opensearch.core.action.ActionListener;
28+
import org.opensearch.core.index.Index;
29+
import org.opensearch.core.index.shard.ShardId;
30+
import org.opensearch.ml.common.MLTaskState;
31+
import org.opensearch.threadpool.ThreadPool;
32+
import org.opensearch.transport.client.Client;
33+
34+
@RunWith(MockitoJUnitRunner.class)
35+
public class MLTaskUtilsTests {
36+
private Client client;
37+
private ThreadPool threadPool;
38+
private ThreadContext threadContext;
39+
40+
@Before
41+
public void setup() {
42+
this.client = mock(Client.class);
43+
this.threadPool = mock(ThreadPool.class);
44+
Settings settings = Settings.builder().build();
45+
this.threadContext = new ThreadContext(settings);
46+
when(client.threadPool()).thenReturn(threadPool);
47+
when(threadPool.getThreadContext()).thenReturn(threadContext);
48+
}
49+
50+
@Test
51+
public void testUpdateMLTaskDirectly_NullFields() {
52+
ActionListener<UpdateResponse> listener = mock(ActionListener.class);
53+
MLTaskUtils.updateMLTaskDirectly("task_id", null, client, listener);
54+
verify(listener).onFailure(any(IllegalArgumentException.class));
55+
}
56+
57+
@Test
58+
public void testUpdateMLTaskDirectly_EmptyFields() {
59+
ActionListener<UpdateResponse> listener = mock(ActionListener.class);
60+
MLTaskUtils.updateMLTaskDirectly("task_id", new HashMap<>(), client, listener);
61+
verify(listener).onFailure(any(IllegalArgumentException.class));
62+
}
63+
64+
@Test
65+
public void testUpdateMLTaskDirectly_NullTaskId() {
66+
ActionListener<UpdateResponse> listener = mock(ActionListener.class);
67+
MLTaskUtils.updateMLTaskDirectly(null, new HashMap<>(), client, listener);
68+
verify(listener).onFailure(any(IllegalArgumentException.class));
69+
}
70+
71+
@Test
72+
public void testUpdateMLTaskDirectly_EmptyTaskId() {
73+
ActionListener<UpdateResponse> listener = mock(ActionListener.class);
74+
MLTaskUtils.updateMLTaskDirectly("", new HashMap<>(), client, listener);
75+
verify(listener).onFailure(any(IllegalArgumentException.class));
76+
}
77+
78+
@Test
79+
public void testUpdateMLTaskDirectly_Success() {
80+
Map<String, Object> updatedFields = new HashMap<>();
81+
updatedFields.put("field1", "value1");
82+
83+
doAnswer(invocation -> {
84+
ActionListener<UpdateResponse> actionListener = invocation.getArgument(1);
85+
ShardId shardId = new ShardId(new Index(ML_TASK_INDEX, "_na_"), 0);
86+
UpdateResponse response = new UpdateResponse(shardId, "task_id", 1, 1, 1, DocWriteResponse.Result.CREATED);
87+
actionListener.onResponse(response);
88+
return null;
89+
}).when(client).update(any(UpdateRequest.class), any());
90+
91+
ActionListener<UpdateResponse> listener = mock(ActionListener.class);
92+
MLTaskUtils.updateMLTaskDirectly("task_id", updatedFields, client, listener);
93+
verify(listener).onResponse(any(UpdateResponse.class));
94+
}
95+
96+
@Test
97+
public void testUpdateMLTaskDirectly_InvalidStateType() {
98+
Map<String, Object> updatedFields = new HashMap<>();
99+
updatedFields.put("state", "INVALID_STATE");
100+
101+
ActionListener<UpdateResponse> listener = mock(ActionListener.class);
102+
MLTaskUtils.updateMLTaskDirectly("task_id", updatedFields, client, listener);
103+
verify(listener).onFailure(any(IllegalArgumentException.class));
104+
}
105+
106+
@Test
107+
public void testUpdateMLTaskDirectly_TaskDoneState() {
108+
Map<String, Object> updatedFields = new HashMap<>();
109+
updatedFields.put("state", MLTaskState.COMPLETED);
110+
111+
doAnswer(invocation -> {
112+
ActionListener<UpdateResponse> actionListener = invocation.getArgument(1);
113+
UpdateRequest request = invocation.getArgument(0);
114+
// Verify retry policy is set for task done state
115+
assert request.retryOnConflict() == 3;
116+
117+
ShardId shardId = new ShardId(new Index(ML_TASK_INDEX, "_na_"), 0);
118+
UpdateResponse response = new UpdateResponse(shardId, "task_id", 1, 1, 1, DocWriteResponse.Result.CREATED);
119+
actionListener.onResponse(response);
120+
return null;
121+
}).when(client).update(any(UpdateRequest.class), any());
122+
123+
ActionListener<UpdateResponse> listener = mock(ActionListener.class);
124+
MLTaskUtils.updateMLTaskDirectly("task_id", updatedFields, client, listener);
125+
verify(listener).onResponse(any(UpdateResponse.class));
126+
}
127+
128+
@Test
129+
public void testUpdateMLTaskDirectly_ClientException() {
130+
Map<String, Object> updatedFields = new HashMap<>();
131+
updatedFields.put("field1", "value1");
132+
133+
doAnswer(invocation -> {
134+
ActionListener<UpdateResponse> actionListener = invocation.getArgument(1);
135+
actionListener.onFailure(new RuntimeException("Test exception"));
136+
return null;
137+
}).when(client).update(any(UpdateRequest.class), any());
138+
139+
ActionListener<UpdateResponse> listener = mock(ActionListener.class);
140+
MLTaskUtils.updateMLTaskDirectly("task_id", updatedFields, client, listener);
141+
verify(listener).onFailure(any(RuntimeException.class));
142+
}
143+
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@
1010
import static org.opensearch.ml.common.CommonValue.MCP_CONNECTORS_FIELD;
1111
import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX;
1212
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;
13-
import static org.opensearch.ml.common.MLTask.LAST_UPDATE_TIME_FIELD;
1413
import static org.opensearch.ml.common.MLTask.RESPONSE_FIELD;
1514
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
15+
import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD;
1616
import static org.opensearch.ml.common.output.model.ModelTensorOutput.INFERENCE_RESULT_FIELD;
1717
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE;
1818
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED;
19+
import static org.opensearch.ml.common.utils.MLTaskUtils.updateMLTaskDirectly;
1920

2021
import java.security.AccessController;
2122
import java.security.PrivilegedActionException;
@@ -34,9 +35,6 @@
3435
import org.opensearch.ResourceNotFoundException;
3536
import org.opensearch.action.get.GetResponse;
3637
import org.opensearch.action.index.IndexResponse;
37-
import org.opensearch.action.support.WriteRequest;
38-
import org.opensearch.action.update.UpdateRequest;
39-
import org.opensearch.action.update.UpdateResponse;
4038
import org.opensearch.cluster.service.ClusterService;
4139
import org.opensearch.common.settings.Settings;
4240
import org.opensearch.common.util.concurrent.ThreadContext;
@@ -82,7 +80,6 @@
8280

8381
import com.google.common.annotations.VisibleForTesting;
8482
import com.google.common.collect.ImmutableList;
85-
import com.google.common.collect.ImmutableSet;
8683
import com.google.gson.Gson;
8784

8885
import lombok.Data;
@@ -101,8 +98,6 @@ public class MLAgentExecutor implements Executable, SettingsChangeListener {
10198
public static final String REGENERATE_INTERACTION_ID = "regenerate_interaction_id";
10299
public static final String MESSAGE_HISTORY_LIMIT = "message_history_limit";
103100
public static final String ERROR_MESSAGE = "error_message";
104-
public static final ImmutableSet<MLTaskState> TASK_DONE_STATES = ImmutableSet
105-
.of(MLTaskState.COMPLETED, MLTaskState.COMPLETED_WITH_ERROR, MLTaskState.FAILED, MLTaskState.CANCELLED);
106101

107102
private Client client;
108103
private SdkClient sdkClient;
@@ -405,8 +400,14 @@ private void executeAgent(
405400
Map<String, Object> agentResponse = new HashMap<>();
406401
if (memoryId != null && !memoryId.isEmpty()) {
407402
agentResponse.put(MEMORY_ID, memoryId);
408-
mlTask.setResponse(agentResponse);
409403
}
404+
405+
String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID);
406+
if (parentInteractionId != null && !parentInteractionId.isEmpty()) {
407+
agentResponse.put(PARENT_INTERACTION_ID, parentInteractionId);
408+
}
409+
mlTask.setResponse(agentResponse);
410+
410411
indexMLTask(mlTask, ActionListener.wrap(indexResponse -> {
411412
String taskId = indexResponse.getId();
412413
mlTask.setTaskId(taskId);
@@ -418,6 +419,7 @@ private void executeAgent(
418419
}
419420
listener.onResponse(outputBuilder);
420421
ActionListener<Object> agentActionListener = createAsyncTaskUpdater(mlTask, outputs, modelTensors);
422+
inputDataSet.getParameters().put(TASK_ID_FIELD, taskId);
421423
mlAgentRunner.run(mlAgent, inputDataSet.getParameters(), agentActionListener);
422424
}, e -> {
423425
log.error("Failed to create task for agent async execution", e);
@@ -468,6 +470,7 @@ private ActionListener<Object> createAsyncTaskUpdater(MLTask mlTask, List<ModelT
468470
updateMLTaskDirectly(
469471
taskId,
470472
updatedTask,
473+
client,
471474
ActionListener
472475
.wrap(
473476
response -> log.info("Updated ML task {} with agent execution results", taskId),
@@ -484,6 +487,7 @@ private ActionListener<Object> createAsyncTaskUpdater(MLTask mlTask, List<ModelT
484487
updateMLTaskDirectly(
485488
taskId,
486489
updatedTask,
490+
client,
487491
ActionListener
488492
.wrap(
489493
response -> log.info("Updated ML task {} with agent execution failed reason", taskId),
@@ -598,29 +602,4 @@ public void indexMLTask(MLTask mlTask, ActionListener<IndexResponse> listener) {
598602
listener.onFailure(e);
599603
}
600604
}
601-
602-
public void updateMLTaskDirectly(String taskId, Map<String, Object> updatedFields, ActionListener<UpdateResponse> listener) {
603-
try {
604-
if (updatedFields == null || updatedFields.isEmpty()) {
605-
listener.onFailure(new IllegalArgumentException("Updated fields is null or empty"));
606-
return;
607-
}
608-
UpdateRequest updateRequest = new UpdateRequest(ML_TASK_INDEX, taskId);
609-
Map<String, Object> updatedContent = new HashMap<>(updatedFields);
610-
updatedContent.put(LAST_UPDATE_TIME_FIELD, Instant.now().toEpochMilli());
611-
updateRequest.doc(updatedContent);
612-
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
613-
if (updatedFields.containsKey(STATE_FIELD) && TASK_DONE_STATES.contains(updatedFields.containsKey(STATE_FIELD))) {
614-
updateRequest.retryOnConflict(3);
615-
}
616-
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
617-
client.update(updateRequest, ActionListener.runBefore(listener, context::restore));
618-
} catch (Exception e) {
619-
listener.onFailure(e);
620-
}
621-
} catch (Exception e) {
622-
log.error("Failed to update ML task {}", taskId, e);
623-
listener.onFailure(e);
624-
}
625-
}
626605
}

0 commit comments

Comments
 (0)