Skip to content

Added support for OpenAI Text to Audio (Speech API ) #317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.openai;

import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.openai.api.OpenAiAudioApi;
import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.AudioResponseFormat;
import org.springframework.ai.openai.api.common.OpenAiApiException;
import org.springframework.ai.openai.audio.speech.*;
import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechResponseMetadata;
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;

import java.time.Duration;

/**
* OpenAI audio speech client implementation for backed by {@link OpenAiAudioApi}.
*
* @author Ahmed Yousri
* @see OpenAiAudioApi
*/
public class OpenAiAudioSpeechClient implements SpeechClient, StreamingSpeechClient {

private final Logger logger = LoggerFactory.getLogger(getClass());

private final OpenAiAudioSpeechOptions defaultOptions;

private static final Float SPEED = 1.0f;

public final RetryTemplate retryTemplate = RetryTemplate.builder()
.maxAttempts(10)
.retryOn(OpenAiApiException.class)
.exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000))
.build();

private final OpenAiAudioApi audioApi;

public OpenAiAudioSpeechClient(OpenAiAudioApi audioApi) {
this(audioApi,
OpenAiAudioSpeechOptions.builder()
.withModel(OpenAiAudioApi.TtsModel.TTS_1.getValue())
.withResponseFormat(AudioResponseFormat.MP3)
.withVoice(OpenAiAudioApi.SpeechRequest.Voice.ALLOY)
.withSpeed(SPEED)
.build());
}

public OpenAiAudioSpeechClient(OpenAiAudioApi audioApi, OpenAiAudioSpeechOptions options) {
Assert.notNull(audioApi, "OpenAiAudioApi must not be null");
Assert.notNull(options, "OpenAiSpeechOptions must not be null");
this.audioApi = audioApi;
this.defaultOptions = options;
}

@Override
public byte[] call(String text) {
SpeechPrompt speechRequest = new SpeechPrompt(text);
return call(speechRequest).getResult().getOutput();
}

@Override
public SpeechResponse call(SpeechPrompt speechPrompt) {

return this.retryTemplate.execute(ctx -> {

OpenAiAudioApi.SpeechRequest speechRequest = createRequestBody(speechPrompt);

ResponseEntity<byte[]> speechEntity = this.audioApi.createSpeech(speechRequest);
var speech = speechEntity.getBody();

if (speech == null) {
logger.warn("No speech response returned for speechRequest: {}", speechRequest);
return new SpeechResponse(new Speech(new byte[0]));
}

RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(speechEntity);

return new SpeechResponse(new Speech(speech), new OpenAiAudioSpeechResponseMetadata(rateLimits));

});
}

/**
* Streams the audio response for the given speech prompt.
* @param prompt The speech prompt containing the text and options for speech
* synthesis.
* @return A Flux of SpeechResponse objects containing the streamed audio and
* metadata.
*/

@Override
public Flux<SpeechResponse> stream(SpeechPrompt prompt) {
return this.audioApi.stream(this.createRequestBody(prompt))
.map(entity -> new SpeechResponse(new Speech(entity.getBody()), new OpenAiAudioSpeechResponseMetadata(
OpenAiResponseHeaderExtractor.extractAiResponseHeaders(entity))));
}

private OpenAiAudioApi.SpeechRequest createRequestBody(SpeechPrompt request) {
OpenAiAudioSpeechOptions options = this.defaultOptions;

if (request.getOptions() != null) {
if (request.getOptions() instanceof OpenAiAudioSpeechOptions runtimeOptions) {
options = this.merge(options, runtimeOptions);
}
else {
throw new IllegalArgumentException("Prompt options are not of type SpeechOptions: "
+ request.getOptions().getClass().getSimpleName());
}
}

String input = StringUtils.isNotBlank(options.getInput()) ? options.getInput()
: request.getInstructions().get(0).getText();

OpenAiAudioApi.SpeechRequest.Builder requestBuilder = OpenAiAudioApi.SpeechRequest.builder()
.withModel(options.getModel())
.withInput(input)
.withVoice(options.getVoice())
.withResponseFormat(options.getResponseFormat())
.withSpeed(options.getSpeed());

return requestBuilder.build();
}

