Skip to content

Prompt Management SEARCH API #3849

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: feature/prompt
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public class MLCreatePromptInput implements ToXContentObject, Writeable {

public static final String PROMPT_FIELD_USER_PROMPT = "user";
public static final String PROMPT_FIELD_SYSTEM_PROMPT = "system";
public static final String PROMPT_VERSION_INITIAL_VERSION = "1";

private String name;
private String description;
Expand Down Expand Up @@ -69,25 +70,19 @@ public MLCreatePromptInput(
if (name == null) {
throw new IllegalArgumentException("MLPrompt name field is null");
}
if (prompt == null) {
throw new IllegalArgumentException("MLPrompt prompt field is null");
}
if (prompt.isEmpty()) {
throw new IllegalArgumentException("MLPrompt prompt field cannot be empty");
if (prompt == null || prompt.isEmpty()) {
throw new IllegalArgumentException("MLPrompt prompt field cannot be empty or null");
}
if (!prompt.containsKey(PROMPT_FIELD_SYSTEM_PROMPT)) {
throw new IllegalArgumentException("MLPrompt prompt field requires " + PROMPT_FIELD_SYSTEM_PROMPT + " parameter");
}
if (!prompt.containsKey(PROMPT_FIELD_USER_PROMPT)) {
throw new IllegalArgumentException("MLPrompt prompt field requires " + PROMPT_FIELD_USER_PROMPT + " parameter");
}
if (version == null) {
throw new IllegalArgumentException("MLPrompt version field is null");
}

this.name = name;
this.description = description;
this.version = version;
this.version = version == null ? PROMPT_VERSION_INITIAL_VERSION : version;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like that version will now be auto initialized to 1 when creating a new prompt; version management should be invisible to the end user when possible.

What is the expected flow for creating a new version of prompt? Should the user call create with the same prompt name but specify version 2, or call update? If the latter, will update auto increment the version?

this.prompt = prompt;
this.tags = tags;
this.tenantId = tenantId;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.prompt;

import org.opensearch.action.ActionType;
import org.opensearch.action.search.SearchResponse;

public class MLPromptSearchAction extends ActionType<SearchResponse> {
// External Action which used for public facing RestAPIs.
public static final String NAME = "cluster:admin/opensearch/ml/prompts/search";
public static final MLPromptSearchAction INSTANCE = new MLPromptSearchAction();

Check warning on line 14 in common/src/main/java/org/opensearch/ml/common/transport/prompt/MLPromptSearchAction.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/transport/prompt/MLPromptSearchAction.java#L14

Added line #L14 was not covered by tests

private MLPromptSearchAction() {
super(NAME, SearchResponse::new);
}

Check warning on line 18 in common/src/main/java/org/opensearch/ml/common/transport/prompt/MLPromptSearchAction.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/transport/prompt/MLPromptSearchAction.java#L17-L18

Added lines #L17 - L18 were not covered by tests
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this test file should be one level higher to match code class MLPrompt.java file location

Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public class MLPromptTest {

@Before
public void setup() {
Instant time = Instant.ofEpochSecond(1641600000);
Instant time = Instant.parse("2022-01-08T00:00:00Z");
Map<String, String> testPrompt = new HashMap<>();
testPrompt.put("system", "some system prompt");
testPrompt.put("user", "some user prompt");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ public void constructMLCreatePromptInput_NullName() {

@Test
public void constructMLCreatePromptInput_NullVersion() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("MLPrompt version field is null");
MLCreatePromptInput mlCreatePromptInput = MLCreatePromptInput
.builder()
.name(TEST_PROMPT_NAME)
Expand All @@ -89,12 +87,14 @@ public void constructMLCreatePromptInput_NullVersion() {
.tags(TEST_PROMPT_TAGS)
.tenantId(TEST_PROMPT_TENANTID)
.build();

Assert.assertEquals(mlCreatePromptInput.getVersion(), TEST_PROMPT_VERSION);
}

@Test
public void constructMLCreatePromptInput_NullPrompt() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("MLPrompt prompt field is null");
exceptionRule.expectMessage("MLPrompt prompt field cannot be empty or null");
MLCreatePromptInput mlCreatePromptInput = MLCreatePromptInput
.builder()
.name(TEST_PROMPT_NAME)
Expand All @@ -109,7 +109,7 @@ public void constructMLCreatePromptInput_NullPrompt() {
@Test
public void constructMLCreatePromptInput_EmptyPromptField() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("MLPrompt prompt field cannot be empty");
exceptionRule.expectMessage("MLPrompt prompt field cannot be empty or null");
MLCreatePromptInput mlCreatePromptInput = MLCreatePromptInput
.builder()
.name(TEST_PROMPT_NAME)
Expand Down Expand Up @@ -232,7 +232,7 @@ public void testParse_MissingPromptField_ShouldThrowException() throws IOExcepti
XContentParser parser = createParser(jsonMissingPrompt);

exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("MLPrompt prompt field is null");
exceptionRule.expectMessage("MLPrompt prompt field cannot be empty or null");
MLCreatePromptInput.parse(parser);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.action.prompt;

import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound;

import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.prompt.MLPromptSearchAction;
import org.opensearch.ml.common.transport.search.MLSearchActionRequest;
import org.opensearch.ml.utils.TenantAwareHelper;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.remote.metadata.client.SearchDataObjectRequest;
import org.opensearch.remote.metadata.common.SdkClientUtils;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.client.Client;

import lombok.extern.log4j.Log4j2;

/**
* Transport Action class that handles received validated ActionRequest from Rest Layer and
* executes the actual operation of searching prompts based on the request parameters.
*/
@Log4j2
public class SearchPromptTransportAction extends HandledTransportAction<MLSearchActionRequest, SearchResponse> {
private final Client client;
private final SdkClient sdkClient;

private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public SearchPromptTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
SdkClient sdkClient,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(MLPromptSearchAction.NAME, transportService, actionFilters, MLSearchActionRequest::new);
this.client = client;
this.sdkClient = sdkClient;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

/**
* Executes the received request by searching prompts based on the request parameters.
* Notify the listener with the SearchResponse if the operation is successful. Otherwise, failure exception
* is notified to the listener.
*
* @param task The task
* @param mlSearchActionRequest MLSearchActionRequest that contains search parameters
* @param actionListener a listener to be notified of the response
*/
@Override
protected void doExecute(Task task, MLSearchActionRequest mlSearchActionRequest, ActionListener<SearchResponse> actionListener) {
String tenantId = mlSearchActionRequest.getTenantId();
if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) {
return;
}

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
final ActionListener<SearchResponse> wrappedListener = ActionListener
.wrap(actionListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, actionListener));

SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest
.builder()
.indices(mlSearchActionRequest.indices())
.searchSourceBuilder(mlSearchActionRequest.source())
.tenantId(tenantId)
.build();

sdkClient.searchDataObjectAsync(searchDataObjectRequest).whenComplete(SdkClientUtils.wrapSearchCompletion(wrappedListener));
} catch (Exception e) {
log.error("Failed to search ML Prompt", e);
actionListener.onFailure(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
import org.opensearch.ml.action.prediction.TransportPredictionTaskAction;
import org.opensearch.ml.action.profile.MLProfileAction;
import org.opensearch.ml.action.profile.MLProfileTransportAction;
import org.opensearch.ml.action.prompt.SearchPromptTransportAction;
import org.opensearch.ml.action.prompt.TransportCreatePromptAction;
import org.opensearch.ml.action.register.TransportRegisterModelAction;
import org.opensearch.ml.action.stats.MLStatsNodesAction;
Expand Down Expand Up @@ -181,6 +182,7 @@
import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prompt.MLCreatePromptAction;
import org.opensearch.ml.common.transport.prompt.MLPromptSearchAction;
import org.opensearch.ml.common.transport.register.MLRegisterModelAction;
import org.opensearch.ml.common.transport.sync.MLSyncUpAction;
import org.opensearch.ml.common.transport.task.MLCancelBatchJobAction;
Expand Down Expand Up @@ -290,6 +292,7 @@
import org.opensearch.ml.rest.RestMLSearchConnectorAction;
import org.opensearch.ml.rest.RestMLSearchModelAction;
import org.opensearch.ml.rest.RestMLSearchModelGroupAction;
import org.opensearch.ml.rest.RestMLSearchPromptAction;
import org.opensearch.ml.rest.RestMLSearchTaskAction;
import org.opensearch.ml.rest.RestMLStatsAction;
import org.opensearch.ml.rest.RestMLTrainAndPredictAction;
Expand Down Expand Up @@ -472,6 +475,7 @@ public MachineLearningPlugin(Settings settings) {
new ActionHandler<>(MLConnectorDeleteAction.INSTANCE, DeleteConnectorTransportAction.class),
new ActionHandler<>(MLConnectorSearchAction.INSTANCE, SearchConnectorTransportAction.class),
new ActionHandler<>(MLCreatePromptAction.INSTANCE, TransportCreatePromptAction.class),
new ActionHandler<>(MLPromptSearchAction.INSTANCE, SearchPromptTransportAction.class),
new ActionHandler<>(CreateConversationAction.INSTANCE, CreateConversationTransportAction.class),
new ActionHandler<>(GetConversationsAction.INSTANCE, GetConversationsTransportAction.class),
new ActionHandler<>(CreateInteractionAction.INSTANCE, CreateInteractionTransportAction.class),
Expand Down Expand Up @@ -847,6 +851,7 @@ public List<RestHandler> getRestHandlers(
RestMLDeleteConnectorAction restMLDeleteConnectorAction = new RestMLDeleteConnectorAction(mlFeatureEnabledSetting);
RestMLSearchConnectorAction restMLSearchConnectorAction = new RestMLSearchConnectorAction(mlFeatureEnabledSetting);
RestMLCreatePromptAction restMLCreatePromptAction = new RestMLCreatePromptAction(mlFeatureEnabledSetting);
RestMLSearchPromptAction restMLSearchPromptAction = new RestMLSearchPromptAction(mlFeatureEnabledSetting);
RestMemoryCreateConversationAction restCreateConversationAction = new RestMemoryCreateConversationAction();
RestMemoryGetConversationsAction restListConversationsAction = new RestMemoryGetConversationsAction();
RestMemoryCreateInteractionAction restCreateInteractionAction = new RestMemoryCreateInteractionAction();
Expand Down Expand Up @@ -909,6 +914,7 @@ public List<RestHandler> getRestHandlers(
restMLDeleteConnectorAction,
restMLSearchConnectorAction,
restMLCreatePromptAction,
restMLSearchPromptAction,
restCreateConversationAction,
restListConversationsAction,
restCreateInteractionAction,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID;

import java.io.IOException;
Expand Down Expand Up @@ -85,9 +84,6 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
*/
@VisibleForTesting
MLCreatePromptRequest getRequest(RestRequest request) throws IOException {
if (!mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
}
if (!request.hasContent()) {
throw new IOException("Create Prompt request has empty body");
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.rest;

import static org.opensearch.ml.common.CommonValue.ML_PROMPT_INDEX;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;

import org.opensearch.ml.common.MLPrompt;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.prompt.MLPromptSearchAction;

import com.google.common.collect.ImmutableList;

/**
* Rest Action class that handles SEARCH REST API request
*/
public class RestMLSearchPromptAction extends AbstractMLSearchAction<MLPrompt> {
private static final String ML_SEARCH_PROMPT_ACTION = "ml_search_prompt_action";
private static final String SEARCH_PROMPT_PATH = ML_BASE_URI + "/prompts/_search";

public RestMLSearchPromptAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) {
super(
ImmutableList.of(SEARCH_PROMPT_PATH),
ML_PROMPT_INDEX,
MLPrompt.class,
MLPromptSearchAction.INSTANCE,
mlFeatureEnabledSetting
);
}

@Override
public String getName() {
return ML_SEARCH_PROMPT_ACTION;
}
}
Loading
Loading