Skip to content

Commit 7d04167

Browse files
michaellavelletzolov
authored andcommitted
Adding support for OpenAI Audio transcriptions
- Make it use the new OpenAiAudioApi. - Remove trascription code from spring-ai-core. Too early to generalize. Move all related code under the spring-ai-openai project. - Fix missing licenses and javadocs. - Add 'Audio' prefix for Transcription classes and packages. - Add missing auto-configuraiotn and tests.
1 parent db383f8 commit 7d04167

File tree

23 files changed

+1354
-54
lines changed

23 files changed

+1354
-54
lines changed

models/spring-ai-openai/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,6 @@
22

33
[OpenAI Embedding Documentation](https://docs.spring.io/spring-ai/reference/api/embeddings/openai-embeddings.html)
44

5+
[OpenAI Image Generation](https://docs.spring.io/spring-ai/reference/api/clients/image/openai-image.html)
6+
7+
[OpenAI Transcription Generation](TODO)

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 136 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@
3030
import org.springframework.ai.openai.api.common.ApiUtils;
3131
import org.springframework.boot.context.properties.bind.ConstructorBinding;
3232
import org.springframework.core.ParameterizedTypeReference;
33+
import org.springframework.http.MediaType;
3334
import org.springframework.http.ResponseEntity;
3435
import org.springframework.util.Assert;
3536
import org.springframework.util.CollectionUtils;
37+
import org.springframework.util.MultiValueMap;
3638
import org.springframework.web.client.RestClient;
3739
import org.springframework.web.reactive.function.client.WebClient;
3840

@@ -42,6 +44,7 @@
4244
* OpenAI Embedding API: https://platform.openai.com/docs/api-reference/embeddings.
4345
*
4446
* @author Christian Tzolov
47+
* @author Michael Lavelle
4548
*/
4649
public class OpenAiApi {
4750

@@ -50,6 +53,9 @@ public class OpenAiApi {
5053
private static final Predicate<String> SSE_DONE_PREDICATE = "[DONE]"::equals;
5154

5255
private final RestClient restClient;
56+
57+
private final RestClient multipartRestClient;
58+
5359
private final WebClient webClient;
5460

5561
/**
@@ -86,6 +92,15 @@ public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClie
8692
.defaultStatusHandler(ApiUtils.DEFAULT_RESPONSE_ERROR_HANDLER)
8793
.build();
8894

95+
this.multipartRestClient = restClientBuilder
96+
.baseUrl(baseUrl)
97+
.defaultHeaders(multipartFormDataHeaders -> {
98+
multipartFormDataHeaders.setBearerAuth(openAiToken);
99+
multipartFormDataHeaders.setContentType(MediaType.MULTIPART_FORM_DATA);
100+
})
101+
.defaultStatusHandler(ApiUtils.DEFAULT_RESPONSE_ERROR_HANDLER)
102+
.build();
103+
89104
this.webClient = WebClient.builder()
90105
.baseUrl(baseUrl)
91106
.defaultHeaders(ApiUtils.getJsonContentHeaders(openAiToken))
@@ -97,7 +112,7 @@ public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClie
97112
* <a href="https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo">GPT-4 and GPT-4 Turbo</a> and
98113
* <a href="https://platform.openai.com/docs/models/gpt-3-5-turbo">GPT-3.5 Turbo</a>.
99114
*/
100-
enum ChatModel {
115+
public enum ChatModel {
101116
/**
102117
* (New) GPT-4 Turbo - latest GPT-4 model intended to reduce cases
103118
* of “laziness” where the model doesn’t complete a task.
@@ -169,42 +184,6 @@ public String getValue() {
169184
}
170185
}
171186

172-
/**
173-
* OpenAI Embeddings Models:
174-
* <a href="https://platform.openai.com/docs/models/embeddings">Embeddings</a>.
175-
*/
176-
enum EmbeddingModel {
177-
178-
/**
179-
* Most capable embedding model for both english and non-english tasks.
180-
* DIMENSION: 3072
181-
*/
182-
TEXT_EMBEDDING_3_LARGE("text-embedding-3-large"),
183-
184-
/**
185-
* Increased performance over 2nd generation ada embedding model.
186-
* DIMENSION: 1536
187-
*/
188-
TEXT_EMBEDDING_3_SMALL("text-embedding-3-small"),
189-
190-
/**
191-
* Most capable 2nd generation embedding model, replacing 16 first
192-
* generation models.
193-
* DIMENSION: 1536
194-
*/
195-
TEXT_EMBEDDING_ADA_002("text-embedding-ada-002");
196-
197-
public final String value;
198-
199-
EmbeddingModel(String value) {
200-
this.value = value;
201-
}
202-
203-
public String getValue() {
204-
return value;
205-
}
206-
}
207-
208187
/**
209188
* Represents a tool the model may call. Currently, only functions are supported as a tool.
210189
*
@@ -708,6 +687,44 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
708687
.map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class));
709688
}
710689

690+
// Embeddings API
691+
692+
/**
693+
* OpenAI Embeddings Models:
694+
* <a href="https://platform.openai.com/docs/models/embeddings">Embeddings</a>.
695+
*/
696+
public enum EmbeddingModel {
697+
698+
/**
699+
* Most capable embedding model for both english and non-english tasks.
700+
* DIMENSION: 3072
701+
*/
702+
TEXT_EMBEDDING_3_LARGE("text-embedding-3-large"),
703+
704+
/**
705+
* Increased performance over 2nd generation ada embedding model.
706+
* DIMENSION: 1536
707+
*/
708+
TEXT_EMBEDDING_3_SMALL("text-embedding-3-small"),
709+
710+
/**
711+
* Most capable 2nd generation embedding model, replacing 16 first
712+
* generation models.
713+
* DIMENSION: 1536
714+
*/
715+
TEXT_EMBEDDING_ADA_002("text-embedding-ada-002");
716+
717+
public final String value;
718+
719+
EmbeddingModel(String value) {
720+
this.value = value;
721+
}
722+
723+
public String getValue() {
724+
return value;
725+
}
726+
}
727+
711728
/**
712729
* Represents an embedding vector returned by embedding endpoint.
713730
*
@@ -824,5 +841,87 @@ public <T> ResponseEntity<EmbeddingList<Embedding>> embeddings(EmbeddingRequest<
824841
.toEntity(new ParameterizedTypeReference<>() {
825842
});
826843
}
844+
845+
// Transcription API
846+
847+
// @JsonInclude(Include.NON_NULL)
848+
// public record Transcription(
849+
// @JsonProperty("text") String text) {
850+
// }
851+
852+
// /**
853+
// *
854+
// * @param model ID of the model to use.
855+
// * @param language The language of the input audio. Supplying the input language in ISO-639-1 format will improve accuracy and latency.
856+
// * @param prompt An optional text to guide the model's style or continue a previous audio segment. The prompt should match the audio language.
857+
// * @param responseFormat An object specifying the format that the model must output.
858+
// * @param temperature What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output
859+
// * more random, while lower values like 0.2 will make it more focused and deterministic. */
860+
// @JsonInclude(Include.NON_NULL)
861+
// public record TranscriptionRequest (
862+
// @JsonProperty("model") String model,
863+
// @JsonProperty("language") String language,
864+
// @JsonProperty("prompt") String prompt,
865+
// @JsonProperty("response_format") ResponseFormat responseFormat,
866+
// @JsonProperty("temperature") Float temperature) {
867+
868+
// /**
869+
// * Shortcut constructor for a transcription request with the given model and temperature
870+
// *
871+
// * @param model ID of the model to use.
872+
// * @param temperature What sampling temperature to use, between 0 and 1.
873+
// */
874+
// public TranscriptionRequest(String model, Float temperature) {
875+
// this(model, null, null, null, temperature);
876+
// }
877+
878+
// public TranscriptionRequest() {
879+
// this(null, null, null, null, null);
880+
// }
881+
882+
// /**
883+
// * An object specifying the format that the model must output.
884+
// * @param type Must be one of 'text' or 'json_object'.
885+
// */
886+
// @JsonInclude(Include.NON_NULL)
887+
// public record ResponseFormat(
888+
// @JsonProperty("type") String type) {
889+
// }
890+
// }
891+
892+
// /**
893+
// * Creates a model response for the given transcription.
894+
// *
895+
// * @param transcriptionRequest The transcription request.
896+
// * @return Entity response with {@link Transcription} as a body and HTTP status code and headers.
897+
// */
898+
// public ResponseEntity<Transcription> transcriptionEntityJson(MultiValueMap<String, Object> transcriptionRequest) {
899+
900+
// Assert.notNull(transcriptionRequest, "The request body can not be null.");
901+
902+
// return this.multipartRestClient.post()
903+
// .uri("/v1/audio/transcriptions")
904+
// .body(transcriptionRequest)
905+
// .retrieve()
906+
// .toEntity(Transcription.class);
907+
// }
908+
909+
// /**
910+
// * Creates a model response for the given transcription.
911+
// *
912+
// * @param transcriptionRequest The transcription request.
913+
// * @return Entity response with {@link String} as a body and HTTP status code and headers.
914+
// */
915+
// public ResponseEntity<String> transcriptionEntityText(MultiValueMap<String, Object> transcriptionRequest) {
916+
917+
// Assert.notNull(transcriptionRequest, "The request body can not be null.");
918+
919+
// return this.multipartRestClient.post()
920+
// .uri("/v1/audio/transcriptions")
921+
// .body(transcriptionRequest)
922+
// .accept(MediaType.TEXT_PLAIN)
923+
// .retrieve()
924+
// .toEntity(String.class);
925+
// }
827926
}
828927
// @formatter:on

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ public record TranscriptionRequest(
280280
@JsonProperty("model") String model,
281281
@JsonProperty("language") String language,
282282
@JsonProperty("prompt") String prompt,
283-
@JsonProperty("response_format") TextualResponseFormat responseFormat,
283+
@JsonProperty("response_format") TranscriptResponseFormat responseFormat,
284284
@JsonProperty("temperature") Float temperature,
285285
@JsonProperty("timestamp_granularities") GranularityType granularityType) {
286286
// @formatter:on
@@ -318,7 +318,7 @@ public static class Builder {
318318

319319
private String prompt;
320320

321-
private TextualResponseFormat responseFormat = TextualResponseFormat.JSON;
321+
private TranscriptResponseFormat responseFormat = TranscriptResponseFormat.JSON;
322322

323323
private Float temperature;
324324

@@ -344,7 +344,7 @@ public Builder withPrompt(String prompt) {
344344
return this;
345345
}
346346

347-
public Builder withResponseFormat(TextualResponseFormat response_format) {
347+
public Builder withResponseFormat(TranscriptResponseFormat response_format) {
348348
this.responseFormat = response_format;
349349
return this;
350350
}
@@ -375,7 +375,7 @@ public TranscriptionRequest build() {
375375
* The format of the transcript and translation outputs, in one of these options:
376376
* json, text, srt, verbose_json, or vtt. Defaults to json.
377377
*/
378-
public enum TextualResponseFormat {
378+
public enum TranscriptResponseFormat {
379379

380380
// @formatter:off
381381
@JsonProperty("json") JSON("json", StructuredResponse.class),
@@ -393,7 +393,7 @@ public boolean isJsonType() {
393393
return this == JSON || this == VERBOSE_JSON;
394394
}
395395

396-
TextualResponseFormat(String value, Class<?> responseType) {
396+
TranscriptResponseFormat(String value, Class<?> responseType) {
397397
this.value = value;
398398
this.responseType = responseType;
399399
}
@@ -429,7 +429,7 @@ public record TranslationRequest(
429429
@JsonProperty("file") byte[] file,
430430
@JsonProperty("model") String model,
431431
@JsonProperty("prompt") String prompt,
432-
@JsonProperty("response_format") TextualResponseFormat responseFormat,
432+
@JsonProperty("response_format") TranscriptResponseFormat responseFormat,
433433
@JsonProperty("temperature") Float temperature) {
434434
// @formatter:on
435435

@@ -445,7 +445,7 @@ public static class Builder {
445445

446446
private String prompt;
447447

448-
private TextualResponseFormat responseFormat = TextualResponseFormat.JSON;
448+
private TranscriptResponseFormat responseFormat = TranscriptResponseFormat.JSON;
449449

450450
private Float temperature;
451451

@@ -464,7 +464,7 @@ public Builder withPrompt(String prompt) {
464464
return this;
465465
}
466466

467-
public Builder withResponseFormat(TextualResponseFormat responseFormat) {
467+
public Builder withResponseFormat(TranscriptResponseFormat responseFormat) {
468468
this.responseFormat = responseFormat;
469469
return this;
470470
}
@@ -601,7 +601,7 @@ public String getFilename() {
601601
multipartBody.add("response_format", requestBody.responseFormat().getValue());
602602
multipartBody.add("temperature", requestBody.temperature());
603603
if (requestBody.granularityType() != null) {
604-
Assert.isTrue(requestBody.responseFormat() == TextualResponseFormat.VERBOSE_JSON,
604+
Assert.isTrue(requestBody.responseFormat() == TranscriptResponseFormat.VERBOSE_JSON,
605605
"response_format must be set to verbose_json to use timestamp granularities.");
606606
multipartBody.add("timestamp_granularities[]", requestBody.granularityType().getValue());
607607
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai.audio.transcription;
18+
19+
import org.springframework.ai.model.ModelResult;
20+
import org.springframework.ai.openai.metadata.audio.OpenAiAudioTranscriptionMetadata;
21+
import org.springframework.lang.Nullable;
22+
23+
import java.util.Objects;
24+
25+
/**
26+
* Represents a response returned by the AI.
27+
*
28+
* @author Michael Lavelle
29+
* @since 0.8.1
30+
*/
31+
public class AudioTranscription implements ModelResult<String> {
32+
33+
private String text;
34+
35+
private OpenAiAudioTranscriptionMetadata transcriptionMetadata;
36+
37+
public AudioTranscription(String text) {
38+
this.text = text;
39+
}
40+
41+
@Override
42+
public String getOutput() {
43+
return this.text;
44+
}
45+
46+
@Override
47+
public OpenAiAudioTranscriptionMetadata getMetadata() {
48+
return transcriptionMetadata != null ? transcriptionMetadata : OpenAiAudioTranscriptionMetadata.NULL;
49+
}
50+
51+
public AudioTranscription withTranscriptionMetadata(
52+
@Nullable OpenAiAudioTranscriptionMetadata transcriptionMetadata) {
53+
this.transcriptionMetadata = transcriptionMetadata;
54+
return this;
55+
}
56+
57+
@Override
58+
public boolean equals(Object o) {
59+
if (this == o)
60+
return true;
61+
if (!(o instanceof AudioTranscription that))
62+
return false;
63+
return Objects.equals(text, that.text) && Objects.equals(transcriptionMetadata, that.transcriptionMetadata);
64+
}
65+
66+
@Override
67+
public int hashCode() {
68+
return Objects.hash(text, transcriptionMetadata);
69+
}
70+
71+
@Override
72+
public String toString() {
73+
return "Transcript{" + "text=" + text + ", transcriptionMetadata=" + transcriptionMetadata + '}';
74+
}
75+
76+
}

0 commit comments

Comments
 (0)