@@ -884,7 +884,7 @@ public void deleteTask() {
884
884
}).when (client ).execute (eq (MLTaskDeleteAction .INSTANCE ), any (), any ());
885
885
886
886
ArgumentCaptor <DeleteResponse > argumentCaptor = ArgumentCaptor .forClass (DeleteResponse .class );
887
- machineLearningNodeClient .deleteTask (taskId , deleteTaskActionListener );
887
+ machineLearningNodeClient .deleteTask (taskId , null , deleteTaskActionListener );
888
888
889
889
verify (client ).execute (eq (MLTaskDeleteAction .INSTANCE ), isA (MLTaskDeleteRequest .class ), any ());
890
890
verify (deleteTaskActionListener ).onResponse (argumentCaptor .capture ());
@@ -1276,6 +1276,185 @@ public void getConfigRejectedMasterKey() {
1276
1276
assertEquals ("You are not allowed to access this config doc" , argumentCaptor .getValue ().getLocalizedMessage ());
1277
1277
}
1278
1278
1279
+ @ Test
1280
+ public void predict_withTenantId () {
1281
+ String tenantId = "testTenant" ;
1282
+ doAnswer (invocation -> {
1283
+ ActionListener <MLTaskResponse > actionListener = invocation .getArgument (2 );
1284
+ MLPredictionOutput predictionOutput = MLPredictionOutput
1285
+ .builder ()
1286
+ .status ("Success" )
1287
+ .predictionResult (output )
1288
+ .taskId ("taskId" )
1289
+ .build ();
1290
+ actionListener .onResponse (MLTaskResponse .builder ().output (predictionOutput ).build ());
1291
+ return null ;
1292
+ }).when (client ).execute (eq (MLPredictionTaskAction .INSTANCE ), any (), any ());
1293
+
1294
+ ArgumentCaptor <MLPredictionTaskRequest > requestCaptor = ArgumentCaptor .forClass (MLPredictionTaskRequest .class );
1295
+ MLInput mlInput = MLInput .builder ().algorithm (FunctionName .KMEANS ).inputDataset (input ).build ();
1296
+ machineLearningNodeClient .predict ("modelId" , tenantId , mlInput , dataFrameActionListener );
1297
+
1298
+ verify (client ).execute (eq (MLPredictionTaskAction .INSTANCE ), requestCaptor .capture (), any ());
1299
+ assertEquals (tenantId , requestCaptor .getValue ().getTenantId ());
1300
+ assertEquals ("modelId" , requestCaptor .getValue ().getModelId ());
1301
+ }
1302
+
1303
+ @ Test
1304
+ public void getTask_withFailure () {
1305
+ String taskId = "taskId" ;
1306
+ String errorMessage = "Task not found" ;
1307
+
1308
+ doAnswer (invocation -> {
1309
+ ActionListener <MLTaskGetResponse > actionListener = invocation .getArgument (2 );
1310
+ actionListener .onFailure (new IllegalArgumentException (errorMessage ));
1311
+ return null ;
1312
+ }).when (client ).execute (eq (MLTaskGetAction .INSTANCE ), any (), any ());
1313
+
1314
+ ArgumentCaptor <Exception > exceptionCaptor = ArgumentCaptor .forClass (Exception .class );
1315
+
1316
+ machineLearningNodeClient .getTask (taskId , new ActionListener <>() {
1317
+ @ Override
1318
+ public void onResponse (MLTask mlTask ) {
1319
+ fail ("Expected failure but got success" );
1320
+ }
1321
+
1322
+ @ Override
1323
+ public void onFailure (Exception e ) {
1324
+ assertEquals (errorMessage , e .getMessage ());
1325
+ }
1326
+ });
1327
+
1328
+ verify (client ).execute (eq (MLTaskGetAction .INSTANCE ), isA (MLTaskGetRequest .class ), any ());
1329
+ }
1330
+
1331
+ @ Test
1332
+ public void deploy_withTenantId () {
1333
+ String modelId = "testModel" ;
1334
+ String tenantId = "testTenant" ;
1335
+ String taskId = "taskId" ;
1336
+ String status = MLTaskState .CREATED .name ();
1337
+
1338
+ doAnswer (invocation -> {
1339
+ ActionListener <MLDeployModelResponse > actionListener = invocation .getArgument (2 );
1340
+ MLDeployModelResponse output = new MLDeployModelResponse (taskId , MLTaskType .DEPLOY_MODEL , status );
1341
+ actionListener .onResponse (output );
1342
+ return null ;
1343
+ }).when (client ).execute (eq (MLDeployModelAction .INSTANCE ), any (), any ());
1344
+
1345
+ ArgumentCaptor <MLDeployModelRequest > requestCaptor = ArgumentCaptor .forClass (MLDeployModelRequest .class );
1346
+ machineLearningNodeClient .deploy (modelId , tenantId , deployModelActionListener );
1347
+
1348
+ verify (client ).execute (eq (MLDeployModelAction .INSTANCE ), requestCaptor .capture (), any ());
1349
+ assertEquals (modelId , requestCaptor .getValue ().getModelId ());
1350
+ assertEquals (tenantId , requestCaptor .getValue ().getTenantId ());
1351
+ }
1352
+
1353
+ @ Test
1354
+ public void trainAndPredict_withNullInput () {
1355
+ exceptionRule .expect (IllegalArgumentException .class );
1356
+ exceptionRule .expectMessage ("ML Input can't be null" );
1357
+
1358
+ machineLearningNodeClient .trainAndPredict (null , trainingActionListener );
1359
+ }
1360
+
1361
+ @ Test
1362
+ public void trainAndPredict_withNullDataSet () {
1363
+ exceptionRule .expect (IllegalArgumentException .class );
1364
+ exceptionRule .expectMessage ("input data set can't be null" );
1365
+
1366
+ MLInput mlInput = MLInput .builder ().algorithm (FunctionName .KMEANS ).build ();
1367
+ machineLearningNodeClient .trainAndPredict (mlInput , trainingActionListener );
1368
+ }
1369
+
1370
+ @ Test
1371
+ public void getTask_withTaskIdAndTenantId () {
1372
+ String taskId = "taskId" ;
1373
+ String tenantId = "testTenant" ;
1374
+ String modelId = "modelId" ;
1375
+
1376
+ doAnswer (invocation -> {
1377
+ ActionListener <MLTaskGetResponse > actionListener = invocation .getArgument (2 );
1378
+ MLTask mlTask = MLTask .builder ().taskId (taskId ).modelId (modelId ).functionName (FunctionName .KMEANS ).build ();
1379
+ MLTaskGetResponse output = MLTaskGetResponse .builder ().mlTask (mlTask ).build ();
1380
+ actionListener .onResponse (output );
1381
+ return null ;
1382
+ }).when (client ).execute (eq (MLTaskGetAction .INSTANCE ), any (), any ());
1383
+
1384
+ ArgumentCaptor <MLTaskGetRequest > requestCaptor = ArgumentCaptor .forClass (MLTaskGetRequest .class );
1385
+ ArgumentCaptor <MLTask > taskCaptor = ArgumentCaptor .forClass (MLTask .class );
1386
+
1387
+ machineLearningNodeClient .getTask (taskId , tenantId , getTaskActionListener );
1388
+
1389
+ verify (client ).execute (eq (MLTaskGetAction .INSTANCE ), requestCaptor .capture (), any ());
1390
+ verify (getTaskActionListener ).onResponse (taskCaptor .capture ());
1391
+
1392
+ // Verify request parameters
1393
+ assertEquals (taskId , requestCaptor .getValue ().getTaskId ());
1394
+ assertEquals (tenantId , requestCaptor .getValue ().getTenantId ());
1395
+
1396
+ // Verify response
1397
+ assertEquals (taskId , taskCaptor .getValue ().getTaskId ());
1398
+ assertEquals (modelId , taskCaptor .getValue ().getModelId ());
1399
+ assertEquals (FunctionName .KMEANS , taskCaptor .getValue ().getFunctionName ());
1400
+ }
1401
+
1402
+ @ Test
1403
+ public void deleteTask_withTaskId () {
1404
+ String taskId = "taskId" ;
1405
+
1406
+ doAnswer (invocation -> {
1407
+ ActionListener <DeleteResponse > actionListener = invocation .getArgument (2 );
1408
+ ShardId shardId = new ShardId (new Index ("indexName" , "uuid" ), 1 );
1409
+ DeleteResponse output = new DeleteResponse (shardId , taskId , 1 , 1 , 1 , true );
1410
+ actionListener .onResponse (output );
1411
+ return null ;
1412
+ }).when (client ).execute (eq (MLTaskDeleteAction .INSTANCE ), any (), any ());
1413
+
1414
+ ArgumentCaptor <MLTaskDeleteRequest > requestCaptor = ArgumentCaptor .forClass (MLTaskDeleteRequest .class );
1415
+ ArgumentCaptor <DeleteResponse > responseCaptor = ArgumentCaptor .forClass (DeleteResponse .class );
1416
+
1417
+ machineLearningNodeClient .deleteTask (taskId , deleteTaskActionListener );
1418
+
1419
+ verify (client ).execute (eq (MLTaskDeleteAction .INSTANCE ), requestCaptor .capture (), any ());
1420
+ verify (deleteTaskActionListener ).onResponse (responseCaptor .capture ());
1421
+
1422
+ // Verify request parameter
1423
+ assertEquals (taskId , requestCaptor .getValue ().getTaskId ());
1424
+
1425
+ // Verify response
1426
+ assertEquals (taskId , responseCaptor .getValue ().getId ());
1427
+ assertEquals ("DELETED" , responseCaptor .getValue ().getResult ().toString ());
1428
+ }
1429
+
1430
+ @ Test
1431
+ public void deleteTask_withFailure () {
1432
+ String taskId = "taskId" ;
1433
+ String errorMessage = "Task deletion failed" ;
1434
+
1435
+ doAnswer (invocation -> {
1436
+ ActionListener <DeleteResponse > actionListener = invocation .getArgument (2 );
1437
+ actionListener .onFailure (new RuntimeException (errorMessage ));
1438
+ return null ;
1439
+ }).when (client ).execute (eq (MLTaskDeleteAction .INSTANCE ), any (), any ());
1440
+
1441
+ ArgumentCaptor <Exception > exceptionCaptor = ArgumentCaptor .forClass (Exception .class );
1442
+
1443
+ machineLearningNodeClient .deleteTask (taskId , new ActionListener <>() {
1444
+ @ Override
1445
+ public void onResponse (DeleteResponse deleteResponse ) {
1446
+ fail ("Expected failure but got success" );
1447
+ }
1448
+
1449
+ @ Override
1450
+ public void onFailure (Exception e ) {
1451
+ assertEquals (errorMessage , e .getMessage ());
1452
+ }
1453
+ });
1454
+
1455
+ verify (client ).execute (eq (MLTaskDeleteAction .INSTANCE ), isA (MLTaskDeleteRequest .class ), any ());
1456
+ }
1457
+
1279
1458
private SearchResponse createSearchResponse (ToXContentObject o ) throws IOException {
1280
1459
XContentBuilder content = o .toXContent (XContentFactory .jsonBuilder (), ToXContent .EMPTY_PARAMS );
1281
1460
0 commit comments