Skip to content

Commit c628c9b

Browse files
committed
wip ai translate
1 parent f516c6c commit c628c9b

File tree

5 files changed

+239
-28
lines changed

5 files changed

+239
-28
lines changed

cli/src/main/java/com/box/l10n/mojito/cli/command/RepositoryAiTranslationCommand.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ public class RepositoryAiTranslationCommand extends Command {
4343
@Parameter(
4444
names = {Param.REPOSITORY_LOCALES_LONG, Param.REPOSITORY_LOCALES_SHORT},
4545
variableArity = true,
46-
required = true,
47-
description = "List of locales (bcp47 tags) to machine translate")
46+
description = "List of locales (bcp47 tags) to translate, if not provided translate all locales in the repository")
4847
List<String> locales;
4948

5049
@Parameter(
@@ -55,6 +54,13 @@ public class RepositoryAiTranslationCommand extends Command {
5554
+ "sending too many strings to MT)")
5655
int sourceTextMaxCount = 100;
5756

57+
@Parameter(
58+
names = {"--use-batch"},
59+
arity = 1,
60+
description =
61+
"To use the batch API or not")
62+
boolean useBatch = false;
63+
5864
@Autowired CommandHelper commandHelper;
5965

6066
@Autowired RepositoryAiTranslateClient repositoryAiTranslateClient;
@@ -75,13 +81,13 @@ public void execute() throws CommandException {
7581
.reset()
7682
.a(" for locales: ")
7783
.fg(Color.CYAN)
78-
.a(locales.stream().collect(Collectors.joining(", ", "[", "]")))
84+
.a(locales == null ? "<all>" : locales.stream().collect(Collectors.joining(", ", "[", "]")))
7985
.println(2);
8086

8187
ProtoAiTranslateResponse protoAiTranslateResponse =
8288
repositoryAiTranslateClient.translateRepository(
8389
new RepositoryAiTranslateClient.ProtoAiTranslateRequest(
84-
repositoryParam, locales, sourceTextMaxCount));
90+
repositoryParam, locales, sourceTextMaxCount, useBatch));
8591

8692
PollableTask pollableTask = protoAiTranslateResponse.pollableTask();
8793
commandHelper.waitForPollableTask(pollableTask.getId());

restclient/src/main/java/com/box/l10n/mojito/rest/client/RepositoryAiTranslateClient.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public ProtoAiTranslateResponse translateRepository(
2929
}
3030

3131
public record ProtoAiTranslateRequest(
32-
String repositoryName, List<String> targetBcp47tags, int sourceTextMaxCountPerLocale) {}
32+
String repositoryName, List<String> targetBcp47tags, int sourceTextMaxCountPerLocale, boolean useBatch) {}
3333

3434
public record ProtoAiTranslateResponse(PollableTask pollableTask) {}
3535
}

webapp/src/main/java/com/box/l10n/mojito/rest/textunit/AiTranslateWS.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,18 @@ public ProtoAiTranslateResponse aiTranslate(
4242
new AiTranslateInput(
4343
protoAiTranslateRequest.repositoryName(),
4444
protoAiTranslateRequest.targetBcp47tags(),
45-
protoAiTranslateRequest.sourceTextMaxCountPerLocale()));
45+
protoAiTranslateRequest.sourceTextMaxCountPerLocale(),
46+
protoAiTranslateRequest.useBatch()));
4647

4748
return new ProtoAiTranslateResponse(pollableFuture.getPollableTask());
4849
}
4950

5051
public record ProtoAiTranslateRequest(
51-
String repositoryName, List<String> targetBcp47tags, int sourceTextMaxCountPerLocale) {}
52+
String repositoryName,
53+
List<String> targetBcp47tags,
54+
int sourceTextMaxCountPerLocale,
55+
boolean useBatch,
56+
boolean allLocales) {}
5257

5358
public record ProtoAiTranslateResponse(PollableTask pollableTask) {}
5459
}

webapp/src/main/java/com/box/l10n/mojito/service/oaitranslate/AiTranslateService.java

