Skip to content

Commit 20ea731

Browse files
committed
Add OpenAI transcription merge tests.
Fix missing granualaritytype option handling.
1 parent b9ba625 commit 20ea731

File tree

2 files changed

+93
-0
lines changed

2 files changed

+93
-0
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionClient.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ OpenAiAudioApi.TranscriptionRequest createRequestBody(AudioTranscriptionPrompt r
193193
.withTemperature(options.getTemperature())
194194
.withLanguage(options.getLanguage())
195195
.withModel(options.getModel())
196+
.withGranularityType(options.getGranularityType())
196197
.build();
197198

198199
return audioTranscriptionRequest;
@@ -221,6 +222,8 @@ private OpenAiAudioTranscriptionOptions merge(OpenAiAudioTranscriptionOptions so
221222
merged.setResponseFormat(
222223
source.getResponseFormat() != null ? source.getResponseFormat() : target.getResponseFormat());
223224
merged.setTemperature(source.getTemperature() != null ? source.getTemperature() : target.getTemperature());
225+
merged.setGranularityType(
226+
source.getGranularityType() != null ? source.getGranularityType() : target.getGranularityType());
224227
return merged;
225228
}
226229

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.openai;
17+
18+
import org.junit.jupiter.api.Test;
19+
20+
import org.springframework.ai.openai.api.OpenAiAudioApi;
21+
import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptResponseFormat;
22+
import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptionRequest.GranularityType;
23+
import org.springframework.ai.openai.audio.transcription.AudioTranscriptionPrompt;
24+
import org.springframework.core.io.DefaultResourceLoader;
25+
26+
import static org.assertj.core.api.Assertions.assertThat;
27+
28+
/**
29+
* @author Christian Tzolov
30+
* @since 1.0.0
31+
*/
32+
public class TranscriptionRequestTests {
33+
34+
@Test
35+
public void defaultOptions() {
36+
37+
var client = new OpenAiAudioTranscriptionClient(new OpenAiAudioApi("TEST"),
38+
OpenAiAudioTranscriptionOptions.builder()
39+
.withModel("DEFAULT_MODEL")
40+
.withResponseFormat(TranscriptResponseFormat.TEXT)
41+
.withLanguage("en")
42+
.withPrompt("Prompt1")
43+
.withGranularityType(GranularityType.WORD)
44+
.withTemperature(66.6f)
45+
.build());
46+
47+
var request = client.createRequestBody(
48+
new AudioTranscriptionPrompt(new DefaultResourceLoader().getResource("classpath:/test.png")));
49+
50+
assertThat(request.model()).isEqualTo("DEFAULT_MODEL");
51+
assertThat(request.responseFormat()).isEqualByComparingTo(TranscriptResponseFormat.TEXT);
52+
assertThat(request.temperature()).isEqualTo(66.6f);
53+
assertThat(request.prompt()).isEqualTo("Prompt1");
54+
assertThat(request.language()).isEqualTo("en");
55+
assertThat(request.granularityType()).isEqualTo(GranularityType.WORD);
56+
}
57+
58+
@Test
59+
public void runtimeOptions() {
60+
61+
var client = new OpenAiAudioTranscriptionClient(new OpenAiAudioApi("TEST"),
62+
OpenAiAudioTranscriptionOptions.builder()
63+
.withModel("DEFAULT_MODEL")
64+
.withResponseFormat(TranscriptResponseFormat.TEXT)
65+
.withLanguage("en")
66+
.withPrompt("Prompt1")
67+
.withGranularityType(GranularityType.WORD)
68+
.withTemperature(66.6f)
69+
.build());
70+
71+
var request = client.createRequestBody(
72+
new AudioTranscriptionPrompt(new DefaultResourceLoader().getResource("classpath:/test.png"),
73+
OpenAiAudioTranscriptionOptions.builder()
74+
.withModel("RUNTIME_MODEL")
75+
.withResponseFormat(TranscriptResponseFormat.JSON)
76+
.withLanguage("bg")
77+
.withPrompt("Prompt2")
78+
.withGranularityType(GranularityType.SEGMENT)
79+
.withTemperature(99.9f)
80+
.build()));
81+
82+
assertThat(request.model()).isEqualTo("RUNTIME_MODEL");
83+
assertThat(request.responseFormat()).isEqualByComparingTo(TranscriptResponseFormat.JSON);
84+
assertThat(request.temperature()).isEqualTo(99.9f);
85+
assertThat(request.prompt()).isEqualTo("Prompt2");
86+
assertThat(request.language()).isEqualTo("bg");
87+
assertThat(request.granularityType()).isEqualTo(GranularityType.SEGMENT);
88+
}
89+
90+
}

0 commit comments

Comments
 (0)