Skip to content

Commit 76f0f3b

Browse files
adding tenantID to the request + undeploy request (#3425) (#3429)
Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> (cherry picked from commit af96fe0) Co-authored-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent 880b674 commit 76f0f3b

File tree

20 files changed

+517
-152
lines changed

20 files changed

+517
-152
lines changed

client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,19 @@ default ActionFuture<MLOutput> predict(String modelId, MLInput mlInput) {
6060
* @param mlInput ML input
6161
* @param listener a listener to be notified of the result
6262
*/
63-
void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener);
63+
default void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
64+
predict(modelId, null, mlInput, listener);
65+
}
66+
67+
/**
68+
* Do prediction machine learning job
69+
* For additional info on Predict, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#predict
70+
* @param modelId the trained model id
71+
* @param tenantId tenant id
72+
* @param mlInput ML input
73+
* @param listener a listener to be notified of the result
74+
*/
75+
void predict(String modelId, String tenantId, MLInput mlInput, ActionListener<MLOutput> listener);
6476

6577
/**
6678
* Train model then predict with the same data set.
@@ -352,7 +364,19 @@ default ActionFuture<MLUndeployModelsResponse> undeploy(String[] modelIds, @Null
352364
* @param modelIds the node ids. May be null for all nodes.
353365
* @param listener a listener to be notified of the result
354366
*/
355-
void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener);
367+
default void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener) {
368+
undeploy(modelIds, nodeIds, null, listener);
369+
}
370+
371+
/**
372+
* Undeploy model
373+
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/undeploy-model/
374+
* @param modelIds the model ids
375+
* @param modelIds the node ids. May be null for all nodes.
376+
* @param tenantId the tenant id. This is necessary for multi-tenancy.
377+
* @param listener a listener to be notified of the result
378+
*/
379+
void undeploy(String[] modelIds, String[] nodeIds, String tenantId, ActionListener<MLUndeployModelsResponse> listener);
356380