Lines changed: 195 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import com.box.l10n.mojito.entity.RepositoryLocale;
2222
import com.box.l10n.mojito.json.ObjectMapper;
2323
import com.box.l10n.mojito.openai.OpenAIClient;
24+
import com.box.l10n.mojito.openai.OpenAIClient.ChatCompletionsResponse;
2425
import com.box.l10n.mojito.openai.OpenAIClient.CreateBatchResponse;
2526
import com.box.l10n.mojito.openai.OpenAIClient.RequestBatchFileLine;
2627
import com.box.l10n.mojito.quartz.QuartzJobInfo;
@@ -41,20 +42,29 @@
4142
import com.fasterxml.jackson.databind.SerializationFeature;
4243
import com.fasterxml.jackson.databind.node.ObjectNode;
4344
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
45+
import java.io.IOException;
46+
import java.time.Duration;
4447
import java.util.ArrayDeque;
4548
import java.util.List;
4649
import java.util.Map;
4750
import java.util.Objects;
4851
import java.util.Set;
4952
import java.util.UUID;
53+
import java.util.concurrent.CompletableFuture;
54+
import java.util.concurrent.CompletionException;
55+
import java.util.concurrent.TimeoutException;
5056
import java.util.function.Function;
5157
import java.util.stream.Collectors;
5258
import org.slf4j.Logger;
5359
import org.slf4j.LoggerFactory;
5460
import org.springframework.beans.factory.annotation.Autowired;
5561
import org.springframework.beans.factory.annotation.Qualifier;
5662
import org.springframework.stereotype.Service;
63+
import reactor.core.publisher.Flux;
5764
import reactor.core.publisher.Mono;
65+
import reactor.util.function.Tuple2;
66+
import reactor.util.function.Tuples;
67+
import reactor.util.retry.Retry;
5868
import reactor.util.retry.RetryBackoffSpec;
5969

6070
@Service
@@ -109,7 +119,10 @@ public AiTranslateService(
109119
}
110120

111121
public record AiTranslateInput(
112-
String repositoryName, List<String> targetBcp47tags, int sourceTextMaxCountPerLocale) {}
122+
String repositoryName,
123+
List<String> targetBcp47tags,
124+
int sourceTextMaxCountPerLocale,
125+
boolean useBatch) {}
113126

114127
public PollableFuture<Void> aiTranslateAsync(AiTranslateInput aiTranslateInput) {
115128

@@ -124,26 +137,168 @@ public PollableFuture<Void> aiTranslateAsync(AiTranslateInput aiTranslateInput)
124137
}
125138

126139
public void aiTranslate(AiTranslateInput aiTranslateInput) throws AiTranslateException {
140+
if (aiTranslateInput.useBatch()) {
141+
aiTranslateBatch(aiTranslateInput);
142+
} else {
143+
aiTranlsateNoBatch(aiTranslateInput);
144+
}
145+
}
127146

