2121import com .box .l10n .mojito .entity .RepositoryLocale ;
2222import com .box .l10n .mojito .json .ObjectMapper ;
2323import com .box .l10n .mojito .openai .OpenAIClient ;
24+ import com .box .l10n .mojito .openai .OpenAIClient .ChatCompletionsResponse ;
2425import com .box .l10n .mojito .openai .OpenAIClient .CreateBatchResponse ;
2526import com .box .l10n .mojito .openai .OpenAIClient .RequestBatchFileLine ;
2627import com .box .l10n .mojito .quartz .QuartzJobInfo ;
4142import com .fasterxml .jackson .databind .SerializationFeature ;
4243import com .fasterxml .jackson .databind .node .ObjectNode ;
4344import com .fasterxml .jackson .datatype .jsr310 .JavaTimeModule ;
45+ import java .io .IOException ;
46+ import java .time .Duration ;
4447import java .util .ArrayDeque ;
4548import java .util .List ;
4649import java .util .Map ;
4750import java .util .Objects ;
4851import java .util .Set ;
4952import java .util .UUID ;
53+ import java .util .concurrent .CompletableFuture ;
54+ import java .util .concurrent .CompletionException ;
55+ import java .util .concurrent .TimeoutException ;
5056import java .util .function .Function ;
5157import java .util .stream .Collectors ;
5258import org .slf4j .Logger ;
5359import org .slf4j .LoggerFactory ;
5460import org .springframework .beans .factory .annotation .Autowired ;
5561import org .springframework .beans .factory .annotation .Qualifier ;
5662import org .springframework .stereotype .Service ;
63+ import reactor .core .publisher .Flux ;
5764import reactor .core .publisher .Mono ;
65+ import reactor .util .function .Tuple2 ;
66+ import reactor .util .function .Tuples ;
67+ import reactor .util .retry .Retry ;
5868import reactor .util .retry .RetryBackoffSpec ;
5969
6070@ Service
@@ -109,7 +119,10 @@ public AiTranslateService(
109119 }
110120
111121 public record AiTranslateInput (
112- String repositoryName , List <String > targetBcp47tags , int sourceTextMaxCountPerLocale ) {}
122+ String repositoryName ,
123+ List <String > targetBcp47tags ,
124+ int sourceTextMaxCountPerLocale ,
125+ boolean useBatch ) {}
113126
114127 public PollableFuture <Void > aiTranslateAsync (AiTranslateInput aiTranslateInput ) {
115128
@@ -124,26 +137,168 @@ public PollableFuture<Void> aiTranslateAsync(AiTranslateInput aiTranslateInput)
124137 }
125138
126139 public void aiTranslate (AiTranslateInput aiTranslateInput ) throws AiTranslateException {
140+ if (aiTranslateInput .useBatch ()) {
141+ aiTranslateBatch (aiTranslateInput );
142+ } else {
143+ aiTranlsateNoBatch (aiTranslateInput );
144+ }
145+ }
127146
128- Repository repository = repositoryRepository . findByName ( aiTranslateInput . repositoryName ());
147+ public void aiTranlsateNoBatch ( AiTranslateInput aiTranslateInput ) {
129148
130- if (repository == null ) {
131- throw new RepositoryNameNotFoundException (
132- String .format (
133- "Repository with name '%s' can not be found!" , aiTranslateInput .repositoryName ()));
149+ Repository repository = getRepository (aiTranslateInput );
150+
151+ logger .debug ("Start AI Translation (no batch) for repository: {}" , repository .getName ());
152+
153+ Set <RepositoryLocale > filteredRepositoryLocales =
154+ getFilteredRepositoryLocales (aiTranslateInput , repository );
155+
156+ filteredRepositoryLocales .forEach (
157+ repositoryLocale -> {
158+ asyncProcessLocale (repositoryLocale , aiTranslateInput .sourceTextMaxCountPerLocale ());
159+ });
160+ }
161+
162+ void asyncProcessLocale (RepositoryLocale repositoryLocale , int sourceTextMaxCountPerLocale ) {
163+
164+ Repository repository = repositoryLocale .getRepository ();
165+
166+ logger .debug (
167+ "Get untranslated strings for locale: '{}' in repository: '{}'" ,
168+ repositoryLocale .getLocale ().getBcp47Tag (),
169+ repository .getName ());
170+
171+ TextUnitSearcherParameters textUnitSearcherParameters = new TextUnitSearcherParameters ();
172+ textUnitSearcherParameters .setRepositoryIds (repository .getId ());
173+ textUnitSearcherParameters .setStatusFilter (StatusFilter .FOR_TRANSLATION );
174+ textUnitSearcherParameters .setLocaleId (repositoryLocale .getLocale ().getId ());
175+ textUnitSearcherParameters .setLimit (sourceTextMaxCountPerLocale );
176+
177+ List <TextUnitDTO > textUnitDTOS = textUnitSearcher .search (textUnitSearcherParameters );
178+
179+ if (textUnitDTOS .isEmpty ()) {
180+ logger .debug (
181+ "Nothing to translate for locale: {}" , repositoryLocale .getLocale ().getBcp47Tag ());
182+ return ;
134183 }
135184
185+ logger .info (
186+ "Starting parallel processing for each string in locale: {}, count: {}" ,
187+ repositoryLocale .getLocale ().getBcp47Tag (),
188+ textUnitDTOS .size ());
189+
190+ int maxConcurrency = 10 ;
191+
192+ Flux .fromIterable (textUnitDTOS )
193+ .flatMap (
194+ textUnitDTO ->
195+ getChatCompletionForTextUnitDTO (textUnitDTO )
196+ .retryWhen (
197+ Retry .backoff (5 , Duration .ofSeconds (1 ))
198+ .filter (this ::isRetriableException )
199+ .doBeforeRetry (
200+ retrySignal -> {
201+ logger .warn (
202+ "Retrying request for TextUnitDTO {} due to {}" ,
203+ textUnitDTO .getTmTextUnitId (),
204+ retrySignal .failure ().getMessage ());
205+ }))
206+ .onErrorResume (
207+ error -> {
208+ logger .error (
209+ "Request for TextUnitDTO {} failed after retries: {}" ,
210+ textUnitDTO .getTmTextUnitId (),
211+ error .getMessage ());
212+ return Mono .empty ();
213+ }),
214+ maxConcurrency )
215+ .collectList ()
216+ .flatMap (this ::processAggregatedResults )
217+ .block (); // Blocking here for simplicity; consider using async handling in real
218+ // applications
219+
220+ logger .info ("Done submitting for processing" );
221+ }
222+
223+ private Mono <Void > processAggregatedResults (
224+ List <Tuple2 <TextUnitDTO , ChatCompletionsResponse >> results ) {
225+
226+ // Process each result
227+ List <TextUnitDTO > forImport =
228+ results .stream ()
229+ .map (
230+ tuple -> {
231+ TextUnitDTO textUnitDTO = tuple .getT1 ();
232+ ChatCompletionsResponse chatCompletionsResponse = tuple .getT2 ();
233+
234+ String completionOutputAsJson =
235+ chatCompletionsResponse .choices ().getFirst ().message ().content ();
236+
237+ CompletionOutput completionOutput =
238+ objectMapper .readValueUnchecked (
239+ completionOutputAsJson , CompletionOutput .class );
240+
241+ textUnitDTO .setTarget (completionOutput .target ().content ());
242+ textUnitDTO .setTargetComment ("ai-translate" );
243+ return textUnitDTO ;
244+ })
245+ .collect (Collectors .toList ());
246+
247+ // Import the translations
248+ textUnitBatchImporterService .importTextUnits (
249+ forImport ,
250+ TextUnitBatchImporterService .IntegrityChecksType .ALWAYS_USE_INTEGRITY_CHECKER_STATUS );
251+
252+ return Mono .empty ();
253+ }
254+
255+ private Mono <Tuple2 <TextUnitDTO , ChatCompletionsResponse >> getChatCompletionForTextUnitDTO (
256+ TextUnitDTO textUnitDTO ) {
257+
258+ CompletionInput completionInput =
259+ new CompletionInput (
260+ textUnitDTO .getTargetLocale (), textUnitDTO .getSource (), textUnitDTO .getComment ());
261+
262+ String inputAsJsonString = objectMapper .writeValueAsStringUnchecked (completionInput );
263+ ObjectNode jsonSchema = createJsonSchema (CompletionOutput .class );
264+
265+ ChatCompletionsRequest chatCompletionsRequest =
266+ chatCompletionsRequest ()
267+ .model ("gpt-4o-2024-08-06" )
268+ .maxTokens (16384 )
269+ .messages (
270+ List .of (
271+ systemMessageBuilder ().content (PROMPT ).build (),
272+ userMessageBuilder ().content (inputAsJsonString ).build ()))
273+ .responseFormat (
274+ new ChatCompletionsRequest .JsonFormat (
275+ "json_schema" ,
276+ new ChatCompletionsRequest .JsonFormat .JsonSchema (
277+ true , "request_json_format" , jsonSchema )))
278+ .build ();
279+
280+ CompletableFuture <ChatCompletionsResponse > futureResult =
281+ getOpenAIClient ().getChatCompletions (chatCompletionsRequest );
282+
283+ Mono <ChatCompletionsResponse > resultMono = Mono .fromFuture (futureResult );
284+
285+ return resultMono .map (chatCompletionResult -> Tuples .of (textUnitDTO , chatCompletionResult ));
286+ }
287+
288+ private boolean isRetriableException (Throwable throwable ) {
289+ Throwable cause = throwable instanceof CompletionException ? throwable .getCause () : throwable ;
290+ return cause instanceof IOException || cause instanceof TimeoutException ;
291+ }
292+
293+ public void aiTranslateBatch (AiTranslateInput aiTranslateInput ) throws AiTranslateException {
294+
295+ Repository repository = getRepository (aiTranslateInput );
296+
136297 logger .debug ("Start AI Translation for repository: {}" , repository .getName ());
137298
138299 try {
139300 Set <RepositoryLocale > repositoryLocalesWithoutRootLocale =
140- repositoryService .getRepositoryLocalesWithoutRootLocale (repository ).stream ()
141- .filter (
142- rl ->
143- aiTranslateInput .targetBcp47tags == null
144- || aiTranslateInput .targetBcp47tags .contains (
145- rl .getLocale ().getBcp47Tag ()))
146- .collect (Collectors .toSet ());
301+ getFilteredRepositoryLocales (aiTranslateInput , repository );
147302
148303 logger .debug ("Create batches for repository: {}" , repository .getName ());
149304 ArrayDeque <CreateBatchResponse > batches =
@@ -167,6 +322,27 @@ public void aiTranslate(AiTranslateInput aiTranslateInput) throws AiTranslateExc
167322 }
168323 }
169324
325+ private Set <RepositoryLocale > getFilteredRepositoryLocales (
326+ AiTranslateInput aiTranslateInput , Repository repository ) {
327+ return repositoryService .getRepositoryLocalesWithoutRootLocale (repository ).stream ()
328+ .filter (
329+ rl ->
330+ aiTranslateInput .targetBcp47tags == null
331+ || aiTranslateInput .targetBcp47tags .contains (rl .getLocale ().getBcp47Tag ()))
332+ .collect (Collectors .toSet ());
333+ }
334+
335+ private Repository getRepository (AiTranslateInput aiTranslateInput ) {
336+ Repository repository = repositoryRepository .findByName (aiTranslateInput .repositoryName ());
337+
338+ if (repository == null ) {
339+ throw new RepositoryNameNotFoundException (
340+ String .format (
341+ "Repository with name '%s' can not be found!" , aiTranslateInput .repositoryName ()));
342+ }
343+ return repository ;
344+ }
345+
170346 void importBatch (RetrieveBatchResponse retrieveBatchResponse ) {
171347
172348 logger .info ("Importing batch: {}" , retrieveBatchResponse .id ());
@@ -208,7 +384,7 @@ void importBatch(RetrieveBatchResponse retrieveBatchResponse) {
208384 "Response batch file line failed: " + chatCompletionResponseBatchFileLine );
209385 }
210386
211- String aiTranslateOutputAsJson =
387+ String completionOutputAsJson =
212388 chatCompletionResponseBatchFileLine
213389 .response ()
214390 .chatCompletionsResponse ()
@@ -217,14 +393,14 @@ void importBatch(RetrieveBatchResponse retrieveBatchResponse) {
217393 .message ()
218394 .content ();
219395
220- AiTranslateOutput aiTranslateOutput =
396+ CompletionOutput completionOutput =
221397 objectMapper .readValueUnchecked (
222- aiTranslateOutputAsJson , AiTranslateOutput .class );
398+ completionOutputAsJson , CompletionOutput .class );
223399
224400 TextUnitDTO textUnitDTO =
225401 tmTextUnitIdToTextUnitDTOs .get (
226402 Long .valueOf (chatCompletionResponseBatchFileLine .customId ()));
227- textUnitDTO .setTarget (aiTranslateOutput .target ().content ());
403+ textUnitDTO .setTarget (completionOutput .target ().content ());
228404 textUnitDTO .setTargetComment ("ai-translate" );
229405 return textUnitDTO ;
230406 })
@@ -301,7 +477,7 @@ String generateBatchFileContent(List<TextUnitDTO> textUnitDTOS) {
301477
302478 String inputAsJsonString = objectMapper .writeValueAsStringUnchecked (completionInput );
303479
304- ObjectNode jsonSchema = createJsonSchema (AiTranslateOutput .class );
480+ ObjectNode jsonSchema = createJsonSchema (CompletionOutput .class );
305481
306482 ChatCompletionsRequest chatCompletionsRequest =
307483 chatCompletionsRequest ()
@@ -378,7 +554,7 @@ RetrieveBatchResponse retrieveBatchWithRetry(CreateBatchResponse batch) {
378554
379555 record CompletionInput (String locale , String source , String sourceDescription ) {}
380556
381- record AiTranslateOutput (
557+ record CompletionOutput (
382558 String source ,
383559 Target target ,
384560 DescriptionRating descriptionRating ,
0 commit comments