Skip to content

Commit 1c2c122

Browse files
committed
Implement Batch Mode for Ai Review
Allow to specify the model to use in the CLI.
1 parent c1da88d commit 1c2c122

File tree

6 files changed

+364
-92
lines changed

6 files changed

+364
-92
lines changed

cli/src/main/java/com/box/l10n/mojito/cli/command/RepositoryAiReviewCommand.java

Lines changed: 176 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,19 @@
44
import com.beust.jcommander.Parameters;
55
import com.box.l10n.mojito.cli.command.param.Param;
66
import com.box.l10n.mojito.cli.console.ConsoleWriter;
7+
import com.box.l10n.mojito.json.ObjectMapper;
8+
import com.box.l10n.mojito.openai.OpenAIClient.RetrieveBatchResponse;
9+
import com.box.l10n.mojito.rest.client.PollableTaskClient;
710
import com.box.l10n.mojito.rest.client.RepositoryAiReviewClient;
11+
import com.box.l10n.mojito.rest.client.exception.PollableTaskException;
812
import com.box.l10n.mojito.rest.entity.PollableTask;
13+
import java.util.Comparator;
914
import java.util.List;
15+
import java.util.Map;
16+
import java.util.Optional;
17+
import java.util.concurrent.atomic.AtomicBoolean;
1018
import java.util.stream.Collectors;
19+
import org.fusesource.jansi.Ansi;
1120
import org.fusesource.jansi.Ansi.Color;
1221
import org.slf4j.Logger;
1322
import org.slf4j.LoggerFactory;
@@ -67,15 +76,26 @@ public class RepositoryAiReviewCommand extends Command {
6776
boolean useBatch = false;
6877

6978
@Parameter(
70-
names = {"--use-model"},
71-
arity = 1,
72-
description = "Use a specific model for the review")
79+
names = {"--use-model"},
80+
arity = 1,
81+
description = "Use a specific model for the review")
7382
String useModel;
7483

84+
@Parameter(
85+
names = "--attach-job-id",
86+
arity = 1,
87+
description =
88+
"ID of an existing job to re-attach to; the CLI will only poll its status and will not start any new work.")
89+
Long attachJobId;
90+
7591
@Autowired CommandHelper commandHelper;
7692

7793
@Autowired RepositoryAiReviewClient repositoryAiReviewClient;
7894

