Skip to content

model predict: move getModel from rest to transport #3687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE;
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;

import java.util.Locale;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
Expand Down Expand Up @@ -126,9 +129,8 @@ public void onResponse(MLModel mlModel) {
context.restore();
modelCacheHelper.setModelInfo(modelId, mlModel);
FunctionName functionName = mlModel.getAlgorithm();
if (FunctionName.isDLModel(functionName) && !mlFeatureEnabledSetting.isLocalModelEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
}
String modelType = functionName.name();
validateModelType(modelType);
mlPredictionTaskRequest.getMlInput().setAlgorithm(functionName);
modelAccessControlHelper
.validateModelGroupAccess(
Expand Down Expand Up @@ -274,4 +276,13 @@ public void validateInputSchema(String modelId, MLInput mlInput) {
}
}

private void validateModelType(String modelType) {
if (FunctionName.REMOTE.name().equals(modelType) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
} else if (FunctionName.isDLModel(FunctionName.from(modelType.toUpperCase(Locale.ROOT)))
&& !mlFeatureEnabledSetting.isLocalModelEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
}
Comment on lines +282 to +285
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: redundant else if, can be a separate if block

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.BATCH_INFERENCE_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
import static org.opensearch.ml.utils.RestActionUtils.getActionTypeFromRestRequest;
Expand All @@ -22,20 +20,15 @@
import java.util.Objects;
import java.util.Optional;

import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;
import org.opensearch.transport.client.node.NodeClient;
Expand Down Expand Up @@ -88,67 +81,28 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
String modelId = getParameterId(request, PARAMETER_MODEL_ID);
Optional<FunctionName> functionName = modelManager.getOptionalModelFunctionName(modelId);

// check if the model is in cache
if (functionName.isPresent()) {
MLPredictionTaskRequest predictionRequest = getRequest(
modelId,
functionName.get().name(),
Objects.requireNonNullElse(userAlgorithm, functionName.get().name()),
request
);
return channel -> client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, new RestToXContentListener<>(channel));
}

// If the model isn't in cache
return channel -> {
ActionListener<MLModel> listener = ActionListener.wrap(mlModel -> {
String modelType = mlModel.getAlgorithm().name();
String modelAlgorithm = Objects.requireNonNullElse(userAlgorithm, mlModel.getAlgorithm().name());
client
.execute(
MLPredictionTaskAction.INSTANCE,
getRequest(modelId, modelType, modelAlgorithm, request),
new RestToXContentListener<>(channel)
);
}, e -> {
log.error("Failed to get ML model", e);
try {
channel.sendResponse(new BytesRestResponse(channel, RestStatus.NOT_FOUND, e));
} catch (IOException ex) {
log.error("Failed to send error response", ex);
}
});
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
modelManager
.getModel(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After refactoring, this method will not be invoked, is that correct ?

modelId,
getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request),
ActionListener.runBefore(listener, context::restore)
);
}
};
MLPredictionTaskRequest predictionRequest = getRequest(
modelId,
Objects.requireNonNullElse(userAlgorithm, functionName.get().name()),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible both userAlgorithm and functionName are null?

request
);
return channel -> client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, new RestToXContentListener<>(channel));
}

/**
* Creates a MLPredictionTaskRequest from a RestRequest. This method validates the request based on
* enabled features and model types, and parses the input data for prediction.
*
* @param modelId The ID of the ML model to use for prediction
* @param modelType The type of the ML model, extracted from model cache to specify if its a remote model or a local model
* @param userAlgorithm The algorithm specified by the user for prediction, this is used todetermine the interface of the model
* @param request The REST request containing prediction input data
* @return MLPredictionTaskRequest configured with the model and input parameters
*/
@VisibleForTesting
MLPredictionTaskRequest getRequest(String modelId, String modelType, String userAlgorithm, RestRequest request) throws IOException {
MLPredictionTaskRequest getRequest(String modelId, String userAlgorithm, RestRequest request) throws IOException {
String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request);
ActionType actionType = ActionType.from(getActionTypeFromRestRequest(request));
if (FunctionName.REMOTE.name().equals(modelType) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
} else if (FunctionName.isDLModel(FunctionName.from(modelType.toUpperCase(Locale.ROOT)))
&& !mlFeatureEnabledSetting.isLocalModelEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
} else if (ActionType.BATCH_PREDICT == actionType && !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) {
if (ActionType.BATCH_PREDICT == actionType && !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) {
Comment on lines -146 to +105
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

throw new IllegalStateException(BATCH_INFERENCE_DISABLED_ERR_MSG);
} else if (!ActionType.isValidActionInModelPrediction(actionType)) {
throw new IllegalArgumentException("Wrong action type in the rest request path!");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import static org.mockito.Mockito.when;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE;
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;

import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -200,6 +201,22 @@ public void testPrediction_local_model_not_exception() {
);
}

@Test
public void testPrediction_remote_inference_not_exception() {
when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model);
when(model.getAlgorithm()).thenReturn(FunctionName.REMOTE);
when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false);

IllegalStateException e = assertThrows(
IllegalStateException.class,
() -> transportPredictionTaskAction.doExecute(null, mlPredictionTaskRequest, actionListener)
);
assertEquals(
e.getMessage(),
REMOTE_INFERENCE_DISABLED_ERR_MSG
);
}

