|  | 
| 11 | 11 | import com.fasterxml.jackson.databind.node.ArrayNode; | 
| 12 | 12 | import com.fasterxml.jackson.databind.node.ObjectNode; | 
| 13 | 13 | import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; | 
| 14 |  | -import com.fasterxml.jackson.module.jsonSchema.JsonSchema; | 
| 15 | 14 | import com.fasterxml.jackson.module.jsonSchema.JsonSchemaGenerator; | 
|  | 15 | +import java.io.IOException; | 
| 16 | 16 | import java.net.URI; | 
| 17 | 17 | import java.net.http.HttpClient; | 
| 18 | 18 | import java.net.http.HttpRequest; | 
|  | 
| 23 | 23 | import java.util.List; | 
| 24 | 24 | import java.util.Map; | 
| 25 | 25 | import java.util.Objects; | 
|  | 26 | +import java.util.UUID; | 
| 26 | 27 | import java.util.concurrent.CompletableFuture; | 
| 27 | 28 | import java.util.function.Predicate; | 
| 28 | 29 | import java.util.stream.Collectors; | 
| @@ -612,19 +613,309 @@ public record Usage( | 
| 612 | 613 |         @JsonProperty("total_tokens") int totalTokens) {} | 
| 613 | 614 |   } | 
| 614 | 615 | 
 | 
