30
30
import org .springframework .ai .openai .api .common .ApiUtils ;
31
31
import org .springframework .boot .context .properties .bind .ConstructorBinding ;
32
32
import org .springframework .core .ParameterizedTypeReference ;
33
+ import org .springframework .http .MediaType ;
33
34
import org .springframework .http .ResponseEntity ;
34
35
import org .springframework .util .Assert ;
35
36
import org .springframework .util .CollectionUtils ;
37
+ import org .springframework .util .MultiValueMap ;
36
38
import org .springframework .web .client .RestClient ;
37
39
import org .springframework .web .reactive .function .client .WebClient ;
38
40
42
44
* OpenAI Embedding API: https://platform.openai.com/docs/api-reference/embeddings.
43
45
*
44
46
* @author Christian Tzolov
47
+ * @author Michael Lavelle
45
48
*/
46
49
public class OpenAiApi {
47
50
@@ -50,6 +53,9 @@ public class OpenAiApi {
50
53
private static final Predicate <String > SSE_DONE_PREDICATE = "[DONE]" ::equals ;
51
54
52
55
private final RestClient restClient ;
56
+
57
+ private final RestClient multipartRestClient ;
58
+
53
59
private final WebClient webClient ;
54
60
55
61
/**
@@ -86,6 +92,15 @@ public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClie
86
92
.defaultStatusHandler (ApiUtils .DEFAULT_RESPONSE_ERROR_HANDLER )
87
93
.build ();
88
94
95
+ this .multipartRestClient = restClientBuilder
96
+ .baseUrl (baseUrl )
97
+ .defaultHeaders (multipartFormDataHeaders -> {
98
+ multipartFormDataHeaders .setBearerAuth (openAiToken );
99
+ multipartFormDataHeaders .setContentType (MediaType .MULTIPART_FORM_DATA );
100
+ })
101
+ .defaultStatusHandler (ApiUtils .DEFAULT_RESPONSE_ERROR_HANDLER )
102
+ .build ();
103
+
89
104
this .webClient = WebClient .builder ()
90
105
.baseUrl (baseUrl )
91
106
.defaultHeaders (ApiUtils .getJsonContentHeaders (openAiToken ))
@@ -97,7 +112,7 @@ public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClie
97
112
* <a href="https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo">GPT-4 and GPT-4 Turbo</a> and
98
113
* <a href="https://platform.openai.com/docs/models/gpt-3-5-turbo">GPT-3.5 Turbo</a>.
99
114
*/
100
- enum ChatModel {
115
+ public enum ChatModel {
101
116
/**
102
117
* (New) GPT-4 Turbo - latest GPT-4 model intended to reduce cases
103
118
* of “laziness” where the model doesn’t complete a task.
@@ -169,42 +184,6 @@ public String getValue() {
169
184
}
170
185
}
171
186
172
- /**
173
- * OpenAI Embeddings Models:
174
- * <a href="https://platform.openai.com/docs/models/embeddings">Embeddings</a>.
175
- */
176
- enum EmbeddingModel {
177
-
178
- /**
179
- * Most capable embedding model for both english and non-english tasks.
180
- * DIMENSION: 3072
181
- */
182
- TEXT_EMBEDDING_3_LARGE ("text-embedding-3-large" ),
183
-
184
- /**
185
- * Increased performance over 2nd generation ada embedding model.
186
- * DIMENSION: 1536
187
- */
188
- TEXT_EMBEDDING_3_SMALL ("text-embedding-3-small" ),
189
-
190
- /**
191
- * Most capable 2nd generation embedding model, replacing 16 first
192
- * generation models.
193
- * DIMENSION: 1536
194
- */
195
- TEXT_EMBEDDING_ADA_002 ("text-embedding-ada-002" );
196
-
197
- public final String value ;
198
-
199
- EmbeddingModel (String value ) {
200
- this .value = value ;
201
- }
202
-
203
- public String getValue () {
204
- return value ;
205
- }
206
- }
207
-
208
187
/**
209
188
* Represents a tool the model may call. Currently, only functions are supported as a tool.
210
189
*
@@ -708,6 +687,44 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
708
687
.map (content -> ModelOptionsUtils .jsonToObject (content , ChatCompletionChunk .class ));
709
688
}
710
689
690
+ // Embeddings API
691
+
692
+ /**
693
+ * OpenAI Embeddings Models:
694
+ * <a href="https://platform.openai.com/docs/models/embeddings">Embeddings</a>.
695
+ */
696
+ public enum EmbeddingModel {
697
+
698
+ /**
699
+ * Most capable embedding model for both english and non-english tasks.
700
+ * DIMENSION: 3072
701
+ */
702
+ TEXT_EMBEDDING_3_LARGE ("text-embedding-3-large" ),
703
+
704
+ /**
705
+ * Increased performance over 2nd generation ada embedding model.
706
+ * DIMENSION: 1536
707
+ */
708
+ TEXT_EMBEDDING_3_SMALL ("text-embedding-3-small" ),
709
+
710
+ /**
711
+ * Most capable 2nd generation embedding model, replacing 16 first
712
+ * generation models.
713
+ * DIMENSION: 1536
714
+ */
715
+ TEXT_EMBEDDING_ADA_002 ("text-embedding-ada-002" );
716
+
717
+ public final String value ;
718
+
719
+ EmbeddingModel (String value ) {
720
+ this .value = value ;
721
+ }
722
+
723
+ public String getValue () {
724
+ return value ;
725
+ }
726
+ }
727
+
711
728
/**
712
729
* Represents an embedding vector returned by embedding endpoint.
713
730
*
@@ -824,5 +841,87 @@ public <T> ResponseEntity<EmbeddingList<Embedding>> embeddings(EmbeddingRequest<
824
841
.toEntity (new ParameterizedTypeReference <>() {
825
842
});
826
843
}
844
+
845
+ // Transcription API
846
+
847
+ // @JsonInclude(Include.NON_NULL)
848
+ // public record Transcription(
849
+ // @JsonProperty("text") String text) {
850
+ // }
851
+
852
+ // /**
853
+ // *
854
+ // * @param model ID of the model to use.
855
+ // * @param language The language of the input audio. Supplying the input language in ISO-639-1 format will improve accuracy and latency.
856
+ // * @param prompt An optional text to guide the model's style or continue a previous audio segment. The prompt should match the audio language.
857
+ // * @param responseFormat An object specifying the format that the model must output.
858
+ // * @param temperature What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output
859
+ // * more random, while lower values like 0.2 will make it more focused and deterministic. */
860
+ // @JsonInclude(Include.NON_NULL)
861
+ // public record TranscriptionRequest (
862
+ // @JsonProperty("model") String model,
863
+ // @JsonProperty("language") String language,
864
+ // @JsonProperty("prompt") String prompt,
865
+ // @JsonProperty("response_format") ResponseFormat responseFormat,
866
+ // @JsonProperty("temperature") Float temperature) {
867
+
868
+ // /**
869
+ // * Shortcut constructor for a transcription request with the given model and temperature
870
+ // *
871
+ // * @param model ID of the model to use.
872
+ // * @param temperature What sampling temperature to use, between 0 and 1.
873
+ // */
874
+ // public TranscriptionRequest(String model, Float temperature) {
875
+ // this(model, null, null, null, temperature);
876
+ // }
877
+
878
+ // public TranscriptionRequest() {
879
+ // this(null, null, null, null, null);
880
+ // }
881
+
882
+ // /**
883
+ // * An object specifying the format that the model must output.
884
+ // * @param type Must be one of 'text' or 'json_object'.
885
+ // */
886
+ // @JsonInclude(Include.NON_NULL)
887
+ // public record ResponseFormat(
888
+ // @JsonProperty("type") String type) {
889
+ // }
890
+ // }
891
+
892
+ // /**
893
+ // * Creates a model response for the given transcription.
894
+ // *
895
+ // * @param transcriptionRequest The transcription request.
896
+ // * @return Entity response with {@link Transcription} as a body and HTTP status code and headers.
897
+ // */
898
+ // public ResponseEntity<Transcription> transcriptionEntityJson(MultiValueMap<String, Object> transcriptionRequest) {
899
+
900
+ // Assert.notNull(transcriptionRequest, "The request body can not be null.");
901
+
902
+ // return this.multipartRestClient.post()
903
+ // .uri("/v1/audio/transcriptions")
904
+ // .body(transcriptionRequest)
905
+ // .retrieve()
906
+ // .toEntity(Transcription.class);
907
+ // }
908
+
909
+ // /**
910
+ // * Creates a model response for the given transcription.
911
+ // *
912
+ // * @param transcriptionRequest The transcription request.
913
+ // * @return Entity response with {@link String} as a body and HTTP status code and headers.
914
+ // */
915
+ // public ResponseEntity<String> transcriptionEntityText(MultiValueMap<String, Object> transcriptionRequest) {
916
+
917
+ // Assert.notNull(transcriptionRequest, "The request body can not be null.");
918
+
919
+ // return this.multipartRestClient.post()
920
+ // .uri("/v1/audio/transcriptions")
921
+ // .body(transcriptionRequest)
922
+ // .accept(MediaType.TEXT_PLAIN)
923
+ // .retrieve()
924
+ // .toEntity(String.class);
925
+ // }
827
926
}
828
927
// @formatter:on
0 commit comments