-
Notifications
You must be signed in to change notification settings - Fork 158
Prompt Management SEARCH API #3849
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: feature/prompt
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.prompt; | ||
|
||
import org.opensearch.action.ActionType; | ||
import org.opensearch.action.search.SearchResponse; | ||
|
||
public class MLPromptSearchAction extends ActionType<SearchResponse> { | ||
// External Action which used for public facing RestAPIs. | ||
public static final String NAME = "cluster:admin/opensearch/ml/prompts/search"; | ||
public static final MLPromptSearchAction INSTANCE = new MLPromptSearchAction(); | ||
Check warning on line 14 in common/src/main/java/org/opensearch/ml/common/transport/prompt/MLPromptSearchAction.java
|
||
|
||
private MLPromptSearchAction() { | ||
super(NAME, SearchResponse::new); | ||
} | ||
Check warning on line 18 in common/src/main/java/org/opensearch/ml/common/transport/prompt/MLPromptSearchAction.java
|
||
} |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: this test file should be one level higher to match code class MLPrompt.java file location |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.action.prompt; | ||
|
||
import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound; | ||
|
||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.action.support.ActionFilters; | ||
import org.opensearch.action.support.HandledTransportAction; | ||
import org.opensearch.common.inject.Inject; | ||
import org.opensearch.common.util.concurrent.ThreadContext; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; | ||
import org.opensearch.ml.common.transport.prompt.MLPromptSearchAction; | ||
import org.opensearch.ml.common.transport.search.MLSearchActionRequest; | ||
import org.opensearch.ml.utils.TenantAwareHelper; | ||
import org.opensearch.remote.metadata.client.SdkClient; | ||
import org.opensearch.remote.metadata.client.SearchDataObjectRequest; | ||
import org.opensearch.remote.metadata.common.SdkClientUtils; | ||
import org.opensearch.tasks.Task; | ||
import org.opensearch.transport.TransportService; | ||
import org.opensearch.transport.client.Client; | ||
|
||
import lombok.extern.log4j.Log4j2; | ||
|
||
/** | ||
* Transport Action class that handles received validated ActionRequest from Rest Layer and | ||
* executes the actual operation of searching prompts based on the request parameters. | ||
*/ | ||
@Log4j2 | ||
public class SearchPromptTransportAction extends HandledTransportAction<MLSearchActionRequest, SearchResponse> { | ||
private final Client client; | ||
private final SdkClient sdkClient; | ||
|
||
private final MLFeatureEnabledSetting mlFeatureEnabledSetting; | ||
|
||
@Inject | ||
public SearchPromptTransportAction( | ||
TransportService transportService, | ||
ActionFilters actionFilters, | ||
Client client, | ||
SdkClient sdkClient, | ||
MLFeatureEnabledSetting mlFeatureEnabledSetting | ||
) { | ||
super(MLPromptSearchAction.NAME, transportService, actionFilters, MLSearchActionRequest::new); | ||
this.client = client; | ||
this.sdkClient = sdkClient; | ||
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; | ||
} | ||
|
||
/** | ||
* Executes the received request by searching prompts based on the request parameters. | ||
* Notify the listener with the SearchResponse if the operation is successful. Otherwise, failure exception | ||
* is notified to the listener. | ||
* | ||
* @param task The task | ||
* @param mlSearchActionRequest MLSearchActionRequest that contains search parameters | ||
* @param actionListener a listener to be notified of the response | ||
*/ | ||
@Override | ||
protected void doExecute(Task task, MLSearchActionRequest mlSearchActionRequest, ActionListener<SearchResponse> actionListener) { | ||
String tenantId = mlSearchActionRequest.getTenantId(); | ||
if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) { | ||
return; | ||
} | ||
|
||
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { | ||
final ActionListener<SearchResponse> wrappedListener = ActionListener | ||
.wrap(actionListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, actionListener)); | ||
|
||
SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest | ||
.builder() | ||
.indices(mlSearchActionRequest.indices()) | ||
.searchSourceBuilder(mlSearchActionRequest.source()) | ||
.tenantId(tenantId) | ||
.build(); | ||
|
||
sdkClient.searchDataObjectAsync(searchDataObjectRequest).whenComplete(SdkClientUtils.wrapSearchCompletion(wrappedListener)); | ||
} catch (Exception e) { | ||
log.error("Failed to search ML Prompt", e); | ||
actionListener.onFailure(e); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.rest; | ||
|
||
import static org.opensearch.ml.common.CommonValue.ML_PROMPT_INDEX; | ||
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; | ||
|
||
import org.opensearch.ml.common.MLPrompt; | ||
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; | ||
import org.opensearch.ml.common.transport.prompt.MLPromptSearchAction; | ||
|
||
import com.google.common.collect.ImmutableList; | ||
|
||
/** | ||
* Rest Action class that handles SEARCH REST API request | ||
*/ | ||
public class RestMLSearchPromptAction extends AbstractMLSearchAction<MLPrompt> { | ||
private static final String ML_SEARCH_PROMPT_ACTION = "ml_search_prompt_action"; | ||
private static final String SEARCH_PROMPT_PATH = ML_BASE_URI + "/prompts/_search"; | ||
|
||
public RestMLSearchPromptAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { | ||
super( | ||
ImmutableList.of(SEARCH_PROMPT_PATH), | ||
ML_PROMPT_INDEX, | ||
MLPrompt.class, | ||
MLPromptSearchAction.INSTANCE, | ||
mlFeatureEnabledSetting | ||
); | ||
} | ||
|
||
@Override | ||
public String getName() { | ||
return ML_SEARCH_PROMPT_ACTION; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like that version will now be auto initialized to 1 when creating a new prompt; version management should be invisible to the end user when possible.
What is the expected flow for creating a new version of prompt? Should the user call create with the same prompt name but specify version 2, or call update? If the latter, will update auto increment the version?