Skip to content

Commit f94f138

Browse files
committed
[Enhancement] Enhance validation for create connector API
This change will address the second part of validation "pre and post processing function validation". Partially resolves opensearch-project#2993 Signed-off-by: Abdul Muneer Kolarkunnu <muneer.kolarkunnu@netapp.com>
1 parent 0e91533 commit f94f138

File tree

2 files changed

+58
-98
lines changed

2 files changed

+58
-98
lines changed

common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java

Lines changed: 34 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,6 @@
66
package org.opensearch.ml.common.connector;
77

88
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
9-
import static org.opensearch.ml.common.connector.MLPostProcessFunction.DEFAULT_EMBEDDING;
10-
import static org.opensearch.ml.common.connector.MLPostProcessFunction.DEFAULT_RERANK;
11-
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT;
12-
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_SIMILARITY_TO_DEFAULT_INPUT;
139

1410
import java.io.IOException;
1511
import java.util.HashSet;
@@ -19,6 +15,8 @@
1915
import java.util.Set;
2016

2117
import org.apache.commons.text.StringSubstitutor;
18+
import org.apache.logging.log4j.LogManager;
19+
import org.apache.logging.log4j.Logger;
2220
import org.opensearch.core.common.io.stream.StreamInput;
2321
import org.opensearch.core.common.io.stream.StreamOutput;
2422
import org.opensearch.core.common.io.stream.Writeable;
@@ -45,9 +43,11 @@ public class ConnectorAction implements ToXContentObject, Writeable {
4543
public static final String COHERE = "cohere";
4644
public static final String BEDROCK = "bedrock";
4745
public static final String SAGEMAKER = "sagemaker";
46+
public static final String SAGEMAKER_PRE_POST_FUNC_TEXT = "default";
4847
public static final List<String> SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES = List.of(SAGEMAKER, OPENAI, BEDROCK, COHERE);
4948

5049
private static final String INBUILT_FUNC_PREFIX = "connector.";
50+
private static final Logger logger = LogManager.getLogger(ConnectorAction.class);
5151

5252
private ActionType actionType;
5353
private String method;
@@ -210,16 +210,12 @@ public void validatePrePostProcessFunctions(Map<String, String> parameters) {
210210
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
211211
String endPoint = substitutor.replace(url);
212212
String remoteServer = getRemoteServerFromURL(endPoint);
213-
if (isInBuiltProcessFunction(preProcessFunction)) {
214-
validatePreProcessFunctions(remoteServer);
215-
}
216-
if (isInBuiltProcessFunction(postProcessFunction)) {
217-
validatePostProcessFunctions(remoteServer);
218-
}
213+
validateProcessFunctions(remoteServer, preProcessFunction);
214+
validateProcessFunctions(remoteServer, postProcessFunction);
219215
}
220216

221217
/**
222-
* To get the remote server name from ULR
218+
* To get the remote server name from url
223219
*
224220
* @param url - remote server url
225221
* @return - returns the corresponding remote server name for url, if server is not in the pre-defined list,
@@ -229,70 +225,38 @@ public static String getRemoteServerFromURL(String url) {
229225
return SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES.stream().filter(url::contains).findFirst().orElse("");
230226
}
231227

232-
private boolean isInBuiltProcessFunction(String processFunction) {
233-
return (processFunction != null && processFunction.startsWith(INBUILT_FUNC_PREFIX));
234-
}
235-
236-
private void validatePreProcessFunctions(String remoteServer) {
237-
switch (remoteServer) {
238-
case OPENAI:
239-
if (!preProcessFunction.contains(OPENAI)) {
240-
throw new IllegalArgumentException(invalidProcessFuncExcText(OPENAI, "PreProcessFunction"));
241-
}
242-
break;
243-
case COHERE:
244-
if (!preProcessFunction.contains(COHERE)) {
245-
throw new IllegalArgumentException(invalidProcessFuncExcText(COHERE, "PreProcessFunction"));
246-
}
247-
break;
248-
case BEDROCK:
249-
if (!preProcessFunction.contains(BEDROCK)) {
250-
throw new IllegalArgumentException(invalidProcessFuncExcText(BEDROCK, "PreProcessFunction"));
251-
}
252-
break;
253-
case SAGEMAKER:
254-
if (!(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT.equals(preProcessFunction)
255-
|| TEXT_SIMILARITY_TO_DEFAULT_INPUT.equals(preProcessFunction))) {
256-
throw new IllegalArgumentException(
257-
"LLM service is "
258-
+ SAGEMAKER
259-
+ ", so PreProcessFunction should be "
260-
+ TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT
261-
+ " or "
262-
+ TEXT_SIMILARITY_TO_DEFAULT_INPUT
263-
);
264-
}
228+
private void validateProcessFunctions(String remoteServer, String processFunction) {
229+
if (isInBuiltProcessFunction(processFunction)) {
230+
switch (remoteServer) {
231+
case OPENAI:
232+
if (!processFunction.contains(OPENAI)) {
233+
logger.warn(invalidProcessFuncWarnText(OPENAI));
234+
}
235+
break;
236+
case COHERE:
237+
if (!processFunction.contains(COHERE)) {
238+
logger.warn(invalidProcessFuncWarnText(COHERE));
239+
}
240+
break;
241+
case BEDROCK:
242+
if (!processFunction.contains(BEDROCK)) {
243+
logger.warn(invalidProcessFuncWarnText(BEDROCK));
244+
}
245+
break;
246+
case SAGEMAKER:
247+
if (!processFunction.contains(SAGEMAKER_PRE_POST_FUNC_TEXT)) {
248+
logger.warn(invalidProcessFuncWarnText(SAGEMAKER));
249+
}
250+
}
265251
}
266252
}
267253

268-
private void validatePostProcessFunctions(String remoteServer) {
269-
switch (remoteServer) {
270-
case OPENAI:
271-
if (!postProcessFunction.contains(OPENAI)) {
272-
throw new IllegalArgumentException(invalidProcessFuncExcText(OPENAI, "PostProcessFunction"));
273-
}
274-
break;
275-
case COHERE:
276-
if (!postProcessFunction.contains(COHERE)) {
277-
throw new IllegalArgumentException(invalidProcessFuncExcText(COHERE, "PostProcessFunction"));
278-
}
279-
break;
280-
case BEDROCK:
281-
if (!postProcessFunction.contains(BEDROCK)) {
282-
throw new IllegalArgumentException(invalidProcessFuncExcText(BEDROCK, "PostProcessFunction"));
283-
}
284-
break;
285-
case SAGEMAKER:
286-
if (!(DEFAULT_EMBEDDING.equals(postProcessFunction) || DEFAULT_RERANK.equals(postProcessFunction))) {
287-
throw new IllegalArgumentException(
288-
"LLM service is " + SAGEMAKER + ", so PostProcessFunction should be " + DEFAULT_EMBEDDING + " or " + DEFAULT_RERANK
289-
);
290-
}
291-
}
254+
private boolean isInBuiltProcessFunction(String processFunction) {
255+
return (processFunction != null && processFunction.startsWith(INBUILT_FUNC_PREFIX));
292256
}
293257

294-
private String invalidProcessFuncExcText(String remoteServer, String func) {
295-
return "LLM service is " + remoteServer + ", so " + func + " should be " + remoteServer + " " + func;
258+
private String invalidProcessFuncWarnText(String remoteServer) {
259+
return "LLM service is " + remoteServer + ", so PrePostProcessFunction should be corresponding to " + remoteServer;
296260
}
297261

298262
public enum ActionType {

common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ public void testValidatePrePostProcessFunctionsWithOpenAIConnectorCorrectInBuilt
122122
}
123123

124124
@Test
125-
public void testValidatePrePostProcessFunctionsWithOpenAIConnectorWrongInBuiltPrePostProcessFunctionThrowsException() {
125+
public void testValidatePrePostProcessFunctionsWithOpenAIConnectorWrongInBuiltPrePostProcessFunction() {
126+
// Testing with wrong PreProcessFunction
126127
ConnectorAction action1 = new ConnectorAction(
127128
TEST_ACTION_TYPE,
128129
TEST_METHOD_HTTP,
@@ -132,8 +133,9 @@ public void testValidatePrePostProcessFunctionsWithOpenAIConnectorWrongInBuiltPr
132133
TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT,
133134
OPENAI_EMBEDDING
134135
);
135-
Throwable exception = assertThrows(IllegalArgumentException.class, () -> action1.validatePrePostProcessFunctions(Map.of()));
136-
assertEquals("LLM service is openai, so PreProcessFunction should be openai PreProcessFunction", exception.getMessage());
136+
action1.validatePrePostProcessFunctions(Map.of());
137+
138+
// Testing with wrong PostProcessFunction
137139
ConnectorAction action2 = new ConnectorAction(
138140
TEST_ACTION_TYPE,
139141
TEST_METHOD_HTTP,
@@ -143,8 +145,7 @@ public void testValidatePrePostProcessFunctionsWithOpenAIConnectorWrongInBuiltPr
143145
TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT,
144146
COHERE_EMBEDDING
145147
);
146-
exception = assertThrows(IllegalArgumentException.class, () -> action2.validatePrePostProcessFunctions(Map.of()));
147-
assertEquals("LLM service is openai, so PostProcessFunction should be openai PostProcessFunction", exception.getMessage());
148+
action2.validatePrePostProcessFunctions(Map.of());
148149
}
149150

150151
@Test
@@ -183,7 +184,8 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorCorrectInBuilt
183184
}
184185

185186
@Test
186-
public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongInBuiltPrePostProcessFunctionThrowsException() {
187+
public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongInBuiltPrePostProcessFunction() {
188+
// Testing with wrong PreProcessFunction
187189
ConnectorAction action1 = new ConnectorAction(
188190
TEST_ACTION_TYPE,
189191
TEST_METHOD_HTTP,
@@ -193,8 +195,9 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongInBuiltPr
193195
TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT,
194196
COHERE_EMBEDDING
195197
);
196-
Throwable exception = assertThrows(IllegalArgumentException.class, () -> action1.validatePrePostProcessFunctions(Map.of()));
197-
assertEquals("LLM service is cohere, so PreProcessFunction should be cohere PreProcessFunction", exception.getMessage());
198+
action1.validatePrePostProcessFunctions(Map.of());
199+
200+
// Testing with wrong PostProcessFunction
198201
ConnectorAction action2 = new ConnectorAction(
199202
TEST_ACTION_TYPE,
200203
TEST_METHOD_HTTP,
@@ -204,8 +207,7 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongInBuiltPr
204207
TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT,
205208
OPENAI_EMBEDDING
206209
);
207-
exception = assertThrows(IllegalArgumentException.class, () -> action2.validatePrePostProcessFunctions(Map.of()));
208-
assertEquals("LLM service is cohere, so PostProcessFunction should be cohere PostProcessFunction", exception.getMessage());
210+
action2.validatePrePostProcessFunctions(Map.of());
209211
}
210212

211213
@Test
@@ -243,7 +245,8 @@ public void testValidatePrePostProcessFunctionsWithBedrockConnectorCorrectInBuil
243245
}
244246

245247
@Test
246-
public void testValidatePrePostProcessFunctionsWithBedrockConnectorWrongInBuiltPrePostProcessFunctionThrowsException() {
248+
public void testValidatePrePostProcessFunctionsWithBedrockConnectorWrongInBuiltPrePostProcessFunction() {
249+
// Testing with wrong PreProcessFunction
247250
ConnectorAction action1 = new ConnectorAction(
248251
TEST_ACTION_TYPE,
249252
TEST_METHOD_HTTP,
@@ -253,8 +256,9 @@ public void testValidatePrePostProcessFunctionsWithBedrockConnectorWrongInBuiltP
253256
TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT,
254257
BEDROCK_EMBEDDING
255258
);
256-
Throwable exception = assertThrows(IllegalArgumentException.class, () -> action1.validatePrePostProcessFunctions(Map.of()));
257-
assertEquals("LLM service is bedrock, so PreProcessFunction should be bedrock PreProcessFunction", exception.getMessage());
259+
action1.validatePrePostProcessFunctions(Map.of());
260+
261+
// Testing with wrong PostProcessFunction
258262
ConnectorAction action2 = new ConnectorAction(
259263
TEST_ACTION_TYPE,
260264
TEST_METHOD_HTTP,
@@ -264,8 +268,7 @@ public void testValidatePrePostProcessFunctionsWithBedrockConnectorWrongInBuiltP
264268
TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT,
265269
COHERE_EMBEDDING
266270
);
267-
exception = assertThrows(IllegalArgumentException.class, () -> action2.validatePrePostProcessFunctions(Map.of()));
268-
assertEquals("LLM service is bedrock, so PostProcessFunction should be bedrock PostProcessFunction", exception.getMessage());
271+
action2.validatePrePostProcessFunctions(Map.of());
269272
}
270273

271274
@Test
@@ -293,7 +296,8 @@ public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWithCorrect
293296
}
294297

295298
@Test
296-
public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWrongInBuiltPrePostProcessFunctionThrowsException() {
299+
public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWrongInBuiltPrePostProcessFunction() {
300+
// Testing with wrong PreProcessFunction
297301
ConnectorAction action1 = new ConnectorAction(
298302
TEST_ACTION_TYPE,
299303
TEST_METHOD_HTTP,
@@ -303,12 +307,9 @@ public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWrongInBuil
303307
TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT,
304308
DEFAULT_EMBEDDING
305309
);
306-
Throwable exception = assertThrows(IllegalArgumentException.class, () -> action1.validatePrePostProcessFunctions(Map.of()));
307-
assertEquals(
308-
"LLM service is sagemaker, so PreProcessFunction should be connector.pre_process.default.embedding"
309-
+ " or connector.pre_process.default.rerank",
310-
exception.getMessage()
311-
);
310+
action1.validatePrePostProcessFunctions(Map.of());
311+
312+
// Testing with wrong PostProcessFunction
312313
ConnectorAction action2 = new ConnectorAction(
313314
TEST_ACTION_TYPE,
314315
TEST_METHOD_HTTP,
@@ -318,12 +319,7 @@ public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWrongInBuil
318319
TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT,
319320
BEDROCK_EMBEDDING
320321
);
321-
exception = assertThrows(IllegalArgumentException.class, () -> action2.validatePrePostProcessFunctions(Map.of()));
322-
assertEquals(
323-
"LLM service is sagemaker, so PostProcessFunction should be connector.post_process.default.embedding"
324-
+ " or connector.post_process.default.rerank",
325-
exception.getMessage()
326-
);
322+
action2.validatePrePostProcessFunctions(Map.of());
327323
}
328324

329325
@Test

0 commit comments

Comments
 (0)