357381
/**
358382
* Create connector for remote model

client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,15 @@ public class MachineLearningNodeClient implements MachineLearningClient {
101101
Client client;
102102

103103
@Override
104-
public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
104+
public void predict(String modelId, String tenantId, MLInput mlInput, ActionListener<MLOutput> listener) {
105105
validateMLInput(mlInput, true);
106106

107107
MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest
108108
.builder()
109109
.mlInput(mlInput)
110110
.modelId(modelId)
111111
.dispatchTask(true)
112+
.tenantId(tenantId)
112113
.build();
113114
client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, getMlPredictionTaskResponseActionListener(listener));
114115
}
@@ -262,8 +263,8 @@ public void deploy(String modelId, String tenantId, ActionListener<MLDeployModel
262263
}
263264

264265
@Override
265-
public void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener) {
266-
MLUndeployModelsRequest undeployModelRequest = new MLUndeployModelsRequest(modelIds, nodeIds);
266+
public void undeploy(String[] modelIds, String[] nodeIds, String tenantId, ActionListener<MLUndeployModelsResponse> listener) {
267+
MLUndeployModelsRequest undeployModelRequest = new MLUndeployModelsRequest(modelIds, nodeIds, tenantId);
267268
client.execute(MLUndeployModelsAction.INSTANCE, undeployModelRequest, getMlUndeployModelsResponseActionListener(listener));
268269
}
269270

client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,11 @@ public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> li
149149
listener.onResponse(output);
150150
}
151151

152+
@Override
153+
public void predict(String modelId, String tenantId, MLInput mlInput, ActionListener<MLOutput> listener) {
154+
listener.onResponse(output);
155+
}
156+
152157
@Override
153158
public void trainAndPredict(MLInput mlInput, ActionListener<MLOutput> listener) {
154159
listener.onResponse(output);
@@ -234,6 +239,11 @@ public void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndep
234239
listener.onResponse(undeployModelsResponse);
235240
}
236241

242+
@Override
243+
public void undeploy(String[] modelIds, String[] nodeIds, String tenantId, ActionListener<MLUndeployModelsResponse> listener) {
244+
listener.onResponse(undeployModelsResponse);
245+
}
246+
237247
@Override
238248
public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener<MLCreateConnectorResponse> listener) {
239249
listener.onResponse(createConnectorResponse);
@@ -320,7 +330,7 @@ public void predict_WithAlgoAndParametersAndInputDataAndModelId() {
320330
public void predict_WithAlgoAndInputDataAndListener() {
321331
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(new DataFrameInputDataset(input)).build();
322332
ArgumentCaptor<MLOutput> dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
323-
machineLearningClient.predict(null, mlInput, dataFrameActionListener);
333+
machineLearningClient.predict(null, null, mlInput, dataFrameActionListener);
324334
verify(dataFrameActionListener).onResponse(dataFrameArgumentCaptor.capture());
325335
assertEquals(output, dataFrameArgumentCaptor.getValue());
326336
}

common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequest.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,32 @@
55

66
package org.opensearch.ml.common.transport.undeploy;
77

8+
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;
9+
810
import java.io.IOException;
911

12+
import org.opensearch.Version;
1013
import org.opensearch.action.support.nodes.BaseNodesRequest;
1114
import org.opensearch.cluster.node.DiscoveryNode;
1215
import org.opensearch.core.common.io.stream.StreamInput;
1316
import org.opensearch.core.common.io.stream.StreamOutput;
1417

1518
import lombok.Getter;
19+
import lombok.Setter;
1620

1721
public class MLUndeployModelNodesRequest extends BaseNodesRequest<MLUndeployModelNodesRequest> {
1822

1923
@Getter
2024
private String[] modelIds;
25+
@Getter
26+
@Setter
27+
private String tenantId;
2128

2229
public MLUndeployModelNodesRequest(StreamInput in) throws IOException {
2330
super(in);
31+
Version streamInputVersion = in.getVersion();
2432
this.modelIds = in.readOptionalStringArray();
33+
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null;
2534
}
2635

2736
public MLUndeployModelNodesRequest(String[] nodeIds, String[] modelIds) {
@@ -36,7 +45,11 @@ public MLUndeployModelNodesRequest(DiscoveryNode... nodes) {
3645
@Override
3746
public void writeTo(StreamOutput out) throws IOException {
3847
super.writeTo(out);
48+
Version streamOutputVersion = out.getVersion();
3949
out.writeOptionalStringArray(modelIds);
50+
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
51+
out.writeOptionalString(tenantId);
52+
}
4053
}
4154

4255
}

common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsRequest.java

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
package org.opensearch.ml.common.transport.undeploy;
77

88
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
9+
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
10+
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;
911

1012
import java.io.ByteArrayInputStream;
1113
import java.io.ByteArrayOutputStream;
@@ -14,6 +16,7 @@
1416
import java.util.ArrayList;
1517
import java.util.List;
1618

19+
import org.opensearch.Version;
1720
import org.opensearch.action.ActionRequest;
1821
import org.opensearch.action.ActionRequestValidationException;
1922
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
@@ -39,24 +42,28 @@ public class MLUndeployModelsRequest extends MLTaskRequest {
3942
private String[] modelIds;
4043
private String[] nodeIds;
4144
boolean async;
45+
private String tenantId;
4246

4347
@Builder
44-
public MLUndeployModelsRequest(String[] modelIds, String[] nodeIds, boolean async, boolean dispatchTask) {
48+
public MLUndeployModelsRequest(String[] modelIds, String[] nodeIds, boolean async, boolean dispatchTask, String tenantId) {
4549
super(dispatchTask);
4650
this.modelIds = modelIds;
4751
this.nodeIds = nodeIds;
4852
this.async = async;
53+
this.tenantId = tenantId;
4954
}
5055

51-
public MLUndeployModelsRequest(String[] modelIds, String[] nodeIds) {
52-
this(modelIds, nodeIds, false, false);
56+
public MLUndeployModelsRequest(String[] modelIds, String[] nodeIds, String tenantId) {
57+
this(modelIds, nodeIds, false, false, tenantId);
5358
}
5459

5560
public MLUndeployModelsRequest(StreamInput in) throws IOException {
5661
super(in);
62+
Version streamInputVersion = in.getVersion();
5763
this.modelIds = in.readOptionalStringArray();
5864
this.nodeIds = in.readOptionalStringArray();
5965
this.async = in.readBoolean();
66+
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null;
6067
}
6168

6269
@Override
@@ -68,15 +75,20 @@ public ActionRequestValidationException validate() {
6875
@Override
6976
public void writeTo(StreamOutput out) throws IOException {
7077
super.writeTo(out);
78+
Version streamOutputVersion = out.getVersion();
7179
out.writeOptionalStringArray(modelIds);
7280
out.writeOptionalStringArray(nodeIds);
7381
out.writeBoolean(async);
82+
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
83+
out.writeOptionalString(tenantId);
84+
}
7485
}
7586

7687
public static MLUndeployModelsRequest parse(XContentParser parser, String modelId) throws IOException {
7788
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
7889
List<String> modelIdList = new ArrayList<>();
7990
List<String> nodeIdList = new ArrayList<>();
91+
String tenantId = null;
8092
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
8193
String fieldName = parser.currentName();
8294
parser.nextToken();
@@ -94,14 +106,17 @@ public static MLUndeployModelsRequest parse(XContentParser parser, String modelI
94106
nodeIdList.add(parser.text());
95107
}
96108
break;
109+
case TENANT_ID_FIELD:
110+
tenantId = parser.textOrNull();
111+
break;
97112
default:
98113
parser.skipChildren();
99114
break;
100115
}
101116
}
102-
String[] modelIds = modelIdList == null ? null : modelIdList.toArray(new String[0]);
103-
String[] nodeIds = nodeIdList == null ? null : nodeIdList.toArray(new String[0]);
104-
return new MLUndeployModelsRequest(modelIds, nodeIds, false, true);
117+
String[] modelIds = modelIdList.toArray(new String[0]);
118+
String[] nodeIds = nodeIdList.toArray(new String[0]);
119+
return new MLUndeployModelsRequest(modelIds, nodeIds, false, true, tenantId);
105120
}
106121

107122
public static MLUndeployModelsRequest fromActionRequest(ActionRequest actionRequest) {
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
package org.opensearch.ml.common.transport.undeploy;
2+
3+
import static org.junit.Assert.*;
4+
import static org.opensearch.ml.common.CommonValue.VERSION_2_18_0;
5+
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;
6+
7+
import java.io.IOException;
8+
import java.io.UncheckedIOException;
9+
import java.util.Collections;
10+
import java.util.function.Consumer;
11+
12+
import org.junit.Before;
13+
import org.junit.Test;
14+
import org.opensearch.action.ActionRequest;
15+
import org.opensearch.action.ActionRequestValidationException;
16+
import org.opensearch.common.io.stream.BytesStreamOutput;
17+
import org.opensearch.common.settings.Settings;
18+
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
19+
import org.opensearch.common.xcontent.XContentType;
20+
import org.opensearch.core.common.io.stream.StreamInput;
21+
import org.opensearch.core.common.io.stream.StreamOutput;
22+
import org.opensearch.core.xcontent.NamedXContentRegistry;
23+
import org.opensearch.core.xcontent.XContentParser;
24+
import org.opensearch.search.SearchModule;
25+
26+
public class MLUndeployModelsRequestTest {
27+
28+
private MLUndeployModelsRequest mlUndeployModelsRequest;
29+
30+
@Before
31+
public void setUp() {
32+
mlUndeployModelsRequest = MLUndeployModelsRequest
33+
.builder()
34+
.modelIds(new String[] { "model1", "model2" })
35+
.nodeIds(new String[] { "node1", "node2" })
36+
.async(true)
37+
.dispatchTask(true)
38+
.tenantId("tenant1")
39+
.build();
40+
}
41+
42+
@Test
43+
public void testValidate() {
44+
MLUndeployModelsRequest request = MLUndeployModelsRequest.builder().modelIds(new String[] { "model1" }).build();
45+
assertNull(request.validate());
46+
}
47+
48+
@Test
49+
public void testStreamInputVersionBefore_2_19_0() throws IOException {
50+
BytesStreamOutput out = new BytesStreamOutput();
51+
out.setVersion(VERSION_2_18_0);
52+
mlUndeployModelsRequest.writeTo(out);
53+
54+
StreamInput in = out.bytes().streamInput();
55+
in.setVersion(VERSION_2_18_0);
56+
MLUndeployModelsRequest request = new MLUndeployModelsRequest(in);
57+
58+
assertArrayEquals(mlUndeployModelsRequest.getModelIds(), request.getModelIds());
59+
assertArrayEquals(mlUndeployModelsRequest.getNodeIds(), request.getNodeIds());
60+
assertEquals(mlUndeployModelsRequest.isAsync(), request.isAsync());
61+
assertEquals(mlUndeployModelsRequest.isDispatchTask(), request.isDispatchTask());
62+
assertNull(request.getTenantId());
63+
}
64+
65+
@Test
66+
public void testStreamInputVersionAfter_2_19_0() throws IOException {
67+
BytesStreamOutput out = new BytesStreamOutput();
68+
out.setVersion(VERSION_2_19_0);
69+
mlUndeployModelsRequest.writeTo(out);
70+
71+
StreamInput in = out.bytes().streamInput();
72+
in.setVersion(VERSION_2_19_0);
73+
MLUndeployModelsRequest request = new MLUndeployModelsRequest(in);
74+
75+
assertArrayEquals(mlUndeployModelsRequest.getModelIds(), request.getModelIds());
76+
assertArrayEquals(mlUndeployModelsRequest.getNodeIds(), request.getNodeIds());
77+
assertEquals(mlUndeployModelsRequest.isAsync(), request.isAsync());
78+
assertEquals(mlUndeployModelsRequest.isDispatchTask(), request.isDispatchTask());
79+
assertEquals(mlUndeployModelsRequest.getTenantId(), request.getTenantId());
80+
}
81+
82+
@Test
83+
public void testWriteToWithNullFields() throws IOException {
84+
MLUndeployModelsRequest request = MLUndeployModelsRequest
85+
.builder()
86+
.modelIds(null)
87+
.nodeIds(null)
88+
.async(true)
89+
.dispatchTask(true)
90+
.build();
91+
92+
BytesStreamOutput out = new BytesStreamOutput();
93+
out.setVersion(VERSION_2_19_0);
94+
request.writeTo(out);
95+
96+
StreamInput in = out.bytes().streamInput();
97+
in.setVersion(VERSION_2_19_0);
98+
MLUndeployModelsRequest result = new MLUndeployModelsRequest(in);
99+
100+
assertNull(result.getModelIds());
101+
assertNull(result.getNodeIds());
102+
assertEquals(request.isAsync(), result.isAsync());
103+
assertEquals(request.isDispatchTask(), result.isDispatchTask());
104+
}
105+
106+
@Test(expected = UncheckedIOException.class)
107+
public void fromActionRequest_IOException() {
108+
ActionRequest actionRequest = new ActionRequest() {
109+
@Override
110+
public ActionRequestValidationException validate() {
111+
return null;
112+
}
113+
114+
@Override
115+
public void writeTo(StreamOutput out) throws IOException {
116+
throw new IOException("test");
117+
}
118+
};
119+
MLUndeployModelsRequest.fromActionRequest(actionRequest);
120+
}
121+
122+
@Test
123+
public void fromActionRequest_Success_WithMLUndeployModelsRequest() {
124+
MLUndeployModelsRequest request = MLUndeployModelsRequest.builder().modelIds(new String[] { "model1" }).build();
125+
assertSame(MLUndeployModelsRequest.fromActionRequest(request), request);
126+
}
127+
128+
@Test
129+
public void testParse() throws Exception {
130+
String expectedInputStr = "{\"model_ids\":[\"model1\"],\"node_ids\":[\"node1\"]}";
131+
parseFromJsonString(expectedInputStr, parsedInput -> {
132+
assertArrayEquals(new String[] { "model1" }, parsedInput.getModelIds());
133+
assertArrayEquals(new String[] { "node1" }, parsedInput.getNodeIds());
134+
assertFalse(parsedInput.isAsync());
135+
assertTrue(parsedInput.isDispatchTask());
136+
});
137+
}
138+
139+
@Test
140+
public void testParseWithInvalidField() throws Exception {
141+
String withInvalidFieldInputStr = "{\"invalid_field\":\"void\",\"model_ids\":[\"model1\"],\"node_ids\":[\"node1\"]}";
142+
parseFromJsonString(withInvalidFieldInputStr, parsedInput -> {
143+
assertArrayEquals(new String[] { "model1" }, parsedInput.getModelIds());
144+
assertArrayEquals(new String[] { "node1" }, parsedInput.getNodeIds());
145+
});
146+
}
147+
148+
private void parseFromJsonString(String expectedInputStr, Consumer<MLUndeployModelsRequest> verify) throws Exception {
149+
XContentParser parser = XContentType.JSON
150+
.xContent()
151+
.createParser(
152+
new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()),
153+
LoggingDeprecationHandler.INSTANCE,
154+
expectedInputStr
155+
);
156+
parser.nextToken();
157+
MLUndeployModelsRequest parsedInput = MLUndeployModelsRequest.parse(parser, null);
158+
verify.accept(parsedInput);
159+
}
160+
}

0 commit comments

Comments
 (0)