Skip to content

Commit 2ce21fb

Browse files
committed
Added support for OpenAI Text to Audio (Speech API)
1 parent ecc81b5 commit 2ce21fb

File tree

20 files changed

+1122
-3
lines changed

20 files changed

+1122
-3
lines changed
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Copyright 2023-2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.openai;
17+
18+
import org.slf4j.Logger;
19+
import org.slf4j.LoggerFactory;
20+
import org.springframework.ai.chat.metadata.RateLimit;
21+
import org.springframework.ai.model.ModelOptionsUtils;
22+
import org.springframework.ai.openai.api.OpenAiApi;
23+
import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException;
24+
import org.springframework.ai.openai.metadata.OpenAiSpeechResponseMetadata;
25+
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
26+
import org.springframework.ai.speech.*;
27+
import org.springframework.http.ResponseEntity;
28+
import org.springframework.retry.support.RetryTemplate;
29+
import org.springframework.util.Assert;
30+
31+
import java.time.Duration;
32+
import java.util.Objects;
33+
34+
/**
35+
* {@link SpeechClient} implementation for {@literal OpenAI} backed by {@link OpenAiApi}.
36+
*
37+
* @author Ahmed Yousri
38+
* @see SpeechClient
39+
* @see OpenAiApi
40+
*/
41+
public class OpenAiSpeechClient implements SpeechClient {
42+
43+
private final Logger logger = LoggerFactory.getLogger(getClass());
44+
45+
private OpenAiSpeechOptions defaultOptions = OpenAiSpeechOptions.builder()
46+
.withModel("tts-1")
47+
.withResponseFormat("mp3")
48+
.withSpeed(1.0f)
49+
.withVoice("alloy")
50+
.build();
51+
52+
public final RetryTemplate retryTemplate = RetryTemplate.builder()
53+
.maxAttempts(10)
54+
.retryOn(OpenAiApiException.class)
55+
.exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000))
56+
.build();
57+
58+
private final OpenAiApi openAiApi;
59+
60+
public OpenAiSpeechClient(OpenAiApi openAiApi) {
61+
Assert.notNull(openAiApi, "OpenAiApi must not be null");
62+
this.openAiApi = openAiApi;
63+
}
64+
65+
public OpenAiSpeechClient withDefaultOptions(OpenAiSpeechOptions options) {
66+
this.defaultOptions = options;
67+
return this;
68+
}
69+
70+
@Override
71+
public SpeechResponse call(SpeechRequest request) {
72+
73+
return this.retryTemplate.execute(ctx -> {
74+
75+
OpenAiApi.SpeechRequest requestBody = createRequestBody(request);
76+
ResponseEntity<OpenAiApi.SpeechResponse> SpeechEntity = this.openAiApi.textToSpeechEntityJson(requestBody);
77+
var speech = SpeechEntity.getBody();
78+
79+
if (speech == null) {
80+
logger.warn("No speech response returned for speechRequest: {}", request);
81+
}
82+
83+
RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(SpeechEntity);
84+
85+
return new SpeechResponse(convertResponse(speech), new OpenAiSpeechResponseMetadata(rateLimits));
86+
87+
});
88+
}
89+
90+
private OpenAiApi.SpeechRequest createRequestBody(SpeechRequest speechRequest) {
91+
92+
OpenAiApi.SpeechRequest request = new OpenAiApi.SpeechRequest();
93+
94+
if (this.defaultOptions != null) {
95+
request = ModelOptionsUtils.merge(request, this.defaultOptions, OpenAiApi.SpeechRequest.class);
96+
}
97+
98+
if (speechRequest.getOptions() != null) {
99+
if (speechRequest.getOptions() instanceof SpeechOptions runtimeOptions) {
100+
OpenAiSpeechOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
101+
SpeechOptions.class, OpenAiSpeechOptions.class);
102+
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, OpenAiApi.SpeechRequest.class);
103+
}
104+
else {
105+
throw new IllegalArgumentException("Prompt options are not of type SpeechOptions: "
106+
+ speechRequest.getOptions().getClass().getSimpleName());
107+
}
108+
}
109+
110+
return request;
111+
}
112+
113+
private Speech convertResponse(OpenAiApi.SpeechResponse speechResponse) {
114+
return new Speech(Objects.requireNonNull(speechResponse).audio());
115+
}
116+
117+
}
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai;
18+
19+
import com.fasterxml.jackson.annotation.JsonInclude;
20+
import com.fasterxml.jackson.annotation.JsonInclude.Include;
21+
import com.fasterxml.jackson.annotation.JsonProperty;
22+
import org.springframework.ai.speech.SpeechOptions;
23+
24+
/**
25+
* @author Ahmed Yousri
26+
*/
27+
@JsonInclude(Include.NON_NULL)
28+
public class OpenAiSpeechOptions implements SpeechOptions {
29+
30+
// @formatter:off
31+
/**
32+
*
33+
* One of the available TTS models
34+
*/
35+
private @JsonProperty("model") String model;
36+
/**
37+
*
38+
* The text to generate audio for.
39+
*/
40+
private @JsonProperty("input") String input;
41+
42+
/**
43+
* The voice to use when generating the audio.
44+
*/
45+
private @JsonProperty("voice") String voice;
46+
47+
/**
48+
* The format to audio in.
49+
*/
50+
private @JsonProperty("response_format") String responseFormat;
51+
52+
/**
53+
* The speed of the generated audi.
54+
*/
55+
private @JsonProperty("speed") Float speed;
56+
57+
58+
public static Builder builder() {
59+
return new Builder();
60+
}
61+
62+
public static class Builder {
63+
64+
protected OpenAiSpeechOptions options;
65+
66+
public Builder() {
67+
this.options = new OpenAiSpeechOptions();
68+
}
69+
70+
public Builder(OpenAiSpeechOptions options) {
71+
this.options = options;
72+
}
73+
74+
public Builder withModel(String model) {
75+
options.model = model;
76+
return this;
77+
}
78+
79+
public Builder withInput(String input) {
80+
options.input = input;
81+
return this;
82+
}
83+
84+
public Builder withVoice(String voice) {
85+
options.voice = voice;
86+
return this;
87+
}
88+
89+
public Builder withResponseFormat(String responseFormat) {
90+
options.responseFormat = responseFormat;
91+
return this;
92+
}
93+
94+
public Builder withSpeed(Float speed) {
95+
options.speed = speed;
96+
return this;
97+
}
98+
99+
public OpenAiSpeechOptions build() {
100+
return this.options;
101+
}
102+
103+
}
104+
105+
public String getModel() {
106+
return model;
107+
}
108+
109+
public void setModel(String model) {
110+
this.model = model;
111+
}
112+
113+
public String getInput() {
114+
return input;
115+
}
116+
117+
public void setInput(String input) {
118+
this.input = input;
119+
}
120+
121+
public String getVoice() {
122+
return voice;
123+
}
124+
125+
public void setVoice(String voice) {
126+
this.voice = voice;
127+
}
128+
129+
public String getResponseFormat() {
130+
return responseFormat;
131+
}
132+
133+
public void setResponseFormat(String responseFormat) {
134+
this.responseFormat = responseFormat;
135+
}
136+
137+
public Float getSpeed() {
138+
return speed;
139+
}
140+
141+
public void setSpeed(Float speed) {
142+
this.speed = speed;
143+
}
144+
145+
@Override
146+
public int hashCode() {
147+
final int prime = 31;
148+
int result = 1;
149+
result = prime * result + ((model == null) ? 0 : model.hashCode());
150+
result = prime * result + ((input == null) ? 0 : input.hashCode());
151+
result = prime * result + ((voice == null) ? 0 : voice.hashCode());
152+
result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode());
153+
result = prime * result + ((speed == null) ? 0 : speed.hashCode());
154+
return result;
155+
}
156+
157+
@Override
158+
public boolean equals(Object obj) {
159+
if (this == obj)
160+
return true;
161+
if (obj == null)
162+
return false;
163+
if (getClass() != obj.getClass())
164+
return false;
165+
OpenAiSpeechOptions other = (OpenAiSpeechOptions) obj;
166+
if (model == null) {
167+
if (other.model != null)
168+
return false;
169+
} else if (!model.equals(other.model))
170+
return false;
171+
if (input == null) {
172+
if (other.input != null)
173+
return false;
174+
} else if (!input.equals(other.input))
175+
return false;
176+
if (voice == null) {
177+
if (other.voice != null)
178+
return false;
179+
} else if (!voice.equals(other.voice))
180+
return false;
181+
if (responseFormat == null) {
182+
if (other.responseFormat != null)
183+
return false;
184+
} else if (!responseFormat.equals(other.responseFormat))
185+
return false;
186+
if (speed == null) {
187+
if (other.speed != null)
188+
return false;
189+
} else if (!speed.equals(other.speed))
190+
return false;
191+
return true;
192+
}
193+
194+
}

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,58 @@ public record ResponseFormat(
389389
}
390390
}
391391

