|
8 | 8 | import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
|
9 | 9 | import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
|
10 | 10 | import static org.opensearch.ml.utils.MLExceptionUtils.BATCH_INFERENCE_DISABLED_ERR_MSG;
|
11 |
| -import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG; |
12 |
| -import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; |
13 | 11 | import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM;
|
14 | 12 | import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
|
15 | 13 | import static org.opensearch.ml.utils.RestActionUtils.getActionTypeFromRestRequest;
|
|
22 | 20 | import java.util.Objects;
|
23 | 21 | import java.util.Optional;
|
24 | 22 |
|
25 |
| -import org.opensearch.common.util.concurrent.ThreadContext; |
26 |
| -import org.opensearch.core.action.ActionListener; |
27 |
| -import org.opensearch.core.rest.RestStatus; |
28 | 23 | import org.opensearch.core.xcontent.XContentParser;
|
29 | 24 | import org.opensearch.ml.common.FunctionName;
|
30 |
| -import org.opensearch.ml.common.MLModel; |
31 | 25 | import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
|
32 | 26 | import org.opensearch.ml.common.input.MLInput;
|
33 | 27 | import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
|
34 | 28 | import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
|
35 | 29 | import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
|
36 | 30 | import org.opensearch.ml.model.MLModelManager;
|
37 | 31 | import org.opensearch.rest.BaseRestHandler;
|
38 |
| -import org.opensearch.rest.BytesRestResponse; |
39 | 32 | import org.opensearch.rest.RestRequest;
|
40 | 33 | import org.opensearch.rest.action.RestToXContentListener;
|
41 | 34 | import org.opensearch.transport.client.node.NodeClient;
|
@@ -88,67 +81,28 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
|
88 | 81 | String modelId = getParameterId(request, PARAMETER_MODEL_ID);
|
89 | 82 | Optional<FunctionName> functionName = modelManager.getOptionalModelFunctionName(modelId);
|
90 | 83 |
|
91 |
| - // check if the model is in cache |
92 |
| - if (functionName.isPresent()) { |
93 |
| - MLPredictionTaskRequest predictionRequest = getRequest( |
94 |
| - modelId, |
95 |
| - functionName.get().name(), |
96 |
| - Objects.requireNonNullElse(userAlgorithm, functionName.get().name()), |
97 |
| - request |
98 |
| - ); |
99 |
| - return channel -> client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, new RestToXContentListener<>(channel)); |
100 |
| - } |
101 |
| - |
102 |
| - // If the model isn't in cache |
103 |
| - return channel -> { |
104 |
| - ActionListener<MLModel> listener = ActionListener.wrap(mlModel -> { |
105 |
| - String modelType = mlModel.getAlgorithm().name(); |
106 |
| - String modelAlgorithm = Objects.requireNonNullElse(userAlgorithm, mlModel.getAlgorithm().name()); |
107 |
| - client |
108 |
| - .execute( |
109 |
| - MLPredictionTaskAction.INSTANCE, |
110 |
| - getRequest(modelId, modelType, modelAlgorithm, request), |
111 |
| - new RestToXContentListener<>(channel) |
112 |
| - ); |
113 |
| - }, e -> { |
114 |
| - log.error("Failed to get ML model", e); |
115 |
| - try { |
116 |
| - channel.sendResponse(new BytesRestResponse(channel, RestStatus.NOT_FOUND, e)); |
117 |
| - } catch (IOException ex) { |
118 |
| - log.error("Failed to send error response", ex); |
119 |
| - } |
120 |
| - }); |
121 |
| - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { |
122 |
| - modelManager |
123 |
| - .getModel( |
124 |
| - modelId, |
125 |
| - getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request), |
126 |
| - ActionListener.runBefore(listener, context::restore) |
127 |
| - ); |
128 |
| - } |
129 |
| - }; |
| 84 | + MLPredictionTaskRequest predictionRequest = getRequest( |
| 85 | + modelId, |
| 86 | + Objects.requireNonNullElse(userAlgorithm, functionName.get().name()), |
| 87 | + request |
| 88 | + ); |
| 89 | + return channel -> client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, new RestToXContentListener<>(channel)); |
130 | 90 | }
|
131 | 91 |
|
132 | 92 | /**
|
133 | 93 | * Creates a MLPredictionTaskRequest from a RestRequest. This method validates the request based on
|
134 | 94 | * enabled features and model types, and parses the input data for prediction.
|
135 | 95 | *
|
136 | 96 | * @param modelId The ID of the ML model to use for prediction
|
137 |
| - * @param modelType The type of the ML model, extracted from model cache to specify if its a remote model or a local model |
138 | 97 | * @param userAlgorithm The algorithm specified by the user for prediction, this is used todetermine the interface of the model
|
139 | 98 | * @param request The REST request containing prediction input data
|
140 | 99 | * @return MLPredictionTaskRequest configured with the model and input parameters
|
141 | 100 | */
|
142 | 101 | @VisibleForTesting
|
143 |
| - MLPredictionTaskRequest getRequest(String modelId, String modelType, String userAlgorithm, RestRequest request) throws IOException { |
| 102 | + MLPredictionTaskRequest getRequest(String modelId, String userAlgorithm, RestRequest request) throws IOException { |
144 | 103 | String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request);
|
145 | 104 | ActionType actionType = ActionType.from(getActionTypeFromRestRequest(request));
|
146 |
| - if (FunctionName.REMOTE.name().equals(modelType) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { |
147 |
| - throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); |
148 |
| - } else if (FunctionName.isDLModel(FunctionName.from(modelType.toUpperCase(Locale.ROOT))) |
149 |
| - && !mlFeatureEnabledSetting.isLocalModelEnabled()) { |
150 |
| - throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG); |
151 |
| - } else if (ActionType.BATCH_PREDICT == actionType && !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) { |
| 105 | + if (ActionType.BATCH_PREDICT == actionType && !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) { |
152 | 106 | throw new IllegalStateException(BATCH_INFERENCE_DISABLED_ERR_MSG);
|
153 | 107 | } else if (!ActionType.isValidActionInModelPrediction(actionType)) {
|
154 | 108 | throw new IllegalArgumentException("Wrong action type in the rest request path!");
|
|
0 commit comments