Skip to content

Commit a9293d1

Browse files
authored
Merge branch 'main' into agent_bug_fix
2 parents 6f5f6ed + 310f556 commit a9293d1

File tree

10 files changed

+1028
-16
lines changed

10 files changed

+1028
-16
lines changed

common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequest.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ public ActionRequestValidationException validate() {
4646
return addValidationError("ML Connector input can't be null", null);
4747
}
4848
Map<String, FieldDescriptor> fieldsToValidate = new HashMap<>();
49-
fieldsToValidate.put("Model connector name", new FieldDescriptor(mlCreateConnectorInput.getName(), true));
49+
50+
fieldsToValidate
51+
.put("Model connector name", new FieldDescriptor(mlCreateConnectorInput.getName(), !mlCreateConnectorInput.isDryRun()));
5052
fieldsToValidate.put("Model connector description", new FieldDescriptor(mlCreateConnectorInput.getDescription(), false));
5153

5254
return validateFields(fieldsToValidate);

common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.opensearch.ml.common.transport.connector;
77

88
import static org.junit.Assert.assertEquals;
9+
import static org.junit.Assert.assertNotNull;
910
import static org.junit.Assert.assertNotSame;
1011
import static org.junit.Assert.assertNull;
1112
import static org.junit.Assert.assertSame;
@@ -267,4 +268,47 @@ public void validateWithEmptyAndInvalidModelConnectorNameAndDescription() {
267268
);
268269
}
269270

271+
@Test
272+
public void validateWithDryRun() {
273+
// Test with dry run set to true
274+
MLCreateConnectorInput dryRunInput = MLCreateConnectorInput
275+
.builder()
276+
.name("") // Empty name, which would normally fail validation
277+
.description("Test description")
278+
.version("1")
279+
.protocol("http")
280+
.parameters(Map.of("input", "test"))
281+
.credential(Map.of("key", "value"))
282+
.actions(List.of())
283+
.access(AccessMode.PUBLIC)
284+
.backendRoles(Arrays.asList("role1"))
285+
.addAllBackendRoles(false)
286+
.dryRun(true) // Set dry run to true
287+
.build();
288+
289+
MLCreateConnectorRequest dryRunRequest = MLCreateConnectorRequest.builder().mlCreateConnectorInput(dryRunInput).build();
290+
ActionRequestValidationException dryRunException = dryRunRequest.validate();
291+
assertNull("Validation should pass when dry run is true, even with empty name", dryRunException);
292+
293+
// Test with dry run set to false (default behavior)
294+
MLCreateConnectorInput nonDryRunInput = MLCreateConnectorInput
295+
.builder()
296+
.name("") // Empty name, which should fail validation
297+
.description("Test description")
298+
.version("1")
299+
.protocol("http")
300+
.parameters(Map.of("input", "test"))
301+
.credential(Map.of("key", "value"))
302+
.actions(List.of())
303+
.access(AccessMode.PUBLIC)
304+
.backendRoles(Arrays.asList("role1"))
305+
.addAllBackendRoles(false)
306+
.dryRun(false) // Set dry run to false
307+
.build();
308+
309+
MLCreateConnectorRequest nonDryRunRequest = MLCreateConnectorRequest.builder().mlCreateConnectorInput(nonDryRunInput).build();
310+
ActionRequestValidationException nonDryRunException = nonDryRunRequest.validate();
311+
assertNotNull("Validation should fail when dry run is false and name is empty", nonDryRunException);
312+
assertTrue(nonDryRunException.getMessage().contains("Model connector name is required"));
313+
}
270314
}

docs/tutorials/aws/AIConnectorHelper.ipynb

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,22 @@
3737
" domain_info = response['DomainStatus']\n",
3838
"\n",
3939
" # Extract the domain URL and ARN\n",
40-
" domain_url = domain_info['Endpoint']\n",
4140
" domain_arn = domain_info['ARN']\n",
4241
"\n",
42+
" # Check if domain has VPC endpoints\n",
43+
" if 'Endpoints' in domain_info:\n",
44+
" # VPC domain case\n",
45+
" if 'vpc' in domain_info['Endpoints']:\n",
46+
" domain_url = domain_info['Endpoints']['vpc']\n",
47+
" else:\n",
48+
" domain_url = next(iter(domain_info['Endpoints'].values()))\n",
49+
" # Non-VPC domain case\n",
50+
" elif 'Endpoint' in domain_info:\n",
51+
" domain_url = domain_info['Endpoint']\n",
52+
" else:\n",
53+
" print(f\"No endpoint found for domain '{domain_name}'\")\n",
54+
" return None, None\n",
55+
"\n",
4356
" return f'https://{domain_url}', domain_arn\n",
4457
"\n",
4558
" except opensearch_client.exceptions.ResourceNotFoundException:\n",

docs/tutorials/aws/DeepSeek_demo_notebook_for_RAG.ipynb

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,22 @@
6565
" try:\n",
6666
" response = self.opensearch_client.describe_elasticsearch_domain(DomainName=domain_name)\n",
6767
" domain_info = response['DomainStatus']\n",
68-
" domain_url = domain_info['Endpoint']\n",
6968
" domain_arn = domain_info['ARN']\n",
69+
"\n",
70+
" # Check if domain has VPC endpoints\n",
71+
" if 'Endpoints' in domain_info:\n",
72+
" # VPC domain case\n",
73+
" if 'vpc' in domain_info['Endpoints']:\n",
74+
" domain_url = domain_info['Endpoints']['vpc']\n",
75+
" else:\n",
76+
" domain_url = next(iter(domain_info['Endpoints'].values()))\n",
77+
" # Non-VPC domain case\n",
78+
" elif 'Endpoint' in domain_info:\n",
79+
" domain_url = domain_info['Endpoint']\n",
80+
" else:\n",
81+
" print(f\"No endpoint found for domain '{domain_name}'\")\n",
82+
" return None, None\n",
83+
"\n",
7084
" return f'https://{domain_url}', domain_arn\n",
7185
" except self.opensearch_client.exceptions.ResourceNotFoundException:\n",
7286
" print(f\"Domain '{domain_name}' not found.\")\n",

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
import org.opensearch.remote.metadata.client.SdkClient;
6767
import org.opensearch.transport.client.Client;
6868

