Skip to content

Commit 3f6d7f2

Browse files
committed
Add upload/download file and create batch in OpenAIClient
1 parent 6ed5890 commit 3f6d7f2

File tree

2 files changed

+638
-10
lines changed

2 files changed

+638
-10
lines changed

common/src/main/java/com/box/l10n/mojito/openai/OpenAIClient.java

Lines changed: 296 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
import com.fasterxml.jackson.databind.node.ArrayNode;
1212
import com.fasterxml.jackson.databind.node.ObjectNode;
1313
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
14-
import com.fasterxml.jackson.module.jsonSchema.JsonSchema;
1514
import com.fasterxml.jackson.module.jsonSchema.JsonSchemaGenerator;
15+
import java.io.IOException;
1616
import java.net.URI;
1717
import java.net.http.HttpClient;
1818
import java.net.http.HttpRequest;
@@ -23,6 +23,7 @@
2323
import java.util.List;
2424
import java.util.Map;
2525
import java.util.Objects;
26+
import java.util.UUID;
2627
import java.util.concurrent.CompletableFuture;
2728
import java.util.function.Predicate;
2829
import java.util.stream.Collectors;
@@ -612,19 +613,309 @@ public record Usage(
612613
@JsonProperty("total_tokens") int totalTokens) {}
613614
}
614615

