Skip to content

Commit 043881f

Browse files
committed
Implement first version of no-batch ai translate api
1 parent 7f8870e commit 043881f

File tree

10 files changed

+702
-219
lines changed

10 files changed

+702
-219
lines changed

cli/src/main/java/com/box/l10n/mojito/cli/command/RepositoryAiTranslationCommand.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ public class RepositoryAiTranslationCommand extends Command {
4343
@Parameter(
4444
names = {Param.REPOSITORY_LOCALES_LONG, Param.REPOSITORY_LOCALES_SHORT},
4545
variableArity = true,
46-
required = true,
47-
description = "List of locales (bcp47 tags) to machine translate")
46+
description =
47+
"List of locales (bcp47 tags) to translate, if not provided translate all locales in the repository")
4848
List<String> locales;
4949

5050
@Parameter(
@@ -55,6 +55,12 @@ public class RepositoryAiTranslationCommand extends Command {
5555
+ "sending too many strings to MT)")
5656
int sourceTextMaxCount = 100;
5757

58+
@Parameter(
59+
names = {"--use-batch"},
60+
arity = 1,
61+
description = "To use the batch API or not")
62+
boolean useBatch = false;
63+
5864
@Autowired CommandHelper commandHelper;
5965

6066
@Autowired RepositoryAiTranslateClient repositoryAiTranslateClient;
@@ -75,13 +81,13 @@ public void execute() throws CommandException {
7581
.reset()
7682
.a(" for locales: ")
7783
.fg(Color.CYAN)
78-
.a(locales.stream().collect(Collectors.joining(", ", "[", "]")))
84+
.a(locales == null ? "<all>" : locales.stream().collect(Collectors.joining(", ", "[", "]")))
7985
.println(2);
8086

8187
ProtoAiTranslateResponse protoAiTranslateResponse =
8288
repositoryAiTranslateClient.translateRepository(
8389
new RepositoryAiTranslateClient.ProtoAiTranslateRequest(
84-
repositoryParam, locales, sourceTextMaxCount));
90+
repositoryParam, locales, sourceTextMaxCount, useBatch));
8591

8692
PollableTask pollableTask = protoAiTranslateResponse.pollableTask();
8793
commandHelper.waitForPollableTask(pollableTask.getId());

