Skip to content

Commit 517890e

Browse files
authored
sdk client implementation for search connector, model group and task (#3707)
Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent 12b0000 commit 517890e

File tree

4 files changed

+32
-66
lines changed

4 files changed

+32
-66
lines changed

plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import java.util.stream.Collectors;
1515

1616
import org.opensearch.ExceptionsHelper;
17-
import org.opensearch.OpenSearchStatusException;
1817
import org.opensearch.action.search.SearchRequest;
1918
import org.opensearch.action.search.SearchResponse;
2019
import org.opensearch.action.search.ShardSearchFailure;
@@ -24,7 +23,6 @@
2423
import org.opensearch.common.util.concurrent.ThreadContext;
2524
import org.opensearch.commons.authuser.User;
2625
import org.opensearch.core.action.ActionListener;
27-
import org.opensearch.core.rest.RestStatus;
2826
import org.opensearch.index.IndexNotFoundException;
2927
import org.opensearch.ml.common.CommonValue;
3028
import org.opensearch.ml.common.connector.HttpConnector;
@@ -120,23 +118,9 @@ private void search(SearchRequest request, String tenantId, ActionListener<Searc
120118
.searchSourceBuilder(request.source())
121119
.tenantId(tenantId)
122120
.build();
123-
sdkClient.searchDataObjectAsync(searchDataObjectRequest).whenComplete((r, throwable) -> {
124-
if (throwable != null) {
125-
Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable);
126-
log.error("Failed to search connector", cause);
127-
doubleWrappedListener.onFailure(cause);
128-
} else {
129-
try {
130-
SearchResponse searchResponse = SearchResponse.fromXContent(r.parser());
131-
log.info("Connector search complete: {}", searchResponse.getHits().getTotalHits());
132-
doubleWrappedListener.onResponse(searchResponse);
133-
} catch (Exception e) {
134-
log.error("Failed to parse search response", e);
135-
doubleWrappedListener
136-
.onFailure(new OpenSearchStatusException("Failed to parse search response", RestStatus.INTERNAL_SERVER_ERROR));
137-
}
138-
}
139-
});
121+
sdkClient
122+
.searchDataObjectAsync(searchDataObjectRequest)
123+
.whenComplete(SdkClientUtils.wrapSearchCompletion(doubleWrappedListener));
140124
} catch (Exception e) {
141125
log.error(e.getMessage(), e);
142126
actionListener.onFailure(e);

plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
package org.opensearch.ml.action.model_group;
77

88
import static org.opensearch.ml.action.handler.MLSearchHandler.wrapRestActionListener;
9-
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
109
import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound;
1110

1211
import org.opensearch.action.search.SearchRequest;
@@ -18,15 +17,15 @@
1817
import org.opensearch.common.util.concurrent.ThreadContext;
1918
import org.opensearch.commons.authuser.User;
2019
import org.opensearch.core.action.ActionListener;
21-
import org.opensearch.index.query.BoolQueryBuilder;
22-
import org.opensearch.index.query.QueryBuilders;
2320
import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction;
2421
import org.opensearch.ml.common.transport.search.MLSearchActionRequest;
2522
import org.opensearch.ml.helper.ModelAccessControlHelper;
2623
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
2724
import org.opensearch.ml.utils.RestActionUtils;
2825
import org.opensearch.ml.utils.TenantAwareHelper;
2926
import org.opensearch.remote.metadata.client.SdkClient;
27+
import org.opensearch.remote.metadata.client.SearchDataObjectRequest;
28+
import org.opensearch.remote.metadata.common.SdkClientUtils;
3029
import org.opensearch.tasks.Task;
3130
import org.opensearch.transport.TransportService;
3231
import org.opensearch.transport.client.Client;
@@ -83,29 +82,20 @@ private void preProcessRoleAndPerformSearch(
8382
final ActionListener<SearchResponse> doubleWrappedListener = ActionListener
8483
.wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener));
8584

