From eee764d1b7e420c240d51b80172c1b3e0ad8f0e4 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Tue, 25 Mar 2025 14:32:45 -0700 Subject: [PATCH] move getModel from rest to transport Signed-off-by: Jing Zhang --- .../TransportPredictionTaskAction.java | 17 ++++- .../ml/rest/RestMLPredictionAction.java | 62 +++---------------- .../TransportPredictionTaskActionTests.java | 17 +++++ .../ml/rest/RestMLPredictionActionTests.java | 41 +----------- .../ml/rest/RestMLRemoteInferenceIT.java | 35 +++++++++++ .../opensearch/ml/rest/SecureMLRestIT.java | 9 ++- 6 files changed, 85 insertions(+), 96 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index 423cb1ed71..4f207a2c16 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -7,6 +7,9 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE; import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG; +import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; + +import java.util.Locale; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; @@ -126,9 +129,8 @@ public void onResponse(MLModel mlModel) { context.restore(); modelCacheHelper.setModelInfo(modelId, mlModel); FunctionName functionName = mlModel.getAlgorithm(); - if (FunctionName.isDLModel(functionName) && !mlFeatureEnabledSetting.isLocalModelEnabled()) { - throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG); - } + String modelType = functionName.name(); + validateModelType(modelType); mlPredictionTaskRequest.getMlInput().setAlgorithm(functionName); modelAccessControlHelper .validateModelGroupAccess( @@ -274,4 +276,13 @@ public void validateInputSchema(String modelId, MLInput mlInput) { } } + private void validateModelType(String modelType) { + if (FunctionName.REMOTE.name().equals(modelType) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { + throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); + } else if (FunctionName.isDLModel(FunctionName.from(modelType.toUpperCase(Locale.ROOT))) + && !mlFeatureEnabledSetting.isLocalModelEnabled()) { + throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG); + } + } + } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java index e0e028d9f0..8a5b768e71 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java @@ -8,8 +8,6 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import static org.opensearch.ml.utils.MLExceptionUtils.BATCH_INFERENCE_DISABLED_ERR_MSG; -import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG; -import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import static org.opensearch.ml.utils.RestActionUtils.getActionTypeFromRestRequest; @@ -22,12 +20,8 @@ import java.util.Objects; import java.util.Optional; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; @@ -35,7 +29,6 @@ import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; -import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; import org.opensearch.transport.client.node.NodeClient; @@ -88,45 +81,12 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client String modelId = getParameterId(request, PARAMETER_MODEL_ID); Optional functionName = modelManager.getOptionalModelFunctionName(modelId); - // check if the model is in cache - if (functionName.isPresent()) { - MLPredictionTaskRequest predictionRequest = getRequest( - modelId, - functionName.get().name(), - Objects.requireNonNullElse(userAlgorithm, functionName.get().name()), - request - ); - return channel -> client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, new RestToXContentListener<>(channel)); - } - - // If the model isn't in cache - return channel -> { - ActionListener listener = ActionListener.wrap(mlModel -> { - String modelType = mlModel.getAlgorithm().name(); - String modelAlgorithm = Objects.requireNonNullElse(userAlgorithm, mlModel.getAlgorithm().name()); - client - .execute( - MLPredictionTaskAction.INSTANCE, - getRequest(modelId, modelType, modelAlgorithm, request), - new RestToXContentListener<>(channel) - ); - }, e -> { - log.error("Failed to get ML model", e); - try { - channel.sendResponse(new BytesRestResponse(channel, RestStatus.NOT_FOUND, e)); - } catch (IOException ex) { - log.error("Failed to send error response", ex); - } - }); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - modelManager - .getModel( - modelId, - getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request), - ActionListener.runBefore(listener, context::restore) - ); - } - }; + MLPredictionTaskRequest predictionRequest = getRequest( + modelId, + Objects.requireNonNullElse(userAlgorithm, functionName.get().name()), + request + ); + return channel -> client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, new RestToXContentListener<>(channel)); } /** @@ -134,21 +94,15 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client * enabled features and model types, and parses the input data for prediction. * * @param modelId The ID of the ML model to use for prediction - * @param modelType The type of the ML model, extracted from model cache to specify if its a remote model or a local model * @param userAlgorithm The algorithm specified by the user for prediction, this is used todetermine the interface of the model * @param request The REST request containing prediction input data * @return MLPredictionTaskRequest configured with the model and input parameters */ @VisibleForTesting - MLPredictionTaskRequest getRequest(String modelId, String modelType, String userAlgorithm, RestRequest request) throws IOException { + MLPredictionTaskRequest getRequest(String modelId, String userAlgorithm, RestRequest request) throws IOException { String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); ActionType actionType = ActionType.from(getActionTypeFromRestRequest(request)); - if (FunctionName.REMOTE.name().equals(modelType) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { - throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); - } else if (FunctionName.isDLModel(FunctionName.from(modelType.toUpperCase(Locale.ROOT))) - && !mlFeatureEnabledSetting.isLocalModelEnabled()) { - throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG); - } else if (ActionType.BATCH_PREDICT == actionType && !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) { + if (ActionType.BATCH_PREDICT == actionType && !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) { throw new IllegalStateException(BATCH_INFERENCE_DISABLED_ERR_MSG); } else if (!ActionType.isValidActionInModelPrediction(actionType)) { throw new IllegalArgumentException("Wrong action type in the rest request path!"); diff --git a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java index 2b253a60b5..9a809b5909 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java @@ -13,6 +13,7 @@ import static org.mockito.Mockito.when; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE; import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG; +import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import java.util.Arrays; import java.util.Collections; @@ -200,6 +201,22 @@ public void testPrediction_local_model_not_exception() { ); } + @Test + public void testPrediction_remote_inference_not_exception() { + when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model); + when(model.getAlgorithm()).thenReturn(FunctionName.REMOTE); + when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false); + + IllegalStateException e = assertThrows( + IllegalStateException.class, + () -> transportPredictionTaskAction.doExecute(null, mlPredictionTaskRequest, actionListener) + ); + assertEquals( + e.getMessage(), + REMOTE_INFERENCE_DISABLED_ERR_MSG + ); + } + @Test public void testPrediction_OpenSearchStatusException() { when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java index 9028ce174e..6814c181c7 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java @@ -8,8 +8,6 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.*; -import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG; -import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import static org.opensearch.ml.utils.TestHelper.getBatchRestRequest; import static org.opensearch.ml.utils.TestHelper.getBatchRestRequest_WrongActionType; @@ -127,35 +125,12 @@ public void testRoutes_Batch() { @Test public void testGetRequest() throws IOException { RestRequest request = getRestRequest_PredictModel(); - MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction - .getRequest("modelId", FunctionName.KMEANS.name(), FunctionName.KMEANS.name(), request); + MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction.getRequest("modelId", FunctionName.KMEANS.name(), request); MLInput mlInput = mlPredictionTaskRequest.getMlInput(); verifyParsedKMeansMLInput(mlInput); } - @Test - public void testGetRequest_RemoteInferenceDisabled() throws IOException { - thrown.expect(IllegalStateException.class); - thrown.expectMessage(REMOTE_INFERENCE_DISABLED_ERR_MSG); - - when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false); - RestRequest request = getRestRequest_PredictModel(); - MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction - .getRequest("modelId", FunctionName.REMOTE.name(), "text_embedding", request); - } - - @Test - public void testGetRequest_LocalModelInferenceDisabled() throws IOException { - thrown.expect(IllegalStateException.class); - thrown.expectMessage(LOCAL_MODEL_DISABLED_ERR_MSG); - - when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false); - RestRequest request = getRestRequest_PredictModel(); - MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction - .getRequest("modelId", FunctionName.TEXT_EMBEDDING.name(), "text_embedding", request); - } - @Test public void testPrepareRequest() throws Exception { RestRequest request = getRestRequest_PredictModel(); @@ -196,7 +171,7 @@ public void testPrepareBatchRequest_WrongActionType() throws Exception { thrown.expectMessage("Wrong Action Type"); RestRequest request = getBatchRestRequest_WrongActionType(); - restMLPredictionAction.getRequest("model id", "remote", "text_embedding", request); + restMLPredictionAction.getRequest("model id", "text_embedding", request); } @Ignore @@ -234,17 +209,7 @@ public void testGetRequest_InvalidActionType() throws IOException { thrown.expectMessage("Wrong Action Type of models"); RestRequest request = getBatchRestRequest_WrongActionType(); - restMLPredictionAction.getRequest("model_id", FunctionName.REMOTE.name(), "text_embedding", request); - } - - @Test - public void testGetRequest_UnsupportedAlgorithm() throws IOException { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Wrong function name"); - - // Create a RestRequest with an unsupported algorithm - RestRequest request = getRestRequest_PredictModel(); - restMLPredictionAction.getRequest("model_id", "INVALID_ALGO", "text_embedding", request); + restMLPredictionAction.getRequest("model_id", "text_embedding", request); } private RestRequest getRestRequest_PredictModel() { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 0b8e713f06..75e146fa75 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -1177,4 +1177,39 @@ public String registerRemoteModelWithInterface(String testCase) throws IOExcepti logger.info("task ID created: {}", taskId); return taskId; } + + public void testPredictRemoteModelFeatureDisabled() throws IOException, InterruptedException { + Response response = createConnector(completionModelConnectorEntity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModelWithInterface("openAI-GPT-3.5 completions", connectorId, "correctInterface"); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"plugins.ml_commons.remote_inference.enabled\":false}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}"; + try { + predictRemoteModel(modelId, predictInput); + } catch (Exception e) { + assertTrue(e instanceof org.opensearch.client.ResponseException); + String stackTrace = ExceptionUtils.getStackTrace(e); + assertTrue(stackTrace.contains("Remote Inference is currently disabled.")); + } + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/SecureMLRestIT.java b/plugin/src/test/java/org/opensearch/ml/rest/SecureMLRestIT.java index 189e08c0a9..0c82ef6d7d 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/SecureMLRestIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/SecureMLRestIT.java @@ -408,7 +408,7 @@ public void testTrainWithReadOnlyMLAccess() throws IOException { train(mlReadOnlyClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, null, false); } - public void testPredictWithReadOnlyMLAccess() throws IOException { + public void testPredictWithReadOnlyMLAccessModelExisting() throws IOException { KMeansParams kMeansParams = KMeansParams.builder().build(); train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, trainResult -> { String modelId = (String) trainResult.get("model_id"); @@ -436,6 +436,13 @@ public void testPredictWithReadOnlyMLAccess() throws IOException { }, false); } + public void testPredictWithReadOnlyMLAccessModelNonExisting() throws IOException { + exceptionRule.expect(ResponseException.class); + exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/predict]"); + KMeansParams kMeansParams = KMeansParams.builder().build(); + predict(mlReadOnlyClient, FunctionName.KMEANS, "modelId", irisIndex, kMeansParams, searchSourceBuilder, null); + } + public void testTrainAndPredictWithFullAccess() throws IOException { trainAndPredict( mlFullAccessClient,