Skip to content

Commit 27278e8

Browse files
committed
wip
1 parent 1363982 commit 27278e8

File tree

7 files changed

+695
-16
lines changed

7 files changed

+695
-16
lines changed

common/src/main/java/com/box/l10n/mojito/openai/OpenAIClient.java

Lines changed: 288 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
import com.fasterxml.jackson.databind.node.ArrayNode;
1212
import com.fasterxml.jackson.databind.node.ObjectNode;
1313
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
14-
import com.fasterxml.jackson.module.jsonSchema.JsonSchema;
1514
import com.fasterxml.jackson.module.jsonSchema.JsonSchemaGenerator;
15+
import java.io.IOException;
1616
import java.net.URI;
1717
import java.net.http.HttpClient;
1818
import java.net.http.HttpRequest;
@@ -23,6 +23,7 @@
2323
import java.util.List;
2424
import java.util.Map;
2525
import java.util.Objects;
26+
import java.util.UUID;
2627
import java.util.concurrent.CompletableFuture;
2728
import java.util.function.Predicate;
2829
import java.util.stream.Collectors;
@@ -144,7 +145,7 @@ public CompletableFuture<ChatCompletionsResponse> getChatCompletions(
144145
httpResponse.body(), ChatCompletionsResponse.class);
145146
} catch (JsonProcessingException e) {
146147
throw new OpenAIClientResponseException(
147-
"Can't deserialize ChatCompletionsResponse", e, httpResponse);
148+
"Can't deserialize ChatCompletionsResponse", e, httpResponse);
148149
}
149150
}
150151
});
@@ -182,7 +183,7 @@ public CompletableFuture<Stream<ChatCompletionsStreamResponse>> streamChatComple
182183
httpResponse -> {
183184
if (httpResponse.statusCode() != 200) {
184185
throw new OpenAIClientResponseException(
185-
"ChatCompletion stream failed", httpResponse);
186+
"ChatCompletion stream failed", httpResponse);
186187
}
187188
return httpResponse
188189
.body()
@@ -192,9 +193,9 @@ public CompletableFuture<Stream<ChatCompletionsStreamResponse>> streamChatComple
192193
body -> {
193194
if (!body.startsWith("data: ")) {
194195
throw new OpenAIClientResponseException(
195-
"Only support \"data\" lines in stream are supported, got: \"%s\""
196-
.formatted(body),
197-
httpResponse);
196+
"Only support \"data\" lines in stream are supported, got: \"%s\""
197+
.formatted(body),
198+
httpResponse);
198199
}
199200

200201
String jsonPart = body.substring(5);
@@ -203,9 +204,9 @@ public CompletableFuture<Stream<ChatCompletionsStreamResponse>> streamChatComple
203204
jsonPart, ChatCompletionsStreamResponse.class);
204205
} catch (JsonProcessingException e) {
205206
throw new OpenAIClientResponseException(
206-
"Can't deserialize ChatCompletionsStreamResponse",
207-
e,
208-
httpResponse);
207+
"Can't deserialize ChatCompletionsStreamResponse",
208+
e,
209+
httpResponse);
209210
}
210211
});
211212
});
@@ -551,7 +552,7 @@ public CompletableFuture<EmbeddingResponse> getEmbedding(EmbeddingRequest embedd
551552
return objectMapper.readValue(httpResponse.body(), EmbeddingResponse.class);
552553
} catch (JsonProcessingException e) {
553554
throw new OpenAIClientResponseException(
554-
"Can't deserialize EmbeddingResponse", e, httpResponse);
555+
"Can't deserialize EmbeddingResponse", e, httpResponse);
555556
}
556557
});
557558

@@ -612,19 +613,292 @@ public record Usage(
612613
@JsonProperty("total_tokens") int totalTokens) {}
613614
}
614615

