diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechClient.java new file mode 100644 index 00000000000..7e79ad34875 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechClient.java @@ -0,0 +1,157 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai; + +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.metadata.RateLimit; +import org.springframework.ai.openai.api.OpenAiAudioApi; +import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.AudioResponseFormat; +import org.springframework.ai.openai.api.common.OpenAiApiException; +import org.springframework.ai.openai.audio.speech.*; +import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechResponseMetadata; +import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor; +import org.springframework.http.ResponseEntity; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; +import reactor.core.publisher.Flux; + +import java.time.Duration; + +/** + * OpenAI audio speech client implementation for backed by {@link OpenAiAudioApi}. + * + * @author Ahmed Yousri + * @see OpenAiAudioApi + */ +public class OpenAiAudioSpeechClient implements SpeechClient, StreamingSpeechClient { + + private final Logger logger = LoggerFactory.getLogger(getClass()); + + private final OpenAiAudioSpeechOptions defaultOptions; + + private static final Float SPEED = 1.0f; + + public final RetryTemplate retryTemplate = RetryTemplate.builder() + .maxAttempts(10) + .retryOn(OpenAiApiException.class) + .exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000)) + .build(); + + private final OpenAiAudioApi audioApi; + + public OpenAiAudioSpeechClient(OpenAiAudioApi audioApi) { + this(audioApi, + OpenAiAudioSpeechOptions.builder() + .withModel(OpenAiAudioApi.TtsModel.TTS_1.getValue()) + .withResponseFormat(AudioResponseFormat.MP3) + .withVoice(OpenAiAudioApi.SpeechRequest.Voice.ALLOY) + .withSpeed(SPEED) + .build()); + } + + public OpenAiAudioSpeechClient(OpenAiAudioApi audioApi, OpenAiAudioSpeechOptions options) { + Assert.notNull(audioApi, "OpenAiAudioApi must not be null"); + Assert.notNull(options, "OpenAiSpeechOptions must not be null"); + this.audioApi = audioApi; + this.defaultOptions = options; + } + + @Override + public byte[] call(String text) { + SpeechPrompt speechRequest = new SpeechPrompt(text); + return call(speechRequest).getResult().getOutput(); + } + + @Override + public SpeechResponse call(SpeechPrompt speechPrompt) { + + return this.retryTemplate.execute(ctx -> { + + OpenAiAudioApi.SpeechRequest speechRequest = createRequestBody(speechPrompt); + + ResponseEntity speechEntity = this.audioApi.createSpeech(speechRequest); + var speech = speechEntity.getBody(); + + if (speech == null) { + logger.warn("No speech response returned for speechRequest: {}", speechRequest); + return new SpeechResponse(new Speech(new byte[0])); + } + + RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(speechEntity); + + return new SpeechResponse(new Speech(speech), new OpenAiAudioSpeechResponseMetadata(rateLimits)); + + }); + } + + /** + * Streams the audio response for the given speech prompt. + * @param prompt The speech prompt containing the text and options for speech + * synthesis. + * @return A Flux of SpeechResponse objects containing the streamed audio and + * metadata. + */ + + @Override + public Flux stream(SpeechPrompt prompt) { + return this.audioApi.stream(this.createRequestBody(prompt)) + .map(entity -> new SpeechResponse(new Speech(entity.getBody()), new OpenAiAudioSpeechResponseMetadata( + OpenAiResponseHeaderExtractor.extractAiResponseHeaders(entity)))); + } + + private OpenAiAudioApi.SpeechRequest createRequestBody(SpeechPrompt request) { + OpenAiAudioSpeechOptions options = this.defaultOptions; + + if (request.getOptions() != null) { + if (request.getOptions() instanceof OpenAiAudioSpeechOptions runtimeOptions) { + options = this.merge(options, runtimeOptions); + } + else { + throw new IllegalArgumentException("Prompt options are not of type SpeechOptions: " + + request.getOptions().getClass().getSimpleName()); + } + } + + String input = StringUtils.isNotBlank(options.getInput()) ? options.getInput() + : request.getInstructions().get(0).getText(); + + OpenAiAudioApi.SpeechRequest.Builder requestBuilder = OpenAiAudioApi.SpeechRequest.builder() + .withModel(options.getModel()) + .withInput(input) + .withVoice(options.getVoice()) + .withResponseFormat(options.getResponseFormat()) + .withSpeed(options.getSpeed()); + + return requestBuilder.build(); + } + + private OpenAiAudioSpeechOptions merge(OpenAiAudioSpeechOptions source, OpenAiAudioSpeechOptions target) { + OpenAiAudioSpeechOptions.Builder mergedBuilder = OpenAiAudioSpeechOptions.builder(); + + mergedBuilder.withModel(source.getModel() != null ? source.getModel() : target.getModel()); + mergedBuilder.withInput(source.getInput() != null ? source.getInput() : target.getInput()); + mergedBuilder.withVoice(source.getVoice() != null ? source.getVoice() : target.getVoice()); + mergedBuilder.withResponseFormat( + source.getResponseFormat() != null ? source.getResponseFormat() : target.getResponseFormat()); + mergedBuilder.withSpeed(source.getSpeed() != null ? source.getSpeed() : target.getSpeed()); + + return mergedBuilder.build(); + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechOptions.java new file mode 100644 index 00000000000..e88ac688017 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechOptions.java @@ -0,0 +1,204 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.model.ModelOptions; +import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.AudioResponseFormat; +import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.Voice; + +/** + * Options for OpenAI text to audio - speech synthesis. + * + * @author Ahmed Yousri + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class OpenAiAudioSpeechOptions implements ModelOptions { + + /** + * ID of the model to use for generating the audio. One of the available TTS models: + * tts-1 or tts-1-hd. + */ + @JsonProperty("model") + private String model; + + /** + * The input text to synthesize. Must be at most 4096 tokens long. + */ + @JsonProperty("input") + private String input; + + /** + * The voice to use for synthesis. One of the available voices for the chosen model: + * 'alloy', 'echo', 'fable', 'onyx', 'nova', and 'shimmer'. + */ + @JsonProperty("voice") + private Voice voice; + + /** + * The format of the audio output. Supported formats are mp3, opus, aac, and flac. + * Defaults to mp3. + */ + @JsonProperty("response_format") + private AudioResponseFormat responseFormat; + + /** + * The speed of the voice synthesis. The acceptable range is from 0.0 (slowest) to 1.0 + * (fastest). + */ + @JsonProperty("speed") + private Float speed; + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private final OpenAiAudioSpeechOptions options = new OpenAiAudioSpeechOptions(); + + public Builder withModel(String model) { + options.model = model; + return this; + } + + public Builder withInput(String input) { + options.input = input; + return this; + } + + public Builder withVoice(Voice voice) { + options.voice = voice; + return this; + } + + public Builder withResponseFormat(AudioResponseFormat responseFormat) { + options.responseFormat = responseFormat; + return this; + } + + public Builder withSpeed(Float speed) { + options.speed = speed; + return this; + } + + public OpenAiAudioSpeechOptions build() { + return options; + } + + } + + public String getModel() { + return model; + } + + public String getInput() { + return input; + } + + public Voice getVoice() { + return voice; + } + + public AudioResponseFormat getResponseFormat() { + return responseFormat; + } + + public Float getSpeed() { + return speed; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((model == null) ? 0 : model.hashCode()); + result = prime * result + ((input == null) ? 0 : input.hashCode()); + result = prime * result + ((voice == null) ? 0 : voice.hashCode()); + result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode()); + result = prime * result + ((speed == null) ? 0 : speed.hashCode()); + return result; + } + + public void setModel(String model) { + this.model = model; + } + + public void setInput(String input) { + this.input = input; + } + + public void setVoice(Voice voice) { + this.voice = voice; + } + + public void setResponseFormat(AudioResponseFormat responseFormat) { + this.responseFormat = responseFormat; + } + + public void setSpeed(Float speed) { + this.speed = speed; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + OpenAiAudioSpeechOptions other = (OpenAiAudioSpeechOptions) obj; + if (model == null) { + if (other.model != null) + return false; + } + else if (!model.equals(other.model)) + return false; + if (input == null) { + if (other.input != null) + return false; + } + else if (!input.equals(other.input)) + return false; + if (voice == null) { + if (other.voice != null) + return false; + } + else if (!voice.equals(other.voice)) + return false; + if (responseFormat == null) { + if (other.responseFormat != null) + return false; + } + else if (!responseFormat.equals(other.responseFormat)) + return false; + if (speed == null) { + return other.speed == null; + } + else + return speed.equals(other.speed); + } + + @Override + public String toString() { + return "OpenAiAudioSpeechOptions{" + "model='" + model + '\'' + ", input='" + input + '\'' + ", voice='" + voice + + '\'' + ", responseFormat='" + responseFormat + '\'' + ", speed=" + speed + '}'; + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java index 65e5ca310af..3daa981ec2f 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java @@ -23,12 +23,17 @@ import org.springframework.ai.retry.RetryUtils; import org.springframework.core.io.ByteArrayResource; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; /** * Turn audio into text or text into audio. Based on @@ -41,12 +46,15 @@ public class OpenAiAudioApi { private final RestClient restClient; + private final WebClient webClient; + /** * Create an new audio api. * @param openAiToken OpenAI apiKey. */ public OpenAiAudioApi(String openAiToken) { - this(ApiUtils.DEFAULT_BASE_URL, openAiToken, RestClient.builder(), RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); + this(ApiUtils.DEFAULT_BASE_URL, openAiToken, RestClient.builder(), WebClient.builder(), + RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); } /** @@ -62,6 +70,30 @@ public OpenAiAudioApi(String baseUrl, String openAiToken, RestClient.Builder res this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(headers -> { headers.setBearerAuth(openAiToken); }).defaultStatusHandler(responseErrorHandler).build(); + + this.webClient = WebClient.builder().baseUrl(baseUrl).defaultHeaders(headers -> { + headers.setBearerAuth(openAiToken); + }).defaultHeaders(ApiUtils.getJsonContentHeaders(openAiToken)).build(); + } + + /** + * Create an new chat completion api. + * @param baseUrl api base URL. + * @param openAiToken OpenAI apiKey. + * @param restClientBuilder RestClient builder. + * @param webClientBuilder WebClient builder. + * @param responseErrorHandler Response error handler. + */ + public OpenAiAudioApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder, + WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { + + this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(headers -> { + headers.setBearerAuth(openAiToken); + }).defaultStatusHandler(responseErrorHandler).build(); + + this.webClient = webClientBuilder.baseUrl(baseUrl).defaultHeaders(headers -> { + headers.setBearerAuth(openAiToken); + }).defaultHeaders(ApiUtils.getJsonContentHeaders(openAiToken)).build(); } /** @@ -570,6 +602,30 @@ public ResponseEntity createSpeech(SpeechRequest requestBody) { return this.restClient.post().uri("/v1/audio/speech").body(requestBody).retrieve().toEntity(byte[].class); } + /** + * Streams audio generated from the input text. + * + * This method sends a POST request to the OpenAI API to generate audio from the + * provided text. The audio is streamed back as a Flux of ResponseEntity objects, each + * containing a byte array of the audio data. + * @param requestBody The request body containing the details for the audio + * generation, such as the input text, model, voice, and response format. + * @return A Flux of ResponseEntity objects, each containing a byte array of the audio + * data. + */ + public Flux> stream(SpeechRequest requestBody) { + + return webClient.post() + .uri("/v1/audio/speech") + .body(Mono.just(requestBody), SpeechRequest.class) + .accept(MediaType.APPLICATION_OCTET_STREAM) + .exchangeToFlux(clientResponse -> { + HttpHeaders headers = clientResponse.headers().asHttpHeaders(); + return clientResponse.bodyToFlux(byte[].class) + .map(bytes -> ResponseEntity.ok().headers(headers).body(bytes)); + }); + } + /** * Transcribes audio into the input language. * @param requestBody The request body. diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/Speech.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/Speech.java new file mode 100644 index 00000000000..b3934182fc3 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/Speech.java @@ -0,0 +1,73 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.openai.audio.speech; + +import org.springframework.ai.model.ModelResult; +import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechMetadata; +import org.springframework.lang.Nullable; + +import java.util.Arrays; +import java.util.Objects; + +/** + * @author Ahmed Yousri + */ + +public class Speech implements ModelResult { + + private final byte[] audio; + + private OpenAiAudioSpeechMetadata speechMetadata; + + public Speech(byte[] audio) { + this.audio = audio; + } + + @Override + public byte[] getOutput() { + return this.audio; + } + + @Override + public OpenAiAudioSpeechMetadata getMetadata() { + return speechMetadata != null ? speechMetadata : OpenAiAudioSpeechMetadata.NULL; + } + + public Speech withSpeechMetadata(@Nullable OpenAiAudioSpeechMetadata speechMetadata) { + this.speechMetadata = speechMetadata; + return this; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof Speech that)) + return false; + return Arrays.equals(audio, that.audio) && Objects.equals(speechMetadata, that.speechMetadata); + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(audio), speechMetadata); + } + + @Override + public String toString() { + return "Speech{" + "text=" + audio + ", speechMetadata=" + speechMetadata + '}'; + } + +} \ No newline at end of file diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechClient.java new file mode 100644 index 00000000000..d15d04d7bff --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechClient.java @@ -0,0 +1,49 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.audio.speech; + +import org.springframework.ai.model.ModelClient; + +/** + * The {@link SpeechClient} interface provides a way to interact with the OpenAI Text-to-Speech (TTS) API. + * It allows you to convert text input into lifelike spoken audio. + * + * @author Ahmed Yousri + */ +@FunctionalInterface +public interface SpeechClient extends ModelClient { + + /** + * Generates spoken audio from the provided text message. + * + * @param message the text message to be converted to audio + * @return the resulting audio bytes + */ + default byte[] call(String message) { + SpeechPrompt prompt = new SpeechPrompt(message); + return call(prompt).getResult().getOutput(); + } + + /** + * Sends a speech request to the OpenAI TTS API and returns the resulting speech response. + * + * @param request the speech prompt containing the input text and other parameters + * @return the speech response containing the generated audio + */ + SpeechResponse call(SpeechPrompt request); + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechMessage.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechMessage.java new file mode 100644 index 00000000000..5dc5a5b1845 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechMessage.java @@ -0,0 +1,52 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.openai.audio.speech; + +/** + * The {@link SpeechMessage} class represents a single text message to be converted to speech by the OpenAI TTS API. + * + * @author Ahmed Yousri + */ +public class SpeechMessage { + private String text; + + /** + * Constructs a new {@link SpeechMessage} object with the given text. + * + * @param text the text to be converted to speech + */ + public SpeechMessage(String text) { + this.text = text; + } + + /** + * Returns the text of this speech message. + * + * @return the text of this speech message + */ + public String getText() { + return text; + } + + /** + * Sets the text of this speech message. + * + * @param text the new text for this speech message + */ + public void setText(String text) { + this.text = text; + } +} \ No newline at end of file diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechPrompt.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechPrompt.java new file mode 100644 index 00000000000..aea41209564 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechPrompt.java @@ -0,0 +1,68 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.openai.audio.speech; + +import org.springframework.ai.model.ModelOptions; +import org.springframework.ai.model.ModelRequest; +import org.springframework.ai.openai.OpenAiAudioSpeechOptions; + +import java.util.Collections; +import java.util.List; + +/** + * The {@link SpeechPrompt} class represents a request to the OpenAI Text-to-Speech (TTS) API. + * It contains a list of {@link SpeechMessage} objects, each representing a piece of text to be converted to speech. + * + * @author Ahmed Yousri + */ +public class SpeechPrompt implements ModelRequest> { + + private OpenAiAudioSpeechOptions speechOptions; + + private final List messages; + + public SpeechPrompt(List messages) { + this.messages = messages; + } + + public SpeechPrompt(List messages, OpenAiAudioSpeechOptions modelOptions) { + this.messages = messages; + this.speechOptions = modelOptions; + } + + public SpeechPrompt(SpeechMessage speechMessage, OpenAiAudioSpeechOptions speechOptions) { + this(Collections.singletonList(speechMessage), speechOptions); + } + + public SpeechPrompt(String instructions, OpenAiAudioSpeechOptions speechOptions) { + this(new SpeechMessage(instructions), speechOptions); + } + + public SpeechPrompt(String instructions) { + this(new SpeechMessage(instructions), OpenAiAudioSpeechOptions.builder().build()); + } + + @Override + public List getInstructions() { + return this.messages; + } + + @Override + public ModelOptions getOptions() { + return speechOptions; + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechResponse.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechResponse.java new file mode 100644 index 00000000000..e5e69696c40 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechResponse.java @@ -0,0 +1,59 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.audio.speech; + +import org.springframework.ai.model.ModelResponse; +import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechResponseMetadata; + +import java.util.Collections; +import java.util.List; + +/** + * @author Ahmed Yousri + */ + +public class SpeechResponse implements ModelResponse { + + private final Speech speech; + + private final OpenAiAudioSpeechResponseMetadata speechResponseMetadata; + + public SpeechResponse(Speech speech) { + this(speech, OpenAiAudioSpeechResponseMetadata.NULL); + } + + public SpeechResponse(Speech speech, OpenAiAudioSpeechResponseMetadata speechResponseMetadata) { + this.speech = speech; + this.speechResponseMetadata = speechResponseMetadata; + } + + @Override + public Speech getResult() { + return speech; + } + + @Override + public List getResults() { + return Collections.singletonList(speech); + } + + @Override + public OpenAiAudioSpeechResponseMetadata getMetadata() { + return speechResponseMetadata; + } + +} \ No newline at end of file diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/StreamingSpeechClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/StreamingSpeechClient.java new file mode 100644 index 00000000000..e81d7fe1115 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/StreamingSpeechClient.java @@ -0,0 +1,51 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.audio.speech; + +import org.springframework.ai.model.StreamingModelClient; +import reactor.core.publisher.Flux; + +/** + * The {@link StreamingSpeechClient} interface provides a way to interact with the OpenAI Text-to-Speech (TTS) API + * using a streaming approach, allowing you to receive the generated audio in a real-time fashion. + * + * @author Ahmed Yousri + */ +@FunctionalInterface +public interface StreamingSpeechClient extends StreamingModelClient { + + /** + * Generates a stream of audio bytes from the provided text message. + * + * @param message the text message to be converted to audio + * @return a Flux of audio bytes representing the generated speech + */ + default Flux stream(String message) { + SpeechPrompt prompt = new SpeechPrompt(message); + return stream(prompt).map(SpeechResponse::getResult).map(Speech::getOutput); + } + + /** + * Sends a speech request to the OpenAI TTS API and returns a stream of the resulting speech responses. + * + * @param prompt the speech prompt containing the input text and other parameters + * @return a Flux of speech responses, each containing a portion of the generated audio + */ + @Override + Flux stream(SpeechPrompt prompt); + +} \ No newline at end of file diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechMetadata.java new file mode 100644 index 00000000000..85289d85408 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechMetadata.java @@ -0,0 +1,34 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.metadata.audio; + +import org.springframework.ai.model.ResultMetadata; + +public interface OpenAiAudioSpeechMetadata extends ResultMetadata { + + OpenAiAudioSpeechMetadata NULL = OpenAiAudioSpeechMetadata.create(); + + /** + * Factory method used to construct a new {@link OpenAiAudioSpeechMetadata} + * @return a new {@link OpenAiAudioSpeechMetadata} + */ + static OpenAiAudioSpeechMetadata create() { + return new OpenAiAudioSpeechMetadata() { + }; + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechResponseMetadata.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechResponseMetadata.java new file mode 100644 index 00000000000..1c6260ffd06 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechResponseMetadata.java @@ -0,0 +1,78 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.metadata.audio; + +import org.springframework.ai.chat.metadata.EmptyRateLimit; +import org.springframework.ai.chat.metadata.RateLimit; +import org.springframework.ai.model.ResponseMetadata; +import org.springframework.ai.openai.api.OpenAiAudioApi; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Audio speech metadata implementation for {@literal OpenAI}. + * + * @author Ahmed Yousri + * @see RateLimit + */ +public class OpenAiAudioSpeechResponseMetadata implements ResponseMetadata { + + protected static final String AI_METADATA_STRING = "{ @type: %1$s, requestsLimit: %2$s }"; + + public static final OpenAiAudioSpeechResponseMetadata NULL = new OpenAiAudioSpeechResponseMetadata() { + }; + + public static OpenAiAudioSpeechResponseMetadata from(OpenAiAudioApi.StructuredResponse result) { + Assert.notNull(result, "OpenAI speech must not be null"); + OpenAiAudioSpeechResponseMetadata speechResponseMetadata = new OpenAiAudioSpeechResponseMetadata(); + return speechResponseMetadata; + } + + public static OpenAiAudioSpeechResponseMetadata from(String result) { + Assert.notNull(result, "OpenAI speech must not be null"); + OpenAiAudioSpeechResponseMetadata speechResponseMetadata = new OpenAiAudioSpeechResponseMetadata(); + return speechResponseMetadata; + } + + @Nullable + private RateLimit rateLimit; + + public OpenAiAudioSpeechResponseMetadata() { + this(null); + } + + public OpenAiAudioSpeechResponseMetadata(@Nullable RateLimit rateLimit) { + this.rateLimit = rateLimit; + } + + @Nullable + public RateLimit getRateLimit() { + RateLimit rateLimit = this.rateLimit; + return rateLimit != null ? rateLimit : new EmptyRateLimit(); + } + + public OpenAiAudioSpeechResponseMetadata withRateLimit(RateLimit rateLimit) { + this.rateLimit = rateLimit; + return this; + } + + @Override + public String toString() { + return AI_METADATA_STRING.formatted(getClass().getName(), getRateLimit()); + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java index e3a70725f3f..3a235cf7aa5 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java @@ -62,6 +62,12 @@ public OpenAiAudioTranscriptionClient openAiTranscriptionClient(OpenAiAudioApi a return openAiTranscriptionClient; } + @Bean + public OpenAiAudioSpeechClient openAiAudioSpeechClient(OpenAiAudioApi api) { + OpenAiAudioSpeechClient openAiAudioSpeechClient = new OpenAiAudioSpeechClient(api); + return openAiAudioSpeechClient; + } + @Bean public OpenAiImageClient openAiImageClient(OpenAiImageApi imageApi) { OpenAiImageClient openAiImageClient = new OpenAiImageClient(imageApi); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechClientIT.java new file mode 100644 index 00000000000..aca93deca69 --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechClientIT.java @@ -0,0 +1,115 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.audio.speech; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.openai.OpenAiAudioSpeechOptions; +import org.springframework.ai.openai.OpenAiTestConfiguration; +import org.springframework.ai.openai.api.OpenAiAudioApi; +import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechResponseMetadata; +import org.springframework.ai.openai.testutils.AbstractIT; +import org.springframework.boot.test.context.SpringBootTest; +import reactor.core.publisher.Flux; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +@SpringBootTest(classes = OpenAiTestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") +class OpenAiSpeechClientIT extends AbstractIT { + + private static final Float SPEED = 1.0f; + + @Test + void shouldSuccessfullyStreamAudioBytesForEmptyMessage() { + Flux response = openAiAudioSpeechClient + .stream("Today is a wonderful day to build something people love!"); + assertThat(response).isNotNull(); + assertThat(response.collectList().block()).isNotNull(); + System.out.println(response.collectList().block()); + } + + @Test + void shouldProduceAudioBytesDirectlyFromMessage() { + byte[] audioBytes = openAiAudioSpeechClient.call("Today is a wonderful day to build something people love!"); + assertThat(audioBytes).hasSizeGreaterThan(0); + + } + + @Test + void shouldGenerateNonEmptyMp3AudioFromSpeechPrompt() { + OpenAiAudioSpeechOptions speechOptions = OpenAiAudioSpeechOptions.builder() + .withVoice(OpenAiAudioApi.SpeechRequest.Voice.ALLOY) + .withSpeed(SPEED) + .withResponseFormat(OpenAiAudioApi.SpeechRequest.AudioResponseFormat.MP3) + .withModel(OpenAiAudioApi.TtsModel.TTS_1.value) + .build(); + SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!", + speechOptions); + SpeechResponse response = openAiAudioSpeechClient.call(speechPrompt); + byte[] audioBytes = response.getResult().getOutput(); + assertThat(response.getResults()).hasSize(1); + assertThat(response.getResults().get(0).getOutput()).isNotEmpty(); + assertThat(audioBytes).hasSizeGreaterThan(0); + + } + + @Test + void speechRateLimitTest() { + OpenAiAudioSpeechOptions speechOptions = OpenAiAudioSpeechOptions.builder() + .withVoice(OpenAiAudioApi.SpeechRequest.Voice.ALLOY) + .withSpeed(SPEED) + .withResponseFormat(OpenAiAudioApi.SpeechRequest.AudioResponseFormat.MP3) + .withModel(OpenAiAudioApi.TtsModel.TTS_1.value) + .build(); + SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!", + speechOptions); + SpeechResponse response = openAiAudioSpeechClient.call(speechPrompt); + OpenAiAudioSpeechResponseMetadata metadata = response.getMetadata(); + assertThat(metadata).isNotNull(); + assertThat(metadata.getRateLimit()).isNotNull(); + assertThat(metadata.getRateLimit().getRequestsLimit()).isPositive(); + assertThat(metadata.getRateLimit().getRequestsLimit()).isPositive(); + + } + + @Test + void shouldStreamNonEmptyResponsesForValidSpeechPrompts() { + + + OpenAiAudioSpeechOptions speechOptions = OpenAiAudioSpeechOptions.builder() + .withVoice(OpenAiAudioApi.SpeechRequest.Voice.ALLOY) + .withSpeed(SPEED) + .withResponseFormat(OpenAiAudioApi.SpeechRequest.AudioResponseFormat.MP3) + .withModel(OpenAiAudioApi.TtsModel.TTS_1.value) + .build(); + + SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!", + speechOptions); + Flux responseFlux = openAiAudioSpeechClient.stream(speechPrompt); + assertThat(responseFlux).isNotNull(); + List responses = responseFlux.collectList().block(); + assertThat(responses).isNotNull(); + responses.forEach(response -> { + System.out.println("Audio data chunk size: " + response.getResult().getOutput().length); + assertThat(response.getResult().getOutput()).isNotEmpty(); + }); + } + +} \ No newline at end of file diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechClientWithSpeechResponseMetadataTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechClientWithSpeechResponseMetadataTests.java new file mode 100644 index 00000000000..d2627eba95f --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechClientWithSpeechResponseMetadataTests.java @@ -0,0 +1,133 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.audio.speech; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.springframework.ai.openai.OpenAiAudioSpeechClient; +import org.springframework.ai.openai.OpenAiAudioSpeechOptions; +import org.springframework.ai.openai.api.OpenAiAudioApi; +import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechResponseMetadata; +import org.springframework.ai.openai.metadata.support.OpenAiApiResponseHeaders; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.autoconfigure.web.client.RestClientTest; +import org.springframework.context.annotation.Bean; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.test.web.client.MockRestServiceServer; +import org.springframework.web.client.RestClient; + +import java.time.Duration; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.*; +import static org.springframework.test.web.client.response.MockRestResponseCreators.withSuccess; + +/** + * @author Ahmed Yousri + */ +@RestClientTest(OpenAiSpeechClientWithSpeechResponseMetadataTests.Config.class) +public class OpenAiSpeechClientWithSpeechResponseMetadataTests { + + private static String TEST_API_KEY = "sk-1234567890"; + + private static final Float SPEED = 1.0f; + + @Autowired + private OpenAiAudioSpeechClient openAiSpeechClient; + + @Autowired + private MockRestServiceServer server; + + @AfterEach + void resetMockServer() { + server.reset(); + } + + @Test + void aiResponseContainsImageResponseMetadata() { + + prepareMock(); + + OpenAiAudioSpeechOptions speechOptions = OpenAiAudioSpeechOptions.builder() + .withVoice(OpenAiAudioApi.SpeechRequest.Voice.ALLOY) + .withSpeed(SPEED) + .withResponseFormat(OpenAiAudioApi.SpeechRequest.AudioResponseFormat.MP3) + .withModel(OpenAiAudioApi.TtsModel.TTS_1.value) + .build(); + + SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!", + speechOptions); + SpeechResponse response = openAiSpeechClient.call(speechPrompt); + + byte[] audioBytes = response.getResult().getOutput(); + assertThat(audioBytes).hasSizeGreaterThan(0); + + OpenAiAudioSpeechResponseMetadata speechResponseMetadata = response.getMetadata(); + assertThat(speechResponseMetadata).isNotNull(); + var requestLimit = speechResponseMetadata.getRateLimit(); + Long requestsLimit = requestLimit.getRequestsLimit(); + Long tokensLimit = requestLimit.getTokensLimit(); + Long tokensRemaining = requestLimit.getTokensRemaining(); + Long requestsRemaining = requestLimit.getRequestsRemaining(); + Duration requestsReset = requestLimit.getRequestsReset(); + assertThat(requestsLimit).isNotNull(); + assertThat(requestsLimit).isEqualTo(4000L); + assertThat(tokensLimit).isEqualTo(725000L); + assertThat(tokensRemaining).isEqualTo(112358L); + assertThat(requestsRemaining).isEqualTo(999L); + assertThat(requestsReset).isEqualTo(Duration.parse("PT64H15M29S")); + + } + + private void prepareMock() { + + HttpHeaders httpHeaders = new HttpHeaders(); + httpHeaders.set(OpenAiApiResponseHeaders.REQUESTS_LIMIT_HEADER.getName(), "4000"); + httpHeaders.set(OpenAiApiResponseHeaders.REQUESTS_REMAINING_HEADER.getName(), "999"); + httpHeaders.set(OpenAiApiResponseHeaders.REQUESTS_RESET_HEADER.getName(), "2d16h15m29s"); + httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_LIMIT_HEADER.getName(), "725000"); + httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_REMAINING_HEADER.getName(), "112358"); + httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_RESET_HEADER.getName(), "27h55s451ms"); + httpHeaders.setContentType(MediaType.APPLICATION_OCTET_STREAM); + + server.expect(requestTo("/v1/audio/speech")) + .andExpect(method(HttpMethod.POST)) + .andExpect(header(HttpHeaders.AUTHORIZATION, "Bearer " + TEST_API_KEY)) + .andRespond(withSuccess("Audio bytes as string", MediaType.APPLICATION_OCTET_STREAM).headers(httpHeaders)); + + } + + @SpringBootConfiguration + static class Config { + + @Bean + public OpenAiAudioSpeechClient openAiAudioSpeechClient(OpenAiAudioApi openAiAudioApi) { + return new OpenAiAudioSpeechClient(openAiAudioApi); + } + + @Bean + public OpenAiAudioApi openAiAudioApi(RestClient.Builder builder) { + return new OpenAiAudioApi("", TEST_API_KEY, builder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); + } + + } + +} \ No newline at end of file diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java index b8e1c86a27e..9c2b8780a4b 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java @@ -29,6 +29,7 @@ import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.image.ImageClient; +import org.springframework.ai.openai.OpenAiAudioSpeechClient; import org.springframework.ai.openai.OpenAiAudioTranscriptionClient; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; @@ -47,6 +48,9 @@ public abstract class AbstractIT { @Autowired protected OpenAiAudioTranscriptionClient openAiTranscriptionClient; + @Autowired + protected OpenAiAudioSpeechClient openAiAudioSpeechClient; + @Autowired protected ImageClient openaiImageClient; diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index 5687ff79ab9..2d356e7cc85 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -37,6 +37,8 @@ *** xref:api/image/stabilityai-image.adoc[Stability] ** xref:api/transcriptions.adoc[] *** xref:api/transcriptions/openai-transcriptions.adoc[OpenAI] +** xref:api/speech.adoc[] +*** xref:api/speech/openai-speech.adoc[OpenAI] ** xref:api/vectordbs.adoc[] *** xref:api/vectordbs/azure.adoc[] *** xref:api/vectordbs/chroma.adoc[] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/speech.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/speech.adoc new file mode 100644 index 00000000000..eedfd2fda9b --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/speech.adoc @@ -0,0 +1,5 @@ +[[Speech]] += Speech API + +Spring AI provides support for OpenAI's Speech API. +When additional providers for Speech are implemented, a common `SpeechClient` and `StreamingSpeechClient` interface will be extracted. \ No newline at end of file diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/speech/openai-speech.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/speech/openai-speech.adoc new file mode 100644 index 00000000000..f1ae2da357e --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/speech/openai-speech.adoc @@ -0,0 +1,144 @@ += OpenAI Text-to-Speech (TTS) Integration + +== Introduction + +The Audio API provides a speech endpoint based on our TTS (text-to-speech) model. It can be used to: + +- Narrate a written blog post +- Produce spoken audio in multiple languages +- Give real time audio output using streaming + +== Prerequisites + +. Create an OpenAI account and obtain an API key. You can sign up at the https://platform.openai.com/signup[OpenAI signup page] and generate an API key on the https://platform.openai.com/account/api-keys[API Keys page]. +. Add the `spring-ai-openai` dependency to your project's build file. Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section for more information. + +== Auto-configuration + +Spring AI provides Spring Boot auto-configuration for the OpenAI Text-to-Speech Client. +To enable it add the following dependency to your project's Maven `pom.xml` file: + +[source,xml] +---- + + org.springframework.ai + spring-ai-openai-spring-boot-starter + +---- + +or to your Gradle `build.gradle` build file: + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-openai-spring-boot-starter' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +=== TTS Properties + +The prefix `spring.ai.openai.audio.speech` is used as the property prefix that lets you configure the OpenAI Text-to-Speech client. + +[cols="3,5,2"] +|==== +| Property | Description | Default + +| spring.ai.openai.audio.speech.options.model | ID of the model to use. Only tts-1 is currently available. | tts-1 +| spring.ai.openai.audio.speech.options.voice | The voice to use for the TTS output. Available options are: alloy, echo, fable, onyx, nova, and shimmer. | alloy +| spring.ai.openai.audio.speech.options.response-format | The format of the audio output. Supported formats are mp3, opus, aac, flac, wav, and pcm. | mp3 +| spring.ai.openai.audio.speech.options.speed | The speed of the voice synthesis. The acceptable range is from 0.0 (slowest) to 1.0 (fastest). | 1.0 +|==== + +== Runtime Options [[speech-options]] + +The `OpenAiAudioSpeechOptions` class provides the options to use when making a text-to-speech request. +On start-up, the options specified by `spring.ai.openai.audio.speech` are used but you can override these at runtime. + +For example: + +[source,java] +---- +OpenAiAudioSpeechOptions speechOptions = OpenAiAudioSpeechOptions.builder() + .withModel("tts-1") + .withVoice(OpenAiAudioApi.SpeechRequest.Voice.ALLOY) + .withResponseFormat(OpenAiAudioApi.SpeechRequest.AudioResponseFormat.MP3) + .withSpeed(1.0f) + .build(); + +SpeechPrompt speechPrompt = new SpeechPrompt("Hello, this is a text-to-speech example.", speechOptions); +SpeechResponse response = openAiAudioSpeechClient.call(speechPrompt); +---- + +== Manual Configuration + +Add the `spring-ai-openai` dependency to your project's Maven `pom.xml` file: + +[source,xml] +---- + + org.springframework.ai + spring-ai-openai + +---- + +or to your Gradle `build.gradle` build file: + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-openai' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +Next, create an `OpenAiAudioSpeechClient`: + +[source,java] +---- +var openAiAudioApi = new OpenAiAudioApi(System.getenv("OPENAI_API_KEY")); + +var openAiAudioSpeechClient = new OpenAiAudioSpeechClient(openAiAudioApi); + +var speechOptions = OpenAiAudioSpeechOptions.builder() + .withResponseFormat(OpenAiAudioApi.SpeechRequest.AudioResponseFormat.MP3) + .withSpeed(1.0f) + .withModel(OpenAiAudioApi.TtsModel.TTS_1.value) + .build(); + +var speechPrompt = new SpeechPrompt("Hello, this is a text-to-speech example.", speechOptions); +SpeechResponse response = openAiAudioSpeechClient.call(speechPrompt); + +// Accessing metadata (rate limit info) +OpenAiAudioSpeechResponseMetadata metadata = response.getMetadata(); + +byte[] responseAsBytes = response.getResult().getOutput(); +---- + +== Streaming Real-time Audio + +The Speech API provides support for real-time audio streaming using chunk transfer encoding. This means that the audio is able to be played before the full file has been generated and made accessible. + +[source,java] +---- +var openAiAudioApi = new OpenAiAudioApi(System.getenv("OPENAI_API_KEY")); + +var openAiAudioSpeechClient = new OpenAiAudioSpeechClient(openAiAudioApi); + +OpenAiAudioSpeechOptions speechOptions = OpenAiAudioSpeechOptions.builder() + .withVoice(OpenAiAudioApi.SpeechRequest.Voice.ALLOY) + .withSpeed(1.0f) + .withResponseFormat(OpenAiAudioApi.SpeechRequest.AudioResponseFormat.MP3) + .withModel(OpenAiAudioApi.TtsModel.TTS_1.value) + .build(); + +SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!", speechOptions); + +Flux responseStream = openAiAudioSpeechClient.stream(speechPrompt); +---- + +== Example Code + +* The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechClientIT.java[OpenAiSpeechClientIT.java] test provides some general examples of how to use the library. diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioSpeechProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioSpeechProperties.java new file mode 100644 index 00000000000..662df9c70ea --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioSpeechProperties.java @@ -0,0 +1,57 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.openai; + +import org.springframework.ai.openai.OpenAiAudioSpeechOptions; +import org.springframework.ai.openai.api.OpenAiAudioApi; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * @author Ahmed Yousri + */ + +@ConfigurationProperties(OpenAiAudioSpeechProperties.CONFIG_PREFIX) +public class OpenAiAudioSpeechProperties extends OpenAiParentProperties { + + public static final String CONFIG_PREFIX = "spring.ai.openai.audio.speech"; + + public static final String DEFAULT_SPEECH_MODEL = OpenAiAudioApi.TtsModel.TTS_1.getValue(); + + private static final Float SPEED = 1.0f; + + private static final OpenAiAudioApi.SpeechRequest.Voice VOICE = OpenAiAudioApi.SpeechRequest.Voice.ALLOY; + + private static final OpenAiAudioApi.SpeechRequest.AudioResponseFormat DEFAULT_RESPONSE_FORMAT = OpenAiAudioApi.SpeechRequest.AudioResponseFormat.MP3; + + @NestedConfigurationProperty + private OpenAiAudioSpeechOptions options = OpenAiAudioSpeechOptions.builder() + .withModel(DEFAULT_SPEECH_MODEL) + .withResponseFormat(DEFAULT_RESPONSE_FORMAT) + .withVoice(VOICE) + .withSpeed(SPEED) + .build(); + + public OpenAiAudioSpeechOptions getOptions() { + return options; + } + + public void setOptions(OpenAiAudioSpeechOptions options) { + this.options = options; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java index ae1daff1f00..848456468f2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java @@ -24,6 +24,7 @@ import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.ai.openai.OpenAiEmbeddingClient; import org.springframework.ai.openai.OpenAiImageClient; +import org.springframework.ai.openai.OpenAiAudioSpeechClient; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.ai.openai.api.OpenAiImageApi; @@ -48,7 +49,8 @@ @AutoConfiguration(after = { RestClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class }) @ConditionalOnClass(OpenAiApi.class) @EnableConfigurationProperties({ OpenAiConnectionProperties.class, OpenAiChatProperties.class, - OpenAiEmbeddingProperties.class, OpenAiImageProperties.class, OpenAiAudioTranscriptionProperties.class }) + OpenAiEmbeddingProperties.class, OpenAiImageProperties.class, OpenAiAudioTranscriptionProperties.class, + OpenAiAudioSpeechProperties.class }) public class OpenAiAutoConfiguration { @Bean @@ -142,6 +144,28 @@ public OpenAiAudioTranscriptionClient openAiAudioTranscriptionClient(OpenAiConne return openAiChatClient; } + @Bean + @ConditionalOnMissingBean + public OpenAiAudioSpeechClient openAiAudioSpeechClient(OpenAiConnectionProperties commonProperties, + OpenAiAudioSpeechProperties speechProperties, ResponseErrorHandler responseErrorHandler) { + + String apiKey = StringUtils.hasText(speechProperties.getApiKey()) ? speechProperties.getApiKey() + : commonProperties.getApiKey(); + + String baseUrl = StringUtils.hasText(speechProperties.getBaseUrl()) ? speechProperties.getBaseUrl() + : commonProperties.getBaseUrl(); + + Assert.hasText(apiKey, "OpenAI API key must be set"); + Assert.hasText(baseUrl, "OpenAI base URL must be set"); + + var openAiAudioApi = new OpenAiAudioApi(baseUrl, apiKey, RestClient.builder(), responseErrorHandler); + + OpenAiAudioSpeechClient openAiSpeechClient = new OpenAiAudioSpeechClient(openAiAudioApi, + speechProperties.getOptions()); + + return openAiSpeechClient; + } + @Bean @ConditionalOnMissingBean public FunctionCallbackContext springAiFunctionManager(ApplicationContext context) { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java index 41bbfb8c7b1..0690985eabb 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.autoconfigure.openai; +import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; @@ -30,6 +31,7 @@ import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; import reactor.core.publisher.Flux; +import org.springframework.ai.openai.OpenAiAudioSpeechClient; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.ChatResponse; @@ -74,6 +76,32 @@ void transcribe() { }); } + @Test + void synthesize() { + contextRunner.run(context -> { + OpenAiAudioSpeechClient client = context.getBean(OpenAiAudioSpeechClient.class); + byte[] response = client.call("H"); + assertThat(response).isNotNull(); + assertThat(verifyMp3FrameHeader(response)) + .withFailMessage("Expected MP3 frame header to be present in the response, but it was not found.") + .isTrue(); + assertThat(response.length).isNotEqualTo(0); + + logger.info("Response: " + Arrays.toString(response)); + }); + } + + public boolean verifyMp3FrameHeader(byte[] audioResponse) { + // Check if the response is null or too short to contain a frame header + if (audioResponse == null || audioResponse.length < 2) { + return false; + } + // Check for the MP3 frame header + // 0xFFE0 is the sync word for an MP3 frame (11 bits set to 1 followed by 3 bits + // set to 0) + return (audioResponse[0] & 0xFF) == 0xFF && (audioResponse[1] & 0xE0) == 0xE0; + } + @Test void generateStreaming() { contextRunner.run(context -> { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java index e14a6a562de..79606e9bfd2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java @@ -155,6 +155,101 @@ public void transcriptionOverrideConnectionProperties() { }); } + @Test + public void speechProperties() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.openai.base-url=TEST_BASE_URL", + "spring.ai.openai.api-key=abc123", + "spring.ai.openai.audio.speech.options.model=TTS_1", + "spring.ai.openai.audio.speech.options.voice=alloy", + "spring.ai.openai.audio.speech.options.response-format=mp3", + "spring.ai.openai.audio.speech.options.speed=0.75") + // @formatter:on + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .run(context -> { + var speechProperties = context.getBean(OpenAiAudioSpeechProperties.class); + var connectionProperties = context.getBean(OpenAiConnectionProperties.class); + + assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + + assertThat(speechProperties.getApiKey()).isNull(); + assertThat(speechProperties.getBaseUrl()).isNull(); + + assertThat(speechProperties.getOptions().getModel()).isEqualTo("TTS_1"); + assertThat(speechProperties.getOptions().getVoice()) + .isEqualTo(OpenAiAudioApi.SpeechRequest.Voice.ALLOY); + assertThat(speechProperties.getOptions().getResponseFormat()) + .isEqualTo(OpenAiAudioApi.SpeechRequest.AudioResponseFormat.MP3); + assertThat(speechProperties.getOptions().getSpeed()).isEqualTo(0.75f); + }); + } + + @Test + public void speechPropertiesTest() { + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.openai.base-url=TEST_BASE_URL", + "spring.ai.openai.api-key=abc123", + "spring.ai.openai.audio.speech.options.model=TTS_1", + "spring.ai.openai.audio.speech.options.voice=alloy", + "spring.ai.openai.audio.speech.options.response-format=mp3", + "spring.ai.openai.audio.speech.options.speed=0.75") + // @formatter:on + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .run(context -> { + var speechProperties = context.getBean(OpenAiAudioSpeechProperties.class); + var connectionProperties = context.getBean(OpenAiConnectionProperties.class); + + assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + + assertThat(speechProperties.getOptions().getModel()).isEqualTo("TTS_1"); + assertThat(speechProperties.getOptions().getVoice()) + .isEqualTo(OpenAiAudioApi.SpeechRequest.Voice.ALLOY); + assertThat(speechProperties.getOptions().getResponseFormat()) + .isEqualTo(OpenAiAudioApi.SpeechRequest.AudioResponseFormat.MP3); + assertThat(speechProperties.getOptions().getSpeed()).isEqualTo(0.75f); + }); + } + + @Test + public void speechOverrideConnectionPropertiesTest() { + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.openai.base-url=TEST_BASE_URL", + "spring.ai.openai.api-key=abc123", + "spring.ai.openai.audio.speech.base-url=TEST_BASE_URL2", + "spring.ai.openai.audio.speech.api-key=456", + "spring.ai.openai.audio.speech.options.model=TTS_2", + "spring.ai.openai.audio.speech.options.voice=echo", + "spring.ai.openai.audio.speech.options.response-format=opus", + "spring.ai.openai.audio.speech.options.speed=0.5") + // @formatter:on + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .run(context -> { + var speechProperties = context.getBean(OpenAiAudioSpeechProperties.class); + var connectionProperties = context.getBean(OpenAiConnectionProperties.class); + + assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + + assertThat(speechProperties.getApiKey()).isEqualTo("456"); + assertThat(speechProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2"); + + assertThat(speechProperties.getOptions().getModel()).isEqualTo("TTS_2"); + assertThat(speechProperties.getOptions().getVoice()).isEqualTo(OpenAiAudioApi.SpeechRequest.Voice.ECHO); + assertThat(speechProperties.getOptions().getResponseFormat()) + .isEqualTo(OpenAiAudioApi.SpeechRequest.AudioResponseFormat.OPUS); + assertThat(speechProperties.getOptions().getSpeed()).isEqualTo(0.5f); + }); + } + @Test public void embeddingProperties() {