Skip to content

Commit 8f604cb

Browse files
addressing client changes due to adding tenantId in the apis (#3474) (#3480)
Signed-off-by: Dhrubo Saha <dhrubo@amazon.com> (cherry picked from commit 17b4d74) Co-authored-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent 6b6dd65 commit 8f604cb

File tree

4 files changed

+198
-54
lines changed

4 files changed

+198
-54
lines changed

client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ default void deleteModel(String modelId, ActionListener<DeleteResponse> listener
237237
*/
238238
default ActionFuture<DeleteResponse> deleteTask(String taskId) {
239239
PlainActionFuture<DeleteResponse> actionFuture = PlainActionFuture.newFuture();
240-
deleteModel(taskId, actionFuture);
240+
deleteTask(taskId, actionFuture);
241241
return actionFuture;
242242
}
243243

@@ -361,7 +361,7 @@ default ActionFuture<MLUndeployModelsResponse> undeploy(String[] modelIds, @Null
361361
* Undeploy model
362362
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/undeploy-model/
363363
* @param modelIds the model ids
364-
* @param modelIds the node ids. May be null for all nodes.
364+
* @param nodeIds the node ids. May be null for all nodes.
365365
* @param listener a listener to be notified of the result
366366
*/
367367
default void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener) {
@@ -372,7 +372,7 @@ default void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUnde
372372
* Undeploy model
373373
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/undeploy-model/
374374
* @param modelIds the model ids
375-
* @param modelIds the node ids. May be null for all nodes.
375+
* @param nodeIds the node ids. May be null for all nodes.
376376
* @param tenantId the tenant id. This is necessary for multi-tenancy.
377377
* @param listener a listener to be notified of the result
378378
*/
@@ -480,8 +480,7 @@ default ActionFuture<DeleteResponse> deleteAgent(String agentId) {
480480
* @param listener a listener to be notified of the result
481481
*/
482482
default void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
483-
PlainActionFuture<DeleteResponse> actionFuture = PlainActionFuture.newFuture();
484-
deleteAgent(agentId, null, actionFuture);
483+
deleteAgent(agentId, null, listener);
485484
}
486485

487486
/**
@@ -543,5 +542,15 @@ default ActionFuture<MLConfig> getConfig(String configId) {
543542
* @param configId ML config id
544543
* @param listener a listener to be notified of the result
545544
*/
546-
void getConfig(String configId, ActionListener<MLConfig> listener);
545+
default void getConfig(String configId, ActionListener<MLConfig> listener) {
546+
getConfig(configId, null, listener);
547+
}
548+
549+
/**
550+
* Delete agent
551+
* @param configId ML config id
552+
* @param tenantId the tenant id. This is necessary for multi-tenancy.
553+
* @param listener a listener to be notified of the result
554+
*/
555+
void getConfig(String configId, String tenantId, ActionListener<MLConfig> listener);
547556
}

client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,8 @@ public void getTool(String toolName, ActionListener<ToolMetadata> listener) {
312312
}
313313

314314
@Override
315-
public void getConfig(String configId, ActionListener<MLConfig> listener) {
316-
MLConfigGetRequest mlConfigGetRequest = MLConfigGetRequest.builder().configId(configId).build();
315+
public void getConfig(String configId, String tenantId, ActionListener<MLConfig> listener) {
316+
MLConfigGetRequest mlConfigGetRequest = MLConfigGetRequest.builder().configId(configId).tenantId(tenantId).build();
317317

318318
client.execute(MLConfigGetAction.INSTANCE, mlConfigGetRequest, getMlGetConfigResponseActionListener(listener));
319319
}

client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,6 @@ public void setUp() {
144144
.build();
145145

146146
machineLearningClient = new MachineLearningClient() {
147-
@Override
148-
public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
149-
listener.onResponse(output);
150-
}
151147

152148
@Override
153149
public void predict(String modelId, String tenantId, MLInput mlInput, ActionListener<MLOutput> listener) {
@@ -169,21 +165,11 @@ public void run(MLInput mlInput, Map<String, Object> args, ActionListener<MLOutp
169165
listener.onResponse(output);
170166
}
171167

172-
@Override
173-
public void getModel(String modelId, ActionListener<MLModel> listener) {
174-
listener.onResponse(mlModel);
175-
}
176-
177168
@Override
178169
public void getModel(String modelId, String tenantId, ActionListener<MLModel> listener) {
179170
listener.onResponse(mlModel);
180171
}
181172

182-
@Override
183-
public void deleteModel(String modelId, ActionListener<DeleteResponse> listener) {
184-
listener.onResponse(deleteResponse);
185-
}
186-
187173
@Override
188174
public void deleteModel(String modelId, String tenantId, ActionListener<DeleteResponse> listener) {
189175
listener.onResponse(deleteResponse);
@@ -194,21 +180,11 @@ public void searchModel(SearchRequest searchRequest, ActionListener<SearchRespon
194180
listener.onResponse(searchResponse);
195181
}
196182

197-
@Override
198-
public void getTask(String taskId, ActionListener<MLTask> listener) {
199-
listener.onResponse(mlTask);
200-
}
201-
202183
@Override
203184
public void getTask(String taskId, String tenantId, ActionListener<MLTask> listener) {
204185
listener.onResponse(mlTask);
205186
}
206187

207-
@Override
208-
public void deleteTask(String taskId, ActionListener<DeleteResponse> listener) {
209-
listener.onResponse(deleteResponse);
210-
}
211-
212188
@Override
213189
public void deleteTask(String taskId, String tenantId, ActionListener<DeleteResponse> listener) {
214190
listener.onResponse(deleteResponse);
@@ -224,21 +200,11 @@ public void register(MLRegisterModelInput mlInput, ActionListener<MLRegisterMode
224200
listener.onResponse(registerModelResponse);
225201
}
226202

227-
@Override
228-
public void deploy(String modelId, ActionListener<MLDeployModelResponse> listener) {
229-
listener.onResponse(deployModelResponse);
230-
}
231-
232203
@Override
233204
public void deploy(String modelId, String tenantId, ActionListener<MLDeployModelResponse> listener) {
234205
listener.onResponse(deployModelResponse);
235206
}
236207

237-
@Override
238-
public void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener) {
239-
listener.onResponse(undeployModelsResponse);
240-
}
241-
242208
@Override
243209
public void undeploy(String[] modelIds, String[] nodeIds, String tenantId, ActionListener<MLUndeployModelsResponse> listener) {
244210
listener.onResponse(undeployModelsResponse);
@@ -259,11 +225,6 @@ public void deleteConnector(String connectorId, String tenantId, ActionListener<
259225
listener.onResponse(deleteResponse);
260226
}
261227

262-
@Override
263-
public void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener) {
264-
listener.onResponse(deleteResponse);
265-
}
266-
267228
@Override
268229
public void listTools(ActionListener<List<ToolMetadata>> listener) {
269230
listener.onResponse(toolsList);
@@ -286,18 +247,13 @@ public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentRespons
286247
listener.onResponse(registerAgentResponse);
287248
}
288249

289-
@Override
290-
public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
291-
listener.onResponse(deleteResponse);
292-
}
293-
294250
@Override
295251
public void deleteAgent(String agentId, String tenantId, ActionListener<DeleteResponse> listener) {
296252
listener.onResponse(deleteResponse);
297253
}
298254

299255
@Override
300-
public void getConfig(String configId, ActionListener<MLConfig> listener) {
256+
public void getConfig(String configId, String tenantId, ActionListener<MLConfig> listener) {
301257
listener.onResponse(mlConfig);
302258
}
303259
};

client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java

Lines changed: 180 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,7 @@ public void deleteTask() {
884884
}).when(client).execute(eq(MLTaskDeleteAction.INSTANCE), any(), any());
885885

886886
ArgumentCaptor<DeleteResponse> argumentCaptor = ArgumentCaptor.forClass(DeleteResponse.class);
887-
machineLearningNodeClient.deleteTask(taskId, deleteTaskActionListener);
887+
machineLearningNodeClient.deleteTask(taskId, null, deleteTaskActionListener);
888888

889889
verify(client).execute(eq(MLTaskDeleteAction.INSTANCE), isA(MLTaskDeleteRequest.class), any());
890890
verify(deleteTaskActionListener).onResponse(argumentCaptor.capture());
@@ -1276,6 +1276,185 @@ public void getConfigRejectedMasterKey() {
12761276
assertEquals("You are not allowed to access this config doc", argumentCaptor.getValue().getLocalizedMessage());
12771277
}
12781278

1279+
@Test
1280+
public void predict_withTenantId() {
1281+
String tenantId = "testTenant";
1282+
doAnswer(invocation -> {
1283+
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
1284+
MLPredictionOutput predictionOutput = MLPredictionOutput
1285+
.builder()
1286+
.status("Success")
1287+
.predictionResult(output)
1288+
.taskId("taskId")
1289+
.build();
1290+
actionListener.onResponse(MLTaskResponse.builder().output(predictionOutput).build());
1291+
return null;
1292+
}).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any());
1293+
1294+
ArgumentCaptor<MLPredictionTaskRequest> requestCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class);
1295+
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build();
1296+
machineLearningNodeClient.predict("modelId", tenantId, mlInput, dataFrameActionListener);
1297+
1298+
verify(client).execute(eq(MLPredictionTaskAction.INSTANCE), requestCaptor.capture(), any());
1299+
assertEquals(tenantId, requestCaptor.getValue().getTenantId());
1300+
assertEquals("modelId", requestCaptor.getValue().getModelId());
1301+
}
1302+
1303+
@Test
1304+
public void getTask_withFailure() {
1305+
String taskId = "taskId";
1306+
String errorMessage = "Task not found";
1307+
1308+
doAnswer(invocation -> {
1309+
ActionListener<MLTaskGetResponse> actionListener = invocation.getArgument(2);
1310+
actionListener.onFailure(new IllegalArgumentException(errorMessage));
1311+
return null;
1312+
}).when(client).execute(eq(MLTaskGetAction.INSTANCE), any(), any());
1313+
1314+
ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
1315+
1316+
machineLearningNodeClient.getTask(taskId, new ActionListener<>() {
1317+
@Override
1318+
public void onResponse(MLTask mlTask) {
1319+
fail("Expected failure but got success");
1320+
}
1321+
1322+
@Override
1323+
public void onFailure(Exception e) {
1324+
assertEquals(errorMessage, e.getMessage());
1325+
}
1326+
});
1327+
1328+
verify(client).execute(eq(MLTaskGetAction.INSTANCE), isA(MLTaskGetRequest.class), any());
1329+
}
1330+
1331+
@Test
1332+
public void deploy_withTenantId() {
1333+
String modelId = "testModel";
1334+
String tenantId = "testTenant";
1335+
String taskId = "taskId";
1336+
String status = MLTaskState.CREATED.name();
1337+
1338+
doAnswer(invocation -> {
1339+
ActionListener<MLDeployModelResponse> actionListener = invocation.getArgument(2);
1340+
MLDeployModelResponse output = new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, status);
1341+
actionListener.onResponse(output);
1342+
return null;
1343+
}).when(client).execute(eq(MLDeployModelAction.INSTANCE), any(), any());
1344+
1345+
ArgumentCaptor<MLDeployModelRequest> requestCaptor = ArgumentCaptor.forClass(MLDeployModelRequest.class);
1346+
machineLearningNodeClient.deploy(modelId, tenantId, deployModelActionListener);
1347+
1348+
verify(client).execute(eq(MLDeployModelAction.INSTANCE), requestCaptor.capture(), any());
1349+
assertEquals(modelId, requestCaptor.getValue().getModelId());
1350+
assertEquals(tenantId, requestCaptor.getValue().getTenantId());
1351+
}
1352+
1353+
@Test
1354+
public void trainAndPredict_withNullInput() {
1355+
exceptionRule.expect(IllegalArgumentException.class);
1356+
exceptionRule.expectMessage("ML Input can't be null");
1357+
1358+
machineLearningNodeClient.trainAndPredict(null, trainingActionListener);
1359+
}
1360+
1361+
@Test
1362+
public void trainAndPredict_withNullDataSet() {
1363+
exceptionRule.expect(IllegalArgumentException.class);
1364+
exceptionRule.expectMessage("input data set can't be null");
1365+
1366+
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).build();
1367+
machineLearningNodeClient.trainAndPredict(mlInput, trainingActionListener);
1368+
}
1369+
1370+
@Test
1371+
public void getTask_withTaskIdAndTenantId() {
1372+
String taskId = "taskId";
1373+
String tenantId = "testTenant";
1374+
String modelId = "modelId";
1375+
1376+
doAnswer(invocation -> {
1377+
ActionListener<MLTaskGetResponse> actionListener = invocation.getArgument(2);
1378+
MLTask mlTask = MLTask.builder().taskId(taskId).modelId(modelId).functionName(FunctionName.KMEANS).build();
1379+
MLTaskGetResponse output = MLTaskGetResponse.builder().mlTask(mlTask).build();
1380+
actionListener.onResponse(output);
1381+
return null;
1382+
}).when(client).execute(eq(MLTaskGetAction.INSTANCE), any(), any());
1383+
1384+
ArgumentCaptor<MLTaskGetRequest> requestCaptor = ArgumentCaptor.forClass(MLTaskGetRequest.class);
1385+
ArgumentCaptor<MLTask> taskCaptor = ArgumentCaptor.forClass(MLTask.class);
1386+
1387+
machineLearningNodeClient.getTask(taskId, tenantId, getTaskActionListener);
1388+
1389+
verify(client).execute(eq(MLTaskGetAction.INSTANCE), requestCaptor.capture(), any());
1390+
verify(getTaskActionListener).onResponse(taskCaptor.capture());
1391+
1392+
// Verify request parameters
1393+
assertEquals(taskId, requestCaptor.getValue().getTaskId());
1394+
assertEquals(tenantId, requestCaptor.getValue().getTenantId());
1395+
1396+
// Verify response
1397+
assertEquals(taskId, taskCaptor.getValue().getTaskId());
1398+
assertEquals(modelId, taskCaptor.getValue().getModelId());
1399+
assertEquals(FunctionName.KMEANS, taskCaptor.getValue().getFunctionName());
1400+
}
1401+
1402+
@Test
1403+
public void deleteTask_withTaskId() {
1404+
String taskId = "taskId";
1405+
1406+
doAnswer(invocation -> {
1407+
ActionListener<DeleteResponse> actionListener = invocation.getArgument(2);
1408+
ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1);
1409+
DeleteResponse output = new DeleteResponse(shardId, taskId, 1, 1, 1, true);
1410+
actionListener.onResponse(output);
1411+
return null;
1412+
}).when(client).execute(eq(MLTaskDeleteAction.INSTANCE), any(), any());
1413+
1414+
ArgumentCaptor<MLTaskDeleteRequest> requestCaptor = ArgumentCaptor.forClass(MLTaskDeleteRequest.class);
1415+
ArgumentCaptor<DeleteResponse> responseCaptor = ArgumentCaptor.forClass(DeleteResponse.class);
1416+
1417+
machineLearningNodeClient.deleteTask(taskId, deleteTaskActionListener);
1418+
1419+
verify(client).execute(eq(MLTaskDeleteAction.INSTANCE), requestCaptor.capture(), any());
1420+
verify(deleteTaskActionListener).onResponse(responseCaptor.capture());
1421+
1422+
// Verify request parameter
1423+
assertEquals(taskId, requestCaptor.getValue().getTaskId());
1424+
1425+
// Verify response
1426+
assertEquals(taskId, responseCaptor.getValue().getId());
1427+
assertEquals("DELETED", responseCaptor.getValue().getResult().toString());
1428+
}
1429+
1430+
@Test
1431+
public void deleteTask_withFailure() {
1432+
String taskId = "taskId";
1433+
String errorMessage = "Task deletion failed";
1434+
1435+
doAnswer(invocation -> {
1436+
ActionListener<DeleteResponse> actionListener = invocation.getArgument(2);
1437+
actionListener.onFailure(new RuntimeException(errorMessage));
1438+
return null;
1439+
}).when(client).execute(eq(MLTaskDeleteAction.INSTANCE), any(), any());
1440+
1441+
ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
1442+
1443+
machineLearningNodeClient.deleteTask(taskId, new ActionListener<>() {
1444+
@Override
1445+
public void onResponse(DeleteResponse deleteResponse) {
1446+
fail("Expected failure but got success");
1447+
}
1448+
1449+
@Override
1450+
public void onFailure(Exception e) {
1451+
assertEquals(errorMessage, e.getMessage());
1452+
}
1453+
});
1454+
1455+
verify(client).execute(eq(MLTaskDeleteAction.INSTANCE), isA(MLTaskDeleteRequest.class), any());
1456+
}
1457+
12791458
private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
12801459
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
12811460

0 commit comments

Comments
 (0)