@Test
public void testPrediction_OpenSearchStatusException() {
when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
import static org.opensearch.ml.utils.TestHelper.getBatchRestRequest;
import static org.opensearch.ml.utils.TestHelper.getBatchRestRequest_WrongActionType;
Expand Down Expand Up @@ -127,35 +125,12 @@ public void testRoutes_Batch() {
@Test
public void testGetRequest() throws IOException {
RestRequest request = getRestRequest_PredictModel();
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction
.getRequest("modelId", FunctionName.KMEANS.name(), FunctionName.KMEANS.name(), request);
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction.getRequest("modelId", FunctionName.KMEANS.name(), request);

MLInput mlInput = mlPredictionTaskRequest.getMlInput();
verifyParsedKMeansMLInput(mlInput);
}

@Test
public void testGetRequest_RemoteInferenceDisabled() throws IOException {
thrown.expect(IllegalStateException.class);
thrown.expectMessage(REMOTE_INFERENCE_DISABLED_ERR_MSG);

when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false);
RestRequest request = getRestRequest_PredictModel();
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction
.getRequest("modelId", FunctionName.REMOTE.name(), "text_embedding", request);
}

@Test
public void testGetRequest_LocalModelInferenceDisabled() throws IOException {
thrown.expect(IllegalStateException.class);
thrown.expectMessage(LOCAL_MODEL_DISABLED_ERR_MSG);

when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false);
RestRequest request = getRestRequest_PredictModel();
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction
.getRequest("modelId", FunctionName.TEXT_EMBEDDING.name(), "text_embedding", request);
}

@Test
public void testPrepareRequest() throws Exception {
RestRequest request = getRestRequest_PredictModel();
Expand Down Expand Up @@ -196,7 +171,7 @@ public void testPrepareBatchRequest_WrongActionType() throws Exception {
thrown.expectMessage("Wrong Action Type");

RestRequest request = getBatchRestRequest_WrongActionType();
restMLPredictionAction.getRequest("model id", "remote", "text_embedding", request);
restMLPredictionAction.getRequest("model id", "text_embedding", request);
}

@Ignore
Expand Down Expand Up @@ -234,17 +209,7 @@ public void testGetRequest_InvalidActionType() throws IOException {
thrown.expectMessage("Wrong Action Type of models");

RestRequest request = getBatchRestRequest_WrongActionType();
restMLPredictionAction.getRequest("model_id", FunctionName.REMOTE.name(), "text_embedding", request);
}

@Test
public void testGetRequest_UnsupportedAlgorithm() throws IOException {
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("Wrong function name");

// Create a RestRequest with an unsupported algorithm
RestRequest request = getRestRequest_PredictModel();
restMLPredictionAction.getRequest("model_id", "INVALID_ALGO", "text_embedding", request);
restMLPredictionAction.getRequest("model_id", "text_embedding", request);
}

private RestRequest getRestRequest_PredictModel() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1177,4 +1177,39 @@ public String registerRemoteModelWithInterface(String testCase) throws IOExcepti
logger.info("task ID created: {}", taskId);
return taskId;
}

public void testPredictRemoteModelFeatureDisabled() throws IOException, InterruptedException {
Response response = createConnector(completionModelConnectorEntity);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModelWithInterface("openAI-GPT-3.5 completions", connectorId, "correctInterface");
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");
response = deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = TestHelper
.makeRequest(
client(),
"PUT",
"_cluster/settings",
null,
"{\"persistent\":{\"plugins.ml_commons.remote_inference.enabled\":false}}",
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, ""))
);
assertEquals(200, response.getStatusLine().getStatusCode());
String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}";
try {
predictRemoteModel(modelId, predictInput);
} catch (Exception e) {
assertTrue(e instanceof org.opensearch.client.ResponseException);
String stackTrace = ExceptionUtils.getStackTrace(e);
assertTrue(stackTrace.contains("Remote Inference is currently disabled."));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ public void testTrainWithReadOnlyMLAccess() throws IOException {
train(mlReadOnlyClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, null, false);
}

public void testPredictWithReadOnlyMLAccess() throws IOException {
public void testPredictWithReadOnlyMLAccessModelExisting() throws IOException {
KMeansParams kMeansParams = KMeansParams.builder().build();
train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, trainResult -> {
String modelId = (String) trainResult.get("model_id");
Expand Down Expand Up @@ -436,6 +436,13 @@ public void testPredictWithReadOnlyMLAccess() throws IOException {
}, false);
}

public void testPredictWithReadOnlyMLAccessModelNonExisting() throws IOException {
exceptionRule.expect(ResponseException.class);
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/predict]");
KMeansParams kMeansParams = KMeansParams.builder().build();
predict(mlReadOnlyClient, FunctionName.KMEANS, "modelId", irisIndex, kMeansParams, searchSourceBuilder, null);
}

public void testTrainAndPredictWithFullAccess() throws IOException {
trainAndPredict(
mlFullAccessClient,
Expand Down
Loading