Skip to content

Commit f01de7f

Browse files
authored
excluding circuit breaker for Agent (#3814)
Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent 887c90a commit f01de7f

File tree

4 files changed

+30
-4
lines changed

4 files changed

+30
-4
lines changed

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -975,7 +975,9 @@ private <T> ThreadedActionListener<T> threadedActionListener(String threadPoolNa
975975
* @param runningTaskLimit limit
976976
*/
977977
public void checkAndAddRunningTask(MLTask mlTask, Integer runningTaskLimit) {
978-
if (Objects.nonNull(mlTask) && mlTask.getFunctionName() != FunctionName.REMOTE) {
978+
979+
// for agent and remote model prediction we don't need to check circuit breaker
980+
if (Objects.nonNull(mlTask) && mlTask.getFunctionName() != FunctionName.REMOTE && mlTask.getFunctionName() != FunctionName.AGENT) {
979981
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
980982
}
981983
mlTaskManager.checkLimitAndAddRunningTask(mlTask, runningTaskLimit);

plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,7 @@ public void dispatchTask(
112112
if (clusterService.localNode().getId().equals(nodeId)) {
113113
// Execute ML task locally
114114
log.debug("Execute ML request {} locally on node {}", request.getRequestID(), nodeId);
115-
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
116-
executeTask(request, listener);
115+
checkCBAndExecute(functionName, request, listener);
117116
} else {
118117
// Execute ML task remotely
119118
log.debug("Execute ML request {} remotely on node {}", request.getRequestID(), nodeId);
@@ -130,7 +129,8 @@ public void dispatchTask(
130129
protected abstract void executeTask(Request request, ActionListener<Response> listener);
131130

132131
protected void checkCBAndExecute(FunctionName functionName, Request request, ActionListener<Response> listener) {
133-
if (functionName != FunctionName.REMOTE) {
132+
// for agent and remote model prediction we don't need to check circuit breaker
133+
if (functionName != FunctionName.REMOTE && functionName != FunctionName.AGENT) {
134134
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
135135
}
136136
executeTask(request, listener);

plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,18 @@ public void testRegisterMLModel_CircuitBreakerOpen() {
361361
verify(mlTaskManager).updateMLTask(anyString(), any(), anyMap(), anyLong(), anyBoolean());
362362
}
363363

364+
public void testRegisterMLModel_CircuitBreakerNotOpenForAgent() {
365+
registerModelInput.setFunctionName(FunctionName.AGENT);
366+
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
367+
when(mlCircuitBreakerService.checkOpenCB()).thenReturn(thresholdCircuitBreaker);
368+
when(thresholdCircuitBreaker.getName()).thenReturn("Disk Circuit Breaker");
369+
when(thresholdCircuitBreaker.getThreshold()).thenReturn(87);
370+
expectedEx.expect(CircuitBreakingException.class);
371+
expectedEx.expectMessage("Disk Circuit Breaker is open, please check your resources!");
372+
modelManager.registerMLModel(registerModelInput, mlTask);
373+
verify(mlTaskManager).updateMLTask(anyString(), any(), anyMap(), anyLong(), anyBoolean());
374+
}
375+
364376
public void testRegisterMLModel_InitModelIndexFailure() {
365377
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
366378
when(mlCircuitBreakerService.checkOpenCB()).thenReturn(null);

plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,4 +155,16 @@ public void testRun_NoCircuitbreakerforRemote() {
155155
Long value = (Long) mlStats.getStat(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT).getValue();
156156
assertEquals(0L, value.longValue());
157157
}
158+
159+
public void testRun_NoCircuitbreakerforAgent() {
160+
when(mlCircuitBreakerService.checkOpenCB()).thenReturn(thresholdCircuitBreaker);
161+
when(thresholdCircuitBreaker.getName()).thenReturn("Memory Circuit Breaker");
162+
when(thresholdCircuitBreaker.getThreshold()).thenReturn(87);
163+
TransportService transportService = mock(TransportService.class);
164+
ActionListener listener = mock(ActionListener.class);
165+
MLTaskRequest request = new MLTaskRequest(false);
166+
mlTaskRunner.run(FunctionName.AGENT, request, transportService, listener);
167+
Long value = (Long) mlStats.getStat(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT).getValue();
168+
assertEquals(0L, value.longValue());
169+
}
158170
}

0 commit comments

Comments
 (0)