Skip to content

Commit 28d8c3a

Browse files
authored
Replace null GetResponse with valid response and not exists (#3759)
* Replace null GetResponse with valid response and not exists Signed-off-by: Daniel Widdis <widdis@gmail.com> * Remove unnecessary null check impacting coverage check Signed-off-by: Daniel Widdis <widdis@gmail.com> * Add tests to increase coverage Signed-off-by: Daniel Widdis <widdis@gmail.com> * Test IndexNotFound branch Signed-off-by: Daniel Widdis <widdis@gmail.com> --------- Signed-off-by: Daniel Widdis <widdis@gmail.com>
1 parent 088c1a5 commit 28d8c3a

File tree

10 files changed

+195
-72
lines changed

10 files changed

+195
-72
lines changed

ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import static org.mockito.Mockito.doThrow;
66
import static org.mockito.Mockito.mock;
77
import static org.mockito.Mockito.when;
8+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM;
9+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO;
810
import static org.opensearch.ml.common.CommonValue.CREATE_TIME_FIELD;
911
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
1012
import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX;
@@ -160,9 +162,10 @@ public void encrypt_NonExistingMasterKey() {
160162
}).when(mlIndicesHandler).initMLConfigIndex(any());
161163
IndexResponse indexResponse = prepareIndexResponse();
162164

165+
GetResponse getResponse = prepareNotExistsGetResponse();
163166
doAnswer(invocation -> {
164167
ActionListener<GetResponse> actionListener = (ActionListener) invocation.getArgument(1);
165-
actionListener.onResponse(null);
168+
actionListener.onResponse(getResponse);
166169
return null;
167170
}).when(client).get(any(), any());
168171

