Skip to content

Commit 27e6a5d

Browse files
committed
move getModel from rest to transport
Signed-off-by: Jing Zhang <jngz@amazon.com>
1 parent 231f2fa commit 27e6a5d

File tree

6 files changed

+85
-96
lines changed

6 files changed

+85
-96
lines changed

plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE;
99
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;
10+
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
11+
12+
import java.util.Locale;
1013

1114
import org.opensearch.OpenSearchStatusException;
1215
import org.opensearch.action.ActionRequest;
@@ -126,9 +129,8 @@ public void onResponse(MLModel mlModel) {
126129
context.restore();
127130
modelCacheHelper.setModelInfo(modelId, mlModel);
128131
FunctionName functionName = mlModel.getAlgorithm();
129-
if (FunctionName.isDLModel(functionName) && !mlFeatureEnabledSetting.isLocalModelEnabled()) {
130-
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
131-
}
132+
String modelType = functionName.name();
133+
validateModelType(modelType);
132134
mlPredictionTaskRequest.getMlInput().setAlgorithm(functionName);
133135
modelAccessControlHelper
134136
.validateModelGroupAccess(
@@ -274,4 +276,13 @@ public void validateInputSchema(String modelId, MLInput mlInput) {
274276
}
275277
}
276278

279+
private void validateModelType(String modelType) {
280+
if (FunctionName.REMOTE.name().equals(modelType) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
281+
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
282+
} else if (FunctionName.isDLModel(FunctionName.from(modelType.toUpperCase(Locale.ROOT)))
283+
&& !mlFeatureEnabledSetting.isLocalModelEnabled()) {
284+
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
285+
}
286+
}
287+
277288
}

plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java

Lines changed: 8 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
99
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
1010
import static org.opensearch.ml.utils.MLExceptionUtils.BATCH_INFERENCE_DISABLED_ERR_MSG;
11-
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;
12-
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
1311
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM;
1412
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
1513
import static org.opensearch.ml.utils.RestActionUtils.getActionTypeFromRestRequest;
@@ -22,20 +20,15 @@
2220
import java.util.Objects;
2321
import java.util.Optional;
2422

25-
import org.opensearch.common.util.concurrent.ThreadContext;
26-
import org.opensearch.core.action.ActionListener;
27-
import org.opensearch.core.rest.RestStatus;
2823
import org.opensearch.core.xcontent.XContentParser;
2924
import org.opensearch.ml.common.FunctionName;
30-
import org.opensearch.ml.common.MLModel;
3125
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
3226
import org.opensearch.ml.common.input.MLInput;
3327
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
3428
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
3529
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
3630
import org.opensearch.ml.model.MLModelManager;
3731
import org.opensearch.rest.BaseRestHandler;
38-
import org.opensearch.rest.BytesRestResponse;
3932
import org.opensearch.rest.RestRequest;
4033
import org.opensearch.rest.action.RestToXContentListener;
4134
import org.opensearch.transport.client.node.NodeClient;
@@ -88,67 +81,28 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
8881
String modelId = getParameterId(request, PARAMETER_MODEL_ID);
8982
Optional<FunctionName> functionName = modelManager.getOptionalModelFunctionName(modelId);
9083

91-
// check if the model is in cache
92-
if (functionName.isPresent()) {
93-
MLPredictionTaskRequest predictionRequest = getRequest(
94-
modelId,
95-
functionName.get().name(),
96-
Objects.requireNonNullElse(userAlgorithm, functionName.get().name()),
97-
request
98-
);
99-
return channel -> client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, new RestToXContentListener<>(channel));
100-
}
101-
102-
// If the model isn't in cache
103-
return channel -> {
104-
ActionListener<MLModel> listener = ActionListener.wrap(mlModel -> {
105-
String modelType = mlModel.getAlgorithm().name();
106-
String modelAlgorithm = Objects.requireNonNullElse(userAlgorithm, mlModel.getAlgorithm().name());
107-
client
108-
.execute(
109-
MLPredictionTaskAction.INSTANCE,
110-
getRequest(modelId, modelType, modelAlgorithm, request),
111-
new RestToXContentListener<>(channel)
112-
);
113-
}, e -> {
114-
log.error("Failed to get ML model", e);
115-
try {
116-
channel.sendResponse(new BytesRestResponse(channel, RestStatus.NOT_FOUND, e));
117-
} catch (IOException ex) {
118-
log.error("Failed to send error response", ex);
119-
}
120-
});
121-
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
122-
modelManager
123-
.getModel(
124-
modelId,
125-
getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request),
126-
ActionListener.runBefore(listener, context::restore)
127-
);
128-
}
129-
};
84+
MLPredictionTaskRequest predictionRequest = getRequest(
85+
modelId,
86+
Objects.requireNonNullElse(userAlgorithm, functionName.get().name()),
87+
request
88+
);
89+
return channel -> client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, new RestToXContentListener<>(channel));
13090
}
13191

