Skip to content

Commit 8039a5f

Browse files
Mohammed, Ahmed yousri salama (Canada)Mohammed, Ahmed yousri salama (Canada)
authored andcommitted
Added support for OpenAI Text to Audio (Speech API ) stream
1 parent a9844c9 commit 8039a5f

File tree

5 files changed

+90
-15
lines changed

5 files changed

+90
-15
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiSpeechClient.java

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
import org.springframework.http.ResponseEntity;
2828
import org.springframework.retry.support.RetryTemplate;
2929
import org.springframework.util.Assert;
30+
import reactor.core.publisher.Flux;
3031

3132
import java.time.Duration;
32-
import java.util.Objects;
3333

3434
/**
3535
* {@link SpeechClient} implementation for {@literal OpenAI} backed by {@link OpenAiApi}.
@@ -38,7 +38,7 @@
3838
* @see SpeechClient
3939
* @see OpenAiApi
4040
*/
41-
public class OpenAiSpeechClient implements SpeechClient {
41+
public class OpenAiSpeechClient implements SpeechClient, StreamingSpeechClient {
4242

4343
private final Logger logger = LoggerFactory.getLogger(getClass());
4444

@@ -72,19 +72,7 @@ public SpeechResponse call(SpeechPrompt speechPrompt) {
7272

7373
return this.retryTemplate.execute(ctx -> {
7474

75-
String instructions = speechPrompt.getInstructions().get(0).getText();
76-
77-
OpenAiApi.SpeechRequest speechRequest = new OpenAiApi.SpeechRequest(instructions);
78-
79-
if (this.defaultOptions != null) {
80-
speechRequest = ModelOptionsUtils.merge(this.defaultOptions, speechRequest,
81-
OpenAiApi.SpeechRequest.class);
82-
}
83-
84-
if (speechPrompt.getOptions() != null) {
85-
speechRequest = ModelOptionsUtils.merge(toOpenAiSpeechOptions(speechPrompt.getOptions()), speechRequest,
86-
OpenAiApi.SpeechRequest.class);
87-
}
75+
OpenAiApi.SpeechRequest speechRequest = createRequest(speechPrompt);
8876

8977
ResponseEntity<OpenAiApi.SpeechResponse> SpeechEntity = this.openAiApi
9078
.textToSpeechEntityJson(speechRequest);
@@ -102,6 +90,23 @@ public SpeechResponse call(SpeechPrompt speechPrompt) {
10290
});
10391
}
10492

93+
private OpenAiApi.SpeechRequest createRequest(SpeechPrompt speechPrompt) {
94+
String instructions = speechPrompt.getInstructions().get(0).getText();
95+
96+
OpenAiApi.SpeechRequest speechRequest = new OpenAiApi.SpeechRequest(instructions);
97+
98+
if (this.defaultOptions != null) {
99+
speechRequest = ModelOptionsUtils.merge(this.defaultOptions, speechRequest,
100+
OpenAiApi.SpeechRequest.class);
101+
}
102+
103+
if (speechPrompt.getOptions() != null) {
104+
speechRequest = ModelOptionsUtils.merge(toOpenAiSpeechOptions(speechPrompt.getOptions()), speechRequest,
105+
OpenAiApi.SpeechRequest.class);
106+
}
107+
return speechRequest;
108+
}
109+
105110
private Speech convertResponse(OpenAiApi.SpeechResponse speechResponse) {
106111
return new Speech(speechResponse.audio());
107112
}
@@ -142,4 +147,12 @@ private OpenAiSpeechOptions toOpenAiSpeechOptions(SpeechOptions runtimeSpeechOpt
142147
return openAiSpeechOptionBuilder.build();
143148
}
144149

150+
@Override
151+
public Flux<SpeechResponse> stream(SpeechPrompt prompt) {
152+
return this.openAiApi.textToSpeechStreaming(this.createRequest(prompt))
153+
.map(entity -> new SpeechResponse(
154+
new Speech(entity.getBody()),
155+
new OpenAiSpeechResponseMetadata(OpenAiResponseHeaderExtractor.extractAiResponseHeaders(entity))
156+
));
157+
}
145158
}

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,19 @@ public ResponseEntity<SpeechResponse> textToSpeechEntityJson(OpenAiApi.SpeechReq
451451
SpeechResponse speechResponse = new SpeechResponse(responseEntity.getBody());
452452
return new ResponseEntity<>(speechResponse, headers, responseEntity.getStatusCode());
453453
}
454+
455+
public Flux<ResponseEntity<byte[]>> textToSpeechStreaming(OpenAiApi.SpeechRequest speechRequest) {
456+
457+
return webClient.post()
458+
.uri("/v1/audio/speech")
459+
.body(Mono.just(speechRequest), SpeechRequest.class)
460+
.accept(MediaType.APPLICATION_OCTET_STREAM)
461+
.exchangeToFlux(clientResponse -> {
462+
HttpHeaders headers = clientResponse.headers().asHttpHeaders();
463+
return clientResponse.bodyToFlux(byte[].class)
464+
.map(bytes -> ResponseEntity.ok().headers(headers).body(bytes));
465+
});
466+
}
454467
/**
455468
* Message comprising the conversation.
456469
*

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/speech/OpenAiSpeechClientIT.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import org.springframework.ai.speech.SpeechPrompt;
1010
import org.springframework.ai.speech.SpeechResponse;
1111
import org.springframework.boot.test.context.SpringBootTest;
12+
import reactor.core.publisher.Flux;
1213

1314
import static org.assertj.core.api.Assertions.assertThat;
1415

@@ -34,4 +35,20 @@ void speechTest() {
3435

3536
}
3637

38+
@Test
39+
void speechStreamingTest() {
40+
SpeechOptions speechOptions = SpeechOptionsBuilder.builder()
41+
.withVoice("shimmer")
42+
.withSpeed(1.0f)
43+
.withResponseFormat("mp3")
44+
.withModel("tts-1-hd")
45+
.build();
46+
SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!",
47+
speechOptions);
48+
Flux<SpeechResponse> response = streamingSpeechClient.stream(speechPrompt);
49+
assertThat(response).isNotNull();
50+
assertThat(response.collectList().block()).isNotNull();
51+
System.out.println(response.collectList().block());
52+
}
53+
3754
}

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.springframework.ai.chat.messages.SystemMessage;
1616
import org.springframework.ai.image.ImageClient;
1717
import org.springframework.ai.speech.SpeechClient;
18+
import org.springframework.ai.speech.StreamingSpeechClient;
1819
import org.springframework.beans.factory.annotation.Autowired;
1920
import org.springframework.beans.factory.annotation.Value;
2021
import org.springframework.core.io.Resource;
@@ -35,6 +36,9 @@ public abstract class AbstractIT {
3536
@Autowired
3637
protected SpeechClient openAiSpeechClient;
3738

39+
@Autowired
40+
protected StreamingSpeechClient streamingSpeechClient;
41+
3842
@Autowired
3943
protected StreamingChatClient openStreamingChatClient;
4044

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright 2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.speech;
18+
19+
import org.springframework.ai.model.StreamingModelClient;
20+
import reactor.core.publisher.Flux;
21+
22+
@FunctionalInterface
23+
public interface StreamingSpeechClient extends StreamingModelClient<SpeechPrompt, SpeechResponse> {
24+
25+
@Override
26+
Flux<SpeechResponse> stream(SpeechPrompt prompt);
27+
28+
}

0 commit comments

Comments
 (0)