@@ -191,7 +194,8 @@ public void encrypt_NonExistingMasterKey_FailedToCreateNewKey() {
191194
}).when(mlIndicesHandler).initMLConfigIndex(any());
192195
doAnswer(invocation -> {
193196
ActionListener<GetResponse> actionListener = (ActionListener) invocation.getArgument(1);
194-
actionListener.onResponse(null);
197+
GetResponse getResponse = prepareNotExistsGetResponse();
198+
actionListener.onResponse(getResponse);
195199
return null;
196200
}).when(client).get(any(), any());
197201
doAnswer(invocation -> {
@@ -216,7 +220,8 @@ public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_NonRuntimeExceptio
216220
}).when(mlIndicesHandler).initMLConfigIndex(any());
217221
doAnswer(invocation -> {
218222
ActionListener<GetResponse> actionListener = (ActionListener) invocation.getArgument(1);
219-
actionListener.onResponse(null);
223+
GetResponse getResponse = prepareNotExistsGetResponse();
224+
actionListener.onResponse(getResponse);
220225
return null;
221226
}).when(client).get(any(), any());
222227
doAnswer(invocation -> {
@@ -245,7 +250,8 @@ public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_VersionConflict()
245250
}).when(mlIndicesHandler).initMLConfigIndex(any());
246251
doAnswer(invocation -> {
247252
ActionListener<GetResponse> actionListener = (ActionListener) invocation.getArgument(1);
248-
actionListener.onResponse(null);
253+
GetResponse getResponse = prepareNotExistsGetResponse();
254+
actionListener.onResponse(getResponse);
249255
return null;
250256
}).doAnswer(invocation -> {
251257
ActionListener<GetResponse> actionListener = (ActionListener) invocation.getArgument(1);
@@ -500,7 +506,8 @@ public void encrypt_SdkClientPutDataObjectFailure() {
500506

501507
doAnswer(invocation -> {
502508
ActionListener<GetResponse> listener = invocation.getArgument(1);
503-
listener.onResponse(null);
509+
GetResponse getResponse = prepareNotExistsGetResponse();
510+
listener.onResponse(getResponse);
504511
return null;
505512
}).when(client).get(any(), any());
506513

@@ -777,4 +784,20 @@ private IndexResponse prepareIndexResponse() {
777784
ShardId shardId = new ShardId(ML_CONFIG_INDEX, "index_uuid", 0);
778785
return new IndexResponse(shardId, MASTER_KEY, 1L, 1L, 1L, true);
779786
}
787+
788+
// Helper method to prepare a valid GetResponse
789+
private GetResponse prepareNotExistsGetResponse() {
790+
GetResult getResult = new GetResult(
791+
ML_CONFIG_INDEX,
792+
"fake_id",
793+
UNASSIGNED_SEQ_NO,
794+
UNASSIGNED_PRIMARY_TERM,
795+
-1L,
796+
false,
797+
null,
798+
null,
799+
null
800+
);
801+
return new GetResponse(getResult);
802+
}
780803
}

plugin/src/main/java/org/opensearch/ml/action/tasks/DeleteTaskTransportAction.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,9 @@ private void processGetDataObjectResponse(
125125
ActionListener<DeleteResponse> actionListener
126126
) {
127127
try {
128-
GetResponse getResponse = getDataObjectResponse.parser() == null
129-
? null
130-
: GetResponse.fromXContent(getDataObjectResponse.parser());
128+
GetResponse getResponse = getDataObjectResponse.getResponse();
131129

132-
if (getResponse == null || !getResponse.isExists()) {
130+
if (!getResponse.isExists()) {
133131
actionListener.onFailure(new OpenSearchStatusException("Failed to find task", RestStatus.NOT_FOUND));
134132
return;
135133
}

plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
import static org.mockito.ArgumentMatchers.any;
88
import static org.mockito.Mockito.*;
99
import static org.mockito.Mockito.verify;
10+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM;
11+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO;
12+
import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX;
1013

1114
import java.io.IOException;
1215
import java.time.Instant;
@@ -205,14 +208,25 @@ public void testDoExecute_RuntimeException() {
205208
}
206209

207210
@Test
208-
public void testGetTask_NullResponse() {
211+
public void testGetTask_NotFoundResponse() {
209212
String agentId = "test-agent-id-NullResponse";
210213
Task task = mock(Task.class);
211214
ActionListener<MLAgentGetResponse> actionListener = mock(ActionListener.class);
212215
MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId, true, null);
216+
GetResult getResult = new GetResult(
217+
ML_AGENT_INDEX,
218+
"fake_id",
219+
UNASSIGNED_SEQ_NO,
220+
UNASSIGNED_PRIMARY_TERM,
221+
-1L,
222+
false,
223+
null,
224+
null,
225+
null
226+
);
213227
doAnswer(invocation -> {
214228
ActionListener<GetResponse> listener = invocation.getArgument(1);
215-
listener.onResponse(null);
229+
listener.onResponse(new GetResponse(getResult));
216230
return null;
217231
}).when(client).get(any(), any());
218232
getAgentTransportAction.doExecute(task, getRequest, actionListener);

plugin/src/test/java/org/opensearch/ml/action/model_group/GetModelGroupTransportActionTests.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
import static org.mockito.Mockito.spy;
1111
import static org.mockito.Mockito.verify;
1212
import static org.mockito.Mockito.when;
13+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM;
14+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO;
15+
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
1316

1417
import java.io.IOException;
1518
import java.util.Collections;
@@ -167,10 +170,21 @@ public void testGetModel_ValidateAccessFailed() throws IOException {
167170
assertEquals("Failed to validate access", argumentCaptor.getValue().getMessage());
168171
}
169172

170-
public void testGetModel_NullResponse() {
173+
public void testGetModel_NotExistsResponse() {
174+
GetResult getResult = new GetResult(
175+
ML_MODEL_GROUP_INDEX,
176+
"fake_id",
177+
UNASSIGNED_SEQ_NO,
178+
UNASSIGNED_PRIMARY_TERM,
179+
-1L,
180+
false,
181+
null,
182+
null,
183+
null
184+
);
171185
doAnswer(invocation -> {
172186
ActionListener<GetResponse> listener = invocation.getArgument(1);
173-
listener.onResponse(null);
187+
listener.onResponse(new GetResponse(getResult));
174188
return null;
175189
}).when(client).get(any(), any());
176190
getModelGroupTransportAction.doExecute(null, mlModelGroupGetRequest, actionListener);

plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
import static org.mockito.Mockito.spy;
1212
import static org.mockito.Mockito.verify;
1313
import static org.mockito.Mockito.when;
14+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM;
15+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO;
16+
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
1417

1518
import java.io.IOException;
1619
import java.util.Collections;
@@ -212,10 +215,21 @@ public void testGetModel_ValidateAccessFailed() throws IOException, InterruptedE
212215
assertEquals("Failed to validate access", argumentCaptor.getValue().getMessage());
213216
}
214217

215-
public void testGetModel_NullResponse() {
218+
public void testGetModel_NotExistsResponse() {
219+
GetResult getResult = new GetResult(
220+
ML_MODEL_INDEX,
221+
"fake_id",
222+
UNASSIGNED_SEQ_NO,
223+
UNASSIGNED_PRIMARY_TERM,
224+
-1L,
225+
false,
226+
null,
227+
null,
228+
null
229+
);
216230
doAnswer(invocation -> {
217231
ActionListener<GetResponse> listener = invocation.getArgument(1);
218-
listener.onResponse(null);
232+
listener.onResponse(new GetResponse(getResult));
219233
return null;
220234
}).when(client).get(any(), any());
221235

plugin/src/test/java/org/opensearch/ml/action/tasks/DeleteTaskTransportActionTests.java

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import static org.mockito.Mockito.verify;
1212
import static org.mockito.Mockito.when;
1313
import static org.opensearch.action.DocWriteResponse.Result.DELETED;
14+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM;
15+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO;
1416
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;
1517

1618
import java.io.IOException;
@@ -34,6 +36,7 @@
3436
import org.opensearch.core.xcontent.NamedXContentRegistry;
3537
import org.opensearch.core.xcontent.ToXContent;
3638
import org.opensearch.core.xcontent.XContentBuilder;
39+
import org.opensearch.index.IndexNotFoundException;
3740
import org.opensearch.index.get.GetResult;
3841
import org.opensearch.ml.common.MLTask;
3942
import org.opensearch.ml.common.MLTaskState;
@@ -144,10 +147,35 @@ public void testDeleteTask_ResourceNotFoundException() throws IOException {
144147
assertEquals("Failed to get data object from index .plugins-ml-task", argumentCaptor.getValue().getMessage());
145148
}
146149

147-
public void testDeleteTask_GetResponseNullException() {
150+
public void testDeleteTask_IndexNotFoundException() {
148151
doAnswer(invocation -> {
149152
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
150-
actionListener.onResponse(null);
153+
actionListener.onFailure(new IndexNotFoundException(ML_TASK_INDEX));
154+
return null;
155+
}).when(client).get(any(), any());
156+
157+
deleteTaskTransportAction.doExecute(null, mlTaskDeleteRequest, actionListener);
158+
159+
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
160+
verify(actionListener).onFailure(argumentCaptor.capture());
161+
assertEquals("Failed to find task", argumentCaptor.getValue().getMessage());
162+
}
163+
164+
public void testDeleteTask_GetResponseNotExistsException() {
165+
GetResult getResult = new GetResult(
166+
ML_TASK_INDEX,
167+
"fake_id",
168+
UNASSIGNED_SEQ_NO,
169+
UNASSIGNED_PRIMARY_TERM,
170+
-1L,
171+
false,
172+
null,
173+
null,
174+
null
175+
);
176+
doAnswer(invocation -> {
177+
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
178+
actionListener.onResponse(new GetResponse(getResult));
151179
return null;
152180
}).when(client).get(any(), any());
153181

@@ -185,6 +213,51 @@ public void testDeleteTask_ThreadContextError() {
185213
assertEquals("thread context error", argumentCaptor.getValue().getMessage());
186214
}
187215

216+
public void testDeleteTask_NullTenantValidation() {
217+
when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true);
218+
MLTaskDeleteRequest request = MLTaskDeleteRequest.builder()
219+
.taskId("test_id")
220+
.tenantId(null)
221+
.build();
222+
223+
deleteTaskTransportAction.doExecute(null, request, actionListener);
224+
225+
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
226+
verify(actionListener).onFailure(argumentCaptor.capture());
227+
assertEquals("You don't have permission to access this resource", argumentCaptor.getValue().getMessage());
228+
}
229+
230+
public void testDeleteTask_TenantMismatch() throws IOException {
231+
when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true);
232+
MLTaskDeleteRequest request = MLTaskDeleteRequest.builder()
233+
.taskId("test_id")
234+
.tenantId("tenant1")
235+
.build();
236+
237+
MLTask mlTask = MLTask.builder()
238+
.taskId("taskID")
239+
.state(MLTaskState.COMPLETED)
240+
.tenantId("tenant2")
241+
.build();
242+
243+
XContentBuilder content = mlTask.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
244+
BytesReference bytesReference = BytesReference.bytes(content);
245+
GetResult getResult = new GetResult("indexName", "111", 111L, 111L, 111L, true, bytesReference, null, null);
246+
GetResponse getResponse = new GetResponse(getResult);
247+
248+
doAnswer(invocation -> {
249+
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
250+
actionListener.onResponse(getResponse);
251+
return null;
252+
}).when(client).get(any(), any());
253+
254+
deleteTaskTransportAction.doExecute(null, request, actionListener);
255+
256+
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
257+
verify(actionListener).onFailure(argumentCaptor.capture());
258+
assertEquals("You don't have permission to access this resource", argumentCaptor.getValue().getMessage());
259+
}
260+
188261
public GetResponse prepareMLTask(MLTaskState mlTaskState) throws IOException {
189262
MLTask mlTask = MLTask.builder().taskId("taskID").state(mlTaskState).build();
190263
XContentBuilder content = mlTask.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);

plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
import static org.mockito.Mockito.spy;
1616
import static org.mockito.Mockito.verify;
1717
import static org.mockito.Mockito.when;
18+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM;
19+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO;
20+
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;
1821
import static org.opensearch.ml.common.connector.AbstractConnector.*;
1922
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX;
2023
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX;
@@ -255,9 +258,20 @@ public void setup() throws IOException {
255258
}
256259

257260
public void testGetTask_NullResponse() {
261+
GetResult getResult = new GetResult(
262+
ML_TASK_INDEX,
263+
"fake_id",
264+
UNASSIGNED_SEQ_NO,
265+
UNASSIGNED_PRIMARY_TERM,
266+
-1L,
267+
false,
268+
null,
269+
null,
270+
null
271+
);
258272
doAnswer(invocation -> {
259273
ActionListener<GetResponse> listener = invocation.getArgument(1);
260-
listener.onResponse(null);
274+
listener.onResponse(new GetResponse(getResult));
261275
return null;
262276
}).when(client).get(any(), any());
263277
getTaskTransportAction.doExecute(null, mlTaskGetRequest, actionListener);

plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import static org.mockito.Mockito.doAnswer;
1111
import static org.mockito.Mockito.verify;
1212
import static org.mockito.Mockito.when;
13+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM;
14+
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO;
1315
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
1416

1517
import java.io.IOException;
@@ -398,9 +400,20 @@ public void test_OtherExceptionGetModelGroup() throws IOException {
398400
}
399401

400402
public void test_NotFoundGetModelGroup() throws IOException {
403+
GetResult getResult = new GetResult(
404+
ML_MODEL_GROUP_INDEX,
405+
"fake_id",
406+
UNASSIGNED_SEQ_NO,
407+
UNASSIGNED_PRIMARY_TERM,
408+
-1L,
409+
false,
410+
null,
411+
null,
412+
null
413+
);
401414
doAnswer(invocation -> {
402415
ActionListener<GetResponse> listener = invocation.getArgument(1);
403-
listener.onResponse(null);
416+
listener.onResponse(new GetResponse(getResult));
404417
return null;
405418
}).when(client).get(any(GetRequest.class), isA(ActionListener.class));
406419

0 commit comments

Comments
 (0)