|  | 
|  | 1 | +package com.box.l10n.mojito.service.oaireview; | 
|  | 2 | + | 
|  | 3 | +import com.box.l10n.mojito.openai.OpenAIClient.CreateBatchResponse; | 
|  | 4 | +import com.box.l10n.mojito.openai.OpenAIClient.RetrieveBatchResponse; | 
|  | 5 | +import com.box.l10n.mojito.quartz.QuartzPollableJob; | 
|  | 6 | +import com.box.l10n.mojito.service.pollableTask.PollableFuture; | 
|  | 7 | +import java.util.ArrayList; | 
|  | 8 | +import java.util.HashSet; | 
|  | 9 | +import java.util.List; | 
|  | 10 | +import java.util.Set; | 
|  | 11 | +import org.slf4j.Logger; | 
|  | 12 | +import org.slf4j.LoggerFactory; | 
|  | 13 | +import org.springframework.beans.factory.annotation.Autowired; | 
|  | 14 | +import org.springframework.stereotype.Component; | 
|  | 15 | + | 
|  | 16 | +@Component | 
|  | 17 | +public class AiReviewBatchesImportJob | 
|  | 18 | +    extends QuartzPollableJob< | 
|  | 19 | +        AiReviewBatchesImportJob.AiReviewBatchesImportInput, | 
|  | 20 | +        AiReviewBatchesImportJob.AiReviewBatchesImportOutput> { | 
|  | 21 | + | 
|  | 22 | +  static Logger logger = LoggerFactory.getLogger(AiReviewBatchesImportJob.class); | 
|  | 23 | + | 
|  | 24 | +  @Autowired AiReviewService aiReviewService; | 
|  | 25 | + | 
|  | 26 | +  public record AiReviewBatchesImportInput( | 
|  | 27 | +      List<CreateBatchResponse> createBatchResponses, List<String> processed, int attempt) {} | 
|  | 28 | + | 
|  | 29 | +  public record AiReviewBatchesImportOutput( | 
|  | 30 | +      List<RetrieveBatchResponse> retrieveBatchResponses, List<String> processed, Long nextJob) {} | 
|  | 31 | + | 
|  | 32 | +  @Override | 
|  | 33 | +  public AiReviewBatchesImportOutput call(AiReviewBatchesImportInput aiReviewBatchesImportInput) | 
|  | 34 | +      throws Exception { | 
|  | 35 | + | 
|  | 36 | +    List<RetrieveBatchResponse> retrieveBatchResponses = new ArrayList<>(); | 
|  | 37 | +    Set<String> processed = new HashSet<>(aiReviewBatchesImportInput.processed()); | 
|  | 38 | + | 
|  | 39 | +    List<CreateBatchResponse> toImport = | 
|  | 40 | +        aiReviewBatchesImportInput.createBatchResponses().stream() | 
|  | 41 | +            .filter(b -> !processed.contains(b.id())) | 
|  | 42 | +            .toList(); | 
|  | 43 | + | 
|  | 44 | +    logger.info( | 
|  | 45 | +        "{} batches to import, already processed: {} and total: {}", | 
|  | 46 | +        toImport.size(), | 
|  | 47 | +        processed.size(), | 
|  | 48 | +        aiReviewBatchesImportInput.createBatchResponses.size()); | 
|  | 49 | + | 
|  | 50 | +    for (CreateBatchResponse createBatchResponse : toImport) { | 
|  | 51 | + | 
|  | 52 | +      logger.debug("Retrieve current status of batch: {}", createBatchResponse.id()); | 
|  | 53 | +      RetrieveBatchResponse retrieveBatchResponse = | 
|  | 54 | +          aiReviewService.retrieveBatchWithRetry(createBatchResponse); | 
|  | 55 | +      retrieveBatchResponses.add(retrieveBatchResponse); | 
|  | 56 | + | 
|  | 57 | +      if ("completed".equals(retrieveBatchResponse.status())) { | 
|  | 58 | +        logger.info("Completed batch: {}", retrieveBatchResponse.id()); | 
|  | 59 | +        aiReviewService.importBatch(retrieveBatchResponse); | 
|  | 60 | +        processed.add(createBatchResponse.id()); | 
|  | 61 | +      } else if ("failed".equals(retrieveBatchResponse.status())) { | 
|  | 62 | +        logger.error("Batch failed, skipping it: {}", retrieveBatchResponse); | 
|  | 63 | +        processed.add(createBatchResponse.id()); | 
|  | 64 | +      } else if ("expired".equals(retrieveBatchResponse.status())) { | 
|  | 65 | +        logger.info("Batch expired, skipping it: {}", retrieveBatchResponse); | 
|  | 66 | +        processed.add(createBatchResponse.id()); | 
|  | 67 | +      } else if ("cancelled".equals(retrieveBatchResponse.status())) { | 
|  | 68 | +        logger.info("Batch cancelled, skipping it: {}", retrieveBatchResponse); | 
|  | 69 | +        processed.add(createBatchResponse.id()); | 
|  | 70 | +      } else { | 
|  | 71 | +        logger.info("Batch is still processing will process later: {}", retrieveBatchResponse); | 
|  | 72 | +      } | 
|  | 73 | +    } | 
|  | 74 | + | 
|  | 75 | +    PollableFuture<AiReviewBatchesImportOutput> pollableFuture = null; | 
|  | 76 | + | 
|  | 77 | +    if (processed.size() >= aiReviewBatchesImportInput.createBatchResponses().size()) { | 
|  | 78 | +      logger.info( | 
|  | 79 | +          "Everything has been processed ({}/{}), don't reschedule", | 
|  | 80 | +          processed.size(), | 
|  | 81 | +          aiReviewBatchesImportInput.createBatchResponses().size()); | 
|  | 82 | +    } else { | 
|  | 83 | +      logger.info("Schedule new job to process remaining batches, processed: {}", processed); | 
|  | 84 | +      pollableFuture = | 
|  | 85 | +          aiReviewService.aiReviewBatchesImportAsync( | 
|  | 86 | +              new AiReviewBatchesImportInput( | 
|  | 87 | +                  aiReviewBatchesImportInput.createBatchResponses(), | 
|  | 88 | +                  processed.stream().toList(), | 
|  | 89 | +                  aiReviewBatchesImportInput.attempt() + 1), | 
|  | 90 | +              getCurrentPollableTask().getParentTask()); | 
|  | 91 | +    } | 
|  | 92 | + | 
|  | 93 | +    return new AiReviewBatchesImportOutput( | 
|  | 94 | +        retrieveBatchResponses, | 
|  | 95 | +        processed.stream().toList(), | 
|  | 96 | +        pollableFuture != null ? pollableFuture.getPollableTask().getId() : null); | 
|  | 97 | +  } | 
|  | 98 | +} | 
0 commit comments