616+
public UploadFileResponse uploadFile(UploadFileRequest uploadFileRequest) {
617+
618+
final String boundary = UUID.randomUUID().toString();
619+
620+
String body = uploadFileRequest.getMultipartBody(boundary);
621+
622+
HttpRequest request =
623+
HttpRequest.newBuilder()
624+
.uri(getUriForEndpoint(UploadFileRequest.ENDPOINT))
625+
.header("Authorization", "Bearer " + apiKey)
626+
.header("Content-Type", "multipart/form-data; boundary=" + boundary)
627+
.POST(HttpRequest.BodyPublishers.ofString(body))
628+
.build();
629+
630+
HttpResponse<String> response;
631+
try {
632+
response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
633+
System.out.printf("response: %s%n", response.body());
634+
} catch (IOException | InterruptedException e) {
635+
throw new RuntimeException(e);
636+
}
637+
638+
UploadFileResponse uploadFileResponse;
639+
try {
640+
uploadFileResponse = objectMapper.readValue(response.body(), UploadFileResponse.class);
641+
} catch (JsonProcessingException e) {
642+
throw new RuntimeException(e);
643+
}
644+
return uploadFileResponse;
645+
}
646+
647+
public record UploadFileRequest(
648+
String purpose, String filename, String content, String contentType) {
649+
650+
public static UploadFileRequest forBatch(String filename, String content) {
651+
return new UploadFileRequest(Purpose.BATCH.toString(), filename, content, "application/json");
652+
}
653+
654+
enum Purpose {
655+
BATCH("batch"),
656+
ASSISTANTS("assistants"),
657+
FINE_TUNE("fine-tune"),
658+
VISION("vision");
659+
660+
private final String purposeCode;
661+
662+
Purpose(String purposeCode) {
663+
this.purposeCode = purposeCode;
664+
}
665+
666+
public String getPurposeCode() {
667+
return purposeCode;
668+
}
669+
670+
@Override
671+
public String toString() {
672+
return purposeCode;
673+
}
674+
675+
public static Purpose fromCode(String purposeCode) {
676+
for (Purpose purpose : Purpose.values()) {
677+
if (purpose.purposeCode.equalsIgnoreCase(purposeCode)) {
678+
return purpose;
679+
}
680+
}
681+
throw new IllegalArgumentException("Unknown purpose: " + purposeCode);
682+
}
683+
}
684+
685+
static final String ENDPOINT = "/v1/files";
686+
687+
String getMultipartBody(String boundary) {
688+
String body =
689+
"""
690+
--%1$s\r
691+
Content-Disposition: form-data; name="purpose"\r
692+
\r
693+
%5$s\r
694+
--%1$s\r
695+
Content-Disposition: form-data; name="file"; filename="%2$s"\r
696+
Content-Type: %3$s\r
697+
\r
698+
%4$s\r
699+
--%1$s--\r
700+
"""
701+
.formatted(boundary, filename, contentType, content, purpose);
702+
return body;
703+
}
704+
}
705+
706+
public record UploadFileResponse(
707+
String object,
708+
String id,
709+
String purpose,
710+
String filename,
711+
int bytes,
712+
@JsonProperty("created_at") long createdAt,
713+
String status,
714+
@JsonProperty("status_details") String statusDetails) {}
715+
716+
public record BatchFileLine(
717+
@JsonProperty("custom_id") String customId, String method, String url, Object body) {
718+
719+
public static BatchFileLine forChatCompletion(
720+
String customId, ChatCompletionsRequest chatCompletionsRequest) {
721+
return new BatchFileLine(customId, "POST", "/v1/chat/completions", chatCompletionsRequest);
722+
}
723+
}
724+
725+
public DownloadFileContentResponse downloadFileContent(
726+
DownloadFileContentRequest downloadFileContentRequest) {
727+
HttpResponse<String> response;
728+
HttpRequest request =
729+
HttpRequest.newBuilder()
730+
.uri(
731+
getUriForEndpoint(
732+
DownloadFileContentRequest.ENDPOINT.formatted(
733+
downloadFileContentRequest.fileId())))
734+
.header("Authorization", "Bearer " + apiKey)
735+
.header("Content-Type", "application/json")
736+
.GET()
737+
.build();
738+
739+
try {
740+
response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
741+
} catch (IOException | InterruptedException e) {
742+
throw new RuntimeException(
743+
"Error while sending the request to download the file: "
744+
+ downloadFileContentRequest.fileId(),
745+
e);
746+
}
747+
748+
if (response.statusCode() != 200) {
749+
throw new OpenAIClientResponseException("Can't download file content", response);
750+
}
751+
752+
return new DownloadFileContentResponse(response.body());
753+
}
754+
755+
public record DownloadFileContentRequest(String fileId) {
756+
static final String ENDPOINT = "/v1/files/%s/content";
757+
}
758+
759+
public record DownloadFileContentResponse(String content) {}
760+
761+
public CreateBatchResponse createBatch(CreateBatchRequest createBatchRequest) {
762+
763+
String jsonBody;
764+
try {
765+
jsonBody = objectMapper.writeValueAsString(createBatchRequest);
766+
} catch (JsonProcessingException e) {
767+
throw new RuntimeException(e);
768+
}
769+
770+
HttpRequest request =
771+
HttpRequest.newBuilder()
772+
.uri(getUriForEndpoint(CreateBatchRequest.ENDPOINT))
773+
.header("Authorization", "Bearer " + apiKey)
774+
.header("Content-Type", "application/json")
775+
.POST(HttpRequest.BodyPublishers.ofString(jsonBody))
776+
.build();
777+
778+
HttpResponse<String> response;
779+
try {
780+
response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
781+
} catch (IOException | InterruptedException e) {
782+
throw new RuntimeException(e);
783+
}
784+
785+
if (response.statusCode() != 200) {
786+
throw new RuntimeException("Can't create batch");
787+
}
788+
789+
CreateBatchResponse createBatchResponse;
790+
try {
791+
createBatchResponse = objectMapper.readValue(response.body(), CreateBatchResponse.class);
792+
} catch (JsonProcessingException e) {
793+
throw new RuntimeException(e);
794+
}
795+
796+
return createBatchResponse;
797+
}
798+
799+
public record CreateBatchRequest(
800+
@JsonProperty("input_file_id") String inputFileId,
801+
String endpoint,
802+
@JsonProperty("completion_window") String completionWindow) {
803+
804+
public static final String ENDPOINT = "/v1/batches";
805+
806+
public static CreateBatchRequest forChatCompletion(String fileId) {
807+
return new CreateBatchRequest(fileId, "/v1/chat/completions", "24h");
808+
}
809+
}
810+
811+
public record CreateBatchResponse(
812+
String id,
813+
String object,
814+
String endpoint,
815+
String errors,
816+
@JsonProperty("input_file_id") String inputFileId,
817+
@JsonProperty("completion_window") String completionWindow,
818+
String status,
819+
@JsonProperty("output_file_id") String outputFileId,
820+
@JsonProperty("error_file_id") String errorFileId,
821+
@JsonProperty("created_at") long createdAt,
822+
@JsonProperty("in_progress_at") Long inProgressAt,
823+
@JsonProperty("expires_at") long expiresAt,
824+
@JsonProperty("completed_at") Long completedAt,
825+
@JsonProperty("failed_at") Long failedAt,
826+
@JsonProperty("expired_at") Long expiredAt,
827+
@JsonProperty("request_counts") RequestCounts requestCounts,
828+
Object metadata) {
829+
record RequestCounts(int total, int completed, int failed) {}
830+
}
831+
832+
public RetrieveBatchResponse retrieveBatch(RetrieveBatchRequest retrieveBatchRequest) {
833+
HttpRequest request =
834+
HttpRequest.newBuilder()
835+
.uri(
836+
getUriForEndpoint(
837+
RetrieveBatchRequest.ENDPOINT.formatted(retrieveBatchRequest.batchId())))
838+
.header("Authorization", "Bearer " + apiKey)
839+
.header("Content-Type", "application/json")
840+
.GET()
841+
.build();
842+
843+
HttpResponse<String> response;
844+
try {
845+
response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
846+
} catch (IOException | InterruptedException e) {
847+
throw new RuntimeException(e);
848+
}
849+
850+
if (response.statusCode() != 200) {
851+
throw new OpenAIClientResponseException("Can't retrieve batch", response);
852+
}
853+
854+
RetrieveBatchResponse retrieveBatchResponse;
855+
try {
856+
retrieveBatchResponse = objectMapper.readValue(response.body(), RetrieveBatchResponse.class);
857+
} catch (JsonProcessingException e) {
858+
throw new RuntimeException(e);
859+
}
860+
861+
return retrieveBatchResponse;
862+
}
863+
864+
public record RetrieveBatchRequest(String batchId) {
865+
public static final String ENDPOINT = "/v1/batches/%s";
866+
}
867+
868+
public record RetrieveBatchResponse(
869+
String id,
870+
String object,
871+
String endpoint,
872+
String errors,
873+
@JsonProperty("input_file_id") String inputFileId,
874+
@JsonProperty("completion_window") String completionWindow,
875+
String status,
876+
@JsonProperty("output_file_id") String outputFileId,
877+
@JsonProperty("error_file_id") String errorFileId,
878+
@JsonProperty("created_at") long createdAt,
879+
@JsonProperty("in_progress_at") Long inProgressAt,
880+
@JsonProperty("expires_at") long expiresAt,
881+
@JsonProperty("completed_at") Long completedAt,
882+
@JsonProperty("failed_at") Long failedAt,
883+
@JsonProperty("expired_at") Long expiredAt,
884+
@JsonProperty("request_counts") RequestCounts requestCounts,
885+
Object metadata) {
886+
record RequestCounts(int total, int completed, int failed) {}
887+
}
888+
615889
private URI getUriForEndpoint(String endpoint) {
616890
return URI.create(host).resolve(endpoint);
617891
}
618892

619-
public class OpenAIClientResponseException extends RuntimeException {
620-
HttpResponse httpResponse;
893+
public static class OpenAIClientResponseException extends RuntimeException {
894+
HttpResponse<?> httpResponse;
621895

622-
public OpenAIClientResponseException(String message, HttpResponse httpResponse) {
896+
public OpenAIClientResponseException(String message, HttpResponse<?> httpResponse) {
623897
super(message);
624898
this.httpResponse = Objects.requireNonNull(httpResponse);
625899
}
626900

627-
public OpenAIClientResponseException(String message, Exception e, HttpResponse httpResponse) {
901+
public OpenAIClientResponseException(String message, Exception e, HttpResponse<?> httpResponse) {
628902
super(message, e);
629903
this.httpResponse = Objects.requireNonNull(httpResponse);
630904
}

0 commit comments

Comments
 (0)