Skip to content

Commit 4b7a08c

Browse files
authored
implement async mode in agent execution (#3714)
* implement async mode in agent execution Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com> * address comments Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com> * fix build failure Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com> * fix tests Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com> * change response field name Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com> * bump task index schema version Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com> --------- Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com>
1 parent 46481bd commit 4b7a08c

File tree

10 files changed

+548
-49
lines changed

10 files changed

+548
-49
lines changed

common/src/main/java/org/opensearch/ml/common/MLTask.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ public class MLTask implements ToXContentObject, Writeable {
5050
public static final String ERROR_FIELD = "error";
5151
public static final String IS_ASYNC_TASK_FIELD = "is_async";
5252
public static final String REMOTE_JOB_FIELD = "remote_job";
53+
public static final String RESPONSE_FIELD = "response";
5354
public static final Version MINIMAL_SUPPORTED_VERSION_FOR_BATCH_PREDICTION_JOB = CommonValue.VERSION_2_17_0;
55+
public static final Version MINIMAL_SUPPORTED_VERSION_FOR_RESPONSE_FIELD = CommonValue.VERSION_3_0_0;
5456

5557
@Setter
5658
private String taskId;
@@ -74,6 +76,8 @@ public class MLTask implements ToXContentObject, Writeable {
7476
private boolean async;
7577
@Setter
7678
private Map<String, Object> remoteJob;
79+
@Setter
80+
private Map<String, Object> response;
7781
private String tenantId;
7882

7983
@Builder(toBuilder = true)
@@ -93,6 +97,7 @@ public MLTask(
9397
User user,
9498
boolean async,
9599
Map<String, Object> remoteJob,
100+
Map<String, Object> response,
96101
String tenantId
97102
) {
98103
this.taskId = taskId;
@@ -110,6 +115,7 @@ public MLTask(
110115
this.user = user;
111116
this.async = async;
112117
this.remoteJob = remoteJob;
118+
this.response = response;
113119
this.tenantId = tenantId;
114120
}
115121

@@ -142,6 +148,11 @@ public MLTask(StreamInput input) throws IOException {
142148
this.remoteJob = input.readMap(StreamInput::readString, StreamInput::readGenericValue);
143149
}
144150
}
151+
if (streamInputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_RESPONSE_FIELD)) {
152+
if (input.readBoolean()) {
153+
this.response = input.readMap(StreamInput::readString, StreamInput::readGenericValue);
154+
}
155+
}
145156
tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
146157
}
147158

@@ -179,6 +190,14 @@ public void writeTo(StreamOutput out) throws IOException {
179190
out.writeBoolean(false);
180191
}
181192
}
193+
if (streamOutputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_RESPONSE_FIELD)) {
194+
if (response != null) {
195+
out.writeBoolean(true);
196+
out.writeMap(response, StreamOutput::writeString, StreamOutput::writeGenericValue);
197+
} else {
198+
out.writeBoolean(false);
199+
}
200+
}
182201
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
183202
out.writeOptionalString(tenantId);
184203
}
@@ -230,6 +249,9 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
230249
if (remoteJob != null) {
231250
builder.field(REMOTE_JOB_FIELD, remoteJob);
232251
}
252+
if (response != null) {
253+
builder.field(RESPONSE_FIELD, response);
254+
}
233255
if (tenantId != null) {
234256
builder.field(TENANT_ID_FIELD, tenantId);
235257
}
@@ -256,6 +278,7 @@ public static MLTask parse(XContentParser parser) throws IOException {
256278
User user = null;
257279
boolean async = false;
258280
Map<String, Object> remoteJob = null;
281+
Map<String, Object> response = null;
259282
String tenantId = null;
260283

261284
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
@@ -317,6 +340,9 @@ public static MLTask parse(XContentParser parser) throws IOException {
317340
case REMOTE_JOB_FIELD:
318341
remoteJob = parser.map();
319342
break;
343+
case RESPONSE_FIELD:
344+
response = parser.map();
345+
break;
320346
case TENANT_ID_FIELD:
321347
tenantId = parser.textOrNull();
322348
break;
@@ -342,6 +368,7 @@ public static MLTask parse(XContentParser parser) throws IOException {
342368
.user(user)
343369
.async(async)
344370
.remoteJob(remoteJob)
371+
.response(response)
345372
.tenantId(tenantId)
346373
.build();
347374
}

common/src/main/java/org/opensearch/ml/common/MLTaskType.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@ public enum MLTaskType {
1717
REGISTER_MODEL,
1818
DEPLOY_MODEL,
1919
BATCH_INGEST,
20-
BATCH_PREDICTION
20+
BATCH_PREDICTION,
21+
AGENT_EXECUTION
2122
}

common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.opensearch.core.common.io.stream.StreamInput;
1717
import org.opensearch.core.common.io.stream.StreamOutput;
1818
import org.opensearch.core.xcontent.XContentParser;
19+
import org.opensearch.ml.common.CommonValue;
1920
import org.opensearch.ml.common.FunctionName;
2021
import org.opensearch.ml.common.dataset.MLInputDataset;
2122
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
@@ -30,6 +31,9 @@
3031
public class AgentMLInput extends MLInput {
3132
public static final String AGENT_ID_FIELD = "agent_id";
3233
public static final String PARAMETERS_FIELD = "parameters";
34+
public static final String ASYNC_FIELD = "isAsync";
35+
36+
public static final Version MINIMAL_SUPPORTED_VERSION_FOR_ASYNC_EXECUTION = CommonValue.VERSION_3_0_0;
3337

3438
@Getter
3539
@Setter
@@ -39,12 +43,22 @@ public class AgentMLInput extends MLInput {
3943
@Setter
4044
private String tenantId;
4145

46+
@Getter
47+
@Setter
48+
private Boolean isAsync;
49+
4250
@Builder(builderMethodName = "AgentMLInputBuilder")
4351
public AgentMLInput(String agentId, String tenantId, FunctionName functionName, MLInputDataset inputDataset) {
52+
this(agentId, tenantId, functionName, inputDataset, false);
53+
}
54+
55+
@Builder(builderMethodName = "AgentMLInputBuilder")
56+
public AgentMLInput(String agentId, String tenantId, FunctionName functionName, MLInputDataset inputDataset, Boolean isAsync) {
4457
this.agentId = agentId;
4558
this.tenantId = tenantId;
4659
this.algorithm = functionName;
4760
this.inputDataset = inputDataset;
61+
this.isAsync = isAsync;
4862
}
4963

5064
@Override
@@ -55,13 +69,19 @@ public void writeTo(StreamOutput out) throws IOException {
5569
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
5670
out.writeOptionalString(tenantId);
5771
}
72+
if (streamOutputVersion.onOrAfter(AgentMLInput.MINIMAL_SUPPORTED_VERSION_FOR_ASYNC_EXECUTION)) {
73+
out.writeOptionalBoolean(isAsync);
74+
}
5875
}
5976

6077
public AgentMLInput(StreamInput in) throws IOException {
6178
super(in);
6279
Version streamInputVersion = in.getVersion();
6380
this.agentId = in.readString();
6481
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null;
82+
if (streamInputVersion.onOrAfter(AgentMLInput.MINIMAL_SUPPORTED_VERSION_FOR_ASYNC_EXECUTION)) {
83+
this.isAsync = in.readOptionalBoolean();
84+
}
6585
}
6686

6787
public AgentMLInput(XContentParser parser, FunctionName functionName) throws IOException {
@@ -83,6 +103,9 @@ public AgentMLInput(XContentParser parser, FunctionName functionName) throws IOE
83103
Map<String, String> parameters = StringUtils.getParameterMap(parser.map());
84104
inputDataset = new RemoteInferenceInputDataSet(parameters);
85105
break;
106+
case ASYNC_FIELD:
107+
isAsync = parser.booleanValue();
108+
break;
86109
default:
87110
parser.skipChildren();
88111
break;

common/src/main/java/org/opensearch/ml/common/output/MLOutputType.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@ public enum MLOutputType {
1010
PREDICTION,
1111
SAMPLE_ALGO,
1212
MODEL_TENSOR,
13-
MCORR_TENSOR
13+
MCORR_TENSOR,
14+
ML_TASK_OUTPUT
1415
}
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.output;
7+
8+
import java.io.IOException;
9+
import java.util.Map;
10+
11+
import org.opensearch.core.common.io.stream.StreamInput;
12+
import org.opensearch.core.common.io.stream.StreamOutput;
13+
import org.opensearch.core.xcontent.XContentBuilder;
14+
import org.opensearch.ml.common.annotation.MLAlgoOutput;
15+
16+
import lombok.Builder;
17+
import lombok.Data;
18+
import lombok.EqualsAndHashCode;
19+
20+
@Data
21+
@EqualsAndHashCode(callSuper = false)
22+
@MLAlgoOutput(MLOutputType.ML_TASK_OUTPUT)
23+
public class MLTaskOutput extends MLOutput {
24+
25+
private static final MLOutputType OUTPUT_TYPE = MLOutputType.ML_TASK_OUTPUT;
26+
public static final String TASK_ID_FIELD = "task_id";
27+
public static final String STATUS_FIELD = "status";
28+
public static final String RESPONSE_FIELD = "response";
29+
30+
String taskId;
31+
String status;
32+
Map<String, Object> response;
33+
34+
@Builder
35+
public MLTaskOutput(String taskId, String status, Map<String, Object> response) {
36+
super(OUTPUT_TYPE);
37+
this.taskId = taskId;
38+
this.status = status;
39+
this.response = response;
40+
}
41+
42+
public MLTaskOutput(StreamInput in) throws IOException {
43+
super(OUTPUT_TYPE);
44+
this.taskId = in.readOptionalString();
45+
this.status = in.readOptionalString();
46+
if (in.readBoolean()) {
47+
this.response = in.readMap(s -> s.readString(), s -> s.readGenericValue());
48+
}
49+
}
50+
51+
@Override
52+
public void writeTo(StreamOutput out) throws IOException {
53+
super.writeTo(out);
54+
out.writeOptionalString(taskId);
55+
out.writeOptionalString(status);
56+
if (response != null) {
57+
out.writeBoolean(true);
58+
out.writeMap(response, StreamOutput::writeString, StreamOutput::writeGenericValue);
59+
} else {
60+
out.writeBoolean(false);
61+
}
62+
}
63+
64+
@Override
65+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
66+
builder.startObject();
67+
if (taskId != null) {
68+
builder.field(TASK_ID_FIELD, taskId);
69+
}
70+
if (status != null) {
71+
builder.field(STATUS_FIELD, status);
72+
}
73+
74+
if (response != null) {
75+
builder.field(RESPONSE_FIELD, response);
76+
}
77+
78+
builder.endObject();
79+
return builder;
80+
}
81+
82+
@Override
83+
public MLOutputType getType() {
84+
return OUTPUT_TYPE;
85+
}
86+
}

common/src/main/resources/index-mappings/ml_task.json

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"_meta": {
3-
"schema_version": 4
3+
"schema_version": 5
44
},
55
"properties": {
66
"model_id": {
@@ -47,6 +47,9 @@
4747
"remote_job": {
4848
"type": "flat_object"
4949
},
50-
"user": USER_MAPPING_PLACEHOLDER
50+
"user": USER_MAPPING_PLACEHOLDER,
51+
"response": {
52+
"type": "flat_object"
53+
}
5154
}
5255
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.output;
7+
8+
import static org.junit.Assert.assertEquals;
9+
10+
import java.io.IOException;
11+
import java.util.ArrayList;
12+
import java.util.HashMap;
13+
import java.util.List;
14+
import java.util.Map;
15+
16+
import org.junit.Before;
17+
import org.junit.Test;
18+
import org.opensearch.common.io.stream.BytesStreamOutput;
19+
import org.opensearch.common.xcontent.XContentType;
20+
import org.opensearch.core.common.io.stream.StreamInput;
21+
import org.opensearch.core.xcontent.MediaTypeRegistry;
22+
import org.opensearch.core.xcontent.ToXContent;
23+
import org.opensearch.core.xcontent.XContentBuilder;
24+
import org.opensearch.ml.common.dataframe.ColumnMeta;
25+
import org.opensearch.ml.common.dataframe.ColumnType;
26+
import org.opensearch.ml.common.dataframe.ColumnValue;
27+
import org.opensearch.ml.common.dataframe.IntValue;
28+
import org.opensearch.ml.common.dataframe.Row;
29+
30+
public class MLTaskOutputTest {
31+
32+
MLTaskOutput output;
33+
34+
@Before
35+
public void setUp() {
36+
ColumnMeta[] columnMetas = new ColumnMeta[] { new ColumnMeta("test", ColumnType.INTEGER) };
37+
List<Row> rows = new ArrayList<>();
38+
rows.add(new Row(new ColumnValue[] { new IntValue(1) }));
39+
rows.add(new Row(new ColumnValue[] { new IntValue(2) }));
40+
Map<String, Object> response = new HashMap<>();
41+
response.put("memory_id", "test-memory-id");
42+
output = MLTaskOutput.builder().taskId("test_task_id").status("test_status").response(response).build();
43+
}
44+
45+
@Test
46+
public void toXContent() throws IOException {
47+
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
48+
XContentBuilder builderWithExecuteResponse = MediaTypeRegistry.contentBuilder(XContentType.JSON);
49+
output.toXContent(builder, ToXContent.EMPTY_PARAMS);
50+
String jsonStr = builder.toString();
51+
assertEquals("{\"task_id\":\"test_task_id\",\"status\":\"test_status\",\"response\":{\"memory_id\":\"test-memory-id\"}}", jsonStr);
52+
output.toXContent(builderWithExecuteResponse, ToXContent.EMPTY_PARAMS);
53+
String jsonStr2 = builderWithExecuteResponse.toString();
54+
assertEquals("{\"task_id\":\"test_task_id\",\"status\":\"test_status\",\"response\":{\"memory_id\":\"test-memory-id\"}}", jsonStr2);
55+
}
56+
57+
@Test
58+
public void toXContent_EmptyOutput() throws IOException {
59+
MLTaskOutput output = MLTaskOutput.builder().build();
60+
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
61+
output.toXContent(builder, ToXContent.EMPTY_PARAMS);
62+
String jsonStr = builder.toString();
63+
assertEquals("{}", jsonStr);
64+
}
65+
66+
@Test
67+
public void readInputStream_Success() throws IOException {
68+
readInputStream(output);
69+
}
70+
71+
private void readInputStream(MLTaskOutput output) throws IOException {
72+
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
73+
output.writeTo(bytesStreamOutput);
74+
75+
StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
76+
MLOutputType outputType = streamInput.readEnum(MLOutputType.class);
77+
assertEquals(MLOutputType.ML_TASK_OUTPUT, outputType);
78+
MLTaskOutput parsedOutput = new MLTaskOutput(streamInput);
79+
assertEquals(output.getType(), parsedOutput.getType());
80+
assertEquals(output.getTaskId(), parsedOutput.getTaskId());
81+
assertEquals(output.getStatus(), parsedOutput.getStatus());
82+
}
83+
}

0 commit comments

Comments
 (0)