1111import com .fasterxml .jackson .databind .node .ArrayNode ;
1212import com .fasterxml .jackson .databind .node .ObjectNode ;
1313import com .fasterxml .jackson .datatype .jsr310 .JavaTimeModule ;
14- import com .fasterxml .jackson .module .jsonSchema .JsonSchema ;
1514import com .fasterxml .jackson .module .jsonSchema .JsonSchemaGenerator ;
15+ import java .io .IOException ;
1616import java .net .URI ;
1717import java .net .http .HttpClient ;
1818import java .net .http .HttpRequest ;
2323import java .util .List ;
2424import java .util .Map ;
2525import java .util .Objects ;
26+ import java .util .UUID ;
2627import java .util .concurrent .CompletableFuture ;
2728import java .util .function .Predicate ;
2829import java .util .stream .Collectors ;
@@ -144,7 +145,7 @@ public CompletableFuture<ChatCompletionsResponse> getChatCompletions(
144145 httpResponse .body (), ChatCompletionsResponse .class );
145146 } catch (JsonProcessingException e ) {
146147 throw new OpenAIClientResponseException (
147- "Can't deserialize ChatCompletionsResponse" , e , httpResponse );
148+ "Can't deserialize ChatCompletionsResponse" , e , httpResponse );
148149 }
149150 }
150151 });
@@ -182,7 +183,7 @@ public CompletableFuture<Stream<ChatCompletionsStreamResponse>> streamChatComple
182183 httpResponse -> {
183184 if (httpResponse .statusCode () != 200 ) {
184185 throw new OpenAIClientResponseException (
185- "ChatCompletion stream failed" , httpResponse );
186+ "ChatCompletion stream failed" , httpResponse );
186187 }
187188 return httpResponse
188189 .body ()
@@ -192,9 +193,9 @@ public CompletableFuture<Stream<ChatCompletionsStreamResponse>> streamChatComple
192193 body -> {
193194 if (!body .startsWith ("data: " )) {
194195 throw new OpenAIClientResponseException (
195- "Only support \" data\" lines in stream are supported, got: \" %s\" "
196- .formatted (body ),
197- httpResponse );
196+ "Only support \" data\" lines in stream are supported, got: \" %s\" "
197+ .formatted (body ),
198+ httpResponse );
198199 }
199200
200201 String jsonPart = body .substring (5 );
@@ -203,9 +204,9 @@ public CompletableFuture<Stream<ChatCompletionsStreamResponse>> streamChatComple
203204 jsonPart , ChatCompletionsStreamResponse .class );
204205 } catch (JsonProcessingException e ) {
205206 throw new OpenAIClientResponseException (
206- "Can't deserialize ChatCompletionsStreamResponse" ,
207- e ,
208- httpResponse );
207+ "Can't deserialize ChatCompletionsStreamResponse" ,
208+ e ,
209+ httpResponse );
209210 }
210211 });
211212 });
@@ -551,7 +552,7 @@ public CompletableFuture<EmbeddingResponse> getEmbedding(EmbeddingRequest embedd
551552 return objectMapper .readValue (httpResponse .body (), EmbeddingResponse .class );
552553 } catch (JsonProcessingException e ) {
553554 throw new OpenAIClientResponseException (
554- "Can't deserialize EmbeddingResponse" , e , httpResponse );
555+ "Can't deserialize EmbeddingResponse" , e , httpResponse );
555556 }
556557 });
557558
@@ -612,19 +613,292 @@ public record Usage(
612613 @ JsonProperty ("total_tokens" ) int totalTokens ) {}
613614 }
614615
616+ public UploadFileResponse uploadFile (UploadFileRequest uploadFileRequest ) {
617+
618+ final String boundary = UUID .randomUUID ().toString ();
619+
620+ String body = uploadFileRequest .getMultipartBody (boundary );
621+
622+ HttpRequest request =
623+ HttpRequest .newBuilder ()
624+ .uri (getUriForEndpoint (UploadFileRequest .ENDPOINT ))
625+ .header ("Authorization" , "Bearer " + apiKey )
626+ .header ("Content-Type" , "multipart/form-data; boundary=" + boundary )
627+ .POST (HttpRequest .BodyPublishers .ofString (body ))
628+ .build ();
629+
630+ HttpResponse <String > response ;
631+ try {
632+ response = httpClient .send (request , HttpResponse .BodyHandlers .ofString ());
633+ System .out .printf ("response: %s%n" , response .body ());
634+ } catch (IOException | InterruptedException e ) {
635+ throw new RuntimeException (e );
636+ }
637+
638+ UploadFileResponse uploadFileResponse ;
639+ try {
640+ uploadFileResponse = objectMapper .readValue (response .body (), UploadFileResponse .class );
641+ } catch (JsonProcessingException e ) {
642+ throw new RuntimeException (e );
643+ }
644+ return uploadFileResponse ;
645+ }
646+
647+ public record UploadFileRequest (
648+ String purpose , String filename , String content , String contentType ) {
649+
650+ public static UploadFileRequest forBatch (String filename , String content ) {
651+ return new UploadFileRequest (Purpose .BATCH .toString (), filename , content , "application/json" );
652+ }
653+
654+ enum Purpose {
655+ BATCH ("batch" ),
656+ ASSISTANTS ("assistants" ),
657+ FINE_TUNE ("fine-tune" ),
658+ VISION ("vision" );
659+
660+ private final String purposeCode ;
661+
662+ Purpose (String purposeCode ) {
663+ this .purposeCode = purposeCode ;
664+ }
665+
666+ public String getPurposeCode () {
667+ return purposeCode ;
668+ }
669+
670+ @ Override
671+ public String toString () {
672+ return purposeCode ;
673+ }
674+
675+ public static Purpose fromCode (String purposeCode ) {
676+ for (Purpose purpose : Purpose .values ()) {
677+ if (purpose .purposeCode .equalsIgnoreCase (purposeCode )) {
678+ return purpose ;
679+ }
680+ }
681+ throw new IllegalArgumentException ("Unknown purpose: " + purposeCode );
682+ }
683+ }
684+
685+ static final String ENDPOINT = "/v1/files" ;
686+
687+ String getMultipartBody (String boundary ) {
688+ String body =
689+ """
690+ --%1$s\r
691+ Content-Disposition: form-data; name="purpose"\r
692+ \r
693+ %5$s\r
694+ --%1$s\r
695+ Content-Disposition: form-data; name="file"; filename="%2$s"\r
696+ Content-Type: %3$s\r
697+ \r
698+ %4$s\r
699+ --%1$s--\r
700+ """
701+ .formatted (boundary , filename , contentType , content , purpose );
702+ return body ;
703+ }
704+ }
705+
706+ public record UploadFileResponse (
707+ String object ,
708+ String id ,
709+ String purpose ,
710+ String filename ,
711+ int bytes ,
712+ @ JsonProperty ("created_at" ) long createdAt ,
713+ String status ,
714+ @ JsonProperty ("status_details" ) String statusDetails ) {}
715+
716+ public record BatchFileLine (
717+ @ JsonProperty ("custom_id" ) String customId , String method , String url , Object body ) {
718+
719+ public static BatchFileLine forChatCompletion (
720+ String customId , ChatCompletionsRequest chatCompletionsRequest ) {
721+ return new BatchFileLine (customId , "POST" , "/v1/chat/completions" , chatCompletionsRequest );
722+ }
723+ }
724+
725+ public DownloadFileContentResponse downloadFileContent (
726+ DownloadFileContentRequest downloadFileContentRequest ) {
727+ HttpResponse <String > response ;
728+ HttpRequest request =
729+ HttpRequest .newBuilder ()
730+ .uri (
731+ getUriForEndpoint (
732+ DownloadFileContentRequest .ENDPOINT .formatted (
733+ downloadFileContentRequest .fileId ())))
734+ .header ("Authorization" , "Bearer " + apiKey )
735+ .header ("Content-Type" , "application/json" )
736+ .GET ()
737+ .build ();
738+
739+ try {
740+ response = httpClient .send (request , HttpResponse .BodyHandlers .ofString ());
741+ } catch (IOException | InterruptedException e ) {
742+ throw new RuntimeException (
743+ "Error while sending the request to download the file: "
744+ + downloadFileContentRequest .fileId (),
745+ e );
746+ }
747+
748+ if (response .statusCode () != 200 ) {
749+ throw new OpenAIClientResponseException ("Can't download file content" , response );
750+ }
751+
752+ return new DownloadFileContentResponse (response .body ());
753+ }
754+
755+ public record DownloadFileContentRequest (String fileId ) {
756+ static final String ENDPOINT = "/v1/files/%s/content" ;
757+ }
758+
759+ public record DownloadFileContentResponse (String content ) {}
760+
761+ public CreateBatchResponse createBatch (CreateBatchRequest createBatchRequest ) {
762+
763+ String jsonBody ;
764+ try {
765+ jsonBody = objectMapper .writeValueAsString (createBatchRequest );
766+ } catch (JsonProcessingException e ) {
767+ throw new RuntimeException (e );
768+ }
769+
770+ HttpRequest request =
771+ HttpRequest .newBuilder ()
772+ .uri (getUriForEndpoint (CreateBatchRequest .ENDPOINT ))
773+ .header ("Authorization" , "Bearer " + apiKey )
774+ .header ("Content-Type" , "application/json" )
775+ .POST (HttpRequest .BodyPublishers .ofString (jsonBody ))
776+ .build ();
777+
778+ HttpResponse <String > response ;
779+ try {
780+ response = httpClient .send (request , HttpResponse .BodyHandlers .ofString ());
781+ } catch (IOException | InterruptedException e ) {
782+ throw new RuntimeException (e );
783+ }
784+
785+ if (response .statusCode () != 200 ) {
786+ throw new RuntimeException ("Can't create batch" );
787+ }
788+
789+ CreateBatchResponse createBatchResponse ;
790+ try {
791+ createBatchResponse = objectMapper .readValue (response .body (), CreateBatchResponse .class );
792+ } catch (JsonProcessingException e ) {
793+ throw new RuntimeException (e );
794+ }
795+
796+ return createBatchResponse ;
797+ }
798+
799+ public record CreateBatchRequest (
800+ @ JsonProperty ("input_file_id" ) String inputFileId ,
801+ String endpoint ,
802+ @ JsonProperty ("completion_window" ) String completionWindow ) {
803+
804+ public static final String ENDPOINT = "/v1/batches" ;
805+
806+ public static CreateBatchRequest forChatCompletion (String fileId ) {
807+ return new CreateBatchRequest (fileId , "/v1/chat/completions" , "24h" );
808+ }
809+ }
810+
811+ public record CreateBatchResponse (
812+ String id ,
813+ String object ,
814+ String endpoint ,
815+ String errors ,
816+ @ JsonProperty ("input_file_id" ) String inputFileId ,
817+ @ JsonProperty ("completion_window" ) String completionWindow ,
818+ String status ,
819+ @ JsonProperty ("output_file_id" ) String outputFileId ,
820+ @ JsonProperty ("error_file_id" ) String errorFileId ,
821+ @ JsonProperty ("created_at" ) long createdAt ,
822+ @ JsonProperty ("in_progress_at" ) Long inProgressAt ,
823+ @ JsonProperty ("expires_at" ) long expiresAt ,
824+ @ JsonProperty ("completed_at" ) Long completedAt ,
825+ @ JsonProperty ("failed_at" ) Long failedAt ,
826+ @ JsonProperty ("expired_at" ) Long expiredAt ,
827+ @ JsonProperty ("request_counts" ) RequestCounts requestCounts ,
828+ Object metadata ) {
829+ record RequestCounts (int total , int completed , int failed ) {}
830+ }
831+
832+ public RetrieveBatchResponse retrieveBatch (RetrieveBatchRequest retrieveBatchRequest ) {
833+ HttpRequest request =
834+ HttpRequest .newBuilder ()
835+ .uri (
836+ getUriForEndpoint (
837+ RetrieveBatchRequest .ENDPOINT .formatted (retrieveBatchRequest .batchId ())))
838+ .header ("Authorization" , "Bearer " + apiKey )
839+ .header ("Content-Type" , "application/json" )
840+ .GET ()
841+ .build ();
842+
843+ HttpResponse <String > response ;
844+ try {
845+ response = httpClient .send (request , HttpResponse .BodyHandlers .ofString ());
846+ } catch (IOException | InterruptedException e ) {
847+ throw new RuntimeException (e );
848+ }
849+
850+ if (response .statusCode () != 200 ) {
851+ throw new OpenAIClientResponseException ("Can't retrieve batch" , response );
852+ }
853+
854+ RetrieveBatchResponse retrieveBatchResponse ;
855+ try {
856+ retrieveBatchResponse = objectMapper .readValue (response .body (), RetrieveBatchResponse .class );
857+ } catch (JsonProcessingException e ) {
858+ throw new RuntimeException (e );
859+ }
860+
861+ return retrieveBatchResponse ;
862+ }
863+
864+ public record RetrieveBatchRequest (String batchId ) {
865+ public static final String ENDPOINT = "/v1/batches/%s" ;
866+ }
867+
868+ public record RetrieveBatchResponse (
869+ String id ,
870+ String object ,
871+ String endpoint ,
872+ String errors ,
873+ @ JsonProperty ("input_file_id" ) String inputFileId ,
874+ @ JsonProperty ("completion_window" ) String completionWindow ,
875+ String status ,
876+ @ JsonProperty ("output_file_id" ) String outputFileId ,
877+ @ JsonProperty ("error_file_id" ) String errorFileId ,
878+ @ JsonProperty ("created_at" ) long createdAt ,
879+ @ JsonProperty ("in_progress_at" ) Long inProgressAt ,
880+ @ JsonProperty ("expires_at" ) long expiresAt ,
881+ @ JsonProperty ("completed_at" ) Long completedAt ,
882+ @ JsonProperty ("failed_at" ) Long failedAt ,
883+ @ JsonProperty ("expired_at" ) Long expiredAt ,
884+ @ JsonProperty ("request_counts" ) RequestCounts requestCounts ,
885+ Object metadata ) {
886+ record RequestCounts (int total , int completed , int failed ) {}
887+ }
888+
615889 private URI getUriForEndpoint (String endpoint ) {
616890 return URI .create (host ).resolve (endpoint );
617891 }
618892
619- public class OpenAIClientResponseException extends RuntimeException {
620- HttpResponse httpResponse ;
893+ public static class OpenAIClientResponseException extends RuntimeException {
894+ HttpResponse <?> httpResponse ;
621895
622- public OpenAIClientResponseException (String message , HttpResponse httpResponse ) {
896+ public OpenAIClientResponseException (String message , HttpResponse <?> httpResponse ) {
623897 super (message );
624898 this .httpResponse = Objects .requireNonNull (httpResponse );
625899 }
626900
627- public OpenAIClientResponseException (String message , Exception e , HttpResponse httpResponse ) {
901+ public OpenAIClientResponseException (String message , Exception e , HttpResponse <?> httpResponse ) {
628902 super (message , e );
629903 this .httpResponse = Objects .requireNonNull (httpResponse );
630904 }
0 commit comments