Skip to content

Commit bf4a5c1

Browse files
Allowing model options to be coniguratble on TranscriptionRequest, making tests more robust to non-deterministic effects, and minor polishing
1 parent 8acccad commit bf4a5c1

File tree

7 files changed

+96
-25
lines changed

7 files changed

+96
-25
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiTranscriptionClient.java

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -75,24 +75,52 @@ public TranscriptionResponse call(TranscriptionRequest request) {
7575
return this.retryTemplate.execute(ctx -> {
7676
Resource audioResource = request.getInstructions();
7777

78-
MultiValueMap<String, Object> reqyestBody = createRequestBody(request);
78+
MultiValueMap<String, Object> requestBody = createRequestBody(request);
7979

80-
ResponseEntity<OpenAiApi.Transcription> transcriptionEntity = this.openAiApi
81-
.transcriptionEntity(reqyestBody);
80+
boolean jsonResponse = !requestBody.containsKey("response_format")
81+
|| requestBody.get("response_format").contains("json");
8282

83-
var transcription = transcriptionEntity.getBody();
83+
if (jsonResponse) {
84+
85+
ResponseEntity<OpenAiApi.Transcription> transcriptionEntity = this.openAiApi
86+
.transcriptionEntityJson(requestBody);
87+
88+
var transcription = transcriptionEntity.getBody();
89+
90+
if (transcription == null) {
91+
logger.warn("No transcription returned for request: {}", audioResource);
92+
return new TranscriptionResponse(null);
93+
}
94+
95+
Transcript transcript = new Transcript(transcription.text());
96+
97+
RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(transcriptionEntity);
98+
99+
return new TranscriptionResponse(transcript,
100+
OpenAiTranscriptionResponseMetadata.from(transcriptionEntity.getBody())
101+
.withRateLimit(rateLimits));
84102

85-
if (transcription == null) {
86-
logger.warn("No transcription returned for request: {}", audioResource);
87-
return new TranscriptionResponse(null);
88103
}
104+
else {
105+
ResponseEntity<String> transcriptionEntity = this.openAiApi.transcriptionEntityText(requestBody);
106+
107+
var transcription = transcriptionEntity.getBody();
108+
109+
if (transcription == null) {
110+
logger.warn("No transcription returned for request: {}", audioResource);
111+
return new TranscriptionResponse(null);
112+
}
89113

90-
Transcript transcript = new Transcript(transcription.text());
114+
Transcript transcript = new Transcript(transcription);
91115

92-
RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(transcriptionEntity);
116+
RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(transcriptionEntity);
117+
118+
return new TranscriptionResponse(transcript,
119+
OpenAiTranscriptionResponseMetadata.from(transcriptionEntity.getBody())
120+
.withRateLimit(rateLimits));
121+
122+
}
93123

94-
return new TranscriptionResponse(transcript,
95-
OpenAiTranscriptionResponseMetadata.from(transcriptionEntity.getBody()).withRateLimit(rateLimits));
96124
});
97125
}
98126

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiTranscriptionOptions.java

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,12 @@
1616

1717
package org.springframework.ai.openai;
1818

19-
import com.fasterxml.jackson.annotation.JsonIgnore;
2019
import com.fasterxml.jackson.annotation.JsonInclude;
2120
import com.fasterxml.jackson.annotation.JsonInclude.Include;
2221
import com.fasterxml.jackson.annotation.JsonProperty;
23-
import org.springframework.ai.chat.ChatOptions;
24-
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ResponseFormat;
25-
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoice;
26-
import org.springframework.ai.openai.api.OpenAiApi.FunctionTool;
22+
import org.springframework.ai.openai.api.OpenAiApi.TranscriptionRequest.ResponseFormat;
2723
import org.springframework.ai.transcription.TranscriptionOptions;
2824

29-
import java.util.List;
30-
import java.util.Map;
31-
3225
/**
3326
* @author Michael Lavelle
3427
*/

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -722,9 +722,9 @@ public ResponseEntity<ChatCompletion> chatCompletionEntity(ChatCompletionRequest
722722
* Creates a model response for the given transcription.
723723
*
724724
* @param transcriptionRequest The transcription request.
725-
* @return Entity response with {@link MultiValueMap} as a body and HTTP status code and headers.
725+
* @return Entity response with {@link Transcription} as a body and HTTP status code and headers.
726726
*/
727-
public ResponseEntity<Transcription> transcriptionEntity(MultiValueMap<String, Object> transcriptionRequest) {
727+
public ResponseEntity<Transcription> transcriptionEntityJson(MultiValueMap<String, Object> transcriptionRequest) {
728728

729729
Assert.notNull(transcriptionRequest, "The request body can not be null.");
730730

@@ -735,6 +735,24 @@ public ResponseEntity<Transcription> transcriptionEntity(MultiValueMap<String, O
735735
.toEntity(Transcription.class);
736736
}
737737

738+
/**
739+
* Creates a model response for the given transcription.
740+
*
741+
* @param transcriptionRequest The transcription request.
742+
* @return Entity response with {@link String} as a body and HTTP status code and headers.
743+
*/
744+
public ResponseEntity<String> transcriptionEntityText(MultiValueMap<String, Object> transcriptionRequest) {
745+
746+
Assert.notNull(transcriptionRequest, "The request body can not be null.");
747+
748+
return this.multipartFormEncodingRestClient.post()
749+
.uri("/v1/audio/transcriptions")
750+
.body(transcriptionRequest)
751+
.accept(MediaType.TEXT_PLAIN)
752+
.retrieve()
753+
.toEntity(String.class);
754+
}
755+
738756
/**
739757
* Creates a streaming chat response for the given chat conversation.
740758
*

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiTranscriptionResponseMetadata.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ public static OpenAiTranscriptionResponseMetadata from(OpenAiApi.Transcription r
3939
return transcriptionResponseMetadata;
4040
}
4141

42+
public static OpenAiTranscriptionResponseMetadata from(String result) {
43+
Assert.notNull(result, "OpenAI Transcription must not be null");
44+
OpenAiTranscriptionResponseMetadata transcriptionResponseMetadata = new OpenAiTranscriptionResponseMetadata();
45+
return transcriptionResponseMetadata;
46+
}
47+
4248
@Nullable
4349
private RateLimit rateLimit;
4450

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transcription/OpenAiTranscriptionClientIT.java

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
import org.junit.jupiter.api.Test;
44
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
55
import org.springframework.ai.openai.OpenAiTestConfiguration;
6+
import org.springframework.ai.openai.OpenAiTranscriptionOptions;
7+
import org.springframework.ai.openai.api.OpenAiApi;
68
import org.springframework.ai.openai.testutils.AbstractIT;
9+
import org.springframework.ai.transcription.TranscriptionOptions;
10+
import org.springframework.ai.transcription.TranscriptionOptionsBuilder;
711
import org.springframework.ai.transcription.TranscriptionRequest;
812
import org.springframework.ai.transcription.TranscriptionResponse;
913
import org.springframework.beans.factory.annotation.Value;
@@ -21,11 +25,27 @@ class OpenAiTranscriptionClientIT extends AbstractIT {
2125

2226
@Test
2327
void transcriptionTest() {
24-
TranscriptionRequest transcriptionRequest = new TranscriptionRequest(audioFile);
28+
TranscriptionOptions transcriptionOptions = TranscriptionOptionsBuilder.builder().withTemperature(0f).build();
29+
TranscriptionRequest transcriptionRequest = new TranscriptionRequest(audioFile, transcriptionOptions);
2530
TranscriptionResponse response = openAiTranscriptionClient.call(transcriptionRequest);
2631
assertThat(response.getResults()).hasSize(1);
27-
assertThat(response.getResults().get(0).getOutput()).isEqualTo(
28-
"And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.");
32+
assertThat(response.getResults().get(0).getOutput().toLowerCase().contains("fellow")).isTrue();
33+
}
34+
35+
@Test
36+
void transcriptionTestWithOptions() {
37+
OpenAiApi.TranscriptionRequest.ResponseFormat responseFormat = new OpenAiApi.TranscriptionRequest.ResponseFormat(
38+
"vtt");
39+
TranscriptionOptions transcriptionOptions = OpenAiTranscriptionOptions.builder()
40+
.withLanguage("en")
41+
.withPrompt("Ask not this, but ask that")
42+
.withTemperature(0f)
43+
.withResponseFormat(responseFormat)
44+
.build();
45+
TranscriptionRequest transcriptionRequest = new TranscriptionRequest(audioFile, transcriptionOptions);
46+
TranscriptionResponse response = openAiTranscriptionClient.call(transcriptionRequest);
47+
assertThat(response.getResults()).hasSize(1);
48+
assertThat(response.getResults().get(0).getOutput().toLowerCase().contains("fellow")).isTrue();
2949
}
3050

3151
}

spring-ai-core/src/main/java/org/springframework/ai/transcription/TranscriptionRequest.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ public TranscriptionRequest(Resource audioResource) {
1414
this.audioResource = audioResource;
1515
}
1616

17+
public TranscriptionRequest(Resource audioResource, ModelOptions modelOptions) {
18+
this.audioResource = audioResource;
19+
this.modelOptions = modelOptions;
20+
}
21+
1722
@Override
1823
public Resource getInstructions() {
1924
return audioResource;

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import org.junit.jupiter.api.Test;
2222

23+
import org.springframework.ai.openai.api.OpenAiApi;
2324
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ResponseFormat;
2425
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoice;
2526
import org.springframework.ai.openai.api.OpenAiApi.FunctionTool.Type;
@@ -366,7 +367,7 @@ public void transcriptionOptionsTest() {
366367
assertThat(transcriptionProperties.getOptions().getLanguage()).isEqualTo("en");
367368
assertThat(transcriptionProperties.getOptions().getPrompt()).isEqualTo("Er, yes, I think so");
368369
assertThat(transcriptionProperties.getOptions().getResponseFormat())
369-
.isEqualTo(new ResponseFormat("json"));
370+
.isEqualTo(new OpenAiApi.TranscriptionRequest.ResponseFormat("json"));
370371
assertThat(transcriptionProperties.getOptions().getTemperature()).isEqualTo(0.55f);
371372
});
372373
}

0 commit comments

Comments
 (0)