Skip to content

Commit c82f916

Browse files
adding tenantId in the agent registration request (#3489) (#3491)
Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> (cherry picked from commit 9e014fa) Co-authored-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent 6ba54cf commit c82f916

File tree

2 files changed

+139
-1
lines changed

2 files changed

+139
-1
lines changed

plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterAgentAction.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
99
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
1010
import static org.opensearch.ml.utils.MLExceptionUtils.AGENT_FRAMEWORK_DISABLED_ERR_MSG;
11+
import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID;
1112

1213
import java.io.IOException;
1314
import java.util.List;
@@ -64,9 +65,10 @@ MLRegisterAgentRequest getRequest(RestRequest request) throws IOException {
6465
if (!mlFeatureEnabledSetting.isAgentFrameworkEnabled()) {
6566
throw new IllegalStateException(AGENT_FRAMEWORK_DISABLED_ERR_MSG);
6667
}
68+
String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request);
6769
XContentParser parser = request.contentParser();
6870
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
69-
MLAgent mlAgent = MLAgent.parseFromUserInput(parser);
71+
MLAgent mlAgent = MLAgent.parseFromUserInput(parser).toBuilder().tenantId(tenantId).build();
7072
return new MLRegisterAgentRequest(mlAgent);
7173
}
7274
}
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.rest;
7+
8+
import static org.mockito.ArgumentMatchers.any;
9+
import static org.mockito.ArgumentMatchers.eq;
10+
import static org.mockito.Mockito.doAnswer;
11+
import static org.mockito.Mockito.spy;
12+
import static org.mockito.Mockito.times;
13+
import static org.mockito.Mockito.verify;
14+
import static org.mockito.Mockito.when;
15+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MULTI_TENANCY_ENABLED;
16+
17+
import java.util.List;
18+
import java.util.Map;
19+
import java.util.Set;
20+
21+
import org.junit.Before;
22+
import org.junit.Rule;
23+
import org.junit.rules.ExpectedException;
24+
import org.mockito.ArgumentCaptor;
25+
import org.mockito.Mock;
26+
import org.mockito.MockitoAnnotations;
27+
import org.opensearch.client.node.NodeClient;
28+
import org.opensearch.cluster.service.ClusterService;
29+
import org.opensearch.common.settings.ClusterSettings;
30+
import org.opensearch.common.settings.Settings;
31+
import org.opensearch.common.xcontent.XContentType;
32+
import org.opensearch.core.common.bytes.BytesArray;
33+
import org.opensearch.core.xcontent.NamedXContentRegistry;
34+
import org.opensearch.ml.common.agent.MLAgent;
35+
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
36+
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
37+
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
38+
import org.opensearch.rest.RestChannel;
39+
import org.opensearch.rest.RestHandler;
40+
import org.opensearch.rest.RestRequest;
41+
import org.opensearch.test.OpenSearchTestCase;
42+
import org.opensearch.test.rest.FakeRestRequest;
43+
import org.opensearch.threadpool.TestThreadPool;
44+
import org.opensearch.threadpool.ThreadPool;
45+
46+
import com.google.gson.Gson;
47+
48+
public class RestMLRegisterAgentActionTests extends OpenSearchTestCase {
49+
@Rule
50+
public ExpectedException exceptionRule = ExpectedException.none();
51+
52+
private RestMLRegisterAgentAction restMLRegisterAgentAction;
53+
private NodeClient client;
54+
private ThreadPool threadPool;
55+
56+
@Mock
57+
private MLFeatureEnabledSetting mlFeatureEnabledSetting;
58+
Settings settings;
59+
60+
@Mock
61+
private ClusterService clusterService;
62+
63+
@Mock
64+
RestChannel channel;
65+
66+
@Before
67+
public void setup() {
68+
MockitoAnnotations.openMocks(this);
69+
threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool");
70+
client = spy(new NodeClient(Settings.EMPTY, threadPool));
71+
settings = Settings.builder().put(ML_COMMONS_MULTI_TENANCY_ENABLED.getKey(), false).build();
72+
when(clusterService.getSettings()).thenReturn(settings);
73+
when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(ML_COMMONS_MULTI_TENANCY_ENABLED)));
74+
when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false);
75+
when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(true);
76+
restMLRegisterAgentAction = new RestMLRegisterAgentAction(mlFeatureEnabledSetting);
77+
doAnswer(invocation -> null).when(client).execute(eq(MLRegisterAgentAction.INSTANCE), any(), any());
78+
}
79+
80+
@Override
81+
public void tearDown() throws Exception {
82+
super.tearDown();
83+
threadPool.shutdown();
84+
client.close();
85+
}
86+
87+
public void testConstructor() {
88+
RestMLRegisterAgentAction registerAgentAction = new RestMLRegisterAgentAction(mlFeatureEnabledSetting);
89+
assertNotNull(registerAgentAction);
90+
}
91+
92+
public void testGetName() {
93+
String actionName = restMLRegisterAgentAction.getName();
94+
assertFalse(actionName.isEmpty());
95+
assertEquals("ml_register_agent_action", actionName);
96+
}
97+
98+
public void testRoutes() {
99+
List<RestHandler.Route> routes = restMLRegisterAgentAction.routes();
100+
assertNotNull(routes);
101+
assertFalse(routes.isEmpty());
102+
RestHandler.Route route = routes.get(0);
103+
assertEquals(RestRequest.Method.POST, route.getMethod());
104+
assertEquals("/_plugins/_ml/agents/_register", route.getPath());
105+
}
106+
107+
public void testRegisterAgentRequest() throws Exception {
108+
RestRequest request = getRestRequest();
109+
restMLRegisterAgentAction.handleRequest(request, channel, client);
110+
ArgumentCaptor<MLRegisterAgentRequest> argumentCaptor = ArgumentCaptor.forClass(MLRegisterAgentRequest.class);
111+
verify(client, times(1)).execute(eq(MLRegisterAgentAction.INSTANCE), argumentCaptor.capture(), any());
112+
MLAgent mlAgent = argumentCaptor.getValue().getMlAgent();
113+
assertEquals("testAgentName", mlAgent.getName());
114+
assertEquals("This is a test agent description", mlAgent.getDescription());
115+
}
116+
117+
public void testRegisterAgentRequestWhenFrameworkDisabled() throws Exception {
118+
when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false);
119+
exceptionRule.expect(IllegalStateException.class);
120+
exceptionRule.expectMessage("Agent Framework is currently disabled. To enable it, update the setting \"plugins.ml_commons.agent_framework_enabled\" to true.");
121+
RestRequest request = getRestRequest();
122+
restMLRegisterAgentAction.handleRequest(request, channel, client);
123+
}
124+
125+
private RestRequest getRestRequest() {
126+
RestRequest.Method method = RestRequest.Method.POST;
127+
final Map<String, Object> agentData = Map
128+
.of("name", "testAgentName", "description", "This is a test agent description", "type", "FLOW");
129+
String requestContent = new Gson().toJson(agentData);
130+
return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
131+
.withMethod(method)
132+
.withPath("/_plugins/_ml/agents/_register")
133+
.withContent(new BytesArray(requestContent), XContentType.JSON)
134+
.build();
135+
}
136+
}

0 commit comments

Comments
 (0)