Skip to content

Commit 2de431e

Browse files
[BACKPORT 2.x] applying multi-tenancy in search [model, model group, agent, connector] (#3433) (#3443) (#3469)
* applying multi-tenancy in search [model, model group, agent, connector] (#3433) * applying multi-tenancy in search Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * addressed comments Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> --------- Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> * changing MLSearchActionRequest to an instance subclass of SearchActionRequest Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> --------- Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> (cherry picked from commit 34a7fb6) Co-authored-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent daff761 commit 2de431e

File tree

32 files changed

+1160
-215
lines changed

32 files changed

+1160
-215
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
package org.opensearch.ml.common.transport.search;
2+
3+
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;
4+
5+
import java.io.ByteArrayInputStream;
6+
import java.io.ByteArrayOutputStream;
7+
import java.io.IOException;
8+
import java.io.UncheckedIOException;
9+
10+
import org.opensearch.Version;
11+
import org.opensearch.action.ActionRequest;
12+
import org.opensearch.action.search.SearchRequest;
13+
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
14+
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
15+
import org.opensearch.core.common.io.stream.StreamInput;
16+
import org.opensearch.core.common.io.stream.StreamOutput;
17+
18+
import lombok.Builder;
19+
import lombok.Getter;
20+
21+
/**
22+
* Represents an extended search action request that includes a tenant ID.
23+
* This class allows OpenSearch to include a tenant ID in search requests,
24+
* which is not natively supported in the standard {@link SearchRequest}.
25+
*/
26+
@Getter
27+
public class MLSearchActionRequest extends SearchRequest {
28+
String tenantId;
29+
30+
/**
31+
* Constructor for building an MLSearchActionRequest.
32+
*
33+
* @param searchRequest The original {@link SearchRequest} to be wrapped.
34+
* @param tenantId The tenant ID associated with the request.
35+
*/
36+
@Builder
37+
public MLSearchActionRequest(SearchRequest searchRequest, String tenantId) {
38+
super(searchRequest);
39+
this.tenantId = tenantId;
40+
}
41+
42+
/**
43+
* Deserializes an {@link MLSearchActionRequest} from a {@link StreamInput}.
44+
*
45+
* @param input The stream input to read from.
46+
* @throws IOException If an I/O error occurs during deserialization.
47+
*/
48+
public MLSearchActionRequest(StreamInput input) throws IOException {
49+
super(input);
50+
Version streamInputVersion = input.getVersion();
51+
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
52+
53+
}
54+
55+
/**
56+
* Serializes this {@link MLSearchActionRequest} to a {@link StreamOutput}.
57+
*
58+
* @param output The stream output to write to.
59+
* @throws IOException If an I/O error occurs during serialization.
60+
*/
61+
@Override
62+
public void writeTo(StreamOutput output) throws IOException {
63+
super.writeTo(output);
64+
Version streamOutputVersion = output.getVersion();
65+
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
66+
output.writeOptionalString(tenantId);
67+
}
68+
}
69+
70+
/**
71+
* Converts a generic {@link ActionRequest} into an {@link MLSearchActionRequest}.
72+
* This is useful when handling requests that may need to be converted for compatibility.
73+
*
74+
* @param actionRequest The original {@link ActionRequest}.
75+
* @return The converted {@link MLSearchActionRequest}.
76+
* @throws UncheckedIOException If the conversion fails due to an I/O error.
77+
*/
78+
public static MLSearchActionRequest fromActionRequest(ActionRequest actionRequest) {
79+
if (actionRequest instanceof MLSearchActionRequest) {
80+
return (MLSearchActionRequest) actionRequest;
81+
}
82+
83+
if (actionRequest instanceof SearchRequest) {
84+
return MLSearchActionRequest
85+
.builder()
86+
.searchRequest((SearchRequest) actionRequest)
87+
.tenantId(null) // No tenant ID in the original request
88+
.build();
89+
}
90+
91+
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
92+
actionRequest.writeTo(osso);
93+
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
94+
return new MLSearchActionRequest(input);
95+
}
96+
} catch (IOException e) {
97+
throw new UncheckedIOException("failed to parse ActionRequest into MLSearchActionRequest", e);
98+
}
99+
}
100+
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
package org.opensearch.ml.common.transport.search;
2+
3+
import static org.junit.Assert.assertEquals;
4+
import static org.junit.Assert.assertNotNull;
5+
import static org.junit.Assert.assertNull;
6+
import static org.junit.Assert.assertSame;
7+
8+
import java.io.IOException;
9+
10+
import org.junit.Before;
11+
import org.junit.Test;
12+
import org.opensearch.Version;
13+
import org.opensearch.action.search.SearchRequest;
14+
import org.opensearch.common.io.stream.BytesStreamOutput;
15+
import org.opensearch.core.common.io.stream.StreamInput;
16+
17+
public class MLSearchActionRequestTest {
18+
19+
private SearchRequest searchRequest;
20+
21+
@Before
22+
public void setUp() {
23+
searchRequest = new SearchRequest("test-index");
24+
}
25+
26+
@Test
27+
public void testSerializationDeserialization_Version_2_19_0() throws IOException {
28+
// Set up a valid SearchRequest
29+
SearchRequest searchRequest = new SearchRequest("test-index");
30+
31+
// Create the MLSearchActionRequest
32+
MLSearchActionRequest originalRequest = MLSearchActionRequest
33+
.builder()
34+
.searchRequest(searchRequest)
35+
.tenantId("test-tenant")
36+
.build();
37+
38+
BytesStreamOutput out = new BytesStreamOutput();
39+
out.setVersion(Version.V_2_19_0);
40+
originalRequest.writeTo(out);
41+
42+
StreamInput in = out.bytes().streamInput();
43+
in.setVersion(Version.V_2_19_0);
44+
MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(in);
45+
46+
assertEquals("test-tenant", deserializedRequest.getTenantId());
47+
}
48+
49+
@Test
50+
public void testSerializationDeserialization_Version_2_18_0() throws IOException {
51+
52+
// Create the MLSearchActionRequest
53+
MLSearchActionRequest originalRequest = MLSearchActionRequest
54+
.builder()
55+
.searchRequest(searchRequest)
56+
.tenantId("test-tenant")
57+
.build();
58+
59+
BytesStreamOutput out = new BytesStreamOutput();
60+
out.setVersion(Version.V_2_18_0);
61+
originalRequest.writeTo(out);
62+
63+
StreamInput in = out.bytes().streamInput();
64+
in.setVersion(Version.V_2_18_0);
65+
MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(in);
66+
67+
assertNull(deserializedRequest.getTenantId());
68+
}
69+
70+
@Test
71+
public void testFromActionRequest_WithMLSearchActionRequest() {
72+
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
73+
74+
MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(request);
75+
76+
assertSame(request, result);
77+
}
78+
79+
@Test
80+
public void testFromActionRequest_WithSearchRequest() throws IOException {
81+
SearchRequest simpleRequest = new SearchRequest("test-index");
82+
83+
MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(simpleRequest);
84+
85+
assertNotNull(result);
86+
assertNull(result.getTenantId()); // Since tenantId wasn't in original request
87+
}
88+
89+
}

memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE;
2121

2222
import org.opensearch.OpenSearchException;
23-
import org.opensearch.action.search.SearchRequest;
2423
import org.opensearch.action.search.SearchResponse;
2524
import org.opensearch.action.support.ActionFilters;
2625
import org.opensearch.action.support.HandledTransportAction;
@@ -30,6 +29,7 @@
3029
import org.opensearch.common.util.concurrent.ThreadContext;
3130
import org.opensearch.core.action.ActionListener;
3231
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
32+
import org.opensearch.ml.common.transport.search.MLSearchActionRequest;
3333
import org.opensearch.ml.memory.ConversationalMemoryHandler;
3434
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
3535
import org.opensearch.tasks.Task;
@@ -38,7 +38,7 @@
3838
import lombok.extern.log4j.Log4j2;
3939

4040
@Log4j2
41-
public class SearchConversationsTransportAction extends HandledTransportAction<SearchRequest, SearchResponse> {
41+
public class SearchConversationsTransportAction extends HandledTransportAction<MLSearchActionRequest, SearchResponse> {
4242

4343
private ConversationalMemoryHandler cmHandler;
4444
private Client client;
@@ -61,7 +61,7 @@ public SearchConversationsTransportAction(
6161
Client client,
6262
ClusterService clusterService
6363
) {
64-
super(SearchConversationsAction.NAME, transportService, actionFilters, SearchRequest::new);
64+
super(SearchConversationsAction.NAME, transportService, actionFilters, MLSearchActionRequest::new);
6565
this.cmHandler = cmHandler;
6666
this.client = client;
6767
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
@@ -71,14 +71,14 @@ public SearchConversationsTransportAction(
7171
}
7272

7373
@Override
74-
public void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
74+
public void doExecute(Task task, MLSearchActionRequest mlSearchActionRequest, ActionListener<SearchResponse> actionListener) {
7575
if (!featureIsEnabled) {
7676
actionListener.onFailure(new OpenSearchException(ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE));
7777
return;
7878
} else {
7979
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
80-
ActionListener<SearchResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
81-
cmHandler.searchConversations(request, internalListener);
80+
ActionListener<SearchResponse> internalListener = ActionListener.runBefore(actionListener, context::restore);
81+
cmHandler.searchConversations(mlSearchActionRequest, internalListener);
8282
} catch (Exception e) {
8383
log.error("Failed to search memories", e);
8484
actionListener.onFailure(e);

memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import org.opensearch.core.action.ActionListener;
4545
import org.opensearch.core.xcontent.NamedXContentRegistry;
4646
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
47+
import org.opensearch.ml.common.transport.search.MLSearchActionRequest;
4748
import org.opensearch.ml.memory.MemoryTestUtil;
4849
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
4950
import org.opensearch.test.OpenSearchTestCase;
@@ -79,12 +80,15 @@ public class SearchConversationsTransportActionTests extends OpenSearchTestCase
7980
@Mock
8081
SearchRequest request;
8182

83+
MLSearchActionRequest mlSearchActionRequest;
84+
8285
SearchConversationsTransportAction action;
8386
ThreadContext threadContext;
8487

8588
@Before
8689
public void setup() throws IOException {
8790
MockitoAnnotations.openMocks(this);
91+
mlSearchActionRequest = new MLSearchActionRequest(request, null);
8892

8993
Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build();
9094
this.threadContext = new ThreadContext(settings);
@@ -104,7 +108,7 @@ public void testEnabled_ThenSucceed() {
104108
listener.onResponse(response);
105109
return null;
106110
}).when(cmHandler).searchConversations(any(), any());
107-
action.doExecute(null, request, actionListener);
111+
action.doExecute(null, mlSearchActionRequest, actionListener);
108112
ArgumentCaptor<SearchResponse> argCaptor = ArgumentCaptor.forClass(SearchResponse.class);
109113
verify(actionListener, times(1)).onResponse(argCaptor.capture());
110114
assert (argCaptor.getValue().equals(response));
@@ -114,7 +118,7 @@ public void testDisabled_ThenFail() {
114118
clusterService = MemoryTestUtil.clusterServiceWithMemoryFeatureDisabled();
115119
this.action = spy(new SearchConversationsTransportAction(transportService, actionFilters, cmHandler, client, clusterService));
116120

117-
action.doExecute(null, request, actionListener);
121+
action.doExecute(null, mlSearchActionRequest, actionListener);
118122
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
119123
verify(actionListener).onFailure(argCaptor.capture());
120124
assertEquals(argCaptor.getValue().getMessage(), ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE);

plugin/src/main/java/org/opensearch/ml/action/agents/TransportSearchAgentAction.java

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.opensearch.ml.action.agents;
77

88
import static org.opensearch.ml.action.handler.MLSearchHandler.wrapRestActionListener;
9+
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
910

1011
import org.opensearch.action.search.SearchRequest;
1112
import org.opensearch.action.search.SearchResponse;
@@ -20,28 +21,46 @@
2021
import org.opensearch.ml.common.CommonValue;
2122
import org.opensearch.ml.common.agent.MLAgent;
2223
import org.opensearch.ml.common.transport.agent.MLSearchAgentAction;
24+
import org.opensearch.ml.common.transport.search.MLSearchActionRequest;
25+
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
26+
import org.opensearch.ml.utils.TenantAwareHelper;
27+
import org.opensearch.remote.metadata.client.SdkClient;
2328
import org.opensearch.tasks.Task;
2429
import org.opensearch.transport.TransportService;
2530

2631
import lombok.extern.log4j.Log4j2;
2732

2833
@Log4j2
29-
public class TransportSearchAgentAction extends HandledTransportAction<SearchRequest, SearchResponse> {
34+
public class TransportSearchAgentAction extends HandledTransportAction<MLSearchActionRequest, SearchResponse> {
3035
private final Client client;
36+
private final SdkClient sdkClient;
37+
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;
3138

3239
@Inject
33-
public TransportSearchAgentAction(TransportService transportService, ActionFilters actionFilters, Client client) {
34-
super(MLSearchAgentAction.NAME, transportService, actionFilters, SearchRequest::new);
40+
public TransportSearchAgentAction(
41+
TransportService transportService,
42+
ActionFilters actionFilters,
43+
Client client,
44+
SdkClient sdkClient,
45+
MLFeatureEnabledSetting mlFeatureEnabledSetting
46+
) {
47+
super(MLSearchAgentAction.NAME, transportService, actionFilters, MLSearchActionRequest::new);
3548
this.client = client;
49+
this.sdkClient = sdkClient;
50+
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
3651
}
3752

3853
@Override
39-
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
54+
protected void doExecute(Task task, MLSearchActionRequest request, ActionListener<SearchResponse> actionListener) {
4055
request.indices(CommonValue.ML_AGENT_INDEX);
41-
search(request, actionListener);
56+
String tenantId = request.getTenantId();
57+
if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) {
58+
return;
59+
}
60+
search(request, tenantId, actionListener);
4261
}
4362

44-
private void search(SearchRequest request, ActionListener<SearchResponse> actionListener) {
63+
private void search(SearchRequest request, String tenantId, ActionListener<SearchResponse> actionListener) {
4564
ActionListener<SearchResponse> listener = wrapRestActionListener(actionListener, "Fail to search agent");
4665
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
4766
ActionListener<SearchResponse> wrappedListener = ActionListener.runBefore(listener, context::restore);
@@ -57,6 +76,11 @@ private void search(SearchRequest request, ActionListener<SearchResponse> action
5776
// Add a should clause to include documents where IS_HIDDEN_FIELD is false
5877
shouldQuery.should(QueryBuilders.termQuery(MLAgent.IS_HIDDEN_FIELD, false));
5978

79+
// For multi-tenancy
80+
if (tenantId != null) {
81+
shouldQuery.should(QueryBuilders.termQuery(TENANT_ID_FIELD, tenantId));
82+
}
83+
6084
// Add a should clause to include documents where IS_HIDDEN_FIELD does not exist or is null
6185
shouldQuery.should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(MLAgent.IS_HIDDEN_FIELD)));
6286

0 commit comments

Comments
 (0)