Skip to content

Commit 54930af

Browse files
committed
Make the OpenAiAudioApi return ResponseEntity<value>
1 parent 5e68158 commit 54930af

File tree

2 files changed

+44
-24
lines changed

2 files changed

+44
-24
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import org.springframework.ai.openai.api.common.ApiUtils;
2626
import org.springframework.core.io.ByteArrayResource;
27+
import org.springframework.http.ResponseEntity;
2728
import org.springframework.util.Assert;
2829
import org.springframework.util.LinkedMultiValueMap;
2930
import org.springframework.util.MultiValueMap;
@@ -561,18 +562,19 @@ public record Segment(
561562
/**
562563
* Request to generates audio from the input text.
563564
* @param requestBody The request body.
564-
* @return The audio file in bytes.
565+
* @return Response entity containing the audio binary.
565566
*/
566-
public byte[] createSpeech(SpeechRequest requestBody) {
567-
return this.restClient.post().uri("/v1/audio/speech").body(requestBody).retrieve().body(byte[].class);
567+
public ResponseEntity<byte[]> createSpeech(SpeechRequest requestBody) {
568+
return this.restClient.post().uri("/v1/audio/speech").body(requestBody).retrieve().toEntity(byte[].class);
568569
}
569570

570571
/**
571572
* Transcribes audio into the input language.
572573
* @param requestBody The request body.
573-
* @return The transcribed text.
574+
* @return Response entity containing the transcribed text in either json or text
575+
* format.
574576
*/
575-
public Object createTranscription(TranscriptionRequest requestBody) {
577+
public ResponseEntity<?> createTranscription(TranscriptionRequest requestBody) {
576578
return createTranscription(requestBody, requestBody.responseFormat().getResponseType());
577579
}
578580

@@ -582,9 +584,9 @@ public Object createTranscription(TranscriptionRequest requestBody) {
582584
* @param <T> The response type.
583585
* @param requestBody The request body.
584586
* @param responseType The response type class.
585-
* @return The transcribed text.
587+
* @return Response entity containing the transcribed text in the responseType format.
586588
*/
587-
public <T> T createTranscription(TranscriptionRequest requestBody, Class<T> responseType) {
589+
public <T> ResponseEntity<T> createTranscription(TranscriptionRequest requestBody, Class<T> responseType) {
588590

589591
MultiValueMap<String, Object> multipartBody = new LinkedMultiValueMap<>();
590592
multipartBody.add("file", new ByteArrayResource(requestBody.file()) {
@@ -604,15 +606,20 @@ public String getFilename() {
604606
multipartBody.add("timestamp_granularities[]", requestBody.granularityType().getValue());
605607
}
606608

607-
return this.restClient.post().uri("/v1/audio/transcriptions").body(multipartBody).retrieve().body(responseType);
609+
return this.restClient.post()
610+
.uri("/v1/audio/transcriptions")
611+
.body(multipartBody)
612+
.retrieve()
613+
.toEntity(responseType);
608614
}
609615

610616
/**
611617
* Translates audio into English.
612618
* @param requestBody The request body.
613-
* @return The transcribed text.
619+
* @return Response entity containing the transcribed text in either json or text
620+
* format.
614621
*/
615-
public Object createTranslation(TranslationRequest requestBody) {
622+
public ResponseEntity<?> createTranslation(TranslationRequest requestBody) {
616623
return createTranslation(requestBody, requestBody.responseFormat().getResponseType());
617624
}
618625

@@ -622,9 +629,9 @@ public Object createTranslation(TranslationRequest requestBody) {
622629
* @param <T> The response type.
623630
* @param requestBody The request body.
624631
* @param responseType The response type class.
625-
* @return The transcribed text.
632+
* @return Response entity containing the transcribed text in the responseType format.
626633
*/
627-
public <T> T createTranslation(TranslationRequest requestBody, Class<T> responseType) {
634+
public <T> ResponseEntity<T> createTranslation(TranslationRequest requestBody, Class<T> responseType) {
628635

629636
MultiValueMap<String, Object> multipartBody = new LinkedMultiValueMap<>();
630637
multipartBody.add("file", new ByteArrayResource(requestBody.file()) {
@@ -638,7 +645,11 @@ public String getFilename() {
638645
multipartBody.add("response_format", requestBody.responseFormat().getValue());
639646
multipartBody.add("temperature", requestBody.temperature());
640647

641-
return this.restClient.post().uri("/v1/audio/translations").body(multipartBody).retrieve().body(responseType);
648+
return this.restClient.post()
649+
.uri("/v1/audio/translations")
650+
.body(multipartBody)
651+
.retrieve()
652+
.toEntity(responseType);
642653
}
643654

644655
}

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/api/OpenAiAudioApiIT.java

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.springframework.ai.openai.api.OpenAiAudioApi.StructuredResponse;
2929
import org.springframework.ai.openai.api.OpenAiAudioApi.TranslationRequest;
3030
import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.Voice;
31+
import org.springframework.lang.NonNull;
3132
import org.springframework.ai.openai.api.OpenAiAudioApi.TtsModel;
3233
import org.springframework.ai.openai.api.OpenAiAudioApi.WhisperModel;
3334
import org.springframework.util.FileCopyUtils;
@@ -42,34 +43,42 @@ public class OpenAiAudioApiIT {
4243

4344
OpenAiAudioApi audioApi = new OpenAiAudioApi(System.getenv("OPENAI_API_KEY"));
4445

46+
@SuppressWarnings("null")
4547
@Test
4648
void speechTranscriptionAndTranslation() throws IOException {
4749

48-
byte[] speech = audioApi.createSpeech(SpeechRequest.builder()
49-
.withModel(TtsModel.TTS_1_HD.getValue())
50-
.withInput("Hello, my name is Chris and I love Spring A.I.")
51-
.withVoice(Voice.ONYX)
52-
.build());
50+
byte[] speech = audioApi
51+
.createSpeech(SpeechRequest.builder()
52+
.withModel(TtsModel.TTS_1_HD.getValue())
53+
.withInput("Hello, my name is Chris and I love Spring A.I.")
54+
.withVoice(Voice.ONYX)
55+
.build())
56+
.getBody();
5357

5458
assertThat(speech).isNotEmpty();
5559

5660
FileCopyUtils.copy(speech, new File("target/speech.mp3"));
5761

58-
StructuredResponse translation = audioApi.createTranslation(
59-
TranslationRequest.builder().withModel(WhisperModel.WHISPER_1.getValue()).withFile(speech).build(),
60-
StructuredResponse.class);
62+
StructuredResponse translation = audioApi
63+
.createTranslation(
64+
TranslationRequest.builder().withModel(WhisperModel.WHISPER_1.getValue()).withFile(speech).build(),
65+
StructuredResponse.class)
66+
.getBody();
6167

6268
assertThat(translation.text().replaceAll(",", "")).isEqualTo("Hello my name is Chris and I love Spring AI.");
6369

6470
StructuredResponse transcriptionEnglish = audioApi.createTranscription(
6571
TranscriptionRequest.builder().withModel(WhisperModel.WHISPER_1.getValue()).withFile(speech).build(),
66-
StructuredResponse.class);
72+
StructuredResponse.class)
73+
.getBody();
6774

6875
assertThat(transcriptionEnglish.text().replaceAll(",", ""))
6976
.isEqualTo("Hello my name is Chris and I love Spring AI.");
7077

71-
StructuredResponse transcriptionDutch = audioApi.createTranscription(
72-
TranscriptionRequest.builder().withFile(speech).withLanguage("nl").build(), StructuredResponse.class);
78+
StructuredResponse transcriptionDutch = audioApi
79+
.createTranscription(TranscriptionRequest.builder().withFile(speech).withLanguage("nl").build(),
80+
StructuredResponse.class)
81+
.getBody();
7382

7483
assertThat(transcriptionDutch.text()).isEqualTo("Hallo, mijn naam is Chris en ik hou van Spring AI.");
7584
}

0 commit comments

Comments
 (0)