Skip to content

Commit 0713f58

Browse files
committed
Implement first version of no-batch ai translate api
1 parent 9af6ee2 commit 0713f58

File tree

9 files changed

+464
-37
lines changed

9 files changed

+464
-37
lines changed

cli/src/main/java/com/box/l10n/mojito/cli/command/RepositoryAiTranslationCommand.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ public class RepositoryAiTranslationCommand extends Command {
4343
@Parameter(
4444
names = {Param.REPOSITORY_LOCALES_LONG, Param.REPOSITORY_LOCALES_SHORT},
4545
variableArity = true,
46-
required = true,
47-
description = "List of locales (bcp47 tags) to machine translate")
46+
description =
47+
"List of locales (bcp47 tags) to translate, if not provided translate all locales in the repository")
4848
List<String> locales;
4949

5050
@Parameter(
@@ -55,6 +55,12 @@ public class RepositoryAiTranslationCommand extends Command {
5555
+ "sending too many strings to MT)")
5656
int sourceTextMaxCount = 100;
5757

58+
@Parameter(
59+
names = {"--use-batch"},
60+
arity = 1,
61+
description = "To use the batch API or not")
62+
boolean useBatch = false;
63+
5864
@Autowired CommandHelper commandHelper;
5965

6066
@Autowired RepositoryAiTranslateClient repositoryAiTranslateClient;
@@ -75,13 +81,13 @@ public void execute() throws CommandException {
7581
.reset()
7682
.a(" for locales: ")
7783
.fg(Color.CYAN)
78-
.a(locales.stream().collect(Collectors.joining(", ", "[", "]")))
84+
.a(locales == null ? "<all>" : locales.stream().collect(Collectors.joining(", ", "[", "]")))
7985
.println(2);
8086

8187
ProtoAiTranslateResponse protoAiTranslateResponse =
8288
repositoryAiTranslateClient.translateRepository(
8389
new RepositoryAiTranslateClient.ProtoAiTranslateRequest(
84-
repositoryParam, locales, sourceTextMaxCount));
90+
repositoryParam, locales, sourceTextMaxCount, useBatch));
8591

8692
PollableTask pollableTask = protoAiTranslateResponse.pollableTask();
8793
commandHelper.waitForPollableTask(pollableTask.getId());