86-
// Modify the query to include tenant ID filtering
87-
if (tenantId != null) {
88-
BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery();
89-
90-
// Preserve existing query if present
91-
if (request.source().query() != null) {
92-
queryBuilder.must(request.source().query());
93-
}
94-
// Add tenancy filter
95-
queryBuilder.filter(QueryBuilders.termQuery(TENANT_ID_FIELD, tenantId));
96-
97-
// Update the request's source with the new query
98-
request.source().query(queryBuilder);
99-
}
100-
101-
if (modelAccessControlHelper.skipModelAccessControl(user)) {
102-
client.search(request, doubleWrappedListener);
103-
} else {
85+
if (!modelAccessControlHelper.skipModelAccessControl(user)) {
10486
// Security is enabled, filter is enabled and user isn't admin
10587
modelAccessControlHelper.addUserBackendRolesFilter(user, request.source());
10688
log.debug("Filtering result by {}", user.getBackendRoles());
107-
client.search(request, doubleWrappedListener);
10889
}
90+
SearchDataObjectRequest searchDataObjecRequest = SearchDataObjectRequest
91+
.builder()
92+
.indices(request.indices())
93+
.searchSourceBuilder(request.source())
94+
.tenantId(tenantId)
95+
.build();
96+
sdkClient
97+
.searchDataObjectAsync(searchDataObjecRequest)
98+
.whenComplete(SdkClientUtils.wrapSearchCompletion(doubleWrappedListener));
10999
} catch (Exception e) {
110100
log.error("Failed to search", e);
111101
listener.onFailure(e);

plugin/src/main/java/org/opensearch/ml/action/tasks/SearchTaskTransportAction.java

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
package org.opensearch.ml.action.tasks;
77

8-
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
98
import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound;
109

1110
import org.opensearch.action.search.SearchResponse;
@@ -14,13 +13,13 @@
1413
import org.opensearch.common.inject.Inject;
1514
import org.opensearch.common.util.concurrent.ThreadContext;
1615
import org.opensearch.core.action.ActionListener;
17-
import org.opensearch.index.query.BoolQueryBuilder;
18-
import org.opensearch.index.query.QueryBuilders;
1916
import org.opensearch.ml.common.transport.search.MLSearchActionRequest;
2017
import org.opensearch.ml.common.transport.task.MLTaskSearchAction;
2118
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
2219
import org.opensearch.ml.utils.TenantAwareHelper;
2320
import org.opensearch.remote.metadata.client.SdkClient;
21+
import org.opensearch.remote.metadata.client.SearchDataObjectRequest;
22+
import org.opensearch.remote.metadata.common.SdkClientUtils;
2423
import org.opensearch.tasks.Task;
2524
import org.opensearch.transport.TransportService;
2625
import org.opensearch.transport.client.Client;
@@ -58,21 +57,14 @@ protected void doExecute(Task task, MLSearchActionRequest mlSearchActionRequest,
5857
final ActionListener<SearchResponse> wrappedListener = ActionListener
5958
.wrap(actionListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, actionListener));
6059

61-
// Modify the query to include tenant ID filtering
62-
if (tenantId != null) {
63-
BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery();
60+
SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest
61+
.builder()
62+
.indices(mlSearchActionRequest.indices())
63+
.searchSourceBuilder(mlSearchActionRequest.source())
64+
.tenantId(tenantId)
65+
.build();
6466

65-
// Preserve existing query if present
66-
if (mlSearchActionRequest.source().query() != null) {
67-
queryBuilder.must(mlSearchActionRequest.source().query());
68-
}
69-
// Add tenancy filter
70-
queryBuilder.filter(QueryBuilders.termQuery(TENANT_ID_FIELD, tenantId)); // Replace 'tenant_id_field' with actual field name
71-
72-
// Update the request's source with the new query
73-
mlSearchActionRequest.source().query(queryBuilder);
74-
}
75-
client.search(mlSearchActionRequest, ActionListener.runBefore(wrappedListener, context::restore));
67+
sdkClient.searchDataObjectAsync(searchDataObjectRequest).whenComplete(SdkClientUtils.wrapSearchCompletion(wrappedListener));
7668
} catch (Exception e) {
7769
log.error(e.getMessage(), e);
7870
actionListener.onFailure(e);

plugin/src/test/java/org/opensearch/ml/action/tasks/SearchTaskTransportActionTests.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@
66
package org.opensearch.ml.action.tasks;
77

88
import static org.mockito.ArgumentMatchers.any;
9-
import static org.mockito.ArgumentMatchers.eq;
109
import static org.mockito.Mockito.doAnswer;
1110
import static org.mockito.Mockito.mock;
12-
import static org.mockito.Mockito.times;
1311
import static org.mockito.Mockito.verify;
1412
import static org.mockito.Mockito.when;
1513

@@ -116,19 +114,21 @@ public void test_DoExecute() {
116114
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
117115
SearchRequest request = new SearchRequest("my_index").source(sourceBuilder);
118116
MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(request, null);
117+
118+
// Mock the response
119119
doAnswer(invocation -> {
120120
ActionListener<SearchResponse> listener = invocation.getArgument(1);
121121
listener.onResponse(searchResponse);
122122
return null;
123-
}).when(client).search(eq(mlSearchActionRequest), any());
123+
}).when(client).search(any(SearchRequest.class), any(ActionListener.class));
124124

125+
// Execute the action
125126
searchTaskTransportAction.doExecute(null, mlSearchActionRequest, actionListener);
126-
verify(client, times(1)).search(eq(mlSearchActionRequest), any());
127-
// Use ArgumentCaptor to capture the SearchResponse
127+
128+
// Verify the response
128129
ArgumentCaptor<SearchResponse> responseCaptor = ArgumentCaptor.forClass(SearchResponse.class);
129-
// Capture the response passed to actionListener.onResponse
130-
verify(actionListener, times(1)).onResponse(responseCaptor.capture());
131-
// Assert that the captured response matches the expected values
130+
verify(actionListener).onResponse(responseCaptor.capture());
131+
132132
SearchResponse capturedResponse = responseCaptor.getValue();
133133
assertEquals(searchResponse.getHits().getTotalHits(), capturedResponse.getHits().getTotalHits());
134134
assertEquals(searchResponse.getHits().getHits().length, capturedResponse.getHits().getHits().length);

0 commit comments

Comments
 (0)