Skip to content

Commit 70391fc

Browse files
authored
fixing the circuit breaker issue for remote model (#3652)
Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent 368f593 commit 70391fc

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ protected void handleAsyncMLTaskComplete(MLTask mlTask) {
8787
public void run(FunctionName functionName, Request request, TransportService transportService, ActionListener<Response> listener) {
8888
if (!request.isDispatchTask()) {
8989
log.debug("Run ML request {} locally", request.getRequestID());
90-
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
9190
checkCBAndExecute(functionName, request, listener);
9291
return;
9392
}

plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,20 @@ public void testRun_CircuitBreakerOpen() {
139139
TransportService transportService = mock(TransportService.class);
140140
ActionListener listener = mock(ActionListener.class);
141141
MLTaskRequest request = new MLTaskRequest(false);
142-
expectThrows(CircuitBreakingException.class, () -> mlTaskRunner.run(FunctionName.REMOTE, request, transportService, listener));
142+
expectThrows(CircuitBreakingException.class, () -> mlTaskRunner.run(FunctionName.BATCH_RCF, request, transportService, listener));
143143
Long value = (Long) mlStats.getStat(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT).getValue();
144144
assertEquals(1L, value.longValue());
145145
}
146+
147+
public void testRun_NoCircuitbreakerforRemote() {
148+
when(mlCircuitBreakerService.checkOpenCB()).thenReturn(thresholdCircuitBreaker);
149+
when(thresholdCircuitBreaker.getName()).thenReturn("Memory Circuit Breaker");
150+
when(thresholdCircuitBreaker.getThreshold()).thenReturn(87);
151+
TransportService transportService = mock(TransportService.class);
152+
ActionListener listener = mock(ActionListener.class);
153+
MLTaskRequest request = new MLTaskRequest(false);
154+
mlTaskRunner.run(FunctionName.REMOTE, request, transportService, listener);
155+
Long value = (Long) mlStats.getStat(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT).getValue();
156+
assertEquals(0L, value.longValue());
157+
}
146158
}

0 commit comments

Comments
 (0)