Skip to content

Commit 6393b21

Browse files
committed
Added support for OpenAI Text to Audio (Speech API )
1 parent c036931 commit 6393b21

File tree

22 files changed

+1496
-2
lines changed

22 files changed

+1496
-2
lines changed
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai;
18+
19+
import org.apache.commons.lang3.StringUtils;
20+
import org.slf4j.Logger;
21+
import org.slf4j.LoggerFactory;
22+
import org.springframework.ai.chat.metadata.RateLimit;
23+
import org.springframework.ai.openai.api.OpenAiAudioApi;
24+
import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.AudioResponseFormat;
25+
import org.springframework.ai.openai.api.common.OpenAiApiException;
26+
import org.springframework.ai.openai.audio.speech.*;
27+
import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechResponseMetadata;
28+
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
29+
import org.springframework.http.ResponseEntity;
30+
import org.springframework.retry.support.RetryTemplate;
31+
import org.springframework.util.Assert;
32+
import reactor.core.publisher.Flux;
33+
34+
import java.time.Duration;
35+
36+
/**
37+
* OpenAI audio speech client implementation for backed by {@link OpenAiAudioApi}.
38+
*
39+
* @author Ahmed Yousri
40+
* @see OpenAiAudioApi
41+
*/
42+
public class OpenAiAudioSpeechClient implements SpeechClient, StreamingSpeechClient {
43+
44+
private final Logger logger = LoggerFactory.getLogger(getClass());
45+
46+
private final OpenAiAudioSpeechOptions defaultOptions;
47+
48+
private static final Float SPEED = 1.0f;
49+
50+
public final RetryTemplate retryTemplate = RetryTemplate.builder()
51+
.maxAttempts(10)
52+
.retryOn(OpenAiApiException.class)
53+
.exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000))
54+
.build();
55+
56+
private final OpenAiAudioApi audioApi;
57+
58+
public OpenAiAudioSpeechClient(OpenAiAudioApi audioApi) {
59+
this(audioApi,
60+
OpenAiAudioSpeechOptions.builder()
61+
.withModel(OpenAiAudioApi.TtsModel.TTS_1.getValue())
62+
.withResponseFormat(AudioResponseFormat.MP3)
63+
.withVoice(OpenAiAudioApi.SpeechRequest.Voice.ALLOY)
64+
.withSpeed(SPEED)
65+
.build());
66+
}
67+
68+
public OpenAiAudioSpeechClient(OpenAiAudioApi audioApi, OpenAiAudioSpeechOptions options) {
69+
Assert.notNull(audioApi, "OpenAiAudioApi must not be null");
70+
Assert.notNull(options, "OpenAiSpeechOptions must not be null");
71+
this.audioApi = audioApi;
72+
this.defaultOptions = options;
73+
}
74+
75+
@Override
76+
public byte[] call(String text) {
77+
SpeechPrompt speechRequest = new SpeechPrompt(text);
78+
return call(speechRequest).getResult().getOutput();
79+
}
80+
81+
@Override
82+
public SpeechResponse call(SpeechPrompt speechPrompt) {
83+
84+
return this.retryTemplate.execute(ctx -> {
85+
86+
OpenAiAudioApi.SpeechRequest speechRequest = createRequestBody(speechPrompt);
87+
88+
ResponseEntity<byte[]> speechEntity = this.audioApi.createSpeech(speechRequest);
89+
var speech = speechEntity.getBody();
90+
91+
if (speech == null) {
92+
logger.warn("No speech response returned for speechRequest: {}", speechRequest);
93+
return new SpeechResponse(new Speech(new byte[0]));
94+
}
95+
96+
RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(speechEntity);
97+
98+
return new SpeechResponse(new Speech(speech), new OpenAiAudioSpeechResponseMetadata(rateLimits));
99+
100+
});
101+
}
102+
103+
/**
104+
* Streams the audio response for the given speech prompt.
105+
* @param prompt The speech prompt containing the text and options for speech
106+
* synthesis.
107+
* @return A Flux of SpeechResponse objects containing the streamed audio and
108+
* metadata.
109+
*/
110+
111+
@Override
112+
public Flux<SpeechResponse> stream(SpeechPrompt prompt) {
113+
return this.audioApi.stream(this.createRequestBody(prompt))
114+
.map(entity -> new SpeechResponse(new Speech(entity.getBody()), new OpenAiAudioSpeechResponseMetadata(
115+
OpenAiResponseHeaderExtractor.extractAiResponseHeaders(entity))));
116+
}
117+
118+
private OpenAiAudioApi.SpeechRequest createRequestBody(SpeechPrompt request) {
119+
OpenAiAudioSpeechOptions options = this.defaultOptions;
120+
121+
if (request.getOptions() != null) {
122+
if (request.getOptions() instanceof OpenAiAudioSpeechOptions runtimeOptions) {
123+
options = this.merge(options, runtimeOptions);
124+
}
125+
else {
126+
throw new IllegalArgumentException("Prompt options are not of type SpeechOptions: "
127+
+ request.getOptions().getClass().getSimpleName());
128+
}
129+
}
130+
131+
String input = StringUtils.isNotBlank(options.getInput()) ? options.getInput()
132+
: request.getInstructions().get(0).getText();
133+
134+
OpenAiAudioApi.SpeechRequest.Builder requestBuilder = OpenAiAudioApi.SpeechRequest.builder()
135+
.withModel(options.getModel())
136+
.withInput(input)
137+
.withVoice(options.getVoice())
138+
.withResponseFormat(options.getResponseFormat())
139+
.withSpeed(options.getSpeed());
140+
141+
return requestBuilder.build();
142+
}
143+
144+
private OpenAiAudioSpeechOptions merge(OpenAiAudioSpeechOptions source, OpenAiAudioSpeechOptions target) {
145+
OpenAiAudioSpeechOptions.Builder mergedBuilder = OpenAiAudioSpeechOptions.builder();
146+
147+
mergedBuilder.withModel(source.getModel() != null ? source.getModel() : target.getModel());
148+
mergedBuilder.withInput(source.getInput() != null ? source.getInput() : target.getInput());
149+
mergedBuilder.withVoice(source.getVoice() != null ? source.getVoice() : target.getVoice());
150+
mergedBuilder.withResponseFormat(
151+
source.getResponseFormat() != null ? source.getResponseFormat() : target.getResponseFormat());
152+
mergedBuilder.withSpeed(source.getSpeed() != null ? source.getSpeed() : target.getSpeed());
153+
154+
return mergedBuilder.build();
155+
}
156+
157+
}
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai;
18+
19+
import com.fasterxml.jackson.annotation.JsonInclude;
20+
import com.fasterxml.jackson.annotation.JsonProperty;
21+
import org.springframework.ai.model.ModelOptions;
22+
import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.AudioResponseFormat;
23+
import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.Voice;
24+
25+
/**
26+
* Options for OpenAI text to audio - speech synthesis.
27+
*
28+
* @author Ahmed Yousri
29+
*/
30+
@JsonInclude(JsonInclude.Include.NON_NULL)
31+
public class OpenAiAudioSpeechOptions implements ModelOptions {
32+
33+
/**
34+
* ID of the model to use for generating the audio. One of the available TTS models:
35+
* tts-1 or tts-1-hd.
36+
*/
37+
@JsonProperty("model")
38+
private String model;
39+
40+
/**
41+
* The input text to synthesize. Must be at most 4096 tokens long.
42+
*/
43+
@JsonProperty("input")
44+
private String input;
45+
46+
/**
47+
* The voice to use for synthesis. One of the available voices for the chosen model:
48+
* 'alloy', 'echo', 'fable', 'onyx', 'nova', and 'shimmer'.
49+
*/
50+
@JsonProperty("voice")
51+
private Voice voice;
52+
53+
/**
54+
* The format of the audio output. Supported formats are mp3, opus, aac, and flac.
55+
* Defaults to mp3.
56+
*/
57+
@JsonProperty("response_format")
58+
private AudioResponseFormat responseFormat;
59+
60+
/**
61+
* The speed of the voice synthesis. The acceptable range is from 0.0 (slowest) to 1.0
62+
* (fastest).
63+
*/
64+
@JsonProperty("speed")
65+
private Float speed;
66+
67+
public static Builder builder() {
68+
return new Builder();
69+
}
70+
71+
public static class Builder {
72+
73+
private final OpenAiAudioSpeechOptions options = new OpenAiAudioSpeechOptions();
74+
75+
public Builder withModel(String model) {
76+
options.model = model;
77+
return this;
78+
}
79+
80+
public Builder withInput(String input) {
81+
options.input = input;
82+
return this;
83+
}
84+
85+
public Builder withVoice(Voice voice) {
86+
options.voice = voice;
87+
return this;
88+
}
89+
90+
public Builder withResponseFormat(AudioResponseFormat responseFormat) {
91+
options.responseFormat = responseFormat;
92+
return this;
93+
}
94+
95+
public Builder withSpeed(Float speed) {
96+
options.speed = speed;
97+
return this;
98+
}
99+
100+
public OpenAiAudioSpeechOptions build() {
101+
return options;
102+
}
103+
104+
}
105+
106+
public String getModel() {
107+
return model;
108+
}
109+
110+
public String getInput() {
111+
return input;
112+
}
113+
114+
public Voice getVoice() {
115+
return voice;
116+
}
117+
118+
public AudioResponseFormat getResponseFormat() {
119+
return responseFormat;
120+
}
121+
122+
public Float getSpeed() {
123+
return speed;
124+
}
125+
126+
@Override
127+
public int hashCode() {
128+
final int prime = 31;
129+
int result = 1;
130+
result = prime * result + ((model == null) ? 0 : model.hashCode());
131+
result = prime * result + ((input == null) ? 0 : input.hashCode());
132+
result = prime * result + ((voice == null) ? 0 : voice.hashCode());
133+
result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode());
134+
result = prime * result + ((speed == null) ? 0 : speed.hashCode());
135+
return result;
136+
}
137+
138+
public void setModel(String model) {
139+
this.model = model;
140+
}
141+
142+
public void setInput(String input) {
143+
this.input = input;
144+
}
145+
146+
public void setVoice(Voice voice) {
147+
this.voice = voice;
148+
}
149+
150+
public void setResponseFormat(AudioResponseFormat responseFormat) {
151+
this.responseFormat = responseFormat;
152+
}
153+
154+
public void setSpeed(Float speed) {
155+
this.speed = speed;
156+
}
157+
158+
@Override
159+
public boolean equals(Object obj) {
160+
if (this == obj)
161+
return true;
162+
if (obj == null)
163+
return false;
164+
if (getClass() != obj.getClass())
165+
return false;
166+
OpenAiAudioSpeechOptions other = (OpenAiAudioSpeechOptions) obj;
167+
if (model == null) {
168+
if (other.model != null)
169+
return false;
170+
}
171+
else if (!model.equals(other.model))
172+
return false;
173+
if (input == null) {
174+
if (other.input != null)
175+
return false;
176+
}
177+
else if (!input.equals(other.input))
178+
return false;
179+
if (voice == null) {
180+
if (other.voice != null)
181+
return false;
182+
}
183+
else if (!voice.equals(other.voice))
184+
return false;
185+
if (responseFormat == null) {
186+
if (other.responseFormat != null)
187+
return false;
188+
}
189+
else if (!responseFormat.equals(other.responseFormat))
190+
return false;
191+
if (speed == null) {
192+
return other.speed == null;
193+
}
194+
else
195+
return speed.equals(other.speed);
196+
}
197+
198+
@Override
199+
public String toString() {
200+
return "OpenAiAudioSpeechOptions{" + "model='" + model + '\'' + ", input='" + input + '\'' + ", voice='" + voice
201+
+ '\'' + ", responseFormat='" + responseFormat + '\'' + ", speed=" + speed + '}';
202+
}
203+
204+
}

0 commit comments

Comments
 (0)