13292
/**
13393
* Creates a MLPredictionTaskRequest from a RestRequest. This method validates the request based on
13494
* enabled features and model types, and parses the input data for prediction.
13595
*
13696
* @param modelId The ID of the ML model to use for prediction
137-
* @param modelType The type of the ML model, extracted from model cache to specify if its a remote model or a local model
13897
* @param userAlgorithm The algorithm specified by the user for prediction, this is used todetermine the interface of the model
13998
* @param request The REST request containing prediction input data
14099
* @return MLPredictionTaskRequest configured with the model and input parameters
141100
*/
142101
@VisibleForTesting
143-
MLPredictionTaskRequest getRequest(String modelId, String modelType, String userAlgorithm, RestRequest request) throws IOException {
102+
MLPredictionTaskRequest getRequest(String modelId, String userAlgorithm, RestRequest request) throws IOException {
144103
String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request);
145104
ActionType actionType = ActionType.from(getActionTypeFromRestRequest(request));
146-
if (FunctionName.REMOTE.name().equals(modelType) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
147-
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
148-
} else if (FunctionName.isDLModel(FunctionName.from(modelType.toUpperCase(Locale.ROOT)))
149-
&& !mlFeatureEnabledSetting.isLocalModelEnabled()) {
150-
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
151-
} else if (ActionType.BATCH_PREDICT == actionType && !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) {
105+
if (ActionType.BATCH_PREDICT == actionType && !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) {
152106
throw new IllegalStateException(BATCH_INFERENCE_DISABLED_ERR_MSG);
153107
} else if (!ActionType.isValidActionInModelPrediction(actionType)) {
154108
throw new IllegalArgumentException("Wrong action type in the rest request path!");

plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import static org.mockito.Mockito.when;
1414
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE;
1515
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;
16+
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
1617

1718
import java.util.Arrays;
1819
import java.util.Collections;
@@ -200,6 +201,22 @@ public void testPrediction_local_model_not_exception() {
200201
);
201202
}
202203

204+
@Test
205+
public void testPrediction_remote_inference_not_exception() {
206+
when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model);
207+
when(model.getAlgorithm()).thenReturn(FunctionName.REMOTE);
208+
when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false);
209+
210+
IllegalStateException e = assertThrows(
211+
IllegalStateException.class,
212+
() -> transportPredictionTaskAction.doExecute(null, mlPredictionTaskRequest, actionListener)
213+
);
214+
assertEquals(
215+
e.getMessage(),
216+
REMOTE_INFERENCE_DISABLED_ERR_MSG
217+
);
218+
}
219+
203220
@Test
204221
public void testPrediction_OpenSearchStatusException() {
205222
when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model);

plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java

