Skip to content

Commit 848100d

Browse files
committed
Added support for OpenAI Text to Audio (Speech API)
1 parent 5552a11 commit 848100d

File tree

23 files changed

+1297
-4
lines changed

23 files changed

+1297
-4
lines changed
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/*
2+
* Copyright 2023-2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.openai;
17+
18+
import org.slf4j.Logger;
19+
import org.slf4j.LoggerFactory;
20+
import org.springframework.ai.chat.metadata.RateLimit;
21+
import org.springframework.ai.model.ModelOptionsUtils;
22+
import org.springframework.ai.openai.api.OpenAiApi;
23+
import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException;
24+
import org.springframework.ai.openai.metadata.OpenAiSpeechResponseMetadata;
25+
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
26+
import org.springframework.ai.speech.*;
27+
import org.springframework.http.ResponseEntity;
28+
import org.springframework.retry.support.RetryTemplate;
29+
import org.springframework.util.Assert;
30+
import reactor.core.publisher.Flux;
31+
32+
import java.time.Duration;
33+
34+
/**
35+
* {@link SpeechClient} implementation for {@literal OpenAI} backed by {@link OpenAiApi}.
36+
*
37+
* @author Ahmed Yousri
38+
* @see SpeechClient
39+
* @see OpenAiApi
40+
*/
41+
public class OpenAiSpeechClient implements SpeechClient, StreamingSpeechClient {
42+
43+
private final Logger logger = LoggerFactory.getLogger(getClass());
44+
45+
private OpenAiSpeechOptions defaultOptions = OpenAiSpeechOptions.builder()
46+
.withModel("tts-1")
47+
.withResponseFormat("mp3")
48+
.withSpeed(1.0f)
49+
.withVoice("alloy")
50+
.build();
51+
52+
public final RetryTemplate retryTemplate = RetryTemplate.builder()
53+
.maxAttempts(10)
54+
.retryOn(OpenAiApiException.class)
55+
.exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000))
56+
.build();
57+
58+
private final OpenAiApi openAiApi;
59+
60+
public OpenAiSpeechClient(OpenAiApi openAiApi) {
61+
Assert.notNull(openAiApi, "OpenAiApi must not be null");
62+
this.openAiApi = openAiApi;
63+
}
64+
65+
public OpenAiSpeechClient withDefaultOptions(OpenAiSpeechOptions options) {
66+
this.defaultOptions = options;
67+
return this;
68+
}
69+
70+
@Override
71+
public SpeechResponse call(SpeechPrompt speechPrompt) {
72+
73+
return this.retryTemplate.execute(ctx -> {
74+
75+
OpenAiApi.SpeechRequest speechRequest = createRequest(speechPrompt);
76+
77+
ResponseEntity<OpenAiApi.SpeechResponse> SpeechEntity = this.openAiApi
78+
.textToSpeechEntityJson(speechRequest);
79+
var speech = SpeechEntity.getBody();
80+
81+
if (speech == null) {
82+
logger.warn("No speech response returned for speechRequest: {}", speechRequest);
83+
return new SpeechResponse(convertResponse(OpenAiApi.SpeechResponse.NULL));
84+
}
85+
86+
RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(SpeechEntity);
87+
88+
return new SpeechResponse(convertResponse(speech), new OpenAiSpeechResponseMetadata(rateLimits));
89+
90+
});
91+
}
92+
93+
private OpenAiApi.SpeechRequest createRequest(SpeechPrompt speechPrompt) {
94+
String instructions = speechPrompt.getInstructions().get(0).getText();
95+
96+
OpenAiApi.SpeechRequest speechRequest = new OpenAiApi.SpeechRequest(instructions);
97+
98+
if (this.defaultOptions != null) {
99+
speechRequest = ModelOptionsUtils.merge(this.defaultOptions, speechRequest, OpenAiApi.SpeechRequest.class);
100+
}
101+
102+
if (speechPrompt.getOptions() != null) {
103+
speechRequest = ModelOptionsUtils.merge(toOpenAiSpeechOptions(speechPrompt.getOptions()), speechRequest,
104+
OpenAiApi.SpeechRequest.class);
105+
}
106+
return speechRequest;
107+
}
108+
109+
private Speech convertResponse(OpenAiApi.SpeechResponse speechResponse) {
110+
return new Speech(speechResponse.audio());
111+
}
112+
113+
private OpenAiSpeechOptions toOpenAiSpeechOptions(SpeechOptions runtimeSpeechOptions) {
114+
OpenAiSpeechOptions.Builder openAiSpeechOptionBuilder = OpenAiSpeechOptions.builder();
115+
if (runtimeSpeechOptions != null) {
116+
// Handle portable speech options
117+
if (runtimeSpeechOptions.getModel() != null) {
118+
openAiSpeechOptionBuilder.withModel(runtimeSpeechOptions.getModel());
119+
}
120+
if (runtimeSpeechOptions.getResponseFormat() != null) {
121+
openAiSpeechOptionBuilder.withResponseFormat(runtimeSpeechOptions.getResponseFormat());
122+
}
123+
if (runtimeSpeechOptions.getSpeed() != null) {
124+
openAiSpeechOptionBuilder.withSpeed(runtimeSpeechOptions.getSpeed());
125+
}
126+
if (runtimeSpeechOptions.getVoice() != null) {
127+
openAiSpeechOptionBuilder.withVoice(runtimeSpeechOptions.getVoice());
128+
}
129+
// Handle OpenAI specific speech options
130+
if (runtimeSpeechOptions instanceof OpenAiSpeechOptions) {
131+
OpenAiSpeechOptions runtimeOpenAiSpeechOptions = (OpenAiSpeechOptions) runtimeSpeechOptions;
132+
if (runtimeOpenAiSpeechOptions.getModel() != null) {
133+
openAiSpeechOptionBuilder.withModel(runtimeOpenAiSpeechOptions.getModel());
134+
}
135+
if (runtimeOpenAiSpeechOptions.getSpeed() != null) {
136+
openAiSpeechOptionBuilder.withSpeed(runtimeOpenAiSpeechOptions.getSpeed());
137+
}
138+
if (runtimeOpenAiSpeechOptions.getVoice() != null) {
139+
openAiSpeechOptionBuilder.withVoice(runtimeOpenAiSpeechOptions.getVoice());
140+
}
141+
if (runtimeOpenAiSpeechOptions.getResponseFormat() != null) {
142+
openAiSpeechOptionBuilder.withResponseFormat(runtimeOpenAiSpeechOptions.getResponseFormat());
143+
}
144+
}
145+
}
146+
return openAiSpeechOptionBuilder.build();
147+
}
148+
149+
@Override
150+
public Flux<SpeechResponse> stream(SpeechPrompt prompt) {
151+
return this.openAiApi.textToSpeechStreaming(this.createRequest(prompt))
152+
.map(entity -> new SpeechResponse(new Speech(entity.getBody()),
153+
new OpenAiSpeechResponseMetadata(OpenAiResponseHeaderExtractor.extractAiResponseHeaders(entity))));
154+
}
155+
156+
}
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai;
18+
19+
import com.fasterxml.jackson.annotation.JsonInclude;
20+
import com.fasterxml.jackson.annotation.JsonInclude.Include;
21+
import com.fasterxml.jackson.annotation.JsonProperty;
22+
import org.springframework.ai.speech.SpeechOptions;
23+
24+
/**
25+
* @author Ahmed Yousri
26+
*/
27+
@JsonInclude(Include.NON_NULL)
28+
public class OpenAiSpeechOptions implements SpeechOptions {
29+
30+
// @formatter:off
31+
/**
32+
*
33+
* One of the available TTS models
34+
*/
35+
private @JsonProperty("model") String model;
36+
37+
/**
38+
* The voice to use when generating the audio.
39+
*/
40+
private @JsonProperty("voice") String voice;
41+
42+
/**
43+
* The format to audio in.
44+
*/
45+
private @JsonProperty("response_format") String responseFormat;
46+
47+
/**
48+
* The speed of the generated audi.
49+
*/
50+
private @JsonProperty("speed") Float speed;
51+
52+
53+
public static Builder builder() {
54+
return new Builder();
55+
}
56+
57+
public static class Builder {
58+
59+
protected OpenAiSpeechOptions options;
60+
61+
public Builder() {
62+
this.options = new OpenAiSpeechOptions();
63+
}
64+
65+
public Builder(OpenAiSpeechOptions options) {
66+
this.options = options;
67+
}
68+
69+
public Builder withModel(String model) {
70+
options.model = model;
71+
return this;
72+
}
73+
74+
public Builder withVoice(String voice) {
75+
options.voice = voice;
76+
return this;
77+
}
78+
79+
public Builder withResponseFormat(String responseFormat) {
80+
options.responseFormat = responseFormat;
81+
return this;
82+
}
83+
84+
public Builder withSpeed(Float speed) {
85+
options.speed = speed;
86+
return this;
87+
}
88+
89+
public OpenAiSpeechOptions build() {
90+
return this.options;
91+
}
92+
93+
}
94+
95+
public String getModel() {
96+
return model;
97+
}
98+
99+
public void setModel(String model) {
100+
this.model = model;
101+
}
102+
103+
public String getVoice() {
104+
return voice;
105+
}
106+
107+
public void setVoice(String voice) {
108+
this.voice = voice;
109+
}
110+
111+
public String getResponseFormat() {
112+
return responseFormat;
113+
}
114+
115+
public void setResponseFormat(String responseFormat) {
116+
this.responseFormat = responseFormat;
117+
}
118+
119+
public Float getSpeed() {
120+
return speed;
121+
}
122+
123+
124+
public void setSpeed(Float speed) {
125+
this.speed = speed;
126+
}
127+
128+
@Override
129+
public int hashCode() {
130+
final int prime = 31;
131+
int result = 1;
132+
result = prime * result + ((model == null) ? 0 : model.hashCode());
133+
result = prime * result + ((voice == null) ? 0 : voice.hashCode());
134+
result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode());
135+
result = prime * result + ((speed == null) ? 0 : speed.hashCode());
136+
return result;
137+
}
138+
139+
@Override
140+
public boolean equals(Object obj) {
141+
if (this == obj)
142+
return true;
143+
if (obj == null)
144+
return false;
145+
if (getClass() != obj.getClass())
146+
return false;
147+
OpenAiSpeechOptions other = (OpenAiSpeechOptions) obj;
148+
if (model == null) {
149+
if (other.model != null)
150+
return false;
151+
} else if (!model.equals(other.model))
152+
return false;
153+
if (voice == null) {
154+
if (other.voice != null)
155+
return false;
156+
} else if (!voice.equals(other.voice))
157+
return false;
158+
if (responseFormat == null) {
159+
if (other.responseFormat != null)
160+
return false;
161+
} else if (!responseFormat.equals(other.responseFormat))
162+
return false;
163+
if (speed == null) {
164+
if (other.speed != null)
165+
return false;
166+
} else if (!speed.equals(other.speed))
167+
return false;
168+
return true;
169+
}
170+
171+
}

0 commit comments

Comments
 (0)