private OpenAiAudioSpeechOptions merge(OpenAiAudioSpeechOptions source, OpenAiAudioSpeechOptions target) {
OpenAiAudioSpeechOptions.Builder mergedBuilder = OpenAiAudioSpeechOptions.builder();

mergedBuilder.withModel(source.getModel() != null ? source.getModel() : target.getModel());
mergedBuilder.withInput(source.getInput() != null ? source.getInput() : target.getInput());
mergedBuilder.withVoice(source.getVoice() != null ? source.getVoice() : target.getVoice());
mergedBuilder.withResponseFormat(
source.getResponseFormat() != null ? source.getResponseFormat() : target.getResponseFormat());
mergedBuilder.withSpeed(source.getSpeed() != null ? source.getSpeed() : target.getSpeed());

return mergedBuilder.build();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.openai;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.springframework.ai.model.ModelOptions;
import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.AudioResponseFormat;
import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.Voice;

/**
* Options for OpenAI text to audio - speech synthesis.
*
* @author Ahmed Yousri
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
public class OpenAiAudioSpeechOptions implements ModelOptions {

/**
* ID of the model to use for generating the audio. One of the available TTS models:
* tts-1 or tts-1-hd.
*/
@JsonProperty("model")
private String model;

/**
* The input text to synthesize. Must be at most 4096 tokens long.
*/
@JsonProperty("input")
private String input;

/**
* The voice to use for synthesis. One of the available voices for the chosen model:
* 'alloy', 'echo', 'fable', 'onyx', 'nova', and 'shimmer'.
*/
@JsonProperty("voice")
private Voice voice;

/**
* The format of the audio output. Supported formats are mp3, opus, aac, and flac.
* Defaults to mp3.
*/
@JsonProperty("response_format")
private AudioResponseFormat responseFormat;

/**
* The speed of the voice synthesis. The acceptable range is from 0.0 (slowest) to 1.0
* (fastest).
*/
@JsonProperty("speed")
private Float speed;

public static Builder builder() {
return new Builder();
}

public static class Builder {

private final OpenAiAudioSpeechOptions options = new OpenAiAudioSpeechOptions();

public Builder withModel(String model) {
options.model = model;
return this;
}

public Builder withInput(String input) {
options.input = input;
return this;
}

public Builder withVoice(Voice voice) {
options.voice = voice;
return this;
}

public Builder withResponseFormat(AudioResponseFormat responseFormat) {
options.responseFormat = responseFormat;
return this;
}

public Builder withSpeed(Float speed) {
options.speed = speed;
return this;
}

public OpenAiAudioSpeechOptions build() {
return options;
}

}

public String getModel() {
return model;
}

public String getInput() {
return input;
}

public Voice getVoice() {
return voice;
}

public AudioResponseFormat getResponseFormat() {
return responseFormat;
}

public Float getSpeed() {
return speed;
}

@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((model == null) ? 0 : model.hashCode());
result = prime * result + ((input == null) ? 0 : input.hashCode());
result = prime * result + ((voice == null) ? 0 : voice.hashCode());
result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode());
result = prime * result + ((speed == null) ? 0 : speed.hashCode());
return result;
}

public void setModel(String model) {
this.model = model;
}

public void setInput(String input) {
this.input = input;
}

public void setVoice(Voice voice) {
this.voice = voice;
}

public void setResponseFormat(AudioResponseFormat responseFormat) {
this.responseFormat = responseFormat;
}

public void setSpeed(Float speed) {
this.speed = speed;
}

@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
OpenAiAudioSpeechOptions other = (OpenAiAudioSpeechOptions) obj;
if (model == null) {
if (other.model != null)
return false;
}
else if (!model.equals(other.model))
return false;
if (input == null) {
if (other.input != null)
return false;
}
else if (!input.equals(other.input))
return false;
if (voice == null) {
if (other.voice != null)
return false;
}
else if (!voice.equals(other.voice))
return false;
if (responseFormat == null) {
if (other.responseFormat != null)
return false;
}
else if (!responseFormat.equals(other.responseFormat))
return false;
if (speed == null) {
return other.speed == null;
}
else
return speed.equals(other.speed);
}

@Override
public String toString() {
return "OpenAiAudioSpeechOptions{" + "model='" + model + '\'' + ", input='" + input + '\'' + ", voice='" + voice
+ '\'' + ", responseFormat='" + responseFormat + '\'' + ", speed=" + speed + '}';
}

}
Loading