392+
393+
394+
@JsonInclude(JsonInclude.Include.NON_NULL)
395+
public record SpeechRequest(
396+
@JsonProperty("model") String model,
397+
@JsonProperty("input") String input,
398+
@JsonProperty("voice") String voice,
399+
@JsonProperty("response_format") String responseFormat,
400+
@JsonProperty("speed") Double speed) {
401+
402+
403+
public SpeechRequest(String model, String input, String voice) {
404+
this(model, input, voice, "mp3", 1.0); // Defaults to "mp3" format and "1.0" speed
405+
}
406+
407+
public SpeechRequest() {
408+
this(null, null, null, "mp3", 1.0);
409+
}
410+
411+
public SpeechRequest(String model, String input, String voice, String responseFormat, Double speed) {
412+
this.model = model;
413+
this.input = input;
414+
this.voice = voice;
415+
this.responseFormat = responseFormat;
416+
this.speed = speed;
417+
}
418+
}
419+
420+
@JsonInclude(Include.NON_NULL)
421+
public record SpeechResponse(
422+
@JsonProperty("audio") byte[] audio) {
423+
}
424+
/**
425+
* Creates a model response for the given text-to-speech request.
426+
*
427+
* @param speechRequest The text-to-speech request.
428+
* @return Entity response with the generated speech as a body and HTTP status code and headers.
429+
*/
430+
public ResponseEntity<SpeechResponse> textToSpeechEntityJson(OpenAiApi.SpeechRequest speechRequest) {
431+
Assert.notNull(speechRequest, "The request body cannot be null.");
432+
433+
var responseEntity = this.restClient.post()
434+
.uri("/v1/audio/speech")
435+
.body(speechRequest)
436+
.accept(MediaType.APPLICATION_OCTET_STREAM)
437+
.retrieve()
438+
.toEntity(byte[].class);
439+
HttpHeaders headers = new HttpHeaders();
440+
headers.addAll(responseEntity.getHeaders());
441+
SpeechResponse speechResponse = new SpeechResponse(responseEntity.getBody());
442+
return new ResponseEntity<>(speechResponse, headers, responseEntity.getStatusCode());
443+
}
392444
/**
393445
* Message comprising the conversation.
394446
*
@@ -494,6 +546,7 @@ public enum ChatCompletionFinishReason {
494546
@JsonProperty("function_call") FUNCTION_CALL
495547
}
496548

549+
497550
/**
498551
* Represents a chat completion response returned by model, based on the provided input.
499552
*

0 commit comments

Comments
 (0)