diff --git a/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLCreatePromptInput.java b/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLCreatePromptInput.java index 2dcc768a23..d26afdd7f4 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLCreatePromptInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLCreatePromptInput.java @@ -38,6 +38,7 @@ public class MLCreatePromptInput implements ToXContentObject, Writeable { public static final String PROMPT_FIELD_USER_PROMPT = "user"; public static final String PROMPT_FIELD_SYSTEM_PROMPT = "system"; + public static final String PROMPT_VERSION_INITIAL_VERSION = "1"; private String name; private String description; @@ -69,11 +70,8 @@ public MLCreatePromptInput( if (name == null) { throw new IllegalArgumentException("MLPrompt name field is null"); } - if (prompt == null) { - throw new IllegalArgumentException("MLPrompt prompt field is null"); - } - if (prompt.isEmpty()) { - throw new IllegalArgumentException("MLPrompt prompt field cannot be empty"); + if (prompt == null || prompt.isEmpty()) { + throw new IllegalArgumentException("MLPrompt prompt field cannot be empty or null"); } if (!prompt.containsKey(PROMPT_FIELD_SYSTEM_PROMPT)) { throw new IllegalArgumentException("MLPrompt prompt field requires " + PROMPT_FIELD_SYSTEM_PROMPT + " parameter"); @@ -81,13 +79,10 @@ public MLCreatePromptInput( if (!prompt.containsKey(PROMPT_FIELD_USER_PROMPT)) { throw new IllegalArgumentException("MLPrompt prompt field requires " + PROMPT_FIELD_USER_PROMPT + " parameter"); } - if (version == null) { - throw new IllegalArgumentException("MLPrompt version field is null"); - } this.name = name; this.description = description; - this.version = version; + this.version = version == null ? PROMPT_VERSION_INITIAL_VERSION : version; this.prompt = prompt; this.tags = tags; this.tenantId = tenantId; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLPromptSearchAction.java b/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLPromptSearchAction.java new file mode 100644 index 0000000000..a614a4f2ec --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/prompt/MLPromptSearchAction.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.prompt; + +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchResponse; + +public class MLPromptSearchAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = "cluster:admin/opensearch/ml/prompts/search"; + public static final MLPromptSearchAction INSTANCE = new MLPromptSearchAction(); + + private MLPromptSearchAction() { + super(NAME, SearchResponse::new); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/prompt/MLPromptTest.java b/common/src/test/java/org/opensearch/ml/common/prompt/MLPromptTest.java index d50c7074f8..6986f9fe37 100644 --- a/common/src/test/java/org/opensearch/ml/common/prompt/MLPromptTest.java +++ b/common/src/test/java/org/opensearch/ml/common/prompt/MLPromptTest.java @@ -36,7 +36,7 @@ public class MLPromptTest { @Before public void setup() { - Instant time = Instant.ofEpochSecond(1641600000); + Instant time = Instant.parse("2022-01-08T00:00:00Z"); Map testPrompt = new HashMap<>(); testPrompt.put("system", "some system prompt"); testPrompt.put("user", "some user prompt"); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/prompt/MLCreatePromptInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/prompt/MLCreatePromptInputTest.java index 8d9af25a5a..c3c7d6a759 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/prompt/MLCreatePromptInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/prompt/MLCreatePromptInputTest.java @@ -78,8 +78,6 @@ public void constructMLCreatePromptInput_NullName() { @Test public void constructMLCreatePromptInput_NullVersion() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("MLPrompt version field is null"); MLCreatePromptInput mlCreatePromptInput = MLCreatePromptInput .builder() .name(TEST_PROMPT_NAME) @@ -89,12 +87,14 @@ public void constructMLCreatePromptInput_NullVersion() { .tags(TEST_PROMPT_TAGS) .tenantId(TEST_PROMPT_TENANTID) .build(); + + Assert.assertEquals(mlCreatePromptInput.getVersion(), TEST_PROMPT_VERSION); } @Test public void constructMLCreatePromptInput_NullPrompt() { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("MLPrompt prompt field is null"); + exceptionRule.expectMessage("MLPrompt prompt field cannot be empty or null"); MLCreatePromptInput mlCreatePromptInput = MLCreatePromptInput .builder() .name(TEST_PROMPT_NAME) @@ -109,7 +109,7 @@ public void constructMLCreatePromptInput_NullPrompt() { @Test public void constructMLCreatePromptInput_EmptyPromptField() { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("MLPrompt prompt field cannot be empty"); + exceptionRule.expectMessage("MLPrompt prompt field cannot be empty or null"); MLCreatePromptInput mlCreatePromptInput = MLCreatePromptInput .builder() .name(TEST_PROMPT_NAME) @@ -232,7 +232,7 @@ public void testParse_MissingPromptField_ShouldThrowException() throws IOExcepti XContentParser parser = createParser(jsonMissingPrompt); exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("MLPrompt prompt field is null"); + exceptionRule.expectMessage("MLPrompt prompt field cannot be empty or null"); MLCreatePromptInput.parse(parser); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/prompt/SearchPromptTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/prompt/SearchPromptTransportAction.java new file mode 100644 index 0000000000..1ee8c38a03 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/prompt/SearchPromptTransportAction.java @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.prompt; + +import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound; + +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.prompt.MLPromptSearchAction; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.SearchDataObjectRequest; +import org.opensearch.remote.metadata.common.SdkClientUtils; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +/** + * Transport Action class that handles received validated ActionRequest from Rest Layer and + * executes the actual operation of searching prompts based on the request parameters. + */ +@Log4j2 +public class SearchPromptTransportAction extends HandledTransportAction { + private final Client client; + private final SdkClient sdkClient; + + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Inject + public SearchPromptTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + SdkClient sdkClient, + MLFeatureEnabledSetting mlFeatureEnabledSetting + ) { + super(MLPromptSearchAction.NAME, transportService, actionFilters, MLSearchActionRequest::new); + this.client = client; + this.sdkClient = sdkClient; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } + + /** + * Executes the received request by searching prompts based on the request parameters. + * Notify the listener with the SearchResponse if the operation is successful. Otherwise, failure exception + * is notified to the listener. + * + * @param task The task + * @param mlSearchActionRequest MLSearchActionRequest that contains search parameters + * @param actionListener a listener to be notified of the response + */ + @Override + protected void doExecute(Task task, MLSearchActionRequest mlSearchActionRequest, ActionListener actionListener) { + String tenantId = mlSearchActionRequest.getTenantId(); + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) { + return; + } + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + final ActionListener wrappedListener = ActionListener + .wrap(actionListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, actionListener)); + + SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest + .builder() + .indices(mlSearchActionRequest.indices()) + .searchSourceBuilder(mlSearchActionRequest.source()) + .tenantId(tenantId) + .build(); + + sdkClient.searchDataObjectAsync(searchDataObjectRequest).whenComplete(SdkClientUtils.wrapSearchCompletion(wrappedListener)); + } catch (Exception e) { + log.error("Failed to search ML Prompt", e); + actionListener.onFailure(e); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index addce574fe..d63c86af58 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -102,6 +102,7 @@ import org.opensearch.ml.action.prediction.TransportPredictionTaskAction; import org.opensearch.ml.action.profile.MLProfileAction; import org.opensearch.ml.action.profile.MLProfileTransportAction; +import org.opensearch.ml.action.prompt.SearchPromptTransportAction; import org.opensearch.ml.action.prompt.TransportCreatePromptAction; import org.opensearch.ml.action.register.TransportRegisterModelAction; import org.opensearch.ml.action.stats.MLStatsNodesAction; @@ -181,6 +182,7 @@ import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prompt.MLCreatePromptAction; +import org.opensearch.ml.common.transport.prompt.MLPromptSearchAction; import org.opensearch.ml.common.transport.register.MLRegisterModelAction; import org.opensearch.ml.common.transport.sync.MLSyncUpAction; import org.opensearch.ml.common.transport.task.MLCancelBatchJobAction; @@ -290,6 +292,7 @@ import org.opensearch.ml.rest.RestMLSearchConnectorAction; import org.opensearch.ml.rest.RestMLSearchModelAction; import org.opensearch.ml.rest.RestMLSearchModelGroupAction; +import org.opensearch.ml.rest.RestMLSearchPromptAction; import org.opensearch.ml.rest.RestMLSearchTaskAction; import org.opensearch.ml.rest.RestMLStatsAction; import org.opensearch.ml.rest.RestMLTrainAndPredictAction; @@ -472,6 +475,7 @@ public MachineLearningPlugin(Settings settings) { new ActionHandler<>(MLConnectorDeleteAction.INSTANCE, DeleteConnectorTransportAction.class), new ActionHandler<>(MLConnectorSearchAction.INSTANCE, SearchConnectorTransportAction.class), new ActionHandler<>(MLCreatePromptAction.INSTANCE, TransportCreatePromptAction.class), + new ActionHandler<>(MLPromptSearchAction.INSTANCE, SearchPromptTransportAction.class), new ActionHandler<>(CreateConversationAction.INSTANCE, CreateConversationTransportAction.class), new ActionHandler<>(GetConversationsAction.INSTANCE, GetConversationsTransportAction.class), new ActionHandler<>(CreateInteractionAction.INSTANCE, CreateInteractionTransportAction.class), @@ -847,6 +851,7 @@ public List getRestHandlers( RestMLDeleteConnectorAction restMLDeleteConnectorAction = new RestMLDeleteConnectorAction(mlFeatureEnabledSetting); RestMLSearchConnectorAction restMLSearchConnectorAction = new RestMLSearchConnectorAction(mlFeatureEnabledSetting); RestMLCreatePromptAction restMLCreatePromptAction = new RestMLCreatePromptAction(mlFeatureEnabledSetting); + RestMLSearchPromptAction restMLSearchPromptAction = new RestMLSearchPromptAction(mlFeatureEnabledSetting); RestMemoryCreateConversationAction restCreateConversationAction = new RestMemoryCreateConversationAction(); RestMemoryGetConversationsAction restListConversationsAction = new RestMemoryGetConversationsAction(); RestMemoryCreateInteractionAction restCreateInteractionAction = new RestMemoryCreateInteractionAction(); @@ -909,6 +914,7 @@ public List getRestHandlers( restMLDeleteConnectorAction, restMLSearchConnectorAction, restMLCreatePromptAction, + restMLSearchPromptAction, restCreateConversationAction, restListConversationsAction, restCreateInteractionAction, diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreatePromptAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreatePromptAction.java index fa2825434f..effa4a8f78 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreatePromptAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreatePromptAction.java @@ -7,7 +7,6 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; -import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID; import java.io.IOException; @@ -85,9 +84,6 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client */ @VisibleForTesting MLCreatePromptRequest getRequest(RestRequest request) throws IOException { - if (!mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { - throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); - } if (!request.hasContent()) { throw new IOException("Create Prompt request has empty body"); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchPromptAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchPromptAction.java new file mode 100644 index 0000000000..ab564c8bfe --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchPromptAction.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.common.CommonValue.ML_PROMPT_INDEX; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; + +import org.opensearch.ml.common.MLPrompt; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.prompt.MLPromptSearchAction; + +import com.google.common.collect.ImmutableList; + +/** + * Rest Action class that handles SEARCH REST API request + */ +public class RestMLSearchPromptAction extends AbstractMLSearchAction { + private static final String ML_SEARCH_PROMPT_ACTION = "ml_search_prompt_action"; + private static final String SEARCH_PROMPT_PATH = ML_BASE_URI + "/prompts/_search"; + + public RestMLSearchPromptAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + super( + ImmutableList.of(SEARCH_PROMPT_PATH), + ML_PROMPT_INDEX, + MLPrompt.class, + MLPromptSearchAction.INSTANCE, + mlFeatureEnabledSetting + ); + } + + @Override + public String getName() { + return ML_SEARCH_PROMPT_ACTION; + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/prompt/SearchPromptTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/prompt/SearchPromptTransportActionTests.java new file mode 100644 index 0000000000..72a0b34cbc --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/prompt/SearchPromptTransportActionTests.java @@ -0,0 +1,163 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.prompt; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.Collections; + +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.search.MLSearchActionRequest; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +public class SearchPromptTransportActionTests extends OpenSearchTestCase { + @Mock + Client client; + SdkClient sdkClient; + + SearchResponse searchResponse; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Mock + ActionListener actionListener; + + SearchPromptTransportAction searchPromptTransportAction; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); + searchPromptTransportAction = new SearchPromptTransportAction( + transportService, + actionFilters, + client, + sdkClient, + mlFeatureEnabledSetting + ); + ThreadPool threadPool = mock(ThreadPool.class); + when(client.threadPool()).thenReturn(threadPool); + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + when(threadPool.getThreadContext()).thenReturn(threadContext); + SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(0L, TotalHits.Relation.EQUAL_TO), Float.NaN); + InternalSearchResponse internalSearchResponse = new InternalSearchResponse( + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 0 + ); + searchResponse = new SearchResponse( + internalSearchResponse, + null, + 0, + 0, + 0, + 1, + ShardSearchFailure.EMPTY_ARRAY, + mock(SearchResponse.Clusters.class), + null + ); + } + + public void testConstructor() { + SearchPromptTransportAction searchPromptTransportAction = new SearchPromptTransportAction( + transportService, + actionFilters, + client, + sdkClient, + mlFeatureEnabledSetting + ); + assertNotNull(searchPromptTransportAction); + } + + public void testDoExecute_success() { + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); + MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(request, null); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), any(ActionListener.class)); + + searchPromptTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(SearchResponse.class); + verify(actionListener).onResponse(responseCaptor.capture()); + + SearchResponse capturedResponse = responseCaptor.getValue(); + assertEquals(searchResponse.getHits().getTotalHits(), capturedResponse.getHits().getTotalHits()); + assertEquals(searchResponse.getHits().getHits().length, capturedResponse.getHits().getHits().length); + assertEquals(searchResponse.status(), capturedResponse.status()); + } + + public void testDoExecute_fail() { + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); + MLSearchActionRequest mlSearchACtionRequest = new MLSearchActionRequest(request, null); + + when(client.threadPool()).thenReturn(null); + + searchPromptTransportAction.doExecute(null, mlSearchACtionRequest, actionListener); + verify(actionListener).onFailure(any(NullPointerException.class)); + } + + public void testDoExecute_multi_tenancy_fail() throws InterruptedException { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); + MLSearchActionRequest mlSearchACtionRequest = new MLSearchActionRequest(request, null); + + searchPromptTransportAction.doExecute(null, mlSearchACtionRequest, actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "You don't have permission to access this resource", + argumentCaptor.getValue().getMessage() + ); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreatePromptActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreatePromptActionTests.java index a24c767cd7..2588230650 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreatePromptActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreatePromptActionTests.java @@ -12,7 +12,6 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.TestHelper.getCreatePromptRestRequest; import static org.opensearch.ml.utils.TestHelper.verifyParsedCreatePromptInput; @@ -63,7 +62,6 @@ public class RestMLCreatePromptActionTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); - when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true); restMLCreatePromptAction = new RestMLCreatePromptAction(mlFeatureEnabledSetting); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); @@ -138,13 +136,4 @@ public void testPrepareRequest_EmptyContent() throws Exception { restMLCreatePromptAction.handleRequest(request, channel, client); } - - public void testPrepareRequestFeatureDisabled() throws Exception { - exceptionRule.expect(IllegalStateException.class); - exceptionRule.expectMessage(REMOTE_INFERENCE_DISABLED_ERR_MSG); - - when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false); - RestRequest request = getCreatePromptRestRequest(null); - restMLCreatePromptAction.handleRequest(request, channel, client); - } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchPromptActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchPromptActionTests.java new file mode 100644 index 0000000000..5717da07ca --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchPromptActionTests.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import java.util.List; + +import org.hamcrest.Matchers; +import org.junit.Before; +import org.mockito.Mock; +import org.opensearch.core.common.Strings; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; + +public class RestMLSearchPromptActionTests extends OpenSearchTestCase { + + private RestMLSearchPromptAction restMLSearchPromptAction; + + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Before + public void setup() { + restMLSearchPromptAction = new RestMLSearchPromptAction(mlFeatureEnabledSetting); + } + + public void testConstructor() { + RestMLSearchPromptAction restMLSearchPromptAction = new RestMLSearchPromptAction(mlFeatureEnabledSetting); + assertNotNull(restMLSearchPromptAction); + } + + public void testGetName() { + String actionName = restMLSearchPromptAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_search_prompt_action", actionName); + } + + public void testRoutes() { + List routes = restMLSearchPromptAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.POST, route.getMethod()); + assertThat(route.getMethod(), Matchers.either(Matchers.is(RestRequest.Method.POST)).or(Matchers.is(RestRequest.Method.GET))); + assertEquals("/_plugins/_ml/prompts/_search", route.getPath()); + } +}