common/src/main/java/com/box/l10n/mojito/openai/OpenAIClient.java

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import java.util.Objects;
2626
import java.util.UUID;
2727
import java.util.concurrent.CompletableFuture;
28+
import java.util.concurrent.Executor;
29+
import java.util.concurrent.ForkJoinPool;
2830
import java.util.function.Predicate;
2931
import java.util.stream.Collectors;
3032
import java.util.stream.Stream;
@@ -39,11 +41,19 @@ public class OpenAIClient {
3941

4042
final HttpClient httpClient;
4143

42-
OpenAIClient(String apiKey, String host, ObjectMapper objectMapper, HttpClient httpClient) {
44+
final Executor asyncExecutor;
45+
46+
OpenAIClient(
47+
String apiKey,
48+
String host,
49+
ObjectMapper objectMapper,
50+
HttpClient httpClient,
51+
Executor asyncExecutor) {
4352
this.apiKey = Objects.requireNonNull(apiKey);
4453
this.host = Objects.requireNonNull(host);
4554
this.objectMapper = Objects.requireNonNull(objectMapper);
4655
this.httpClient = Objects.requireNonNull(httpClient);
56+
this.asyncExecutor = Objects.requireNonNull(asyncExecutor);
4757
}
4858

4959
public static class Builder {
@@ -56,6 +66,8 @@ public static class Builder {
5666

5767
private HttpClient httpClient;
5868

69+
private Executor asyncExecutor;
70+
5971
public Builder() {}
6072

6173
public Builder apiKey(String apiKey) {
@@ -78,6 +90,11 @@ public Builder httpClient(HttpClient httpClient) {
7890
return this;
7991
}
8092

93+
public Builder asyncExecutor(Executor asyncExecutor) {
94+
this.asyncExecutor = asyncExecutor;
95+
return this;
96+
}
97+
8198
public OpenAIClient build() {
8299
if (apiKey == null) {
83100
throw new IllegalStateException("API key must be provided");
@@ -89,11 +106,16 @@ public OpenAIClient build() {
89106
if (httpClient == null) {
90107
httpClient = createHttpClient();
91108
}
92-
return new OpenAIClient(apiKey, host, objectMapper, httpClient);
109+
110+
if (asyncExecutor == null) {
111+
asyncExecutor = ForkJoinPool.commonPool();
112+
}
113+
114+
return new OpenAIClient(apiKey, host, objectMapper, httpClient, asyncExecutor);
93115
}
94116

95117
private HttpClient createHttpClient() {
96-
return HttpClient.newHttpClient();
118+
return HttpClient.newBuilder().build();
97119
}
98120

99121
private ObjectMapper createObjectMapper() {
@@ -135,7 +157,7 @@ public CompletableFuture<ChatCompletionsResponse> getChatCompletions(
135157
CompletableFuture<ChatCompletionsResponse> chatCompletionsResponse =
136158
httpClient
137159
.sendAsync(request, HttpResponse.BodyHandlers.ofString())
138-
.thenApply(
160+
.thenApplyAsync(
139161
httpResponse -> {
140162
if (httpResponse.statusCode() != 200) {
141163
throw new OpenAIClientResponseException("ChatCompletion failed", httpResponse);
@@ -148,7 +170,8 @@ public CompletableFuture<ChatCompletionsResponse> getChatCompletions(
148170
"Can't deserialize ChatCompletionsResponse", e, httpResponse);
149171
}
150172
}
151-
});
173+
},
174+
asyncExecutor);
152175

153176
return chatCompletionsResponse;
154177
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package com.box.l10n.mojito.openai;
2+
3+
import com.google.common.base.Function;
4+
import java.net.http.HttpClient;
5+
import java.util.concurrent.CompletableFuture;
6+
import java.util.concurrent.ExecutorService;
7+
import java.util.concurrent.Executors;
8+
import java.util.concurrent.Semaphore;
9+
import java.util.concurrent.ThreadLocalRandom;
10+
import org.slf4j.Logger;
11+
import org.slf4j.LoggerFactory;
12+
13+
public class OpenAIClientPool {
14+
15+
static Logger logger = LoggerFactory.getLogger(OpenAIClientPool.class);
16+
17+
int numberOfClients;
18+
OpenAIClientWithSemaphore[] openAIClientWithSemaphores;
19+
20+
/**
21+
* Pool to parallelize slower requests (1s+) over HTTP/2 connections.
22+
*
23+
* @param numberOfClients Number of OpenAIClient instances with independent HttpClients.
24+
* @param numberOfParallelRequestPerClient Maximum parallel requests per client, controlled by a
25+
* semaphore to prevent overload.
26+
* @param sizeOfAsyncProcessors Shared async processors across all HttpClients to limit threads,
27+
* as request time is the main bottleneck.
28+
* @param apiKey API key for authentication.
29+
*/
30+
public OpenAIClientPool(
31+
int numberOfClients,
32+
int numberOfParallelRequestPerClient,
33+
int sizeOfAsyncProcessors,
34+
String apiKey) {
35+
ExecutorService asyncExecutor = Executors.newWorkStealingPool(sizeOfAsyncProcessors);
36+
this.numberOfClients = numberOfClients;
37+
this.openAIClientWithSemaphores = new OpenAIClientWithSemaphore[numberOfClients];
38+
for (int i = 0; i < numberOfClients; i++) {
39+
this.openAIClientWithSemaphores[i] =
40+
new OpenAIClientWithSemaphore(
41+
OpenAIClient.builder()
42+
.apiKey(apiKey)
43+
.asyncExecutor(asyncExecutor)
44+
.httpClient(HttpClient.newBuilder().executor(asyncExecutor).build())
45+
.build(),
46+
new Semaphore(numberOfParallelRequestPerClient));
47+
}
48+
}
49+
50+
public <T> CompletableFuture<T> submit(Function<OpenAIClient, CompletableFuture<T>> f) {
51+
52+
while (true) {
53+
for (OpenAIClientWithSemaphore openAIClientWithSemaphore : openAIClientWithSemaphores) {
54+
if (openAIClientWithSemaphore.semaphore().tryAcquire()) {
55+
return f.apply(openAIClientWithSemaphore.openAIClient())
56+
.whenComplete((o, e) -> openAIClientWithSemaphore.semaphore().release());
57+
}
58+
}
59+
60+
try {
61+
logger.debug("can't directly acquire any semaphore, do blocking");
62+
int randomSemaphoreIndex =
63+
ThreadLocalRandom.current().nextInt(openAIClientWithSemaphores.length);
64+
OpenAIClientWithSemaphore randomClientWithSemaphore =
65+
this.openAIClientWithSemaphores[randomSemaphoreIndex];
66+
randomClientWithSemaphore.semaphore().acquire();
67+
return f.apply(randomClientWithSemaphore.openAIClient())
68+
.whenComplete((o, e) -> randomClientWithSemaphore.semaphore().release());
69+
} catch (InterruptedException e) {
70+
Thread.currentThread().interrupt();
71+
throw new RuntimeException("Can't submit task to the OpenAIClientPool", e);
72+
}
73+
}
74+
}
75+
76+
record OpenAIClientWithSemaphore(OpenAIClient openAIClient, Semaphore semaphore) {}
77+
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package com.box.l10n.mojito.openai;
2+
3+
import static com.box.l10n.mojito.openai.OpenAIClient.ChatCompletionsRequest.SystemMessage.systemMessageBuilder;
4+
import static com.box.l10n.mojito.openai.OpenAIClient.ChatCompletionsRequest.UserMessage.userMessageBuilder;
5+
import static com.box.l10n.mojito.openai.OpenAIClient.ChatCompletionsRequest.chatCompletionsRequest;
6+
7+
import com.box.l10n.mojito.io.Files;
8+
import com.box.l10n.mojito.openai.OpenAIClient.ChatCompletionsResponse;
9+
import com.google.common.base.Stopwatch;
10+
import java.nio.file.Paths;
11+
import java.util.ArrayList;
12+
import java.util.List;
13+
import java.util.concurrent.CompletableFuture;
14+
import java.util.concurrent.TimeUnit;
15+
import java.util.concurrent.atomic.AtomicInteger;
16+
import org.junit.Ignore;
17+
import org.junit.jupiter.api.Test;
18+
19+
@Ignore
20+
class OpenAIClientPoolTest {
21+
22+
static final String API_KEY;
23+
24+
static {
25+
try {
26+
API_KEY =
27+
Files.readString(Paths.get(System.getProperty("user.home")).resolve(".keys/openai"))
28+
.trim();
29+
// API_KEY = "test-api-key";
30+
} catch (Throwable e) {
31+
throw new RuntimeException(e);
32+
}
33+
}
34+
35+
@Test
36+
void test2() {
37+
OpenAIClientPool openAIClientPool = new OpenAIClientPool(100, 50, 5, API_KEY);
38+
39+
AtomicInteger counter = new AtomicInteger();
40+
Stopwatch stopwatch = Stopwatch.createStarted();
41+
42+
List<CompletableFuture<ChatCompletionsResponse>> responses = new ArrayList<>();
43+
for (int i = 0; i < 10000; i++) {
44+
String message = "Is %d prime?".formatted(i);
45+
OpenAIClient.ChatCompletionsRequest chatCompletionsRequest =
46+
chatCompletionsRequest()
47+
.model("gpt-4o-2024-08-06")
48+
.messages(
49+
List.of(
50+
systemMessageBuilder()
51+
.content("You're an engine designed to check prime numbers")
52+
.build(),
53+
userMessageBuilder().content(message).build()))
54+
.build();
55+
Stopwatch requestStopwatch = Stopwatch.createStarted();
56+
CompletableFuture<ChatCompletionsResponse> response =
57+
openAIClientPool.submit(
58+
openAIClient -> openAIClient.getChatCompletions(chatCompletionsRequest));
59+
response.thenApply(
60+
chatCompletionsResponse -> {
61+
// System.out.println(message + " --> " +
62+
// chatCompletionsResponse.choices().get(0).message().content());
63+
if (counter.get() % 100 == 0) {
64+
System.out.println("QPS: " + counter.get() / stopwatch.elapsed(TimeUnit.SECONDS));
65+
System.out.println("elapsed: " + requestStopwatch.elapsed(TimeUnit.SECONDS));
66+
}
67+
return chatCompletionsResponse;
68+
});
69+
responses.add(response);
70+
counter.incrementAndGet();
71+
}
72+
73+
for (CompletableFuture<ChatCompletionsResponse> future : responses) {
74+
future.join();
75+
}
76+
}
77+
}

restclient/src/main/java/com/box/l10n/mojito/rest/client/RepositoryAiTranslateClient.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ public ProtoAiTranslateResponse translateRepository(
2929
}
3030

3131
public record ProtoAiTranslateRequest(
32-
String repositoryName, List<String> targetBcp47tags, int sourceTextMaxCountPerLocale) {}
32+
String repositoryName,
33+
List<String> targetBcp47tags,
34+
int sourceTextMaxCountPerLocale,
35+
boolean useBatch) {}
3336

3437
public record ProtoAiTranslateResponse(PollableTask pollableTask) {}
3538
}

webapp/src/main/java/com/box/l10n/mojito/rest/textunit/AiTranslateWS.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,18 @@ public ProtoAiTranslateResponse aiTranslate(
4242
new AiTranslateInput(
4343
protoAiTranslateRequest.repositoryName(),
4444
protoAiTranslateRequest.targetBcp47tags(),
45-
protoAiTranslateRequest.sourceTextMaxCountPerLocale()));
45+
protoAiTranslateRequest.sourceTextMaxCountPerLocale(),
46+
protoAiTranslateRequest.useBatch()));
4647

4748
return new ProtoAiTranslateResponse(pollableFuture.getPollableTask());
4849
}
4950

5051
public record ProtoAiTranslateRequest(
51-
String repositoryName, List<String> targetBcp47tags, int sourceTextMaxCountPerLocale) {}
52+
String repositoryName,
53+
List<String> targetBcp47tags,
54+
int sourceTextMaxCountPerLocale,
55+
boolean useBatch,
56+
boolean allLocales) {}
5257

5358
public record ProtoAiTranslateResponse(PollableTask pollableTask) {}
5459
}

webapp/src/main/java/com/box/l10n/mojito/service/oaitranslate/AiTranslateConfig.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import com.box.l10n.mojito.json.ObjectMapper;
44
import com.box.l10n.mojito.openai.OpenAIClient;
5+
import com.box.l10n.mojito.openai.OpenAIClientPool;
56
import java.time.Duration;
67
import org.springframework.beans.factory.annotation.Qualifier;
78
import org.springframework.context.annotation.Bean;
@@ -28,6 +29,17 @@ OpenAIClient openAIClient() {
2829
return new OpenAIClient.Builder().apiKey(openaiClientToken).build();
2930
}
3031

32+
@Bean
33+
@Qualifier("AiTranslate")
34+
OpenAIClientPool openAIClientPool() {
35+
String openaiClientToken = aiTranslateConfigurationProperties.getOpenaiClientToken();
36+
if (openaiClientToken == null) {
37+
return null;
38+
}
39+
return new OpenAIClientPool(
40+
10, 50, 5, aiTranslateConfigurationProperties.getOpenaiClientToken());
41+
}
42+
3143
@Bean
3244
@Qualifier("AiTranslate")
3345
ObjectMapper objectMapper() {

0 commit comments

Comments
 (0)