diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiTranscriptionClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiTranscriptionClient.java new file mode 100644 index 00000000000..a5e043b7b3d --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiTranscriptionClient.java @@ -0,0 +1,168 @@ +package org.springframework.ai.openai;/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.ChatOptions; +import org.springframework.ai.chat.metadata.RateLimit; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException; +import org.springframework.ai.openai.metadata.OpenAiTranscriptionResponseMetadata; +import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor; +import org.springframework.ai.transcription.*; +import org.springframework.core.io.Resource; +import org.springframework.http.ResponseEntity; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +import java.time.Duration; +import java.util.List; + +/** + * {@link TranscriptionClient} implementation for {@literal OpenAI} backed by + * {@link OpenAiApi}. + * + * @author Michael Lavelle + * @see TranscriptionClient + * @see OpenAiApi + */ +public class OpenAiTranscriptionClient implements TranscriptionClient { + + private final Logger logger = LoggerFactory.getLogger(getClass()); + + private OpenAiTranscriptionOptions defaultOptions = OpenAiTranscriptionOptions.builder() + .withModel("whisper-1") + .withTemperature(0.7f) + .build(); + + public final RetryTemplate retryTemplate = RetryTemplate.builder() + .maxAttempts(10) + .retryOn(OpenAiApiException.class) + .exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000)) + .build(); + + private final OpenAiApi openAiApi; + + public OpenAiTranscriptionClient(OpenAiApi openAiApi) { + Assert.notNull(openAiApi, "OpenAiApi must not be null"); + this.openAiApi = openAiApi; + } + + public OpenAiTranscriptionClient withDefaultOptions(OpenAiTranscriptionOptions options) { + this.defaultOptions = options; + return this; + } + + @Override + public TranscriptionResponse call(TranscriptionRequest request) { + + return this.retryTemplate.execute(ctx -> { + Resource audioResource = request.getInstructions(); + + MultiValueMap requestBody = createRequestBody(request); + + boolean jsonResponse = !requestBody.containsKey("response_format") + || requestBody.get("response_format").contains("json"); + + if (jsonResponse) { + + ResponseEntity transcriptionEntity = this.openAiApi + .transcriptionEntityJson(requestBody); + + var transcription = transcriptionEntity.getBody(); + + if (transcription == null) { + logger.warn("No transcription returned for request: {}", audioResource); + return new TranscriptionResponse(null); + } + + Transcript transcript = new Transcript(transcription.text()); + + RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(transcriptionEntity); + + return new TranscriptionResponse(transcript, + OpenAiTranscriptionResponseMetadata.from(transcriptionEntity.getBody()) + .withRateLimit(rateLimits)); + + } + else { + ResponseEntity transcriptionEntity = this.openAiApi.transcriptionEntityText(requestBody); + + var transcription = transcriptionEntity.getBody(); + + if (transcription == null) { + logger.warn("No transcription returned for request: {}", audioResource); + return new TranscriptionResponse(null); + } + + Transcript transcript = new Transcript(transcription); + + RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(transcriptionEntity); + + return new TranscriptionResponse(transcript, + OpenAiTranscriptionResponseMetadata.from(transcriptionEntity.getBody()) + .withRateLimit(rateLimits)); + + } + + }); + } + + private MultiValueMap createRequestBody(TranscriptionRequest transcriptionRequest) { + + OpenAiApi.TranscriptionRequest request = new OpenAiApi.TranscriptionRequest(); + + if (this.defaultOptions != null) { + request = ModelOptionsUtils.merge(request, this.defaultOptions, OpenAiApi.TranscriptionRequest.class); + } + + if (transcriptionRequest.getOptions() != null) { + if (transcriptionRequest.getOptions() instanceof TranscriptionOptions runtimeOptions) { + OpenAiTranscriptionOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions, + TranscriptionOptions.class, OpenAiTranscriptionOptions.class); + request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, OpenAiApi.TranscriptionRequest.class); + } + else { + throw new IllegalArgumentException("Prompt options are not of type TranscriptionOptions: " + + transcriptionRequest.getOptions().getClass().getSimpleName()); + } + } + MultiValueMap requestBody = new LinkedMultiValueMap<>(); + if (request.responseFormat() != null) { + requestBody.add("response_format", request.responseFormat().type()); + } + if (request.prompt() != null) { + requestBody.add("prompt", request.prompt()); + } + if (request.temperature() != null) { + requestBody.add("temperature", request.temperature()); + } + if (request.language() != null) { + requestBody.add("language", request.language()); + } + if (request.model() != null) { + requestBody.add("model", request.model()); + } + if (transcriptionRequest.getInstructions() != null) { + requestBody.add("file", transcriptionRequest.getInstructions()); + } + return requestBody; + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiTranscriptionOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiTranscriptionOptions.java new file mode 100644 index 00000000000..699819a2a03 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiTranscriptionOptions.java @@ -0,0 +1,190 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.openai.api.OpenAiApi.TranscriptionRequest.ResponseFormat; +import org.springframework.ai.transcription.TranscriptionOptions; + +/** + * @author Michael Lavelle + */ +@JsonInclude(Include.NON_NULL) +public class OpenAiTranscriptionOptions implements TranscriptionOptions { + + // @formatter:off + /** + * ID of the model to use. + */ + private @JsonProperty("model") String model; + + /** + * An object specifying the format that the model must output. Setting to { "type": + * "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. + */ + private @JsonProperty("response_format") ResponseFormat responseFormat; + + private @JsonProperty("prompt") String prompt; + + private @JsonProperty("language") String language; + + /** + * What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output + * more random, while lower values like 0.2 will make it more focused and deterministic. + */ + private @JsonProperty("temperature") Float temperature = 0.8f; + + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + protected OpenAiTranscriptionOptions options; + + public Builder() { + this.options = new OpenAiTranscriptionOptions(); + } + + public Builder(OpenAiTranscriptionOptions options) { + this.options = options; + } + + public Builder withModel(String model) { + this.options.model = model; + return this; + } + + public Builder withLanguage(String language) { + this.options.language = language; + return this; + } + + public Builder withPrompt(String prompt) { + this.options.prompt = prompt; + return this; + } + + public Builder withResponseFormat(ResponseFormat responseFormat) { + this.options.responseFormat = responseFormat; + return this; + } + + public Builder withTemperature(Float temperature) { + this.options.temperature = temperature; + return this; + } + + public OpenAiTranscriptionOptions build() { + return this.options; + } + + } + + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + public String getLanguage() { + return this.language; + } + + public void setLanguage(String language) { + this.language = language; + } + + public String getPrompt() { + return this.prompt; + } + + public void setPrompt(String prompt) { + this.prompt = prompt; + } + + public Float getTemperature() { + return this.temperature; + } + + public void setTemperature(Float temperature) { + this.temperature = temperature; + } + + + public ResponseFormat getResponseFormat() { + return this.responseFormat; + } + + public void setResponseFormat(ResponseFormat responseFormat) { + this.responseFormat = responseFormat; + } + + + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((model == null) ? 0 : model.hashCode()); + result = prime * result + ((prompt == null) ? 0 : prompt.hashCode()); + result = prime * result + ((language == null) ? 0 : language.hashCode()); + result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + OpenAiTranscriptionOptions other = (OpenAiTranscriptionOptions) obj; + if (this.model == null) { + if (other.model != null) + return false; + } + else if (!model.equals(other.model)) + return false; + if (this.prompt == null) { + if (other.prompt != null) + return false; + } + else if (!this.prompt.equals(other.prompt)) + return false; + if (this.language == null) { + if (other.language != null) + return false; + } + else if (!this.language.equals(other.language)) + return false; + if (this.responseFormat == null) { + if (other.responseFormat != null) + return false; + } + else if (!this.responseFormat.equals(other.responseFormat)) + return false; + return true; + } +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 097234a04ff..c80ce5262ca 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -28,6 +28,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.util.MultiValueMap; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -57,6 +58,9 @@ public class OpenAiApi { private static final Predicate SSE_DONE_PREDICATE = "[DONE]"::equals; private final RestClient restClient; + + private final RestClient multipartFormEncodingRestClient; + private final WebClient webClient; private final ObjectMapper objectMapper; @@ -95,6 +99,11 @@ public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClie headers.setContentType(MediaType.APPLICATION_JSON); }; + Consumer multipartFormDataContentHeaders = multipartFormDataheaders -> { + multipartFormDataheaders.setBearerAuth(openAiToken); + multipartFormDataheaders.setContentType(MediaType.MULTIPART_FORM_DATA); + }; + var responseErrorHandler = new ResponseErrorHandler() { @Override @@ -121,6 +130,12 @@ public void handleError(ClientHttpResponse response) throws IOException { .defaultStatusHandler(responseErrorHandler) .build(); + this.multipartFormEncodingRestClient = restClientBuilder + .baseUrl(baseUrl) + .defaultHeaders(multipartFormDataContentHeaders) + .defaultStatusHandler(responseErrorHandler) + .build(); + this.webClient = WebClient.builder() .baseUrl(baseUrl) .defaultHeaders(jsonContentHeaders) @@ -389,6 +404,46 @@ public record ResponseFormat( } } + /** + * + * @param model ID of the model to use. + * @param language The language of the input audio. Supplying the input language in ISO-639-1 format will improve accuracy and latency. + * @param prompt An optional text to guide the model's style or continue a previous audio segment. The prompt should match the audio language. + * @param responseFormat An object specifying the format that the model must output. + * @param temperature What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output + * more random, while lower values like 0.2 will make it more focused and deterministic. */ + @JsonInclude(Include.NON_NULL) + public record TranscriptionRequest ( + @JsonProperty("model") String model, + @JsonProperty("language") String language, + @JsonProperty("prompt") String prompt, + @JsonProperty("response_format") ResponseFormat responseFormat, + @JsonProperty("temperature") Float temperature) { + + /** + * Shortcut constructor for a transcription request with the given model and temperature + * + * @param model ID of the model to use. + * @param temperature What sampling temperature to use, between 0 and 1. + */ + public TranscriptionRequest(String model, Float temperature) { + this(model, null, null, null, temperature); + } + + public TranscriptionRequest() { + this(null, null, null, null, null); + } + + /** + * An object specifying the format that the model must output. + * @param type Must be one of 'text' or 'json_object'. + */ + @JsonInclude(Include.NON_NULL) + public record ResponseFormat( + @JsonProperty("type") String type) { + } + } + /** * Message comprising the conversation. * @@ -497,6 +552,11 @@ public enum ChatCompletionFinishReason { @JsonProperty("function_call") FUNCTION_CALL } + @JsonInclude(Include.NON_NULL) + public record Transcription( + @JsonProperty("text") String text) { + } + /** * Represents a chat completion response returned by model, based on the provided input. * @@ -658,6 +718,41 @@ public ResponseEntity chatCompletionEntity(ChatCompletionRequest .toEntity(ChatCompletion.class); } + /** + * Creates a model response for the given transcription. + * + * @param transcriptionRequest The transcription request. + * @return Entity response with {@link Transcription} as a body and HTTP status code and headers. + */ + public ResponseEntity transcriptionEntityJson(MultiValueMap transcriptionRequest) { + + Assert.notNull(transcriptionRequest, "The request body can not be null."); + + return this.multipartFormEncodingRestClient.post() + .uri("/v1/audio/transcriptions") + .body(transcriptionRequest) + .retrieve() + .toEntity(Transcription.class); + } + + /** + * Creates a model response for the given transcription. + * + * @param transcriptionRequest The transcription request. + * @return Entity response with {@link String} as a body and HTTP status code and headers. + */ + public ResponseEntity transcriptionEntityText(MultiValueMap transcriptionRequest) { + + Assert.notNull(transcriptionRequest, "The request body can not be null."); + + return this.multipartFormEncodingRestClient.post() + .uri("/v1/audio/transcriptions") + .body(transcriptionRequest) + .accept(MediaType.TEXT_PLAIN) + .retrieve() + .toEntity(String.class); + } + /** * Creates a streaming chat response for the given chat conversation. * diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiTranscriptionResponseMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiTranscriptionResponseMetadata.java new file mode 100644 index 00000000000..612c6c67bd9 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiTranscriptionResponseMetadata.java @@ -0,0 +1,76 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.metadata; + +import org.springframework.ai.chat.metadata.RateLimit; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.transcription.metadata.TranscriptionResponseMetadata; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * {@link TranscriptionResponseMetadata} implementation for {@literal OpenAI}. + * + * @author MichaelLavelle + * @see TranscriptionResponseMetadata + * @see RateLimit + */ +public class OpenAiTranscriptionResponseMetadata implements TranscriptionResponseMetadata { + + protected static final String AI_METADATA_STRING = "{ @type: %1$s, rateLimit: %4$s }"; + + public static OpenAiTranscriptionResponseMetadata from(OpenAiApi.Transcription result) { + Assert.notNull(result, "OpenAI Transcription must not be null"); + OpenAiTranscriptionResponseMetadata transcriptionResponseMetadata = new OpenAiTranscriptionResponseMetadata(); + return transcriptionResponseMetadata; + } + + public static OpenAiTranscriptionResponseMetadata from(String result) { + Assert.notNull(result, "OpenAI Transcription must not be null"); + OpenAiTranscriptionResponseMetadata transcriptionResponseMetadata = new OpenAiTranscriptionResponseMetadata(); + return transcriptionResponseMetadata; + } + + @Nullable + private RateLimit rateLimit; + + protected OpenAiTranscriptionResponseMetadata() { + this(null); + } + + protected OpenAiTranscriptionResponseMetadata(@Nullable OpenAiRateLimit rateLimit) { + this.rateLimit = rateLimit; + } + + @Override + @Nullable + public RateLimit getRateLimit() { + RateLimit rateLimit = this.rateLimit; + return rateLimit != null ? rateLimit : RateLimit.NULL; + } + + public OpenAiTranscriptionResponseMetadata withRateLimit(RateLimit rateLimit) { + this.rateLimit = rateLimit; + return this; + } + + @Override + public String toString() { + return AI_METADATA_STRING.formatted(getClass().getName(), getRateLimit()); + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java index 21c28652202..ebcec840a8a 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java @@ -52,7 +52,7 @@ public class OpenAiResponseHeaderExtractor { private static final Logger logger = LoggerFactory.getLogger(OpenAiResponseHeaderExtractor.class); - public static RateLimit extractAiResponseHeaders(ResponseEntity response) { + public static RateLimit extractAiResponseHeaders(ResponseEntity response) { Long requestsLimit = getHeaderAsLong(response, REQUESTS_LIMIT_HEADER.getName()); Long requestsRemaining = getHeaderAsLong(response, REQUESTS_REMAINING_HEADER.getName()); @@ -66,7 +66,7 @@ public static RateLimit extractAiResponseHeaders(ResponseEntity tokensReset); } - private static Duration getHeaderAsDuration(ResponseEntity response, String headerName) { + private static Duration getHeaderAsDuration(ResponseEntity response, String headerName) { var headers = response.getHeaders(); if (headers.containsKey(headerName)) { var values = headers.get(headerName); @@ -77,7 +77,7 @@ private static Duration getHeaderAsDuration(ResponseEntity respo return null; } - private static Long getHeaderAsLong(ResponseEntity response, String headerName) { + private static Long getHeaderAsLong(ResponseEntity response, String headerName) { var headers = response.getHeaders(); if (headers.containsKey(headerName)) { var values = headers.get(headerName); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java index d43a25dde1d..dddf4b2a2f4 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java @@ -35,6 +35,12 @@ public OpenAiChatClient openAiChatClient(OpenAiApi api) { return openAiChatClient; } + @Bean + public OpenAiTranscriptionClient openAiTranscriptionClient(OpenAiApi api) { + OpenAiTranscriptionClient openAiTranscriptionClient = new OpenAiTranscriptionClient(api); + return openAiTranscriptionClient; + } + @Bean public OpenAiImageClient openAiImageClient(OpenAiImageApi imageApi) { OpenAiImageClient openAiImageClient = new OpenAiImageClient(imageApi); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java index 2e878f802b7..5afe8947b37 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java @@ -14,6 +14,7 @@ import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.image.ImageClient; +import org.springframework.ai.transcription.TranscriptionClient; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.core.io.Resource; @@ -28,6 +29,9 @@ public abstract class AbstractIT { @Autowired protected ChatClient openAiChatClient; + @Autowired + protected TranscriptionClient openAiTranscriptionClient; + @Autowired protected ImageClient openaiImageClient; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transcription/OpenAiTranscriptionClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transcription/OpenAiTranscriptionClientIT.java new file mode 100644 index 00000000000..472cc9abce0 --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transcription/OpenAiTranscriptionClientIT.java @@ -0,0 +1,51 @@ +package org.springframework.ai.openai.transcription; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.openai.OpenAiTestConfiguration; +import org.springframework.ai.openai.OpenAiTranscriptionOptions; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.testutils.AbstractIT; +import org.springframework.ai.transcription.TranscriptionOptions; +import org.springframework.ai.transcription.TranscriptionOptionsBuilder; +import org.springframework.ai.transcription.TranscriptionRequest; +import org.springframework.ai.transcription.TranscriptionResponse; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.core.io.Resource; + +import static org.assertj.core.api.Assertions.assertThat; + +@SpringBootTest(classes = OpenAiTestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") +class OpenAiTranscriptionClientIT extends AbstractIT { + + @Value("classpath:/speech/jfk.flac") + private Resource audioFile; + + @Test + void transcriptionTest() { + TranscriptionOptions transcriptionOptions = TranscriptionOptionsBuilder.builder().withTemperature(0f).build(); + TranscriptionRequest transcriptionRequest = new TranscriptionRequest(audioFile, transcriptionOptions); + TranscriptionResponse response = openAiTranscriptionClient.call(transcriptionRequest); + assertThat(response.getResults()).hasSize(1); + assertThat(response.getResults().get(0).getOutput().toLowerCase().contains("fellow")).isTrue(); + } + + @Test + void transcriptionTestWithOptions() { + OpenAiApi.TranscriptionRequest.ResponseFormat responseFormat = new OpenAiApi.TranscriptionRequest.ResponseFormat( + "vtt"); + TranscriptionOptions transcriptionOptions = OpenAiTranscriptionOptions.builder() + .withLanguage("en") + .withPrompt("Ask not this, but ask that") + .withTemperature(0f) + .withResponseFormat(responseFormat) + .build(); + TranscriptionRequest transcriptionRequest = new TranscriptionRequest(audioFile, transcriptionOptions); + TranscriptionResponse response = openAiTranscriptionClient.call(transcriptionRequest); + assertThat(response.getResults()).hasSize(1); + assertThat(response.getResults().get(0).getOutput().toLowerCase().contains("fellow")).isTrue(); + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transcription/OpenAiTranscriptionClientWithTranscriptionResponseMetadataTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transcription/OpenAiTranscriptionClientWithTranscriptionResponseMetadataTests.java new file mode 100644 index 00000000000..b1783c1b397 --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transcription/OpenAiTranscriptionClientWithTranscriptionResponseMetadataTests.java @@ -0,0 +1,163 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.transcription; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.metadata.RateLimit; +import org.springframework.ai.openai.OpenAiTranscriptionClient; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.metadata.support.OpenAiApiResponseHeaders; +import org.springframework.ai.transcription.TranscriptionRequest; +import org.springframework.ai.transcription.TranscriptionResponse; +import org.springframework.ai.transcription.metadata.TranscriptionMetadata; +import org.springframework.ai.transcription.metadata.TranscriptionResponseMetadata; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.autoconfigure.web.client.RestClientTest; +import org.springframework.context.annotation.Bean; +import org.springframework.core.io.Resource; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.test.web.client.MockRestServiceServer; +import org.springframework.web.client.RestClient; + +import java.time.Duration; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.*; +import static org.springframework.test.web.client.response.MockRestResponseCreators.withSuccess; + +/** + * @author Michael Lavelle + */ +@RestClientTest(OpenAiTranscriptionClientWithTranscriptionResponseMetadataTests.Config.class) +public class OpenAiTranscriptionClientWithTranscriptionResponseMetadataTests { + + private static String TEST_API_KEY = "sk-1234567890"; + + @Autowired + private OpenAiTranscriptionClient openAiTranscriptionClient; + + @Autowired + private MockRestServiceServer server; + + @AfterEach + void resetMockServer() { + server.reset(); + } + + @Test + void aiResponseContainsAiMetadata() { + + prepareMock(); + + Resource audioFile = null; + + TranscriptionRequest transcriptionRequest = new TranscriptionRequest(audioFile); + + TranscriptionResponse response = this.openAiTranscriptionClient.call(transcriptionRequest); + + assertThat(response).isNotNull(); + + TranscriptionResponseMetadata transcriptionResponseMetadata = response.getMetadata(); + + assertThat(transcriptionResponseMetadata).isNotNull(); + + RateLimit rateLimit = transcriptionResponseMetadata.getRateLimit(); + + Duration expectedRequestsReset = Duration.ofDays(2L) + .plus(Duration.ofHours(16L)) + .plus(Duration.ofMinutes(15)) + .plus(Duration.ofSeconds(29L)); + + Duration expectedTokensReset = Duration.ofHours(27L) + .plus(Duration.ofSeconds(55L)) + .plus(Duration.ofMillis(451L)); + + assertThat(rateLimit).isNotNull(); + assertThat(rateLimit.getRequestsLimit()).isEqualTo(4000L); + assertThat(rateLimit.getRequestsRemaining()).isEqualTo(999); + assertThat(rateLimit.getRequestsReset()).isEqualTo(expectedRequestsReset); + assertThat(rateLimit.getTokensLimit()).isEqualTo(725_000L); + assertThat(rateLimit.getTokensRemaining()).isEqualTo(112_358L); + assertThat(rateLimit.getTokensReset()).isEqualTo(expectedTokensReset); + + response.getResults().forEach(transcript -> { + TranscriptionMetadata transcriptionMetadata = transcript.getMetadata(); + assertThat(transcriptionMetadata).isNotNull(); + }); + } + + private void prepareMock() { + + HttpHeaders httpHeaders = new HttpHeaders(); + httpHeaders.set(OpenAiApiResponseHeaders.REQUESTS_LIMIT_HEADER.getName(), "4000"); + httpHeaders.set(OpenAiApiResponseHeaders.REQUESTS_REMAINING_HEADER.getName(), "999"); + httpHeaders.set(OpenAiApiResponseHeaders.REQUESTS_RESET_HEADER.getName(), "2d16h15m29s"); + httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_LIMIT_HEADER.getName(), "725000"); + httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_REMAINING_HEADER.getName(), "112358"); + httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_RESET_HEADER.getName(), "27h55s451ms"); + + server.expect(requestTo("/v1/audio/transcriptions")) + .andExpect(method(HttpMethod.POST)) + .andExpect(header(HttpHeaders.AUTHORIZATION, "Bearer " + TEST_API_KEY)) + .andRespond(withSuccess(getJson(), MediaType.APPLICATION_JSON).headers(httpHeaders)); + + } + + private String getJson() { + return """ + { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-3.5-turbo-0613", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "I surrender!" + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 9, + "completion_tokens": 12, + "total_tokens": 21 + } + } + """; + } + + @SpringBootConfiguration + static class Config { + + @Bean + public OpenAiApi chatCompletionApi(RestClient.Builder builder) { + return new OpenAiApi("", TEST_API_KEY, builder); + } + + @Bean + public OpenAiTranscriptionClient openAiClient(OpenAiApi openAiApi) { + return new OpenAiTranscriptionClient(openAiApi); + } + + } + +} diff --git a/models/spring-ai-openai/src/test/resources/speech/jfk.flac b/models/spring-ai-openai/src/test/resources/speech/jfk.flac new file mode 100644 index 00000000000..e44b7c13897 Binary files /dev/null and b/models/spring-ai-openai/src/test/resources/speech/jfk.flac differ diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transcription/Transcript.java b/spring-ai-core/src/main/java/org/springframework/ai/transcription/Transcript.java new file mode 100644 index 00000000000..c5600a93e2d --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/transcription/Transcript.java @@ -0,0 +1,73 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.transcription; + +import org.springframework.ai.model.ModelResult; +import org.springframework.ai.transcription.metadata.TranscriptionMetadata; +import org.springframework.lang.Nullable; + +import java.util.Objects; + +/** + * Represents a response returned by the AI. + */ +public class Transcript implements ModelResult { + + private String text; + + private TranscriptionMetadata transcriptionMetadata; + + public Transcript(String text) { + this.text = text; + } + + @Override + public String getOutput() { + return this.text; + } + + @Override + public TranscriptionMetadata getMetadata() { + TranscriptionMetadata chatGenerationMetadata = this.transcriptionMetadata; + return transcriptionMetadata != null ? transcriptionMetadata : TranscriptionMetadata.NULL; + } + + public Transcript withTranscriptionMetadata(@Nullable TranscriptionMetadata transcriptionMetadata) { + this.transcriptionMetadata = transcriptionMetadata; + return this; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof Transcript that)) + return false; + return Objects.equals(text, that.text) && Objects.equals(transcriptionMetadata, that.transcriptionMetadata); + } + + @Override + public int hashCode() { + return Objects.hash(text, transcriptionMetadata); + } + + @Override + public String toString() { + return "Transcript{" + "text=" + text + ", transcriptionMetadata=" + transcriptionMetadata + '}'; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transcription/TranscriptionClient.java b/spring-ai-core/src/main/java/org/springframework/ai/transcription/TranscriptionClient.java new file mode 100644 index 00000000000..2af8862d783 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/transcription/TranscriptionClient.java @@ -0,0 +1,32 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.transcription; + +import org.springframework.ai.model.ModelClient; +import org.springframework.core.io.Resource; + +@FunctionalInterface +public interface TranscriptionClient extends ModelClient { + + default String call(Resource audioResource) { + TranscriptionRequest transcriptionRequest = new TranscriptionRequest(audioResource); + return call(transcriptionRequest).getResult().getOutput(); + } + + TranscriptionResponse call(TranscriptionRequest request); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transcription/TranscriptionOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/transcription/TranscriptionOptions.java new file mode 100644 index 00000000000..ed362bb62f3 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/transcription/TranscriptionOptions.java @@ -0,0 +1,30 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.transcription; + +import org.springframework.ai.model.ModelOptions; + +/** + * The ChatOptions represent the common options, portable across different chat models. + */ +public interface TranscriptionOptions extends ModelOptions { + + Float getTemperature(); + + void setTemperature(Float temperature); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transcription/TranscriptionOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/transcription/TranscriptionOptionsBuilder.java new file mode 100644 index 00000000000..fa4e2291ca2 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/transcription/TranscriptionOptionsBuilder.java @@ -0,0 +1,54 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.transcription; + +public class TranscriptionOptionsBuilder { + + private class TranscriptionOptionsImpl implements TranscriptionOptions { + + private Float temperature; + + @Override + public Float getTemperature() { + return temperature; + } + + public void setTemperature(Float temperature) { + this.temperature = temperature; + } + + } + + private final TranscriptionOptionsImpl options = new TranscriptionOptionsImpl(); + + private TranscriptionOptionsBuilder() { + } + + public static TranscriptionOptionsBuilder builder() { + return new TranscriptionOptionsBuilder(); + } + + public TranscriptionOptionsBuilder withTemperature(Float temperature) { + options.setTemperature(temperature); + return this; + } + + public TranscriptionOptions build() { + return options; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transcription/TranscriptionRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/transcription/TranscriptionRequest.java new file mode 100644 index 00000000000..d759b56dcde --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/transcription/TranscriptionRequest.java @@ -0,0 +1,32 @@ +package org.springframework.ai.transcription; + +import org.springframework.ai.model.ModelOptions; +import org.springframework.ai.model.ModelRequest; +import org.springframework.core.io.Resource; + +public class TranscriptionRequest implements ModelRequest { + + private Resource audioResource; + + private ModelOptions modelOptions; + + public TranscriptionRequest(Resource audioResource) { + this.audioResource = audioResource; + } + + public TranscriptionRequest(Resource audioResource, ModelOptions modelOptions) { + this.audioResource = audioResource; + this.modelOptions = modelOptions; + } + + @Override + public Resource getInstructions() { + return audioResource; + } + + @Override + public ModelOptions getOptions() { + return modelOptions; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transcription/TranscriptionResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/transcription/TranscriptionResponse.java new file mode 100644 index 00000000000..a3cc541aff9 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/transcription/TranscriptionResponse.java @@ -0,0 +1,39 @@ +package org.springframework.ai.transcription; + +import org.springframework.ai.model.ModelResponse; +import org.springframework.ai.transcription.metadata.TranscriptionResponseMetadata; + +import java.util.Arrays; +import java.util.List; + +public class TranscriptionResponse implements ModelResponse { + + private Transcript transcript; + + private TranscriptionResponseMetadata transcriptionResponseMetadata; + + public TranscriptionResponse(Transcript transcript) { + this(transcript, TranscriptionResponseMetadata.NULL); + } + + public TranscriptionResponse(Transcript transcript, TranscriptionResponseMetadata transcriptionResponseMetadata) { + this.transcript = transcript; + this.transcriptionResponseMetadata = transcriptionResponseMetadata; + } + + @Override + public Transcript getResult() { + return transcript; + } + + @Override + public List getResults() { + return Arrays.asList(transcript); + } + + @Override + public TranscriptionResponseMetadata getMetadata() { + return transcriptionResponseMetadata; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transcription/metadata/TranscriptionMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/transcription/metadata/TranscriptionMetadata.java new file mode 100644 index 00000000000..fbfc05bb166 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/transcription/metadata/TranscriptionMetadata.java @@ -0,0 +1,18 @@ +package org.springframework.ai.transcription.metadata; + +import org.springframework.ai.model.ResultMetadata; + +public interface TranscriptionMetadata extends ResultMetadata { + + TranscriptionMetadata NULL = TranscriptionMetadata.create(); + + /** + * Factory method used to construct a new {@link TranscriptionMetadata} + * @return a new {@link TranscriptionMetadata} + */ + static TranscriptionMetadata create() { + return new TranscriptionMetadata() { + }; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transcription/metadata/TranscriptionResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/transcription/metadata/TranscriptionResponseMetadata.java new file mode 100644 index 00000000000..701327f3449 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/transcription/metadata/TranscriptionResponseMetadata.java @@ -0,0 +1,42 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.transcription.metadata; + +import org.springframework.ai.chat.metadata.RateLimit; +import org.springframework.ai.model.ResponseMetadata; + +/** + * Abstract Data Type (ADT) modeling common AI provider metadata returned in an AI + * response. + * + * @author Michael Lavelle + */ +public interface TranscriptionResponseMetadata extends ResponseMetadata { + + TranscriptionResponseMetadata NULL = new TranscriptionResponseMetadata() { + }; + + /** + * Returns AI provider specific metadata on rate limits. + * @return AI provider specific metadata on rate limits. + * @see RateLimit + */ + default RateLimit getRateLimit() { + return RateLimit.NULL; + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/transcription/TranscriptionClientTests.java b/spring-ai-core/src/test/java/org/springframework/ai/transcription/TranscriptionClientTests.java new file mode 100644 index 00000000000..63db8e5e58b --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/transcription/TranscriptionClientTests.java @@ -0,0 +1,77 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.transcription; + +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.springframework.core.io.Resource; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +/** + * Unit Tests for {@link TranscriptionClient}. + * + * @author Michael Lavelle + */ +class TranscriptionClientTests { + + @Test + void transcrbeRequestReturnsResponseCorrectly() { + + Resource mockAudioFile = Mockito.mock(Resource.class); + + TranscriptionClient mockClient = Mockito.mock(TranscriptionClient.class); + + String mockTranscription = "All your bases are belong to us"; + + // Create a mock Transcript + Transcript transcript = Mockito.mock(Transcript.class); + when(transcript.getOutput()).thenReturn(mockTranscription); + + // Create a mock TranscriptionResponse with the mock Transcript + TranscriptionResponse response = Mockito.mock(TranscriptionResponse.class); + when(response.getResult()).thenReturn(transcript); + + // Transcript transcript = spy(new Transcript(responseMessage)); + // TranscriptionResponse response = spy(new + // TranscriptionResponse(Collections.singletonList(transcript))); + + doCallRealMethod().when(mockClient).call(any(Resource.class)); + + doAnswer(invocationOnMock -> { + + TranscriptionRequest transcriptionRequest = invocationOnMock.getArgument(0); + + assertThat(transcriptionRequest).isNotNull(); + assertThat(transcriptionRequest.getInstructions()).isEqualTo(mockAudioFile); + + return response; + + }).when(mockClient).call(any(TranscriptionRequest.class)); + + assertThat(mockClient.call(mockAudioFile)).isEqualTo(mockTranscription); + + verify(mockClient, times(1)).call(eq(mockAudioFile)); + verify(mockClient, times(1)).call(isA(TranscriptionRequest.class)); + verify(response, times(1)).getResult(); + verify(transcript, times(1)).getOutput(); + verifyNoMoreInteractions(mockClient, transcript, response); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java index 684ab22947a..46f8df90ba6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java @@ -21,6 +21,7 @@ import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.ai.openai.OpenAiEmbeddingClient; import org.springframework.ai.openai.OpenAiImageClient; +import org.springframework.ai.openai.OpenAiTranscriptionClient; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiImageApi; import org.springframework.boot.autoconfigure.AutoConfiguration; @@ -36,7 +37,7 @@ @AutoConfiguration @ConditionalOnClass(OpenAiApi.class) @EnableConfigurationProperties({ OpenAiConnectionProperties.class, OpenAiChatProperties.class, - OpenAiEmbeddingProperties.class, OpenAiImageProperties.class }) + OpenAiEmbeddingProperties.class, OpenAiImageProperties.class, OpenAiTranscriptionProperties.class }) @ImportRuntimeHints(NativeHints.class) /** * @author Christian Tzolov @@ -65,6 +66,28 @@ public OpenAiChatClient openAiChatClient(OpenAiConnectionProperties commonProper return openAiChatClient; } + @Bean + @ConditionalOnMissingBean + public OpenAiTranscriptionClient openAiTranscriptionClient(OpenAiConnectionProperties commonProperties, + OpenAiTranscriptionProperties transcriptionProperties) { + + String apiKey = StringUtils.hasText(transcriptionProperties.getApiKey()) ? transcriptionProperties.getApiKey() + : commonProperties.getApiKey(); + + String baseUrl = StringUtils.hasText(transcriptionProperties.getBaseUrl()) + ? transcriptionProperties.getBaseUrl() : commonProperties.getBaseUrl(); + + Assert.hasText(apiKey, "OpenAI API key must be set"); + Assert.hasText(baseUrl, "OpenAI base URL must be set"); + + var openAiApi = new OpenAiApi(baseUrl, apiKey, RestClient.builder()); + + OpenAiTranscriptionClient openAiChatClient = new OpenAiTranscriptionClient(openAiApi) + .withDefaultOptions(transcriptionProperties.getOptions()); + + return openAiChatClient; + } + @Bean @ConditionalOnMissingBean public EmbeddingClient openAiEmbeddingClient(OpenAiConnectionProperties commonProperties, diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiTranscriptionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiTranscriptionProperties.java new file mode 100644 index 00000000000..68133327f9b --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiTranscriptionProperties.java @@ -0,0 +1,46 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.openai; + +import org.springframework.ai.openai.OpenAiTranscriptionOptions; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +@ConfigurationProperties(OpenAiTranscriptionProperties.CONFIG_PREFIX) +public class OpenAiTranscriptionProperties extends OpenAiParentProperties { + + public static final String CONFIG_PREFIX = "spring.ai.openai.transcription"; + + public static final String DEFAULT_TRANSCRIPTION_MODEL = "whisper-1"; + + private static final Double DEFAULT_TEMPERATURE = 0.7; + + @NestedConfigurationProperty + private OpenAiTranscriptionOptions options = OpenAiTranscriptionOptions.builder() + .withModel(DEFAULT_TRANSCRIPTION_MODEL) + .withTemperature(DEFAULT_TEMPERATURE.floatValue()) + .build(); + + public OpenAiTranscriptionOptions getOptions() { + return options; + } + + public void setOptions(OpenAiTranscriptionOptions options) { + this.options = options; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java index 6827fb23340..1d4ee743f33 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java @@ -28,6 +28,9 @@ import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImageResponse; import org.springframework.ai.openai.OpenAiImageClient; +import org.springframework.ai.openai.OpenAiTranscriptionClient; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; import reactor.core.publisher.Flux; import org.springframework.ai.chat.ChatResponse; @@ -58,6 +61,17 @@ void generate() { }); } + @Test + void transcribe() { + contextRunner.run(context -> { + OpenAiTranscriptionClient client = context.getBean(OpenAiTranscriptionClient.class); + Resource audioFile = new ClassPathResource("/speech/jfk.flac"); + String response = client.call(audioFile); + assertThat(response).isNotEmpty(); + logger.info("Response: " + response); + }); + } + @Test void generateStreaming() { contextRunner.run(context -> { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java index c411a2e5285..3316899af53 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java @@ -20,6 +20,7 @@ import org.junit.jupiter.api.Test; +import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ResponseFormat; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoice; import org.springframework.ai.openai.api.OpenAiApi.FunctionTool.Type; @@ -64,6 +65,32 @@ public void chatProperties() { }); } + @Test + public void transcriptionProperties() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.openai.base-url=TEST_BASE_URL", + "spring.ai.openai.api-key=abc123", + "spring.ai.openai.transcription.options.model=MODEL_XYZ", + "spring.ai.openai.transcription.options.temperature=0.55") + // @formatter:on + .withConfiguration(AutoConfigurations.of(OpenAiAutoConfiguration.class)) + .run(context -> { + var transcriptionProperties = context.getBean(OpenAiTranscriptionProperties.class); + var connectionProperties = context.getBean(OpenAiConnectionProperties.class); + + assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + + assertThat(transcriptionProperties.getApiKey()).isNull(); + assertThat(transcriptionProperties.getBaseUrl()).isNull(); + + assertThat(transcriptionProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + assertThat(transcriptionProperties.getOptions().getTemperature()).isEqualTo(0.55f); + }); + } + @Test public void chatOverrideConnectionProperties() { @@ -92,6 +119,34 @@ public void chatOverrideConnectionProperties() { }); } + @Test + public void transcriptionOverrideConnectionProperties() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.openai.base-url=TEST_BASE_URL", + "spring.ai.openai.api-key=abc123", + "spring.ai.openai.transcription.base-url=TEST_BASE_URL2", + "spring.ai.openai.transcription.api-key=456", + "spring.ai.openai.transcription.options.model=MODEL_XYZ", + "spring.ai.openai.transcription.options.temperature=0.55") + // @formatter:on + .withConfiguration(AutoConfigurations.of(OpenAiAutoConfiguration.class)) + .run(context -> { + var transcriptionProperties = context.getBean(OpenAiTranscriptionProperties.class); + var connectionProperties = context.getBean(OpenAiConnectionProperties.class); + + assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + + assertThat(transcriptionProperties.getApiKey()).isEqualTo("456"); + assertThat(transcriptionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2"); + + assertThat(transcriptionProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + assertThat(transcriptionProperties.getOptions().getTemperature()).isEqualTo(0.55f); + }); + } + @Test public void embeddingProperties() { @@ -282,6 +337,41 @@ public void chatOptionsTest() { }); } + @Test + public void transcriptionOptionsTest() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.openai.api-key=API_KEY", + "spring.ai.openai.base-url=TEST_BASE_URL", + + "spring.ai.openai.transcription.options.model=MODEL_XYZ", + "spring.ai.openai.transcription.options.language=en", + "spring.ai.openai.transcription.options.prompt=Er, yes, I think so", + "spring.ai.openai.transcription.options.responseFormat.type=json", + "spring.ai.openai.transcription.options.temperature=0.55" + ) + // @formatter:on + .withConfiguration(AutoConfigurations.of(OpenAiAutoConfiguration.class)) + .run(context -> { + var transcriptionProperties = context.getBean(OpenAiTranscriptionProperties.class); + var connectionProperties = context.getBean(OpenAiConnectionProperties.class); + var embeddingProperties = context.getBean(OpenAiEmbeddingProperties.class); + + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY"); + + assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("text-embedding-ada-002"); + + assertThat(transcriptionProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + assertThat(transcriptionProperties.getOptions().getLanguage()).isEqualTo("en"); + assertThat(transcriptionProperties.getOptions().getPrompt()).isEqualTo("Er, yes, I think so"); + assertThat(transcriptionProperties.getOptions().getResponseFormat()) + .isEqualTo(new OpenAiApi.TranscriptionRequest.ResponseFormat("json")); + assertThat(transcriptionProperties.getOptions().getTemperature()).isEqualTo(0.55f); + }); + } + @Test public void embeddingOptionsTest() { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/resources/speech/jfk.flac b/spring-ai-spring-boot-autoconfigure/src/test/resources/speech/jfk.flac new file mode 100644 index 00000000000..e44b7c13897 Binary files /dev/null and b/spring-ai-spring-boot-autoconfigure/src/test/resources/speech/jfk.flac differ