616+
public UploadFileResponse uploadFile(UploadFileRequest uploadFileRequest) {
617+
618+
final String boundary = UUID.randomUUID().toString();
619+
620+
String body = uploadFileRequest.getMultipartBody(boundary);
621+
622+
HttpRequest request =
623+
HttpRequest.newBuilder()
624+
.uri(getUriForEndpoint(UploadFileRequest.ENDPOINT))
625+
.header("Authorization", "Bearer " + apiKey)
626+
.header("Content-Type", "multipart/form-data; boundary=" + boundary)
627+
.POST(HttpRequest.BodyPublishers.ofString(body))
628+
.build();
629+
630+
HttpResponse<String> response;
631+
try {
632+
response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
633+
} catch (IOException | InterruptedException e) {
634+
throw new RuntimeException(e);
635+
}
636+
637+
if (response.statusCode() != 200) {
638+
throw new OpenAIClientResponseException("Can't upload file", response);
639+
}
640+
641+
UploadFileResponse uploadFileResponse;
642+
try {
643+
uploadFileResponse = objectMapper.readValue(response.body(), UploadFileResponse.class);
644+
} catch (JsonProcessingException e) {
645+
throw new RuntimeException(e);
646+
}
647+
return uploadFileResponse;
648+
}
649+
650+
public record UploadFileRequest(
651+
String purpose, String filename, String content, String contentType) {
652+
653+
static final String ENDPOINT = "/v1/files";
654+
655+
public static UploadFileRequest forBatch(String filename, String content) {
656+
return new UploadFileRequest(Purpose.BATCH.toString(), filename, content, "application/json");
657+
}
658+
659+
enum Purpose {
660+
BATCH("batch"),
661+
ASSISTANTS("assistants"),
662+
FINE_TUNE("fine-tune"),
663+
VISION("vision");
664+
665+
private final String purposeCode;
666+
667+
Purpose(String purposeCode) {
668+
this.purposeCode = purposeCode;
669+
}
670+
671+
public String getPurposeCode() {
672+
return purposeCode;
673+
}
674+
675+
@Override
676+
public String toString() {
677+
return purposeCode;
678+
}
679+
680+
public static Purpose fromCode(String purposeCode) {
681+
for (Purpose purpose : Purpose.values()) {
682+
if (purpose.purposeCode.equalsIgnoreCase(purposeCode)) {
683+
return purpose;
684+
}
685+
}
686+
throw new IllegalArgumentException("Unknown purpose: " + purposeCode);
687+
}
688+
}
689+
690+
String getMultipartBody(String boundary) {
691+
String body =
692+
"""
693+
--%1$s\r
694+
Content-Disposition: form-data; name="purpose"\r
695+
\r
696+
%5$s\r
697+
--%1$s\r
698+
Content-Disposition: form-data; name="file"; filename="%2$s"\r
699+
Content-Type: %3$s\r
700+
\r
701+
%4$s\r
702+
--%1$s--\r
703+
"""
704+
.formatted(boundary, filename, contentType, content, purpose);
705+
return body;
706+
}
707+
}
708+
709+
public record UploadFileResponse(
710+
String object,
711+
String id,
712+
String purpose,
713+
String filename,
714+
int bytes,
715+
@JsonProperty("created_at") long createdAt,
716+
String status,
717+
@JsonProperty("status_details") String statusDetails) {}
718+
719+
public record RequestBatchFileLine(
720+
@JsonProperty("custom_id") String customId, String method, String url, Object body) {
721+
722+
public static RequestBatchFileLine forChatCompletion(
723+
String customId, ChatCompletionsRequest chatCompletionsRequest) {
724+
return new RequestBatchFileLine(
725+
customId, "POST", "/v1/chat/completions", chatCompletionsRequest);
726+
}
727+
}
728+
729+
public record ChatCompletionResponseBatchFileLine(
730+
String id, @JsonProperty("custom_id") String customId, Response response) {
731+
public record Response(
732+
@JsonProperty("status_code") int statusCode,
733+
@JsonProperty("request_id") String requestId,
734+
@JsonProperty("body") ChatCompletionsResponse chatCompletionsResponse) {}
735+
}
736+
737+
public DownloadFileContentResponse downloadFileContent(
738+
DownloadFileContentRequest downloadFileContentRequest) {
739+
HttpResponse<String> response;
740+
HttpRequest request =
741+
HttpRequest.newBuilder()
742+
.uri(
743+
getUriForEndpoint(
744+
DownloadFileContentRequest.ENDPOINT.formatted(
745+
downloadFileContentRequest.fileId())))
746+
.header("Authorization", "Bearer " + apiKey)
747+
.header("Content-Type", "application/json")
748+
.GET()
749+
.build();
750+
751+
try {
752+
response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
753+
} catch (IOException | InterruptedException e) {
754+
throw new RuntimeException(
755+
"Error while sending the request to download the file: "
756+
+ downloadFileContentRequest.fileId(),
757+
e);
758+
}
759+
760+
if (response.statusCode() != 200) {
761+
throw new OpenAIClientResponseException("Can't download file content", response);
762+
}
763+
764+
return new DownloadFileContentResponse(response.body());
765+
}
766+
767+
public record DownloadFileContentRequest(String fileId) {
768+
static final String ENDPOINT = "/v1/files/%s/content";
769+
}
770+
771+
public record DownloadFileContentResponse(String content) {}
772+
773+
public CreateBatchResponse createBatch(CreateBatchRequest createBatchRequest) {
774+
775+
String jsonBody;
776+
try {
777+
jsonBody = objectMapper.writeValueAsString(createBatchRequest);
778+
} catch (JsonProcessingException e) {
779+
throw new RuntimeException(e);
780+
}
781+
782+
HttpRequest request =
783+
HttpRequest.newBuilder()
784+
.uri(getUriForEndpoint(CreateBatchRequest.ENDPOINT))
785+
.header("Authorization", "Bearer " + apiKey)
786+
.header("Content-Type", "application/json")
787+
.POST(HttpRequest.BodyPublishers.ofString(jsonBody))
788+
.build();
789+
790+
HttpResponse<String> response;
791+
try {
792+
response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
793+
} catch (IOException | InterruptedException e) {
794+
throw new RuntimeException(e);
795+
}
796+
797+
if (response.statusCode() != 200) {
798+
throw new OpenAIClientResponseException("Can't create batch", response);
799+
}
800+
801+
System.out.println(response.body());
802+
803+
CreateBatchResponse createBatchResponse;
804+
try {
805+
createBatchResponse = objectMapper.readValue(response.body(), CreateBatchResponse.class);
806+
} catch (JsonProcessingException e) {
807+
throw new RuntimeException(e);
808+
}
809+
810+
return createBatchResponse;
811+
}
812+
813+
public record CreateBatchRequest(
814+
@JsonProperty("input_file_id") String inputFileId,
815+
String endpoint,
816+
@JsonProperty("completion_window") String completionWindow,
817+
Map<String, String> metadata) {
818+
819+
public static final String ENDPOINT = "/v1/batches";
820+
821+
public static CreateBatchRequest forChatCompletion(
822+
String fileId, Map<String, String> metadata) {
823+
return new CreateBatchRequest(fileId, "/v1/chat/completions", "24h", metadata);
824+
}
825+
}
826+
827+
public record CreateBatchResponse(
828+
String id,
829+
String object,
830+
String endpoint,
831+
String errors,
832+
@JsonProperty("input_file_id") String inputFileId,
833+
@JsonProperty("completion_window") String completionWindow,
834+
String status,
835+
@JsonProperty("output_file_id") String outputFileId,
836+
@JsonProperty("error_file_id") String errorFileId,
837+
@JsonProperty("created_at") long createdAt,
838+
@JsonProperty("in_progress_at") Long inProgressAt,
839+
@JsonProperty("expires_at") long expiresAt,
840+
@JsonProperty("completed_at") Long completedAt,
841+
@JsonProperty("failed_at") Long failedAt,
842+
@JsonProperty("expired_at") Long expiredAt,
843+
@JsonProperty("request_counts") RequestCounts requestCounts,
844+
Map<String, String> metadata) {
845+
record RequestCounts(int total, int completed, int failed) {}
846+
}
847+
848+
public RetrieveBatchResponse retrieveBatch(RetrieveBatchRequest retrieveBatchRequest) {
849+
HttpRequest request =
850+
HttpRequest.newBuilder()
851+
.uri(
852+
getUriForEndpoint(
853+
RetrieveBatchRequest.ENDPOINT.formatted(retrieveBatchRequest.batchId())))
854+
.header("Authorization", "Bearer " + apiKey)
855+
.header("Content-Type", "application/json")
856+
.GET()
857+
.build();
858+
859+
HttpResponse<String> response;
860+
try {
861+
response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
862+
} catch (IOException | InterruptedException e) {
863+
throw new RuntimeException(e);
864+
}
865+
866+
if (response.statusCode() != 200) {
867+
throw new OpenAIClientResponseException("Can't retrieve batch", response);
868+
}
869+
870+
RetrieveBatchResponse retrieveBatchResponse;
871+
try {
872+
retrieveBatchResponse = objectMapper.readValue(response.body(), RetrieveBatchResponse.class);
873+
} catch (JsonProcessingException e) {
874+
throw new RuntimeException(e);
875+
}
876+
877+
return retrieveBatchResponse;
878+
}
879+
880+
public record RetrieveBatchRequest(String batchId) {
881+
public static final String ENDPOINT = "/v1/batches/%s";
882+
}
883+
884+
public record RetrieveBatchResponse(
885+
String id,
886+
String object,
887+
String endpoint,
888+
String errors,
889+
@JsonProperty("input_file_id") String inputFileId,
890+
@JsonProperty("completion_window") String completionWindow,
891+
String status,
892+
@JsonProperty("output_file_id") String outputFileId,
893+
@JsonProperty("error_file_id") String errorFileId,
894+
@JsonProperty("created_at") long createdAt,
895+
@JsonProperty("in_progress_at") Long inProgressAt,
896+
@JsonProperty("expires_at") long expiresAt,
897+
@JsonProperty("completed_at") Long completedAt,
898+
@JsonProperty("failed_at") Long failedAt,
899+
@JsonProperty("expired_at") Long expiredAt,
900+
@JsonProperty("request_counts") RequestCounts requestCounts,
901+
Map<String, String> metadata) {
902+
record RequestCounts(int total, int completed, int failed) {}
903+
}
904+
615905
private URI getUriForEndpoint(String endpoint) {
616906
return URI.create(host).resolve(endpoint);
617907
}
618908

619-
public class OpenAIClientResponseException extends RuntimeException {
620-
HttpResponse httpResponse;
909+
public static class OpenAIClientResponseException extends RuntimeException {
910+
HttpResponse<?> httpResponse;
621911

622-
public OpenAIClientResponseException(String message, HttpResponse httpResponse) {
912+
public OpenAIClientResponseException(String message, HttpResponse<?> httpResponse) {
623913
super(message);
624914
this.httpResponse = Objects.requireNonNull(httpResponse);
625915
}
626916

627-
public OpenAIClientResponseException(String message, Exception e, HttpResponse httpResponse) {
917+
public OpenAIClientResponseException(
918+
String message, Exception e, HttpResponse<?> httpResponse) {
628919
super(message, e);
629920
this.httpResponse = Objects.requireNonNull(httpResponse);
630921
}

0 commit comments

Comments
 (0)