Skip to content

Commit e60967b

Browse files
hemeda3Mohammed, Ahmed yousri salama (Canada)
authored andcommitted
Added support for OpenAI Text to Audio (Speech API)
1 parent 43dbaf4 commit e60967b

File tree

21 files changed

+1152
-3
lines changed

21 files changed

+1152
-3
lines changed
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
* Copyright 2023-2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.openai;
17+
18+
import org.slf4j.Logger;
19+
import org.slf4j.LoggerFactory;
20+
import org.springframework.ai.chat.metadata.RateLimit;
21+
import org.springframework.ai.model.ModelOptionsUtils;
22+
import org.springframework.ai.openai.api.OpenAiApi;
23+
import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException;
24+
import org.springframework.ai.openai.metadata.OpenAiSpeechResponseMetadata;
25+
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
26+
import org.springframework.ai.speech.*;
27+
import org.springframework.http.ResponseEntity;
28+
import org.springframework.retry.support.RetryTemplate;
29+
import org.springframework.util.Assert;
30+
31+
import java.time.Duration;
32+
import java.util.Objects;
33+
34+
/**
35+
* {@link SpeechClient} implementation for {@literal OpenAI} backed by {@link OpenAiApi}.
36+
*
37+
* @author Ahmed Yousri
38+
* @see SpeechClient
39+
* @see OpenAiApi
40+
*/
41+
public class OpenAiSpeechClient implements SpeechClient {
42+
43+
private final Logger logger = LoggerFactory.getLogger(getClass());
44+
45+
private OpenAiSpeechOptions defaultOptions = OpenAiSpeechOptions.builder()
46+
.withModel("tts-1")
47+
.withResponseFormat("mp3")
48+
.withSpeed(1.0f)
49+
.withVoice("alloy")
50+
.build();
51+
52+
public final RetryTemplate retryTemplate = RetryTemplate.builder()
53+
.maxAttempts(10)
54+
.retryOn(OpenAiApiException.class)
55+
.exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000))
56+
.build();
57+
58+
private final OpenAiApi openAiApi;
59+
60+
public OpenAiSpeechClient(OpenAiApi openAiApi) {
61+
Assert.notNull(openAiApi, "OpenAiApi must not be null");
62+
this.openAiApi = openAiApi;
63+
}
64+
65+
public OpenAiSpeechClient withDefaultOptions(OpenAiSpeechOptions options) {
66+
this.defaultOptions = options;
67+
return this;
68+
}
69+
70+
@Override
71+
public SpeechResponse call(SpeechPrompt speechPrompt) {
72+
73+
return this.retryTemplate.execute(ctx -> {
74+
75+
String instructions = speechPrompt.getInstructions().get(0).getText();
76+
77+
OpenAiApi.SpeechRequest speechRequest = new OpenAiApi.SpeechRequest(instructions);
78+
79+
if (this.defaultOptions != null) {
80+
speechRequest = ModelOptionsUtils.merge(this.defaultOptions, speechRequest,
81+
OpenAiApi.SpeechRequest.class);
82+
}
83+
84+
if (speechPrompt.getOptions() != null) {
85+
speechRequest = ModelOptionsUtils.merge(toOpenAiSpeechOptions(speechPrompt.getOptions()), speechRequest,
86+
OpenAiApi.SpeechRequest.class);
87+
}
88+
89+
ResponseEntity<OpenAiApi.SpeechResponse> SpeechEntity = this.openAiApi
90+
.textToSpeechEntityJson(speechRequest);
91+
var speech = SpeechEntity.getBody();
92+
93+
if (speech == null) {
94+
logger.warn("No speech response returned for speechRequest: {}", speechRequest);
95+
return new SpeechResponse(convertResponse(OpenAiApi.SpeechResponse.NULL));
96+
}
97+
98+
RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(SpeechEntity);
99+
100+
return new SpeechResponse(convertResponse(speech), new OpenAiSpeechResponseMetadata(rateLimits));
101+
102+
});
103+
}
104+
105+
private Speech convertResponse(OpenAiApi.SpeechResponse speechResponse) {
106+
return new Speech(speechResponse.audio());
107+
}
108+
109+
private OpenAiSpeechOptions toOpenAiSpeechOptions(SpeechOptions runtimeSpeechOptions) {
110+
OpenAiSpeechOptions.Builder openAiSpeechOptionBuilder = OpenAiSpeechOptions.builder();
111+
if (runtimeSpeechOptions != null) {
112+
// Handle portable speech options
113+
if (runtimeSpeechOptions.getModel() != null) {
114+
openAiSpeechOptionBuilder.withModel(runtimeSpeechOptions.getModel());
115+
}
116+
if (runtimeSpeechOptions.getResponseFormat() != null) {
117+
openAiSpeechOptionBuilder.withResponseFormat(runtimeSpeechOptions.getResponseFormat());
118+
}
119+
if (runtimeSpeechOptions.getSpeed() != null) {
120+
openAiSpeechOptionBuilder.withSpeed(runtimeSpeechOptions.getSpeed());
121+
}
122+
if (runtimeSpeechOptions.getVoice() != null) {
123+
openAiSpeechOptionBuilder.withVoice(runtimeSpeechOptions.getVoice());
124+
}
125+
// Handle OpenAI specific speech options
126+
if (runtimeSpeechOptions instanceof OpenAiSpeechOptions) {
127+
OpenAiSpeechOptions runtimeOpenAiSpeechOptions = (OpenAiSpeechOptions) runtimeSpeechOptions;
128+
if (runtimeOpenAiSpeechOptions.getModel() != null) {
129+
openAiSpeechOptionBuilder.withModel(runtimeOpenAiSpeechOptions.getModel());
130+
}
131+
if (runtimeOpenAiSpeechOptions.getSpeed() != null) {
132+
openAiSpeechOptionBuilder.withSpeed(runtimeOpenAiSpeechOptions.getSpeed());
133+
}
134+
if (runtimeOpenAiSpeechOptions.getVoice() != null) {
135+
openAiSpeechOptionBuilder.withVoice(runtimeOpenAiSpeechOptions.getVoice());
136+
}
137+
if (runtimeOpenAiSpeechOptions.getResponseFormat() != null) {
138+
openAiSpeechOptionBuilder.withResponseFormat(runtimeOpenAiSpeechOptions.getResponseFormat());
139+
}
140+
}
141+
}
142+
return openAiSpeechOptionBuilder.build();
143+
}
144+
145+
}
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai;
18+
19+
import com.fasterxml.jackson.annotation.JsonInclude;
20+
import com.fasterxml.jackson.annotation.JsonInclude.Include;
21+
import com.fasterxml.jackson.annotation.JsonProperty;
22+
import org.springframework.ai.speech.SpeechOptions;
23+
24+
/**
25+
* @author Ahmed Yousri
26+
*/
27+
@JsonInclude(Include.NON_NULL)
28+
public class OpenAiSpeechOptions implements SpeechOptions {
29+
30+
// @formatter:off
31+
/**
32+
*
33+
* One of the available TTS models
34+
*/
35+
private @JsonProperty("model") String model;
36+
37+
/**
38+
* The voice to use when generating the audio.
39+
*/
40+
private @JsonProperty("voice") String voice;
41+
42+
/**
43+
* The format to audio in.
44+
*/
45+
private @JsonProperty("response_format") String responseFormat;
46+
47+
/**
48+
* The speed of the generated audi.
49+
*/
50+
private @JsonProperty("speed") Float speed;
51+
52+
53+
public static Builder builder() {
54+
return new Builder();
55+
}
56+
57+
public static class Builder {
58+
59+
protected OpenAiSpeechOptions options;
60+
61+
public Builder() {
62+
this.options = new OpenAiSpeechOptions();
63+
}
64+
65+
public Builder(OpenAiSpeechOptions options) {
66+
this.options = options;
67+
}
68+
69+
public Builder withModel(String model) {
70+
options.model = model;
71+
return this;
72+
}
73+
74+
public Builder withVoice(String voice) {
75+
options.voice = voice;
76+
return this;
77+
}
78+
79+
public Builder withResponseFormat(String responseFormat) {
80+
options.responseFormat = responseFormat;
81+
return this;
82+
}
83+
84+
public Builder withSpeed(Float speed) {
85+
options.speed = speed;
86+
return this;
87+
}
88+
89+
public OpenAiSpeechOptions build() {
90+
return this.options;
91+
}
92+
93+
}
94+
95+
public String getModel() {
96+
return model;
97+
}
98+
99+
public void setModel(String model) {
100+
this.model = model;
101+
}
102+
103+
public String getVoice() {
104+
return voice;
105+
}
106+
107+
public void setVoice(String voice) {
108+
this.voice = voice;
109+
}
110+
111+
public String getResponseFormat() {
112+
return responseFormat;
113+
}
114+
115+
public void setResponseFormat(String responseFormat) {
116+
this.responseFormat = responseFormat;
117+
}
118+
119+
public Float getSpeed() {
120+
return speed;
121+
}
122+
123+
124+
public void setSpeed(Float speed) {
125+
this.speed = speed;
126+
}
127+
128+
@Override
129+
public int hashCode() {
130+
final int prime = 31;
131+
int result = 1;
132+
result = prime * result + ((model == null) ? 0 : model.hashCode());
133+
result = prime * result + ((voice == null) ? 0 : voice.hashCode());
134+
result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode());
135+
result = prime * result + ((speed == null) ? 0 : speed.hashCode());
136+
return result;
137+
}
138+
139+
@Override
140+
public boolean equals(Object obj) {
141+
if (this == obj)
142+
return true;
143+
if (obj == null)
144+
return false;
145+
if (getClass() != obj.getClass())
146+
return false;
147+
OpenAiSpeechOptions other = (OpenAiSpeechOptions) obj;
148+
if (model == null) {
149+
if (other.model != null)
150+
return false;
151+
} else if (!model.equals(other.model))
152+
return false;
153+
if (voice == null) {
154+
if (other.voice != null)
155+
return false;
156+
} else if (!voice.equals(other.voice))
157+
return false;
158+
if (responseFormat == null) {
159+
if (other.responseFormat != null)
160+
return false;
161+
} else if (!responseFormat.equals(other.responseFormat))
162+
return false;
163+
if (speed == null) {
164+
if (other.speed != null)
165+
return false;
166+
} else if (!speed.equals(other.speed))
167+
return false;
168+
return true;
169+
}
170+
171+
}

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,68 @@ public record ResponseFormat(
389389
}
390390
}
391391