128-
Repository repository = repositoryRepository.findByName(aiTranslateInput.repositoryName());
147+
public void aiTranlsateNoBatch(AiTranslateInput aiTranslateInput) {
129148

130-
if (repository == null) {
131-
throw new RepositoryNameNotFoundException(
132-
String.format(
133-
"Repository with name '%s' can not be found!", aiTranslateInput.repositoryName()));
149+
Repository repository = getRepository(aiTranslateInput);
150+
151+
logger.debug("Start AI Translation (no batch) for repository: {}", repository.getName());
152+
153+
Set<RepositoryLocale> filteredRepositoryLocales =
154+
getFilteredRepositoryLocales(aiTranslateInput, repository);
155+
156+
filteredRepositoryLocales.forEach(
157+
repositoryLocale -> {
158+
asyncProcessLocale(repositoryLocale, aiTranslateInput.sourceTextMaxCountPerLocale());
159+
});
160+
}
161+
162+
void asyncProcessLocale(RepositoryLocale repositoryLocale, int sourceTextMaxCountPerLocale) {
163+
164+
Repository repository = repositoryLocale.getRepository();
165+
166+
logger.debug(
167+
"Get untranslated strings for locale: '{}' in repository: '{}'",
168+
repositoryLocale.getLocale().getBcp47Tag(),
169+
repository.getName());
170+
171+
TextUnitSearcherParameters textUnitSearcherParameters = new TextUnitSearcherParameters();
172+
textUnitSearcherParameters.setRepositoryIds(repository.getId());
173+
textUnitSearcherParameters.setStatusFilter(StatusFilter.FOR_TRANSLATION);
174+
textUnitSearcherParameters.setLocaleId(repositoryLocale.getLocale().getId());
175+
textUnitSearcherParameters.setLimit(sourceTextMaxCountPerLocale);
176+
177+
List<TextUnitDTO> textUnitDTOS = textUnitSearcher.search(textUnitSearcherParameters);
178+
179+
if (textUnitDTOS.isEmpty()) {
180+
logger.debug(
181+
"Nothing to translate for locale: {}", repositoryLocale.getLocale().getBcp47Tag());
182+
return;
134183
}
135184

185+
logger.info(
186+
"Starting parallel processing for each string in locale: {}, count: {}",
187+
repositoryLocale.getLocale().getBcp47Tag(),
188+
textUnitDTOS.size());
189+
190+
int maxConcurrency = 10;
191+
192+
Flux.fromIterable(textUnitDTOS)
193+
.flatMap(
194+
textUnitDTO ->
195+
getChatCompletionForTextUnitDTO(textUnitDTO)
196+
.retryWhen(
197+
Retry.backoff(5, Duration.ofSeconds(1))
198+
.filter(this::isRetriableException)
199+
.doBeforeRetry(
200+
retrySignal -> {
201+
logger.warn(
202+
"Retrying request for TextUnitDTO {} due to {}",
203+
textUnitDTO.getTmTextUnitId(),
204+
retrySignal.failure().getMessage());
205+
}))
206+
.onErrorResume(
207+
error -> {
208+
logger.error(
209+
"Request for TextUnitDTO {} failed after retries: {}",
210+
textUnitDTO.getTmTextUnitId(),
211+
error.getMessage());
212+
return Mono.empty();
213+
}),
214+
maxConcurrency)
215+
.collectList()
216+
.flatMap(this::processAggregatedResults)
217+
.block(); // Blocking here for simplicity; consider using async handling in real
218+
// applications
219+
220+
logger.info("Done submitting for processing");
221+
}
222+
223+
private Mono<Void> processAggregatedResults(
224+
List<Tuple2<TextUnitDTO, ChatCompletionsResponse>> results) {
225+
226+
// Process each result
227+
List<TextUnitDTO> forImport =
228+
results.stream()
229+
.map(
230+
tuple -> {
231+
TextUnitDTO textUnitDTO = tuple.getT1();
232+
ChatCompletionsResponse chatCompletionsResponse = tuple.getT2();
233+
234+
String completionOutputAsJson =
235+
chatCompletionsResponse.choices().getFirst().message().content();
236+
237+
CompletionOutput completionOutput =
238+
objectMapper.readValueUnchecked(
239+
completionOutputAsJson, CompletionOutput.class);
240+
241+
textUnitDTO.setTarget(completionOutput.target().content());
242+
textUnitDTO.setTargetComment("ai-translate");
243+
return textUnitDTO;
244+
})
245+
.collect(Collectors.toList());
246+
247+
// Import the translations
248+
textUnitBatchImporterService.importTextUnits(
249+
forImport,
250+
TextUnitBatchImporterService.IntegrityChecksType.ALWAYS_USE_INTEGRITY_CHECKER_STATUS);
251+
252+
return Mono.empty();
253+
}
254+
255+
private Mono<Tuple2<TextUnitDTO, ChatCompletionsResponse>> getChatCompletionForTextUnitDTO(
256+
TextUnitDTO textUnitDTO) {
257+
258+
CompletionInput completionInput =
259+
new CompletionInput(
260+
textUnitDTO.getTargetLocale(), textUnitDTO.getSource(), textUnitDTO.getComment());
261+
262+
String inputAsJsonString = objectMapper.writeValueAsStringUnchecked(completionInput);
263+
ObjectNode jsonSchema = createJsonSchema(CompletionOutput.class);
264+
265+
ChatCompletionsRequest chatCompletionsRequest =
266+
chatCompletionsRequest()
267+
.model("gpt-4o-2024-08-06")
268+
.maxTokens(16384)
269+
.messages(
270+
List.of(
271+
systemMessageBuilder().content(PROMPT).build(),
272+
userMessageBuilder().content(inputAsJsonString).build()))
273+
.responseFormat(
274+
new ChatCompletionsRequest.JsonFormat(
275+
"json_schema",
276+
new ChatCompletionsRequest.JsonFormat.JsonSchema(
277+
true, "request_json_format", jsonSchema)))
278+
.build();
279+
280+
CompletableFuture<ChatCompletionsResponse> futureResult =
281+
getOpenAIClient().getChatCompletions(chatCompletionsRequest);
282+
283+
Mono<ChatCompletionsResponse> resultMono = Mono.fromFuture(futureResult);
284+
285+
return resultMono.map(chatCompletionResult -> Tuples.of(textUnitDTO, chatCompletionResult));
286+
}
287+
288+
private boolean isRetriableException(Throwable throwable) {
289+
Throwable cause = throwable instanceof CompletionException ? throwable.getCause() : throwable;
290+
return cause instanceof IOException || cause instanceof TimeoutException;
291+
}
292+
293+
public void aiTranslateBatch(AiTranslateInput aiTranslateInput) throws AiTranslateException {
294+
295+
Repository repository = getRepository(aiTranslateInput);
296+
136297
logger.debug("Start AI Translation for repository: {}", repository.getName());
137298

138299
try {
139300
Set<RepositoryLocale> repositoryLocalesWithoutRootLocale =
140-
repositoryService.getRepositoryLocalesWithoutRootLocale(repository).stream()
141-
.filter(
142-
rl ->
143-
aiTranslateInput.targetBcp47tags == null
144-
|| aiTranslateInput.targetBcp47tags.contains(
145-
rl.getLocale().getBcp47Tag()))
146-
.collect(Collectors.toSet());
301+
getFilteredRepositoryLocales(aiTranslateInput, repository);
147302

148303
logger.debug("Create batches for repository: {}", repository.getName());
149304
ArrayDeque<CreateBatchResponse> batches =
@@ -167,6 +322,27 @@ public void aiTranslate(AiTranslateInput aiTranslateInput) throws AiTranslateExc
167322
}
168323
}
169324

325+
private Set<RepositoryLocale> getFilteredRepositoryLocales(
326+
AiTranslateInput aiTranslateInput, Repository repository) {
327+
return repositoryService.getRepositoryLocalesWithoutRootLocale(repository).stream()
328+
.filter(
329+
rl ->
330+
aiTranslateInput.targetBcp47tags == null
331+
|| aiTranslateInput.targetBcp47tags.contains(rl.getLocale().getBcp47Tag()))
332+
.collect(Collectors.toSet());
333+
}
334+
335+
private Repository getRepository(AiTranslateInput aiTranslateInput) {
336+
Repository repository = repositoryRepository.findByName(aiTranslateInput.repositoryName());
337+
338+
if (repository == null) {
339+
throw new RepositoryNameNotFoundException(
340+
String.format(
341+
"Repository with name '%s' can not be found!", aiTranslateInput.repositoryName()));
342+
}
343+
return repository;
344+
}
345+
170346
void importBatch(RetrieveBatchResponse retrieveBatchResponse) {
171347

172348
logger.info("Importing batch: {}", retrieveBatchResponse.id());
@@ -208,7 +384,7 @@ void importBatch(RetrieveBatchResponse retrieveBatchResponse) {
208384
"Response batch file line failed: " + chatCompletionResponseBatchFileLine);
209385
}
210386

211-
String aiTranslateOutputAsJson =
387+
String completionOutputAsJson =
212388
chatCompletionResponseBatchFileLine
213389
.response()
214390
.chatCompletionsResponse()
@@ -217,14 +393,14 @@ void importBatch(RetrieveBatchResponse retrieveBatchResponse) {
217393
.message()
218394
.content();
219395

220-
AiTranslateOutput aiTranslateOutput =
396+
CompletionOutput completionOutput =
221397
objectMapper.readValueUnchecked(
222-
aiTranslateOutputAsJson, AiTranslateOutput.class);
398+
completionOutputAsJson, CompletionOutput.class);
223399

224400
TextUnitDTO textUnitDTO =
225401
tmTextUnitIdToTextUnitDTOs.get(
226402
Long.valueOf(chatCompletionResponseBatchFileLine.customId()));
227-
textUnitDTO.setTarget(aiTranslateOutput.target().content());
403+
textUnitDTO.setTarget(completionOutput.target().content());
228404
textUnitDTO.setTargetComment("ai-translate");
229405
return textUnitDTO;
230406
})
@@ -301,7 +477,7 @@ String generateBatchFileContent(List<TextUnitDTO> textUnitDTOS) {
301477

302478
String inputAsJsonString = objectMapper.writeValueAsStringUnchecked(completionInput);
303479

304-
ObjectNode jsonSchema = createJsonSchema(AiTranslateOutput.class);
480+
ObjectNode jsonSchema = createJsonSchema(CompletionOutput.class);
305481

306482
ChatCompletionsRequest chatCompletionsRequest =
307483
chatCompletionsRequest()
@@ -378,7 +554,7 @@ RetrieveBatchResponse retrieveBatchWithRetry(CreateBatchResponse batch) {
378554

379555
record CompletionInput(String locale, String source, String sourceDescription) {}
380556

381-
record AiTranslateOutput(
557+
record CompletionOutput(
382558
String source,
383559
Target target,
384560
DescriptionRating descriptionRating,

0 commit comments

Comments
 (0)