-
Notifications
You must be signed in to change notification settings - Fork 158
model predict: move getModel from rest to transport #3687
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,8 +8,6 @@ | |
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; | ||
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; | ||
import static org.opensearch.ml.utils.MLExceptionUtils.BATCH_INFERENCE_DISABLED_ERR_MSG; | ||
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG; | ||
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; | ||
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM; | ||
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; | ||
import static org.opensearch.ml.utils.RestActionUtils.getActionTypeFromRestRequest; | ||
|
@@ -22,20 +20,15 @@ | |
import java.util.Objects; | ||
import java.util.Optional; | ||
|
||
import org.opensearch.common.util.concurrent.ThreadContext; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.rest.RestStatus; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
import org.opensearch.ml.common.FunctionName; | ||
import org.opensearch.ml.common.MLModel; | ||
import org.opensearch.ml.common.connector.ConnectorAction.ActionType; | ||
import org.opensearch.ml.common.input.MLInput; | ||
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; | ||
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; | ||
import org.opensearch.ml.model.MLModelManager; | ||
import org.opensearch.ml.settings.MLFeatureEnabledSetting; | ||
import org.opensearch.rest.BaseRestHandler; | ||
import org.opensearch.rest.BytesRestResponse; | ||
import org.opensearch.rest.RestRequest; | ||
import org.opensearch.rest.action.RestToXContentListener; | ||
import org.opensearch.transport.client.node.NodeClient; | ||
|
@@ -88,67 +81,28 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client | |
String modelId = getParameterId(request, PARAMETER_MODEL_ID); | ||
Optional<FunctionName> functionName = modelManager.getOptionalModelFunctionName(modelId); | ||
|
||
// check if the model is in cache | ||
if (functionName.isPresent()) { | ||
MLPredictionTaskRequest predictionRequest = getRequest( | ||
modelId, | ||
functionName.get().name(), | ||
Objects.requireNonNullElse(userAlgorithm, functionName.get().name()), | ||
request | ||
); | ||
return channel -> client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, new RestToXContentListener<>(channel)); | ||
} | ||
|
||
// If the model isn't in cache | ||
return channel -> { | ||
ActionListener<MLModel> listener = ActionListener.wrap(mlModel -> { | ||
String modelType = mlModel.getAlgorithm().name(); | ||
String modelAlgorithm = Objects.requireNonNullElse(userAlgorithm, mlModel.getAlgorithm().name()); | ||
client | ||
.execute( | ||
MLPredictionTaskAction.INSTANCE, | ||
getRequest(modelId, modelType, modelAlgorithm, request), | ||
new RestToXContentListener<>(channel) | ||
); | ||
}, e -> { | ||
log.error("Failed to get ML model", e); | ||
try { | ||
channel.sendResponse(new BytesRestResponse(channel, RestStatus.NOT_FOUND, e)); | ||
} catch (IOException ex) { | ||
log.error("Failed to send error response", ex); | ||
} | ||
}); | ||
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { | ||
modelManager | ||
.getModel( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After refactoring, this method will not be invoked, is that correct ? |
||
modelId, | ||
getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request), | ||
ActionListener.runBefore(listener, context::restore) | ||
); | ||
} | ||
}; | ||
MLPredictionTaskRequest predictionRequest = getRequest( | ||
modelId, | ||
Objects.requireNonNullElse(userAlgorithm, functionName.get().name()), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible both |
||
request | ||
); | ||
return channel -> client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, new RestToXContentListener<>(channel)); | ||
} | ||
|
||
/** | ||
* Creates a MLPredictionTaskRequest from a RestRequest. This method validates the request based on | ||
* enabled features and model types, and parses the input data for prediction. | ||
* | ||
* @param modelId The ID of the ML model to use for prediction | ||
* @param modelType The type of the ML model, extracted from model cache to specify if its a remote model or a local model | ||
* @param userAlgorithm The algorithm specified by the user for prediction, this is used todetermine the interface of the model | ||
* @param request The REST request containing prediction input data | ||
* @return MLPredictionTaskRequest configured with the model and input parameters | ||
*/ | ||
@VisibleForTesting | ||
MLPredictionTaskRequest getRequest(String modelId, String modelType, String userAlgorithm, RestRequest request) throws IOException { | ||
MLPredictionTaskRequest getRequest(String modelId, String userAlgorithm, RestRequest request) throws IOException { | ||
String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); | ||
ActionType actionType = ActionType.from(getActionTypeFromRestRequest(request)); | ||
if (FunctionName.REMOTE.name().equals(modelType) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { | ||
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); | ||
} else if (FunctionName.isDLModel(FunctionName.from(modelType.toUpperCase(Locale.ROOT))) | ||
&& !mlFeatureEnabledSetting.isLocalModelEnabled()) { | ||
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG); | ||
} else if (ActionType.BATCH_PREDICT == actionType && !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) { | ||
if (ActionType.BATCH_PREDICT == actionType && !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) { | ||
Comment on lines
-146
to
+105
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
throw new IllegalStateException(BATCH_INFERENCE_DISABLED_ERR_MSG); | ||
} else if (!ActionType.isValidActionInModelPrediction(actionType)) { | ||
throw new IllegalArgumentException("Wrong action type in the rest request path!"); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: redundant else if, can be a separate if block