95+
@Autowired PollableTaskClient pollableTaskClient;
96+
97+
@Autowired ObjectMapper objectMapper;
98+
7999
@Override
80100
public boolean shouldShowInCommandList() {
81101
return false;
@@ -84,23 +104,160 @@ public boolean shouldShowInCommandList() {
84104
@Override
85105
public void execute() throws CommandException {
86106

107+
if (attachJobId == null) {
108+
consoleWriter
109+
.newLine()
110+
.a("Ai review repository: ")
111+
.fg(Color.CYAN)
112+
.a(repositoryParam)
113+
.reset()
114+
.a(" for locales: ")
115+
.fg(Color.CYAN)
116+
.a(
117+
locales == null
118+
? "<all>"
119+
: locales.stream().collect(Collectors.joining(", ", "[", "]")))
120+
.println(2);
121+
122+
RepositoryAiReviewClient.ProtoAiReviewResponse protoAiTranslateResponse =
123+
repositoryAiReviewClient.reviewRepository(
124+
new RepositoryAiReviewClient.ProtoAiReviewRequest(
125+
repositoryParam, locales, sourceTextMaxCount, textUnitIds, useBatch, useModel));
126+
127+
PollableTask pollableTask = protoAiTranslateResponse.pollableTask();
128+
consoleWriter.a("Running: ").fg(Color.MAGENTA).a(pollableTask.getId()).println();
129+
waitForPollable(pollableTask.getId());
130+
} else {
131+
consoleWriter.a("Attaching to: ").fg(Color.MAGENTA).a(attachJobId).println();
132+
waitForPollable(attachJobId);
133+
}
134+
135+
consoleWriter.fg(Ansi.Color.GREEN).newLine().a("Finished").println(2);
136+
}
137+
138+
void waitForPollable(Long pollableTaskId) {
139+
try {
140+
final AtomicBoolean firstRender = new AtomicBoolean(true);
141+
142+
pollableTaskClient.waitForPollableTask(
143+
pollableTaskId,
144+
PollableTaskClient.NO_TIMEOUT,
145+
pollableTask -> {
146+
Optional<PollableTask> lastFinishedForOutput =
147+
pollableTask.getSubTasks().stream()
148+
.filter(t -> t.getCreatedDate() != null)
149+
.sorted(Comparator.comparing(PollableTask::getCreatedDate).reversed())
150+
.filter(PollableTask::isAllFinished)
151+
.findFirst();
152+
153+
if (lastFinishedForOutput.isPresent()) {
154+
if (!firstRender.get()) {
155+
consoleWriter.erasePreviouslyPrintedLines();
156+
} else {
157+
firstRender.set(false);
158+
}
159+
160+
Long lastFinishedTaskId = lastFinishedForOutput.get().getId();
161+
consoleWriter.a("Checking: ").fg(Color.MAGENTA).a(lastFinishedTaskId).newLine();
162+
String pollableTaskOutput =
163+
pollableTaskClient.getPollableTaskOutput(lastFinishedTaskId);
164+
try {
165+
renderAiReviewBatchesImportOutput(
166+
objectMapper.readValueUnchecked(
167+
pollableTaskOutput, AiReviewBatchesImportOutput.class));
168+
} catch (Exception e) {
169+
consoleWriter.a(pollableTaskOutput).println();
170+
}
171+
}
172+
});
173+
174+
} catch (PollableTaskException e) {
175+
throw new CommandException(e.getMessage(), e.getCause());
176+
}
177+
}
178+
179+
void renderAiReviewBatchesImportOutput(AiReviewBatchesImportOutput aiReviewBatchesImportOutput) {
180+
aiReviewBatchesImportOutput
181+
.retrieveBatchResponses()
182+
.forEach(
183+
r ->
184+
renderBatch(
185+
r,
186+
aiReviewBatchesImportOutput.failedToImport.get(r.id()),
187+
aiReviewBatchesImportOutput.processed.contains(r.id())));
188+
consoleWriter.println();
189+
}
190+
191+
void renderBatch(
192+
RetrieveBatchResponse retrieveBatchResponse, String importError, boolean processed) {
193+
consoleWriter.a("- ").fg(Color.CYAN).a(retrieveBatchResponse.id()).a(" ");
194+
195+
consoleWriter.reset().a("[import: ");
196+
if (importError != null) {
197+
consoleWriter.fg(Color.RED).a("failed");
198+
} else {
199+
if (processed) {
200+
if ("completed".equals(retrieveBatchResponse.status())) {
201+
consoleWriter.fg(Color.GREEN).a("success");
202+
} else {
203+
consoleWriter.fg(Color.YELLOW).a(" - ");
204+
}
205+
} else {
206+
consoleWriter.fg(Color.YELLOW).a("waiting");
207+
}
208+
}
209+
consoleWriter.reset().a("]");
210+
211+
Color batchStatusColor =
212+
switch (retrieveBatchResponse.status()) {
213+
case "completed" -> Color.GREEN;
214+
case "failed" -> Color.RED;
215+
case "running", "queued", "in_progress" -> Color.YELLOW;
216+
default -> Color.DEFAULT;
217+
};
218+
219+
RetrieveBatchResponse.RequestCounts c = retrieveBatchResponse.requestCounts();
220+
87221
consoleWriter
88-
.newLine()
89-
.a("Ai review repository: ")
90-
.fg(Color.CYAN)
91-
.a(repositoryParam)
92222
.reset()
93-
.a(" for locales: ")
94-
.fg(Color.CYAN)
95-
.a(locales == null ? "<all>" : locales.stream().collect(Collectors.joining(", ", "[", "]")))
96-
.println(2);
97-
98-
RepositoryAiReviewClient.ProtoAiReviewResponse protoAiTranslateResponse =
99-
repositoryAiReviewClient.reviewRepository(
100-
new RepositoryAiReviewClient.ProtoAiReviewRequest(
101-
repositoryParam, locales, sourceTextMaxCount, textUnitIds, useBatch, useModel));
102-
103-
PollableTask pollableTask = protoAiTranslateResponse.pollableTask();
104-
commandHelper.waitForPollableTask(pollableTask.getId());
223+
.a(" [batch: ")
224+
.fg(batchStatusColor)
225+
.a(retrieveBatchResponse.status())
226+
.reset()
227+
.a(" ; total=")
228+
.a(c.total())
229+
.a(", completed=")
230+
.a(c.completed())
231+
.a(", failed=")
232+
.a(c.failed())
233+
.a("]")
234+
.newLine();
235+
236+
if (importError != null) {
237+
consoleWriter.fg(Color.RED).a("Import error: ").newLine().a(importError).newLine();
238+
}
239+
240+
if (retrieveBatchResponse.errors() != null
241+
&& retrieveBatchResponse.errors().data() != null
242+
&& !retrieveBatchResponse.errors().data().isEmpty()) {
243+
consoleWriter.fg(Color.RED).a(" Errors:").reset().newLine();
244+
retrieveBatchResponse
245+
.errors()
246+
.data()
247+
.forEach(
248+
e ->
249+
consoleWriter
250+
.a(" - ")
251+
.a(
252+
"[%s] %s (param=%s, line=%s)"
253+
.formatted(e.code(), e.message(), e.param(), e.line()))
254+
.newLine());
255+
}
105256
}
257+
258+
public record AiReviewBatchesImportOutput(
259+
List<RetrieveBatchResponse> retrieveBatchResponses,
260+
List<String> processed,
261+
Map<String, String> failedToImport,
262+
Long nextJob) {}
106263
}

common/src/main/java/com/box/l10n/mojito/openai/OpenAIClient.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -920,7 +920,7 @@ public record RetrieveBatchResponse(
920920
@JsonProperty("expired_at") Long expiredAt,
921921
@JsonProperty("request_counts") RequestCounts requestCounts,
922922
Map<String, String> metadata) {
923-
record RequestCounts(int total, int completed, int failed) {}
923+
public record RequestCounts(int total, int completed, int failed) {}
924924

925925
public record Errors(
926926
@JsonProperty("object") String objectType, // e.g. "list"

restclient/src/main/java/com/box/l10n/mojito/rest/client/RepositoryAiReviewClient.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ public ProtoAiReviewResponse reviewRepository(ProtoAiReviewRequest protoAiReview
2828
}
2929

3030
public record ProtoAiReviewRequest(
31-
String repositoryName,
32-
List<String> targetBcp47tags,
33-
int sourceTextMaxCountPerLocale,
34-
List<Long> tmTextUnitIds,
35-
boolean useBatch,
36-
String useModel) {}
31+
String repositoryName,
32+
List<String> targetBcp47tags,
33+
int sourceTextMaxCountPerLocale,
34+
List<Long> tmTextUnitIds,
35+
boolean useBatch,
36+
String useModel) {}
3737

3838
public record ProtoAiReviewResponse(PollableTask pollableTask) {}
3939
}
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
package com.box.l10n.mojito.service.oaireview;
2+
3+
import com.box.l10n.mojito.openai.OpenAIClient.CreateBatchResponse;
4+
import com.box.l10n.mojito.openai.OpenAIClient.RetrieveBatchResponse;
5+
import com.box.l10n.mojito.quartz.QuartzPollableJob;
6+
import com.box.l10n.mojito.service.pollableTask.PollableFuture;
7+
import com.box.l10n.mojito.service.pollableTask.PollableTaskService;
8+
import java.util.ArrayList;
9+
import java.util.HashMap;
10+
import java.util.HashSet;
11+
import java.util.List;
12+
import java.util.Map;
13+
import java.util.Set;
14+
import org.slf4j.Logger;
15+
import org.slf4j.LoggerFactory;
16+
import org.springframework.beans.factory.annotation.Autowired;
17+
import org.springframework.stereotype.Component;
18+
19+
@Component
20+
public class AiReviewBatchesImportJob
21+
extends QuartzPollableJob<
22+
AiReviewBatchesImportJob.AiReviewBatchesImportInput,
23+
AiReviewBatchesImportJob.AiReviewBatchesImportOutput> {
24+
25+
static Logger logger = LoggerFactory.getLogger(AiReviewBatchesImportJob.class);
26+
27+
@Autowired AiReviewService aiReviewService;
28+
@Autowired private PollableTaskService pollableTaskService;
29+
30+
public record AiReviewBatchesImportInput(
31+
List<CreateBatchResponse> createBatchResponses, List<String> processed, int attempt) {}
32+
33+
public record AiReviewBatchesImportOutput(
34+
List<RetrieveBatchResponse> retrieveBatchResponses,
35+
List<String> processed,
36+
Map<String, String> failedToImport,
37+
Long nextJob) {}
38+
39+
@Override
40+
public AiReviewBatchesImportOutput call(AiReviewBatchesImportInput aiReviewBatchesImportInput)
41+
throws Exception {
42+
43+
List<RetrieveBatchResponse> retrieveBatchResponses = new ArrayList<>();
44+
Set<String> processed = new HashSet<>(aiReviewBatchesImportInput.processed());
45+
Map<String, String> failedImport = new HashMap<>();
46+
47+
logger.info(
48+
"Batches already processed: {} and total: {}",
49+
processed.size(),
50+
aiReviewBatchesImportInput.createBatchResponses.size());
51+
52+
for (CreateBatchResponse createBatchResponse :
53+
aiReviewBatchesImportInput.createBatchResponses()) {
54+
55+
logger.debug(
56+
"Retrieve current status of batch, regardless if it was already processed: {}",
57+
createBatchResponse.id());
58+
RetrieveBatchResponse retrieveBatchResponse =
59+
aiReviewService.retrieveBatchWithRetry(createBatchResponse);
60+
retrieveBatchResponses.add(retrieveBatchResponse);
61+
62+
if (!processed.contains(createBatchResponse.id())) {
63+
if ("completed".equals(retrieveBatchResponse.status())) {
64+
logger.info("Completed batch: {}", retrieveBatchResponse.id());
65+
try {
66+
aiReviewService.importBatch(retrieveBatchResponse);
67+
} catch (Throwable t) {
68+
logger.error("Failed to import batch: {}, skip", createBatchResponse.id(), t);
69+
failedImport.put(createBatchResponse.id(), t.getMessage());
70+
processed.add(createBatchResponse.id());
71+
}
72+
processed.add(createBatchResponse.id());
73+
} else if ("failed".equals(retrieveBatchResponse.status())) {
74+
logger.error("Batch failed, skipping it: {}", retrieveBatchResponse);
75+
processed.add(createBatchResponse.id());
76+
} else if ("expired".equals(retrieveBatchResponse.status())) {
77+
logger.info("Batch expired, skipping it: {}", retrieveBatchResponse);
78+
processed.add(createBatchResponse.id());
79+
} else if ("cancelled".equals(retrieveBatchResponse.status())) {
80+
logger.info("Batch cancelled, skipping it: {}", retrieveBatchResponse);
81+
processed.add(createBatchResponse.id());
82+
} else {
83+
logger.info("Batch is still processing will process later: {}", retrieveBatchResponse);
84+
}
85+
}
86+
}
87+
88+
PollableFuture<AiReviewBatchesImportOutput> pollableFuture = null;
89+
90+
if (processed.size() >= aiReviewBatchesImportInput.createBatchResponses().size()) {
91+
logger.info(
92+
"Everything has been processed ({}/{}), don't reschedule",
93+
processed.size(),
94+
aiReviewBatchesImportInput.createBatchResponses().size());
95+
} else {
96+
logger.info("Schedule new job to process remaining batches, processed: {}", processed);
97+
pollableFuture =
98+
aiReviewService.aiReviewBatchesImportAsync(
99+
new AiReviewBatchesImportInput(
100+
aiReviewBatchesImportInput.createBatchResponses(),
101+
processed.stream().toList(),
102+
aiReviewBatchesImportInput.attempt() + 1),
103+
getCurrentPollableTask().getParentTask());
104+
logger.info(
105+
"New job created with pollableTask id: {}", pollableFuture.getPollableTask().getId());
106+
}
107+
108+
return new AiReviewBatchesImportOutput(
109+
retrieveBatchResponses,
110+
processed.stream().toList(),
111+
failedImport,
112+
pollableFuture != null ? pollableFuture.getPollableTask().getId() : null);
113+
}
114+
}

webapp/src/main/java/com/box/l10n/mojito/service/oaireview/AiReviewJob.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ public class AiReviewJob extends QuartzPollableJob<AiReviewInput, Void> {
1616

1717
@Override
1818
public Void call(AiReviewInput aiReviewJobInput) throws Exception {
19-
aiReviewService.aiReview(aiReviewJobInput);
19+
aiReviewService.aiReview(aiReviewJobInput, getCurrentPollableTask());
2020
return null;
2121
}
2222
}

0 commit comments

Comments
 (0)