Lines changed: 3 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
import static org.mockito.ArgumentMatchers.any;
99
import static org.mockito.ArgumentMatchers.eq;
1010
import static org.mockito.Mockito.*;
11-
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;
12-
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
1311
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
1412
import static org.opensearch.ml.utils.TestHelper.getBatchRestRequest;
1513
import static org.opensearch.ml.utils.TestHelper.getBatchRestRequest_WrongActionType;
@@ -127,35 +125,12 @@ public void testRoutes_Batch() {
127125
@Test
128126
public void testGetRequest() throws IOException {
129127
RestRequest request = getRestRequest_PredictModel();
130-
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction
131-
.getRequest("modelId", FunctionName.KMEANS.name(), FunctionName.KMEANS.name(), request);
128+
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction.getRequest("modelId", FunctionName.KMEANS.name(), request);
132129

133130
MLInput mlInput = mlPredictionTaskRequest.getMlInput();
134131
verifyParsedKMeansMLInput(mlInput);
135132
}
136133

137-
@Test
138-
public void testGetRequest_RemoteInferenceDisabled() throws IOException {
139-
thrown.expect(IllegalStateException.class);
140-
thrown.expectMessage(REMOTE_INFERENCE_DISABLED_ERR_MSG);
141-
142-
when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false);
143-
RestRequest request = getRestRequest_PredictModel();
144-
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction
145-
.getRequest("modelId", FunctionName.REMOTE.name(), "text_embedding", request);
146-
}
147-
148-
@Test
149-
public void testGetRequest_LocalModelInferenceDisabled() throws IOException {
150-
thrown.expect(IllegalStateException.class);
151-
thrown.expectMessage(LOCAL_MODEL_DISABLED_ERR_MSG);
152-
153-
when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false);
154-
RestRequest request = getRestRequest_PredictModel();
155-
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction
156-
.getRequest("modelId", FunctionName.TEXT_EMBEDDING.name(), "text_embedding", request);
157-
}
158-
159134
@Test
160135
public void testPrepareRequest() throws Exception {
161136
RestRequest request = getRestRequest_PredictModel();
@@ -196,7 +171,7 @@ public void testPrepareBatchRequest_WrongActionType() throws Exception {
196171
thrown.expectMessage("Wrong Action Type");
197172

198173
RestRequest request = getBatchRestRequest_WrongActionType();
199-
restMLPredictionAction.getRequest("model id", "remote", "text_embedding", request);
174+
restMLPredictionAction.getRequest("model id", "text_embedding", request);
200175
}
201176

202177
@Ignore
@@ -234,17 +209,7 @@ public void testGetRequest_InvalidActionType() throws IOException {
234209
thrown.expectMessage("Wrong Action Type of models");
235210

236211
RestRequest request = getBatchRestRequest_WrongActionType();
237-
restMLPredictionAction.getRequest("model_id", FunctionName.REMOTE.name(), "text_embedding", request);
238-
}
239-
240-
@Test
241-
public void testGetRequest_UnsupportedAlgorithm() throws IOException {
242-
thrown.expect(IllegalArgumentException.class);
243-
thrown.expectMessage("Wrong function name");
244-
245-
// Create a RestRequest with an unsupported algorithm
246-
RestRequest request = getRestRequest_PredictModel();
247-
restMLPredictionAction.getRequest("model_id", "INVALID_ALGO", "text_embedding", request);
212+
restMLPredictionAction.getRequest("model_id", "text_embedding", request);
248213
}
249214

250215
private RestRequest getRestRequest_PredictModel() {

plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,4 +1177,39 @@ public String registerRemoteModelWithInterface(String testCase) throws IOExcepti
11771177
logger.info("task ID created: {}", taskId);
11781178
return taskId;
11791179
}
1180+
1181+
public void testPredictRemoteModelFeatureDisabled() throws IOException, InterruptedException {
1182+
Response response = createConnector(completionModelConnectorEntity);
1183+
Map responseMap = parseResponseToMap(response);
1184+
String connectorId = (String) responseMap.get("connector_id");
1185+
response = registerRemoteModelWithInterface("openAI-GPT-3.5 completions", connectorId, "correctInterface");
1186+
responseMap = parseResponseToMap(response);
1187+
String taskId = (String) responseMap.get("task_id");
1188+
waitForTask(taskId, MLTaskState.COMPLETED);
1189+
response = getTask(taskId);
1190+
responseMap = parseResponseToMap(response);
1191+
String modelId = (String) responseMap.get("model_id");
1192+
response = deployRemoteModel(modelId);
1193+
responseMap = parseResponseToMap(response);
1194+
taskId = (String) responseMap.get("task_id");
1195+
waitForTask(taskId, MLTaskState.COMPLETED);
1196+
response = TestHelper
1197+
.makeRequest(
1198+
client(),
1199+
"PUT",
1200+
"_cluster/settings",
1201+
null,
1202+
"{\"persistent\":{\"plugins.ml_commons.remote_inference.enabled\":false}}",
1203+
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, ""))
1204+
);
1205+
assertEquals(200, response.getStatusLine().getStatusCode());
1206+
String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}";
1207+
try {
1208+
predictRemoteModel(modelId, predictInput);
1209+
} catch (Exception e) {
1210+
assertTrue(e instanceof org.opensearch.client.ResponseException);
1211+
String stackTrace = ExceptionUtils.getStackTrace(e);
1212+
assertTrue(stackTrace.contains("Remote Inference is currently disabled."));
1213+
}
1214+
}
11801215
}

plugin/src/test/java/org/opensearch/ml/rest/SecureMLRestIT.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ public void testTrainWithReadOnlyMLAccess() throws IOException {
408408
train(mlReadOnlyClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, null, false);
409409
}
410410

411-
public void testPredictWithReadOnlyMLAccess() throws IOException {
411+
public void testPredictWithReadOnlyMLAccessModelExisting() throws IOException {
412412
KMeansParams kMeansParams = KMeansParams.builder().build();
413413
train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, trainResult -> {
414414
String modelId = (String) trainResult.get("model_id");
@@ -436,6 +436,13 @@ public void testPredictWithReadOnlyMLAccess() throws IOException {
436436
}, false);
437437
}
438438

439+
public void testPredictWithReadOnlyMLAccessModelNonExisting() throws IOException {
440+
exceptionRule.expect(ResponseException.class);
441+
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/predict]");
442+
KMeansParams kMeansParams = KMeansParams.builder().build();
443+
predict(mlReadOnlyClient, FunctionName.KMEANS, "modelId", irisIndex, kMeansParams, searchSourceBuilder, null);
444+
}
445+
439446
public void testTrainAndPredictWithFullAccess() throws IOException {
440447
trainAndPredict(
441448
mlFullAccessClient,

0 commit comments

Comments
 (0)