392+
393+
394+
@JsonInclude(JsonInclude.Include.NON_NULL)
395+
public record SpeechRequest(
396+
@JsonProperty("input") String prompt,
397+
398+
@JsonProperty("model") String model,
399+
@JsonProperty("voice") String voice,
400+
@JsonProperty("response_format") String responseFormat,
401+
@JsonProperty("speed") Double speed) {
402+
403+
public static SpeechRequest NULL = new SpeechRequest();
404+
405+
public SpeechRequest(String model, String input, String voice) {
406+
this(model, input, voice, "mp3", 1.0); // Defaults to "mp3" format and "1.0" speed
407+
}
408+
409+
public SpeechRequest( String input) {
410+
this(input,"tts-1", "alloy", "mp3", 1.0); // Defaults to "mp3" format and "1.0" speed
411+
}
412+
413+
public SpeechRequest() {
414+
this(null, null, null, "mp3", 1.0);
415+
}
416+
417+
public SpeechRequest(String prompt, String model, String voice, String responseFormat, Double speed) {
418+
this.model = model;
419+
this.prompt = prompt;
420+
this.voice = voice;
421+
this.responseFormat = responseFormat;
422+
this.speed = speed;
423+
}
424+
}
425+
426+
@JsonInclude(Include.NON_NULL)
427+
public record SpeechResponse(
428+
429+
@JsonProperty("audio") byte[] audio) {
430+
public static SpeechResponse NULL = new SpeechResponse(new byte[0]);
431+
432+
}
433+
434+
/**
435+
* Creates a model response for the given text-to-speech request.
436+
*
437+
* @param speechRequest The text-to-speech request.
438+
* @return Entity response with the generated speech as a body and HTTP status code and headers.
439+
*/
440+
public ResponseEntity<SpeechResponse> textToSpeechEntityJson(OpenAiApi.SpeechRequest speechRequest) {
441+
Assert.notNull(speechRequest, "The request body cannot be null.");
442+
443+
var responseEntity = this.restClient.post()
444+
.uri("/v1/audio/speech")
445+
.body(speechRequest)
446+
.accept(MediaType.APPLICATION_OCTET_STREAM)
447+
.retrieve()
448+
.toEntity(byte[].class);
449+
HttpHeaders headers = new HttpHeaders();
450+
headers.addAll(responseEntity.getHeaders());
451+
SpeechResponse speechResponse = new SpeechResponse(responseEntity.getBody());
452+
return new ResponseEntity<>(speechResponse, headers, responseEntity.getStatusCode());
453+
}
392454
/**
393455
* Message comprising the conversation.
394456
*
@@ -494,6 +556,7 @@ public enum ChatCompletionFinishReason {
494556
@JsonProperty("function_call") FUNCTION_CALL
495557
}
496558

559+
497560
/**
498561
* Represents a chat completion response returned by model, based on the provided input.
499562
*

0 commit comments

Comments
 (0)