69+
import com.google.common.annotations.VisibleForTesting;
6970
import com.jayway.jsonpath.JsonPath;
7071

7172
import joptsimple.internal.Strings;
@@ -154,7 +155,8 @@ public MLPlanExecuteAndReflectAgentRunner(
154155
this.plannerWithHistoryPromptTemplate = DEFAULT_PLANNER_WITH_HISTORY_PROMPT_TEMPLATE;
155156
}
156157

157-
private void setupPromptParameters(Map<String, String> params) {
158+
@VisibleForTesting
159+
void setupPromptParameters(Map<String, String> params) {
158160
// populated depending on whether LLM is asked to plan or re-evaluate
159161
// removed here, so that error is thrown in case this field is not populated
160162
params.remove(PROMPT_FIELD);
@@ -203,22 +205,26 @@ private void setupPromptParameters(Map<String, String> params) {
203205
}
204206
}
205207

206-
private void usePlannerPromptTemplate(Map<String, String> params) {
208+
@VisibleForTesting
209+
void usePlannerPromptTemplate(Map<String, String> params) {
207210
params.put(PROMPT_TEMPLATE_FIELD, this.plannerPromptTemplate);
208211
populatePrompt(params);
209212
}
210213

211-
private void useReflectPromptTemplate(Map<String, String> params) {
214+
@VisibleForTesting
215+
void useReflectPromptTemplate(Map<String, String> params) {
212216
params.put(PROMPT_TEMPLATE_FIELD, this.reflectPromptTemplate);
213217
populatePrompt(params);
214218
}
215219

216-
private void usePlannerWithHistoryPromptTemplate(Map<String, String> params) {
220+
@VisibleForTesting
221+
void usePlannerWithHistoryPromptTemplate(Map<String, String> params) {
217222
params.put(PROMPT_TEMPLATE_FIELD, this.plannerWithHistoryPromptTemplate);
218223
populatePrompt(params);
219224
}
220225

221-
private void populatePrompt(Map<String, String> allParams) {
226+
@VisibleForTesting
227+
void populatePrompt(Map<String, String> allParams) {
222228
String promptTemplate = allParams.get(PROMPT_TEMPLATE_FIELD);
223229
StringSubstitutor promptSubstitutor = new StringSubstitutor(allParams, "${parameters.", "}");
224230
String prompt = promptSubstitutor.replace(promptTemplate);
@@ -475,7 +481,8 @@ private void executePlanningLoop(
475481
client.execute(MLPredictionTaskAction.INSTANCE, request, planListener);
476482
}
477483

478-
private Map<String, String> parseLLMOutput(Map<String, String> allParams, ModelTensorOutput modelTensorOutput) {
484+
@VisibleForTesting
485+
Map<String, String> parseLLMOutput(Map<String, String> allParams, ModelTensorOutput modelTensorOutput) {
479486
Map<String, String> modelOutput = new HashMap<>();
480487
Map<String, ?> dataAsMap = modelTensorOutput.getMlModelOutputs().getFirst().getMlModelTensors().getFirst().getDataAsMap();
481488
String llmResponse;
@@ -513,7 +520,8 @@ private Map<String, String> parseLLMOutput(Map<String, String> allParams, ModelT
513520
return modelOutput;
514521
}
515522

516-
private String extractJsonFromMarkdown(String response) {
523+
@VisibleForTesting
524+
String extractJsonFromMarkdown(String response) {
517525
response = response.trim();
518526
if (response.contains("```json")) {
519527
response = response.substring(response.indexOf("```json") + "```json".length());
@@ -535,7 +543,8 @@ private String extractJsonFromMarkdown(String response) {
535543
return response;
536544
}
537545

538-
private void addToolsToPrompt(Map<String, Tool> tools, Map<String, String> allParams) {
546+
@VisibleForTesting
547+
void addToolsToPrompt(Map<String, Tool> tools, Map<String, String> allParams) {
539548
StringBuilder toolsPrompt = new StringBuilder("In this environment, you have access to the below tools: \n");
540549
for (Map.Entry<String, Tool> entry : tools.entrySet()) {
541550
String toolName = entry.getKey();
@@ -548,11 +557,13 @@ private void addToolsToPrompt(Map<String, Tool> tools, Map<String, String> allPa
548557
cleanUpResource(tools);
549558
}
550559

551-
private void addSteps(List<String> steps, Map<String, String> allParams, String field) {
560+
@VisibleForTesting
561+
void addSteps(List<String> steps, Map<String, String> allParams, String field) {
552562
allParams.put(field, String.join(", ", steps));
553563
}
554564

555-
private void saveAndReturnFinalResult(
565+
@VisibleForTesting
566+
void saveAndReturnFinalResult(
556567
ConversationIndexMemory memory,
557568
String parentInteractionId,
558569
String reactAgentMemoryId,
@@ -591,7 +602,8 @@ private void saveAndReturnFinalResult(
591602
}));
592603
}
593604

594-
private static List<ModelTensors> createModelTensors(
605+
@VisibleForTesting
606+
static List<ModelTensors> createModelTensors(
595607
String sessionId,
596608
String parentInteractionId,
597609
String reactAgentMemoryId,

0 commit comments

Comments
 (0)