common/src/main/java/com/box/l10n/mojito/openai/OpenAIClient.java

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import java.util.Objects;
2626
import java.util.UUID;
2727
import java.util.concurrent.CompletableFuture;
28+
import java.util.concurrent.Executor;
29+
import java.util.concurrent.ForkJoinPool;
2830
import java.util.function.Predicate;
2931
import java.util.stream.Collectors;
3032
import java.util.stream.Stream;
@@ -39,11 +41,19 @@ public class OpenAIClient {
3941

4042
final HttpClient httpClient;
4143

42-
OpenAIClient(String apiKey, String host, ObjectMapper objectMapper, HttpClient httpClient) {
44+
final Executor asyncExecutor;
45+
46+
OpenAIClient(
47+
String apiKey,
48+
String host,
49+
ObjectMapper objectMapper,
50+
HttpClient httpClient,
51+
Executor asyncExecutor) {
4352
this.apiKey = Objects.requireNonNull(apiKey);
4453
this.host = Objects.requireNonNull(host);
4554
this.objectMapper = Objects.requireNonNull(objectMapper);
4655
this.httpClient = Objects.requireNonNull(httpClient);
56+
this.asyncExecutor = Objects.requireNonNull(asyncExecutor);
4757
}
4858

4959
public static class Builder {
@@ -56,6 +66,8 @@ public static class Builder {
5666

5767
private HttpClient httpClient;
5868

69+
private Executor asyncExecutor;
70+
5971
public Builder() {}
6072

6173
public Builder apiKey(String apiKey) {
@@ -78,6 +90,11 @@ public Builder httpClient(HttpClient httpClient) {
7890
return this;
7991
}
8092

93+
public Builder asyncExecutor(Executor asyncExecutor) {
94+
this.asyncExecutor = asyncExecutor;
95+
return this;
96+
}
97+
8198
public OpenAIClient build() {
8299
if (apiKey == null) {
83100
throw new IllegalStateException("API key must be provided");
@@ -89,11 +106,16 @@ public OpenAIClient build() {
89106
if (httpClient == null) {
90107
httpClient = createHttpClient();
91108
}
92-
return new OpenAIClient(apiKey, host, objectMapper, httpClient);
109+
110+
if (asyncExecutor == null) {
111+
asyncExecutor = ForkJoinPool.commonPool();
112+
}
113+
114+
return new OpenAIClient(apiKey, host, objectMapper, httpClient, asyncExecutor);
93115
}
94116

95117
private HttpClient createHttpClient() {
96-
return HttpClient.newHttpClient();
118+
return HttpClient.newBuilder().build();
97119
}
98120

99121
private ObjectMapper createObjectMapper() {
@@ -135,7 +157,7 @@ public CompletableFuture<ChatCompletionsResponse> getChatCompletions(
135157
CompletableFuture<ChatCompletionsResponse> chatCompletionsResponse =
136158
httpClient
137159
.sendAsync(request, HttpResponse.BodyHandlers.ofString())
138-
.thenApply(
160+
.thenApplyAsync(
139161
httpResponse -> {
140162
if (httpResponse.statusCode() != 200) {
141163
throw new OpenAIClientResponseException("ChatCompletion failed", httpResponse);
@@ -148,7 +170,8 @@ public CompletableFuture<ChatCompletionsResponse> getChatCompletions(
148170
"Can't deserialize ChatCompletionsResponse", e, httpResponse);
149171
}
150172
}
151-
});
173+
},
174+
asyncExecutor);
152175

153176
return chatCompletionsResponse;
154177
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package com.box.l10n.mojito.openai;
2+
3+
import com.google.common.base.Function;
4+
import java.net.http.HttpClient;
5+
import java.util.concurrent.CompletableFuture;
6+
import java.util.concurrent.ExecutorService;
7+
import java.util.concurrent.Executors;
8+
import java.util.concurrent.Semaphore;
9+
import java.util.concurrent.ThreadLocalRandom;
10+
import org.slf4j.Logger;
11+
import org.slf4j.LoggerFactory;
12+
13+
public class OpenAIClientPool {
14+
15+
static Logger logger = LoggerFactory.getLogger(OpenAIClientPool.class);
16+
17+
int numberOfClients;
18+
OpenAIClientWithSemaphore[] openAIClientWithSemaphores;
19+
20+
/**
21+
* Pool to parallelize slower requests (1s+) over HTTP/2 connections.
22+
*
23+
* @param numberOfClients Number of OpenAIClient instances with independent HttpClients.
24+
* @param numberOfParallelRequestPerClient Maximum parallel requests per client, controlled by a
25+
* semaphore to prevent overload.
26+
* @param sizeOfAsyncProcessors Shared async processors across all HttpClients to limit threads,
27+
* as request time is the main bottleneck.
28+
* @param apiKey API key for authentication.
29+
*/
30+
public OpenAIClientPool(
31+
int numberOfClients,
32+
int numberOfParallelRequestPerClient,
33+
int sizeOfAsyncProcessors,
34+
String apiKey) {
35+
ExecutorService asyncExecutor = Executors.newWorkStealingPool(sizeOfAsyncProcessors);
36+
this.numberOfClients = numberOfClients;
37+
this.openAIClientWithSemaphores = new OpenAIClientWithSemaphore[numberOfClients];
38+
for (int i = 0; i < numberOfClients; i++) {
39+
this.openAIClientWithSemaphores[i] =
40+
new OpenAIClientWithSemaphore(
41+
OpenAIClient.builder()
42+
.apiKey(apiKey)
43+
.asyncExecutor(asyncExecutor)
44+
.httpClient(HttpClient.newBuilder().executor(asyncExecutor).build())
45+
.build(),
46+
new Semaphore(numberOfParallelRequestPerClient));
47+
}
48+
}
49+
50+
public <T> CompletableFuture<T> submit(Function<OpenAIClient, CompletableFuture<T>> f) {
51+
52+
while (true) {
53+
for (OpenAIClientWithSemaphore openAIClientWithSemaphore : openAIClientWithSemaphores) {
54+
if (openAIClientWithSemaphore.semaphore().tryAcquire()) {
55+
return f.apply(openAIClientWithSemaphore.openAIClient())
56+
.whenComplete((o, e) -> openAIClientWithSemaphore.semaphore().release());
57+
}
58+
}
59+
60+
try {
61+
logger.debug("can't directly acquire any semaphore, do blocking");
62+
int randomSemaphoreIndex =
63+
ThreadLocalRandom.current().nextInt(openAIClientWithSemaphores.length);
64+
OpenAIClientWithSemaphore randomClientWithSemaphore =
65+
this.openAIClientWithSemaphores[randomSemaphoreIndex];
66+
randomClientWithSemaphore.semaphore().acquire();
67+
return f.apply(randomClientWithSemaphore.openAIClient())
68+
.whenComplete((o, e) -> randomClientWithSemaphore.semaphore().release());
69+
} catch (InterruptedException e) {
70+
Thread.currentThread().interrupt();
71+
throw new RuntimeException("Can't submit task to the OpenAIClientPool", e);
72+
}
73+
}
74+
}
75+
76+
record OpenAIClientWithSemaphore(OpenAIClient openAIClient, Semaphore semaphore) {}
77+
}
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
package com.box.l10n.mojito.openai;
2+
3+
import static com.box.l10n.mojito.openai.OpenAIClient.ChatCompletionsRequest.SystemMessage.systemMessageBuilder;
4+
import static com.box.l10n.mojito.openai.OpenAIClient.ChatCompletionsRequest.UserMessage.userMessageBuilder;
5+
import static com.box.l10n.mojito.openai.OpenAIClient.ChatCompletionsRequest.chatCompletionsRequest;
6+
7+
import com.box.l10n.mojito.openai.OpenAIClient.ChatCompletionsResponse;
8+
import com.google.common.base.Stopwatch;
9+
import java.util.ArrayList;
10+
import java.util.List;
11+
import java.util.concurrent.CompletableFuture;
12+
import java.util.concurrent.TimeUnit;
13+
import java.util.concurrent.atomic.AtomicInteger;
14+
import java.util.stream.Collectors;
15+
import org.junit.Ignore;
16+
import org.junit.jupiter.api.Test;
17+
import org.slf4j.Logger;
18+
import org.slf4j.LoggerFactory;
19+
20+
@Ignore
21+
class OpenAIClientPoolTest {
22+
23+
static Logger logger = LoggerFactory.getLogger(OpenAIClientPoolTest.class);
24+
25+
static final String API_KEY;
26+
27+
static {
28+
try {
29+
// API_KEY =
30+
//
31+
// Files.readString(Paths.get(System.getProperty("user.home")).resolve(".keys/openai"))
32+
// .trim();
33+
API_KEY = "test-api-key";
34+
} catch (Throwable e) {
35+
throw new RuntimeException(e);
36+
}
37+
}
38+
39+
@Test
40+
void test2() {
41+
int numberOfClients = 10;
42+
int numberOfParallelRequestPerClient = 50;
43+
int numberOfRequests = 10000;
44+
int sizeOfAsyncProcessors = 10;
45+
int totalExecutions = numberOfClients * numberOfParallelRequestPerClient;
46+
47+
OpenAIClientPool openAIClientPool =
48+
new OpenAIClientPool(
49+
numberOfClients, numberOfParallelRequestPerClient, sizeOfAsyncProcessors, API_KEY);
50+
51+
AtomicInteger responseCounter = new AtomicInteger();
52+
AtomicInteger submitted = new AtomicInteger();
53+
Stopwatch stopwatch = Stopwatch.createStarted();
54+
55+
ArrayList<Long> submissionTimes = new ArrayList<>();
56+
ArrayList<Long> responseTimes = new ArrayList<>();
57+
58+
List<CompletableFuture<ChatCompletionsResponse>> responses = new ArrayList<>();
59+
for (int i = 0; i < numberOfRequests; i++) {
60+
String message = "Is %d prime?".formatted(i);
61+
Stopwatch requestStopwatch = Stopwatch.createStarted();
62+
OpenAIClient.ChatCompletionsRequest chatCompletionsRequest =
63+
chatCompletionsRequest()
64+
.model("gpt-4o-2024-08-06")
65+
.messages(
66+
List.of(
67+
systemMessageBuilder()
68+
.content("You're an engine designed to check prime numbers")
69+
.build(),
70+
userMessageBuilder().content(message).build()))
71+
.build();
72+
73+
CompletableFuture<ChatCompletionsResponse> response =
74+
openAIClientPool.submit(
75+
openAIClient -> {
76+
CompletableFuture<ChatCompletionsResponse> chatCompletions =
77+
openAIClient.getChatCompletions(chatCompletionsRequest);
78+
submissionTimes.add(requestStopwatch.elapsed(TimeUnit.SECONDS));
79+
if (submitted.incrementAndGet() % 100 == 0) {
80+
logger.info(
81+
"--> request per second: "
82+
+ submitted.get() / (stopwatch.elapsed(TimeUnit.SECONDS) + 0.00001)
83+
+ ", submission count: "
84+
+ submitted.get()
85+
+ ", future response count: "
86+
+ responses.size()
87+
+ ", last submissions took: "
88+
+ submissionTimes.subList(
89+
Math.max(0, submissionTimes.size() - 100), submissionTimes.size()));
90+
}
91+
return chatCompletions;
92+
});
93+
94+
response.thenApply(
95+
chatCompletionsResponse -> {
96+
responseTimes.add(requestStopwatch.elapsed(TimeUnit.MILLISECONDS));
97+
if (responseCounter.incrementAndGet() % 10 == 0) {
98+
double avg =
99+
responseTimes.stream().collect(Collectors.averagingLong(Long::longValue));
100+
logger.info(
101+
"<-- response per second: "
102+
+ responseCounter.get() / stopwatch.elapsed(TimeUnit.SECONDS)
103+
+ ", average response time: "
104+
+ Math.round(avg)
105+
+ " (rps: "
106+
+ Math.round(totalExecutions / (avg / 1000.0))
107+
+ "), response count from counter: "
108+
+ responseCounter.get()
109+
+ ", last elapsed times: "
110+
+ responseTimes.subList(responseTimes.size() - 20, responseTimes.size()));
111+
}
112+
return chatCompletionsResponse;
113+
});
114+
115+
responses.add(response);
116+
}
117+
118+
Stopwatch started = Stopwatch.createStarted();
119+
CompletableFuture.allOf(responses.toArray(new CompletableFuture[responses.size()])).join();
120+
logger.info("Waiting for join: " + started.elapsed());
121+
122+
double avg = responseTimes.stream().collect(Collectors.averagingLong(Long::longValue));
123+
logger.info(
124+
"Total time: "
125+
+ stopwatch.elapsed().toString()
126+
+ ", request per second: "
127+
+ Math.round((double) numberOfRequests / stopwatch.elapsed(TimeUnit.SECONDS))
128+
+ ", average response time: "
129+
+ Math.round(avg)
130+
+ " (theory rps: "
131+
+ Math.round(totalExecutions / (avg / 1000.0))
132+
+ ")");
133+
}
134+
}

0 commit comments

Comments
 (0)