Skip to content

Commit 1690295

Browse files
Expose Update Agent API (#3820)
* update agent draft Signed-off-by: Jiaping Zeng <jpz@amazon.com> * Moved tenant check & refactored response Signed-off-by: Jiaping Zeng <jpz@amazon.com> * added tests Signed-off-by: Jiaping Zeng <jpz@amazon.com> * added agent ID in success logging Signed-off-by: Jiaping Zeng <jpz@amazon.com> * modified update agent request input Signed-off-by: Jiaping Zeng <jpz@amazon.com> * updated unit tests Signed-off-by: Jiaping Zeng <jpz@amazon.com> * added input validation for update agent & replaced assert with exception Signed-off-by: Jiaping Zeng <jpz@amazon.com> * updated and added new tests Signed-off-by: Jiaping Zeng <jpz@amazon.com> * updated test name prefixes Signed-off-by: Jiaping Zeng <jpz@amazon.com> --------- Signed-off-by: Jiaping Zeng <jpz@amazon.com> Co-authored-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent 08f0a88 commit 1690295

File tree

12 files changed

+1656
-0
lines changed

12 files changed

+1656
-0
lines changed

common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ public class MLAgent implements ToXContentObject, Writeable {
5151
public static final String APP_TYPE_FIELD = "app_type";
5252
public static final String IS_HIDDEN_FIELD = "is_hidden";
5353

54+
public static final int AGENT_NAME_MAX_LENGTH = 128;
55+
5456
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT = CommonValue.VERSION_2_13_0;
5557

5658
private String name;
@@ -102,6 +104,11 @@ private void validate() {
102104
if (name == null) {
103105
throw new IllegalArgumentException("Agent name can't be null");
104106
}
107+
if (name.isBlank() || name.length() > AGENT_NAME_MAX_LENGTH) {
108+
throw new IllegalArgumentException(
109+
String.format("Agent name cannot be empty or exceed max length of %d characters", MLAgent.AGENT_NAME_MAX_LENGTH)
110+
);
111+
}
105112
validateMLAgentType(type);
106113
if (type.equalsIgnoreCase(MLAgentType.CONVERSATIONAL.toString()) && llm == null) {
107114
throw new IllegalArgumentException("We need model information for the conversational agent type");
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.agent;
7+
8+
import org.opensearch.action.ActionType;
9+
import org.opensearch.action.update.UpdateResponse;
10+
11+
public class MLAgentUpdateAction extends ActionType<UpdateResponse> {
12+
public static final MLAgentUpdateAction INSTANCE = new MLAgentUpdateAction();
13+
public static final String NAME = "cluster:admin/opensearch/ml/agents/update";
14+
15+
private MLAgentUpdateAction() {
16+
super(NAME, UpdateResponse::new);
17+
}
18+
}
Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.agent;
7+
8+
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
9+
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
10+
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;
11+
12+
import java.io.IOException;
13+
import java.time.Instant;
14+
import java.util.ArrayList;
15+
import java.util.HashSet;
16+
import java.util.List;
17+
import java.util.Map;
18+
import java.util.Optional;
19+
import java.util.Set;
20+
21+
import org.opensearch.Version;
22+
import org.opensearch.core.common.io.stream.StreamInput;
23+
import org.opensearch.core.common.io.stream.StreamOutput;
24+
import org.opensearch.core.common.io.stream.Writeable;
25+
import org.opensearch.core.xcontent.ToXContentObject;
26+
import org.opensearch.core.xcontent.XContentBuilder;
27+
import org.opensearch.core.xcontent.XContentParser;
28+
import org.opensearch.ml.common.agent.LLMSpec;
29+
import org.opensearch.ml.common.agent.MLAgent;
30+
import org.opensearch.ml.common.agent.MLMemorySpec;
31+
import org.opensearch.ml.common.agent.MLToolSpec;
32+
33+
import lombok.Builder;
34+
import lombok.Data;
35+
import lombok.Getter;
36+
37+
@Data
38+
public class MLAgentUpdateInput implements ToXContentObject, Writeable {
39+
40+
public static final String AGENT_ID_FIELD = "agent_id";
41+
public static final String AGENT_NAME_FIELD = "name";
42+
public static final String DESCRIPTION_FIELD = "description";
43+
public static final String LLM_FIELD = "llm";
44+
public static final String TOOLS_FIELD = "tools";
45+
public static final String PARAMETERS_FIELD = "parameters";
46+
public static final String MEMORY_FIELD = "memory";
47+
public static final String APP_TYPE_FIELD = "app_type";
48+
public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time";
49+
50+
@Getter
51+
private String agentId;
52+
private String name;
53+
private String description;
54+
private LLMSpec llm;
55+
private List<MLToolSpec> tools;
56+
private Map<String, String> parameters;
57+
private MLMemorySpec memory;
58+
private String appType;
59+
private Instant lastUpdateTime;
60+
private String tenantId;
61+
62+
@Builder(toBuilder = true)
63+
public MLAgentUpdateInput(
64+
String agentId,
65+
String name,
66+
String description,
67+
LLMSpec llm,
68+
List<MLToolSpec> tools,
69+
Map<String, String> parameters,
70+
MLMemorySpec memory,
71+
String appType,
72+
Instant lastUpdateTime,
73+
String tenantId
74+
) {
75+
this.agentId = agentId;
76+
this.name = name;
77+
this.description = description;
78+
this.llm = llm;
79+
this.tools = tools;
80+
this.parameters = parameters;
81+
this.memory = memory;
82+
this.appType = appType;
83+
this.lastUpdateTime = lastUpdateTime;
84+
this.tenantId = tenantId;
85+
validate();
86+
}
87+
88+
public MLAgentUpdateInput(StreamInput in) throws IOException {
89+
Version streamInputVersion = in.getVersion();
90+
agentId = in.readString();
91+
name = in.readOptionalString();
92+
description = in.readOptionalString();
93+
if (in.readBoolean()) {
94+
llm = new LLMSpec(in);
95+
}
96+
if (in.readBoolean()) {
97+
tools = new ArrayList<>();
98+
int size = in.readInt();
99+
for (int i = 0; i < size; i++) {
100+
tools.add(new MLToolSpec(in));
101+
}
102+
}
103+
if (in.readBoolean()) {
104+
parameters = in.readMap(StreamInput::readString, StreamInput::readOptionalString);
105+
}
106+
if (in.readBoolean()) {
107+
memory = new MLMemorySpec(in);
108+
}
109+
lastUpdateTime = in.readOptionalInstant();
110+
appType = in.readOptionalString();
111+
tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null;
112+
}
113+
114+
@Override
115+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
116+
builder.startObject();
117+
builder.field(AGENT_ID_FIELD, agentId);
118+
if (name != null) {
119+
builder.field(AGENT_NAME_FIELD, name);
120+
}
121+
if (description != null) {
122+
builder.field(DESCRIPTION_FIELD, description);
123+
}
124+
if (llm != null) {
125+
builder.field(LLM_FIELD, llm);
126+
}
127+
if (tools != null && !tools.isEmpty()) {
128+
builder.field(TOOLS_FIELD, tools);
129+
}
130+
if (parameters != null && !parameters.isEmpty()) {
131+
builder.field(PARAMETERS_FIELD, parameters);
132+
}
133+
if (memory != null) {
134+
builder.field(MEMORY_FIELD, memory);
135+
}
136+
if (appType != null) {
137+
builder.field(APP_TYPE_FIELD, appType);
138+
}
139+
if (lastUpdateTime != null) {
140+
builder.field(LAST_UPDATED_TIME_FIELD, lastUpdateTime.toEpochMilli());
141+
}
142+
if (tenantId != null) {
143+
builder.field(TENANT_ID_FIELD, tenantId);
144+
}
145+
builder.endObject();
146+
return builder;
147+
}
148+
149+
@Override
150+
public void writeTo(StreamOutput out) throws IOException {
151+
Version streamOutputVersion = out.getVersion();
152+
out.writeString(agentId);
153+
out.writeOptionalString(name);
154+
out.writeOptionalString(description);
155+
if (llm != null) {
156+
out.writeBoolean(true);
157+
llm.writeTo(out);
158+
} else {
159+
out.writeBoolean(false);
160+
}
161+
if (tools != null && !tools.isEmpty()) {
162+
out.writeBoolean(true);
163+
out.writeInt(tools.size());
164+
for (MLToolSpec tool : tools) {
165+
tool.writeTo(out);
166+
}
167+
} else {
168+
out.writeBoolean(false);
169+
}
170+
if (parameters != null && !parameters.isEmpty()) {
171+
out.writeBoolean(true);
172+
out.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeOptionalString);
173+
} else {
174+
out.writeBoolean(false);
175+
}
176+
if (memory != null) {
177+
out.writeBoolean(true);
178+
memory.writeTo(out);
179+
} else {
180+
out.writeBoolean(false);
181+
}
182+
out.writeOptionalInstant(lastUpdateTime);
183+
out.writeOptionalString(appType);
184+
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
185+
out.writeOptionalString(tenantId);
186+
}
187+
}
188+
189+
public static MLAgentUpdateInput parse(XContentParser parser) throws IOException {
190+
String agentId = null;
191+
String name = null;
192+
String description = null;
193+
LLMSpec llm = null;
194+
List<MLToolSpec> tools = null;
195+
Map<String, String> parameters = null;
196+
MLMemorySpec memory = null;
197+
String appType = null;
198+
Instant lastUpdateTime = null;
199+
String tenantId = null;
200+
201+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
202+
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
203+
String fieldName = parser.currentName();
204+
parser.nextToken();
205+
switch (fieldName) {
206+
case AGENT_ID_FIELD:
207+
agentId = parser.text();
208+
break;
209+
case AGENT_NAME_FIELD:
210+
name = parser.text();
211+
break;
212+
case DESCRIPTION_FIELD:
213+
description = parser.text();
214+
break;
215+
case LLM_FIELD:
216+
llm = LLMSpec.parse(parser);
217+
break;
218+
case TOOLS_FIELD:
219+
tools = new ArrayList<>();
220+
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
221+
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
222+
tools.add(MLToolSpec.parse(parser));
223+
}
224+
break;
225+
case PARAMETERS_FIELD:
226+
parameters = parser.mapStrings();
227+
break;
228+
case MEMORY_FIELD:
229+
memory = MLMemorySpec.parse(parser);
230+
break;
231+
case APP_TYPE_FIELD:
232+
appType = parser.text();
233+
break;
234+
case LAST_UPDATED_TIME_FIELD:
235+
lastUpdateTime = Instant.ofEpochMilli(parser.longValue());
236+
break;
237+
case TENANT_ID_FIELD:
238+
tenantId = parser.textOrNull();
239+
break;
240+
default:
241+
parser.skipChildren();
242+
break;
243+
}
244+
}
245+
246+
return new MLAgentUpdateInput(agentId, name, description, llm, tools, parameters, memory, appType, lastUpdateTime, tenantId);
247+
}
248+
249+
public MLAgent toMLAgent(MLAgent originalAgent) {
250+
return MLAgent
251+
.builder()
252+
.type(originalAgent.getType())
253+
.createdTime(originalAgent.getCreatedTime())
254+
.isHidden(originalAgent.getIsHidden())
255+
.name(name == null ? originalAgent.getName() : name)
256+
.description(description == null ? originalAgent.getDescription() : description)
257+
.llm(llm == null ? originalAgent.getLlm() : llm)
258+
.tools(tools == null ? originalAgent.getTools() : tools)
259+
.parameters(parameters == null ? originalAgent.getParameters() : parameters)
260+
.memory(memory == null ? originalAgent.getMemory() : memory)
261+
.lastUpdateTime(lastUpdateTime)
262+
.appType(appType)
263+
.tenantId(tenantId)
264+
.build();
265+
}
266+
267+
private void validate() {
268+
if (name != null && (name.isBlank() || name.length() > MLAgent.AGENT_NAME_MAX_LENGTH)) {
269+
throw new IllegalArgumentException(
270+
String.format("Agent name cannot be empty or exceed max length of %d characters", MLAgent.AGENT_NAME_MAX_LENGTH)
271+
);
272+
}
273+
if (memory != null && !memory.getType().equals("conversation_index")) {
274+
throw new IllegalArgumentException(String.format("Invalid memory type: %s", memory.getType()));
275+
}
276+
if (tools != null) {
277+
Set<String> toolNames = new HashSet<>();
278+
for (MLToolSpec toolSpec : tools) {
279+
String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getType());
280+
if (toolNames.contains(toolName)) {
281+
throw new IllegalArgumentException("Duplicate tool defined: " + toolName);
282+
} else {
283+
toolNames.add(toolName);
284+
}
285+
}
286+
}
287+
}
288+
}

0 commit comments

Comments
 (0)