|  | 616 | +  public UploadFileResponse uploadFile(UploadFileRequest uploadFileRequest) { | 
|  | 617 | + | 
|  | 618 | +    final String boundary = UUID.randomUUID().toString(); | 
|  | 619 | + | 
|  | 620 | +    String body = uploadFileRequest.getMultipartBody(boundary); | 
|  | 621 | + | 
|  | 622 | +    HttpRequest request = | 
|  | 623 | +        HttpRequest.newBuilder() | 
|  | 624 | +            .uri(getUriForEndpoint(UploadFileRequest.ENDPOINT)) | 
|  | 625 | +            .header("Authorization", "Bearer " + apiKey) | 
|  | 626 | +            .header("Content-Type", "multipart/form-data; boundary=" + boundary) | 
|  | 627 | +            .POST(HttpRequest.BodyPublishers.ofString(body)) | 
|  | 628 | +            .build(); | 
|  | 629 | + | 
|  | 630 | +    HttpResponse<String> response; | 
|  | 631 | +    try { | 
|  | 632 | +      response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); | 
|  | 633 | +    } catch (IOException | InterruptedException e) { | 
|  | 634 | +      throw new RuntimeException(e); | 
|  | 635 | +    } | 
|  | 636 | + | 
|  | 637 | +    if (response.statusCode() != 200) { | 
|  | 638 | +      throw new OpenAIClientResponseException("Can't upload file", response); | 
|  | 639 | +    } | 
|  | 640 | + | 
|  | 641 | +    UploadFileResponse uploadFileResponse; | 
|  | 642 | +    try { | 
|  | 643 | +      uploadFileResponse = objectMapper.readValue(response.body(), UploadFileResponse.class); | 
|  | 644 | +    } catch (JsonProcessingException e) { | 
|  | 645 | +      throw new RuntimeException(e); | 
|  | 646 | +    } | 
|  | 647 | +    return uploadFileResponse; | 
|  | 648 | +  } | 
|  | 649 | + | 
|  | 650 | +  public record UploadFileRequest( | 
|  | 651 | +      String purpose, String filename, String content, String contentType) { | 
|  | 652 | + | 
|  | 653 | +    static final String ENDPOINT = "/v1/files"; | 
|  | 654 | + | 
|  | 655 | +    public static UploadFileRequest forBatch(String filename, String content) { | 
|  | 656 | +      return new UploadFileRequest(Purpose.BATCH.toString(), filename, content, "application/json"); | 
|  | 657 | +    } | 
|  | 658 | + | 
|  | 659 | +    enum Purpose { | 
|  | 660 | +      BATCH("batch"), | 
|  | 661 | +      ASSISTANTS("assistants"), | 
|  | 662 | +      FINE_TUNE("fine-tune"), | 
|  | 663 | +      VISION("vision"); | 
|  | 664 | + | 
|  | 665 | +      private final String purposeCode; | 
|  | 666 | + | 
|  | 667 | +      Purpose(String purposeCode) { | 
|  | 668 | +        this.purposeCode = purposeCode; | 
|  | 669 | +      } | 
|  | 670 | + | 
|  | 671 | +      public String getPurposeCode() { | 
|  | 672 | +        return purposeCode; | 
|  | 673 | +      } | 
|  | 674 | + | 
|  | 675 | +      @Override | 
|  | 676 | +      public String toString() { | 
|  | 677 | +        return purposeCode; | 
|  | 678 | +      } | 
|  | 679 | + | 
|  | 680 | +      public static Purpose fromCode(String purposeCode) { | 
|  | 681 | +        for (Purpose purpose : Purpose.values()) { | 
|  | 682 | +          if (purpose.purposeCode.equalsIgnoreCase(purposeCode)) { | 
|  | 683 | +            return purpose; | 
|  | 684 | +          } | 
|  | 685 | +        } | 
|  | 686 | +        throw new IllegalArgumentException("Unknown purpose: " + purposeCode); | 
|  | 687 | +      } | 
|  | 688 | +    } | 
|  | 689 | + | 
|  | 690 | +    String getMultipartBody(String boundary) { | 
|  | 691 | +      String body = | 
|  | 692 | +          """ | 
|  | 693 | +      --%1$s\r | 
|  | 694 | +      Content-Disposition: form-data; name="purpose"\r | 
|  | 695 | +      \r | 
|  | 696 | +      %5$s\r | 
|  | 697 | +      --%1$s\r | 
|  | 698 | +      Content-Disposition: form-data; name="file"; filename="%2$s"\r | 
|  | 699 | +      Content-Type: %3$s\r | 
|  | 700 | +      \r | 
|  | 701 | +      %4$s\r | 
|  | 702 | +      --%1$s--\r | 
|  | 703 | +      """ | 
|  | 704 | +              .formatted(boundary, filename, contentType, content, purpose); | 
|  | 705 | +      return body; | 
|  | 706 | +    } | 
|  | 707 | +  } | 
|  | 708 | + | 
|  | 709 | +  public record UploadFileResponse( | 
|  | 710 | +      String object, | 
|  | 711 | +      String id, | 
|  | 712 | +      String purpose, | 
|  | 713 | +      String filename, | 
|  | 714 | +      int bytes, | 
|  | 715 | +      @JsonProperty("created_at") long createdAt, | 
|  | 716 | +      String status, | 
|  | 717 | +      @JsonProperty("status_details") String statusDetails) {} | 
|  | 718 | + | 
|  | 719 | +  public record RequestBatchFileLine( | 
|  | 720 | +      @JsonProperty("custom_id") String customId, String method, String url, Object body) { | 
|  | 721 | + | 
|  | 722 | +    public static RequestBatchFileLine forChatCompletion( | 
|  | 723 | +        String customId, ChatCompletionsRequest chatCompletionsRequest) { | 
|  | 724 | +      return new RequestBatchFileLine( | 
|  | 725 | +          customId, "POST", "/v1/chat/completions", chatCompletionsRequest); | 
|  | 726 | +    } | 
|  | 727 | +  } | 
|  | 728 | + | 
|  | 729 | +  public record ChatCompletionResponseBatchFileLine( | 
|  | 730 | +      String id, @JsonProperty("custom_id") String customId, Response response) { | 
|  | 731 | +    public record Response( | 
|  | 732 | +        @JsonProperty("status_code") int statusCode, | 
|  | 733 | +        @JsonProperty("request_id") String requestId, | 
|  | 734 | +        @JsonProperty("body") ChatCompletionsResponse chatCompletionsResponse) {} | 
|  | 735 | +  } | 
|  | 736 | + | 
|  | 737 | +  public DownloadFileContentResponse downloadFileContent( | 
|  | 738 | +      DownloadFileContentRequest downloadFileContentRequest) { | 
|  | 739 | +    HttpResponse<String> response; | 
|  | 740 | +    HttpRequest request = | 
|  | 741 | +        HttpRequest.newBuilder() | 
|  | 742 | +            .uri( | 
|  | 743 | +                getUriForEndpoint( | 
|  | 744 | +                    DownloadFileContentRequest.ENDPOINT.formatted( | 
|  | 745 | +                        downloadFileContentRequest.fileId()))) | 
|  | 746 | +            .header("Authorization", "Bearer " + apiKey) | 
|  | 747 | +            .header("Content-Type", "application/json") | 
|  | 748 | +            .GET() | 
|  | 749 | +            .build(); | 
|  | 750 | + | 
|  | 751 | +    try { | 
|  | 752 | +      response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); | 
|  | 753 | +    } catch (IOException | InterruptedException e) { | 
|  | 754 | +      throw new RuntimeException( | 
|  | 755 | +          "Error while sending the request to download the file: " | 
|  | 756 | +              + downloadFileContentRequest.fileId(), | 
|  | 757 | +          e); | 
|  | 758 | +    } | 
|  | 759 | + | 
|  | 760 | +    if (response.statusCode() != 200) { | 
|  | 761 | +      throw new OpenAIClientResponseException("Can't download file content", response); | 
|  | 762 | +    } | 
|  | 763 | + | 
|  | 764 | +    return new DownloadFileContentResponse(response.body()); | 
|  | 765 | +  } | 
|  | 766 | + | 
|  | 767 | +  public record DownloadFileContentRequest(String fileId) { | 
|  | 768 | +    static final String ENDPOINT = "/v1/files/%s/content"; | 
|  | 769 | +  } | 
|  | 770 | + | 
|  | 771 | +  public record DownloadFileContentResponse(String content) {} | 
|  | 772 | + | 
|  | 773 | +  public CreateBatchResponse createBatch(CreateBatchRequest createBatchRequest) { | 
|  | 774 | + | 
|  | 775 | +    String jsonBody; | 
|  | 776 | +    try { | 
|  | 777 | +      jsonBody = objectMapper.writeValueAsString(createBatchRequest); | 
|  | 778 | +    } catch (JsonProcessingException e) { | 
|  | 779 | +      throw new RuntimeException(e); | 
|  | 780 | +    } | 
|  | 781 | + | 
|  | 782 | +    HttpRequest request = | 
|  | 783 | +        HttpRequest.newBuilder() | 
|  | 784 | +            .uri(getUriForEndpoint(CreateBatchRequest.ENDPOINT)) | 
|  | 785 | +            .header("Authorization", "Bearer " + apiKey) | 
|  | 786 | +            .header("Content-Type", "application/json") | 
|  | 787 | +            .POST(HttpRequest.BodyPublishers.ofString(jsonBody)) | 
|  | 788 | +            .build(); | 
|  | 789 | + | 
|  | 790 | +    HttpResponse<String> response; | 
|  | 791 | +    try { | 
|  | 792 | +      response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); | 
|  | 793 | +    } catch (IOException | InterruptedException e) { | 
|  | 794 | +      throw new RuntimeException(e); | 
|  | 795 | +    } | 
|  | 796 | + | 
|  | 797 | +    if (response.statusCode() != 200) { | 
|  | 798 | +      throw new OpenAIClientResponseException("Can't create batch", response); | 
|  | 799 | +    } | 
|  | 800 | + | 
|  | 801 | +    System.out.println(response.body()); | 
|  | 802 | + | 
|  | 803 | +    CreateBatchResponse createBatchResponse; | 
|  | 804 | +    try { | 
|  | 805 | +      createBatchResponse = objectMapper.readValue(response.body(), CreateBatchResponse.class); | 
|  | 806 | +    } catch (JsonProcessingException e) { | 
|  | 807 | +      throw new RuntimeException(e); | 
|  | 808 | +    } | 
|  | 809 | + | 
|  | 810 | +    return createBatchResponse; | 
|  | 811 | +  } | 
|  | 812 | + | 
|  | 813 | +  public record CreateBatchRequest( | 
|  | 814 | +      @JsonProperty("input_file_id") String inputFileId, | 
|  | 815 | +      String endpoint, | 
|  | 816 | +      @JsonProperty("completion_window") String completionWindow, | 
|  | 817 | +      Map<String, String> metadata) { | 
|  | 818 | + | 
|  | 819 | +    public static final String ENDPOINT = "/v1/batches"; | 
|  | 820 | + | 
|  | 821 | +    public static CreateBatchRequest forChatCompletion( | 
|  | 822 | +        String fileId, Map<String, String> metadata) { | 
|  | 823 | +      return new CreateBatchRequest(fileId, "/v1/chat/completions", "24h", metadata); | 
|  | 824 | +    } | 
|  | 825 | +  } | 
|  | 826 | + | 
|  | 827 | +  public record CreateBatchResponse( | 
|  | 828 | +      String id, | 
|  | 829 | +      String object, | 
|  | 830 | +      String endpoint, | 
|  | 831 | +      String errors, | 
|  | 832 | +      @JsonProperty("input_file_id") String inputFileId, | 
|  | 833 | +      @JsonProperty("completion_window") String completionWindow, | 
|  | 834 | +      String status, | 
|  | 835 | +      @JsonProperty("output_file_id") String outputFileId, | 
|  | 836 | +      @JsonProperty("error_file_id") String errorFileId, | 
|  | 837 | +      @JsonProperty("created_at") long createdAt, | 
|  | 838 | +      @JsonProperty("in_progress_at") Long inProgressAt, | 
|  | 839 | +      @JsonProperty("expires_at") long expiresAt, | 
|  | 840 | +      @JsonProperty("completed_at") Long completedAt, | 
|  | 841 | +      @JsonProperty("failed_at") Long failedAt, | 
|  | 842 | +      @JsonProperty("expired_at") Long expiredAt, | 
|  | 843 | +      @JsonProperty("request_counts") RequestCounts requestCounts, | 
|  | 844 | +      Map<String, String> metadata) { | 
|  | 845 | +    record RequestCounts(int total, int completed, int failed) {} | 
|  | 846 | +  } | 
|  | 847 | + | 
|  | 848 | +  public RetrieveBatchResponse retrieveBatch(RetrieveBatchRequest retrieveBatchRequest) { | 
|  | 849 | +    HttpRequest request = | 
|  | 850 | +        HttpRequest.newBuilder() | 
|  | 851 | +            .uri( | 
|  | 852 | +                getUriForEndpoint( | 
|  | 853 | +                    RetrieveBatchRequest.ENDPOINT.formatted(retrieveBatchRequest.batchId()))) | 
|  | 854 | +            .header("Authorization", "Bearer " + apiKey) | 
|  | 855 | +            .header("Content-Type", "application/json") | 
|  | 856 | +            .GET() | 
|  | 857 | +            .build(); | 
|  | 858 | + | 
|  | 859 | +    HttpResponse<String> response; | 
|  | 860 | +    try { | 
|  | 861 | +      response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); | 
|  | 862 | +    } catch (IOException | InterruptedException e) { | 
|  | 863 | +      throw new RuntimeException(e); | 
|  | 864 | +    } | 
|  | 865 | + | 
|  | 866 | +    if (response.statusCode() != 200) { | 
|  | 867 | +      throw new OpenAIClientResponseException("Can't retrieve batch", response); | 
|  | 868 | +    } | 
|  | 869 | + | 
|  | 870 | +    RetrieveBatchResponse retrieveBatchResponse; | 
|  | 871 | +    try { | 
|  | 872 | +      retrieveBatchResponse = objectMapper.readValue(response.body(), RetrieveBatchResponse.class); | 
|  | 873 | +    } catch (JsonProcessingException e) { | 
|  | 874 | +      throw new RuntimeException(e); | 
|  | 875 | +    } | 
|  | 876 | + | 
|  | 877 | +    return retrieveBatchResponse; | 
|  | 878 | +  } | 
|  | 879 | + | 
|  | 880 | +  public record RetrieveBatchRequest(String batchId) { | 
|  | 881 | +    public static final String ENDPOINT = "/v1/batches/%s"; | 
|  | 882 | +  } | 
|  | 883 | + | 
|  | 884 | +  public record RetrieveBatchResponse( | 
|  | 885 | +      String id, | 
|  | 886 | +      String object, | 
|  | 887 | +      String endpoint, | 
|  | 888 | +      String errors, | 
|  | 889 | +      @JsonProperty("input_file_id") String inputFileId, | 
|  | 890 | +      @JsonProperty("completion_window") String completionWindow, | 
|  | 891 | +      String status, | 
|  | 892 | +      @JsonProperty("output_file_id") String outputFileId, | 
|  | 893 | +      @JsonProperty("error_file_id") String errorFileId, | 
|  | 894 | +      @JsonProperty("created_at") long createdAt, | 
|  | 895 | +      @JsonProperty("in_progress_at") Long inProgressAt, | 
|  | 896 | +      @JsonProperty("expires_at") long expiresAt, | 
|  | 897 | +      @JsonProperty("completed_at") Long completedAt, | 
|  | 898 | +      @JsonProperty("failed_at") Long failedAt, | 
|  | 899 | +      @JsonProperty("expired_at") Long expiredAt, | 
|  | 900 | +      @JsonProperty("request_counts") RequestCounts requestCounts, | 
|  | 901 | +      Map<String, String> metadata) { | 
|  | 902 | +    record RequestCounts(int total, int completed, int failed) {} | 
|  | 903 | +  } | 
|  | 904 | + | 
| 615 | 905 |   private URI getUriForEndpoint(String endpoint) { | 
| 616 | 906 |     return URI.create(host).resolve(endpoint); | 
| 617 | 907 |   } | 
| 618 | 908 | 
 | 
| 619 |  | -  public class OpenAIClientResponseException extends RuntimeException { | 
| 620 |  | -    HttpResponse httpResponse; | 
|  | 909 | +  public static class OpenAIClientResponseException extends RuntimeException { | 
|  | 910 | +    HttpResponse<?> httpResponse; | 
| 621 | 911 | 
 | 
| 622 |  | -    public OpenAIClientResponseException(String message, HttpResponse httpResponse) { | 
|  | 912 | +    public OpenAIClientResponseException(String message, HttpResponse<?> httpResponse) { | 
| 623 | 913 |       super(message); | 
| 624 | 914 |       this.httpResponse = Objects.requireNonNull(httpResponse); | 
| 625 | 915 |     } | 
| 626 | 916 | 
 | 
| 627 |  | -    public OpenAIClientResponseException(String message, Exception e, HttpResponse httpResponse) { | 
|  | 917 | +    public OpenAIClientResponseException( | 
|  | 918 | +        String message, Exception e, HttpResponse<?> httpResponse) { | 
| 628 | 919 |       super(message, e); | 
| 629 | 920 |       this.httpResponse = Objects.requireNonNull(httpResponse); | 
| 630 | 921 |     } | 
|  | 
0 commit comments